From ec81e8e778accff323d76cd0ab0f3f280dc05f26 Mon Sep 17 00:00:00 2001 From: Dustin Spicuzza Date: Wed, 26 Jan 2022 20:03:52 -0500 Subject: [PATCH] Propagate py::multiple_inheritance to all children (#3650) * Add tests demonstrating smart_holder issues with multiple inheritance * Propagate C++ multiple inheritance markers to all children - Makes py::multiple_inheritance only needed in base classes hidden from pybind11 --- include/pybind11/pybind11.h | 3 + tests/test_multiple_inheritance.cpp | 83 ++++++++++++++++++++ tests/test_multiple_inheritance.py | 114 ++++++++++++++++++++++++++++ 3 files changed, 200 insertions(+) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 344a18c56..df33d5182 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1206,10 +1206,13 @@ protected: if (rec.bases.size() > 1 || rec.multiple_inheritance) { mark_parents_nonsimple(tinfo->type); tinfo->simple_ancestors = false; + tinfo->simple_type = false; } else if (rec.bases.size() == 1) { auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + // a child of a non-simple type can never be a simple type + tinfo->simple_type = parent_tinfo->simple_type; } if (rec.module_local) { diff --git a/tests/test_multiple_inheritance.cpp b/tests/test_multiple_inheritance.cpp index 6963197a5..44b9876eb 100644 --- a/tests/test_multiple_inheritance.cpp +++ b/tests/test_multiple_inheritance.cpp @@ -230,4 +230,87 @@ TEST_SUBMODULE(multiple_inheritance, m) { .def("c1", [](C1 *self) { return self; }); py::class_(m, "D") .def(py::init<>()); + + // test_pr3635_diamond_* + // - functions are get_{base}_{var}, return {var} + struct MVB { + MVB() = default; + MVB(const MVB&) = default; + virtual ~MVB() = default; + + int b = 1; + int get_b_b() const { return b; } + }; + struct MVC : virtual MVB { + int c = 2; + int get_c_b() const { return b; } + int get_c_c() const { return c; } + }; + struct MVD0 : virtual MVC { + int d0 = 3; + int get_d0_b() const { return b; } + int get_d0_c() const { return c; } + int get_d0_d0() const { return d0; } + }; + struct MVD1 : virtual MVC { + int d1 = 4; + int get_d1_b() const { return b; } + int get_d1_c() const { return c; } + int get_d1_d1() const { return d1; } + }; + struct MVE : virtual MVD0, virtual MVD1 { + int e = 5; + int get_e_b() const { return b; } + int get_e_c() const { return c; } + int get_e_d0() const { return d0; } + int get_e_d1() const { return d1; } + int get_e_e() const { return e; } + }; + struct MVF : virtual MVE { + int f = 6; + int get_f_b() const { return b; } + int get_f_c() const { return c; } + int get_f_d0() const { return d0; } + int get_f_d1() const { return d1; } + int get_f_e() const { return e; } + int get_f_f() const { return f; } + }; + py::class_(m, "MVB") + .def(py::init<>()) + .def("get_b_b", &MVB::get_b_b) + .def_readwrite("b", &MVB::b); + py::class_(m, "MVC") + .def(py::init<>()) + .def("get_c_b", &MVC::get_c_b) + .def("get_c_c", &MVC::get_c_c) + .def_readwrite("c", &MVC::c); + py::class_(m, "MVD0") + .def(py::init<>()) + .def("get_d0_b", &MVD0::get_d0_b) + .def("get_d0_c", &MVD0::get_d0_c) + .def("get_d0_d0", &MVD0::get_d0_d0) + .def_readwrite("d0", &MVD0::d0); + py::class_(m, "MVD1") + .def(py::init<>()) + .def("get_d1_b", &MVD1::get_d1_b) + .def("get_d1_c", &MVD1::get_d1_c) + .def("get_d1_d1", &MVD1::get_d1_d1) + .def_readwrite("d1", &MVD1::d1); + py::class_(m, "MVE") + .def(py::init<>()) + .def("get_e_b", &MVE::get_e_b) + .def("get_e_c", &MVE::get_e_c) + .def("get_e_d0", &MVE::get_e_d0) + .def("get_e_d1", &MVE::get_e_d1) + .def("get_e_e", &MVE::get_e_e) + .def_readwrite("e", &MVE::e); + py::class_(m, "MVF") + .def(py::init<>()) + .def("get_f_b", &MVF::get_f_b) + .def("get_f_c", &MVF::get_f_c) + .def("get_f_d0", &MVF::get_f_d0) + .def("get_f_d1", &MVF::get_f_d1) + .def("get_f_e", &MVF::get_f_e) + .def("get_f_f", &MVF::get_f_f) + .def_readwrite("f", &MVF::f); } diff --git a/tests/test_multiple_inheritance.py b/tests/test_multiple_inheritance.py index a02c31300..71741b925 100644 --- a/tests/test_multiple_inheritance.py +++ b/tests/test_multiple_inheritance.py @@ -358,3 +358,117 @@ def test_diamond_inheritance(): assert d is d.c0().b() assert d is d.c1().b() assert d is d.c0().c1().b().c0().b() + + +def test_pr3635_diamond_b(): + o = m.MVB() + assert o.b == 1 + + assert o.get_b_b() == 1 + + +def test_pr3635_diamond_c(): + o = m.MVC() + assert o.b == 1 + assert o.c == 2 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + + assert o.get_c_c() == 2 + + +def test_pr3635_diamond_d0(): + o = m.MVD0() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + + assert o.get_d0_d0() == 3 + + +def test_pr3635_diamond_d1(): + o = m.MVD1() + assert o.b == 1 + assert o.c == 2 + assert o.d1 == 4 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d1_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d1_c() == 2 + + assert o.get_d1_d1() == 4 + + +def test_pr3635_diamond_e(): + o = m.MVE() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + assert o.d1 == 4 + assert o.e == 5 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + assert o.get_d1_b() == 1 + assert o.get_e_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + assert o.get_d1_c() == 2 + assert o.get_e_c() == 2 + + assert o.get_d0_d0() == 3 + assert o.get_e_d0() == 3 + + assert o.get_d1_d1() == 4 + assert o.get_e_d1() == 4 + + assert o.get_e_e() == 5 + + +def test_pr3635_diamond_f(): + o = m.MVF() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + assert o.d1 == 4 + assert o.e == 5 + assert o.f == 6 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + assert o.get_d1_b() == 1 + assert o.get_e_b() == 1 + assert o.get_f_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + assert o.get_d1_c() == 2 + assert o.get_e_c() == 2 + assert o.get_f_c() == 2 + + assert o.get_d0_d0() == 3 + assert o.get_e_d0() == 3 + assert o.get_f_d0() == 3 + + assert o.get_d1_d1() == 4 + assert o.get_e_d1() == 4 + assert o.get_f_d1() == 4 + + assert o.get_e_e() == 5 + assert o.get_f_e() == 5 + + assert o.get_f_f() == 6