diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index 0676d5da6..740d3beff 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -185,6 +185,9 @@ struct type_record { /// Does the class require its own metaclass? bool metaclass : 1; + /// Is the default (unique_ptr) holder type used? + bool default_holder : 1; + PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *)) { auto base_info = detail::get_type_info(*base, false); if (!base_info) { @@ -194,6 +197,15 @@ struct type_record { "\" referenced unknown base type \"" + tname + "\""); } + if (default_holder != base_info->default_holder) { + std::string tname(base->name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (default_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + "\" " + + (base_info->default_holder ? "does not" : "does")); + } + bases.append((PyObject *) base_info->type); if (base_info->type->tp_dictoffset != 0) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index b953cc897..9077dbba7 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -32,6 +32,8 @@ struct type_info { /** A simple type never occurs as a (direct or indirect) parent * of a class that makes use of multiple inheritance */ bool simple_type = true; + /* for base vs derived holder_type checks */ + bool default_holder = true; }; PYBIND11_NOINLINE inline internals &get_internals() { diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index addcce74b..99b1f7248 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -741,6 +741,7 @@ protected: tinfo->type_size = rec->type_size; tinfo->init_holder = rec->init_holder; tinfo->direct_conversions = &internals.direct_conversions[tindex]; + tinfo->default_holder = rec->default_holder; internals.registered_types_cpp[tindex] = tinfo; internals.registered_types_py[type] = tinfo; @@ -1006,6 +1007,7 @@ public: record.instance_size = sizeof(instance_type); record.init_holder = init_holder; record.dealloc = dealloc; + record.default_holder = std::is_same>::value; /* Register base classes specified via template arguments to class_, if any */ bool unused[] = { (add_base(record), false)..., false }; diff --git a/tests/test_inheritance.cpp b/tests/test_inheritance.cpp index 2ec0b4a7a..914b7a839 100644 --- a/tests/test_inheritance.cpp +++ b/tests/test_inheritance.cpp @@ -49,6 +49,12 @@ struct BaseClass { virtual ~BaseClass() {} }; struct DerivedClass1 : BaseClass { }; struct DerivedClass2 : BaseClass { }; +struct MismatchBase1 { }; +struct MismatchDerived1 : MismatchBase1 { }; + +struct MismatchBase2 { }; +struct MismatchDerived2 : MismatchBase2 { }; + test_initializer inheritance([](py::module &m) { py::class_ pet_class(m, "Pet"); pet_class @@ -97,4 +103,15 @@ test_initializer inheritance([](py::module &m) { py::isinstance(l[6]) ); }); + + m.def("test_mismatched_holder_type_1", []() { + auto m = py::module::import("__main__"); + py::class_>(m, "MismatchBase1"); + py::class_(m, "MismatchDerived1"); + }); + m.def("test_mismatched_holder_type_2", []() { + auto m = py::module::import("__main__"); + py::class_(m, "MismatchBase2"); + py::class_, MismatchBase2>(m, "MismatchDerived2"); + }); }); diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index 7bb52be02..e4ab20265 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -37,7 +37,8 @@ def test_automatic_upcasting(): assert type(return_class_1()).__name__ == "DerivedClass1" assert type(return_class_2()).__name__ == "DerivedClass2" assert type(return_none()).__name__ == "NoneType" - # Repeat these a few times in a random order to ensure no invalid caching is applied + # Repeat these a few times in a random order to ensure no invalid caching + # is applied assert type(return_class_n(1)).__name__ == "DerivedClass1" assert type(return_class_n(2)).__name__ == "DerivedClass2" assert type(return_class_n(0)).__name__ == "BaseClass" @@ -53,3 +54,21 @@ def test_isinstance(): objects = [tuple(), dict(), Pet("Polly", "parrot")] + [Dog("Molly")] * 4 expected = (True, True, True, True, True, False, False) assert test_isinstance(objects) == expected + + +def test_holder(): + from pybind11_tests import test_mismatched_holder_type_1, test_mismatched_holder_type_2 + + with pytest.raises(RuntimeError) as excinfo: + test_mismatched_holder_type_1() + + assert str(excinfo.value) == ("generic_type: type \"MismatchDerived1\" does not have " + "a non-default holder type while its base " + "\"MismatchBase1\" does") + + with pytest.raises(RuntimeError) as excinfo: + test_mismatched_holder_type_2() + + assert str(excinfo.value) == ("generic_type: type \"MismatchDerived2\" has a " + "non-default holder type while its base " + "\"MismatchBase2\" does not")