mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Add check for matching holder_type when inheriting (#588)
This commit is contained in:
parent
7830e8509f
commit
cc88aaecc8
@ -185,6 +185,9 @@ struct type_record {
|
||||
/// Does the class require its own metaclass?
|
||||
bool metaclass : 1;
|
||||
|
||||
/// Is the default (unique_ptr) holder type used?
|
||||
bool default_holder : 1;
|
||||
|
||||
PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *)) {
|
||||
auto base_info = detail::get_type_info(*base, false);
|
||||
if (!base_info) {
|
||||
@ -194,6 +197,15 @@ struct type_record {
|
||||
"\" referenced unknown base type \"" + tname + "\"");
|
||||
}
|
||||
|
||||
if (default_holder != base_info->default_holder) {
|
||||
std::string tname(base->name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
|
||||
(default_holder ? "does not have" : "has") +
|
||||
" a non-default holder type while its base \"" + tname + "\" " +
|
||||
(base_info->default_holder ? "does not" : "does"));
|
||||
}
|
||||
|
||||
bases.append((PyObject *) base_info->type);
|
||||
|
||||
if (base_info->type->tp_dictoffset != 0)
|
||||
|
@ -32,6 +32,8 @@ struct type_info {
|
||||
/** A simple type never occurs as a (direct or indirect) parent
|
||||
* of a class that makes use of multiple inheritance */
|
||||
bool simple_type = true;
|
||||
/* for base vs derived holder_type checks */
|
||||
bool default_holder = true;
|
||||
};
|
||||
|
||||
PYBIND11_NOINLINE inline internals &get_internals() {
|
||||
|
@ -741,6 +741,7 @@ protected:
|
||||
tinfo->type_size = rec->type_size;
|
||||
tinfo->init_holder = rec->init_holder;
|
||||
tinfo->direct_conversions = &internals.direct_conversions[tindex];
|
||||
tinfo->default_holder = rec->default_holder;
|
||||
internals.registered_types_cpp[tindex] = tinfo;
|
||||
internals.registered_types_py[type] = tinfo;
|
||||
|
||||
@ -1006,6 +1007,7 @@ public:
|
||||
record.instance_size = sizeof(instance_type);
|
||||
record.init_holder = init_holder;
|
||||
record.dealloc = dealloc;
|
||||
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
|
||||
|
||||
/* Register base classes specified via template arguments to class_, if any */
|
||||
bool unused[] = { (add_base<options>(record), false)..., false };
|
||||
|
@ -49,6 +49,12 @@ struct BaseClass { virtual ~BaseClass() {} };
|
||||
struct DerivedClass1 : BaseClass { };
|
||||
struct DerivedClass2 : BaseClass { };
|
||||
|
||||
struct MismatchBase1 { };
|
||||
struct MismatchDerived1 : MismatchBase1 { };
|
||||
|
||||
struct MismatchBase2 { };
|
||||
struct MismatchDerived2 : MismatchBase2 { };
|
||||
|
||||
test_initializer inheritance([](py::module &m) {
|
||||
py::class_<Pet> pet_class(m, "Pet");
|
||||
pet_class
|
||||
@ -97,4 +103,15 @@ test_initializer inheritance([](py::module &m) {
|
||||
py::isinstance<Unregistered>(l[6])
|
||||
);
|
||||
});
|
||||
|
||||
m.def("test_mismatched_holder_type_1", []() {
|
||||
auto m = py::module::import("__main__");
|
||||
py::class_<MismatchBase1, std::shared_ptr<MismatchBase1>>(m, "MismatchBase1");
|
||||
py::class_<MismatchDerived1, MismatchBase1>(m, "MismatchDerived1");
|
||||
});
|
||||
m.def("test_mismatched_holder_type_2", []() {
|
||||
auto m = py::module::import("__main__");
|
||||
py::class_<MismatchBase2>(m, "MismatchBase2");
|
||||
py::class_<MismatchDerived2, std::shared_ptr<MismatchDerived2>, MismatchBase2>(m, "MismatchDerived2");
|
||||
});
|
||||
});
|
||||
|
@ -37,7 +37,8 @@ def test_automatic_upcasting():
|
||||
assert type(return_class_1()).__name__ == "DerivedClass1"
|
||||
assert type(return_class_2()).__name__ == "DerivedClass2"
|
||||
assert type(return_none()).__name__ == "NoneType"
|
||||
# Repeat these a few times in a random order to ensure no invalid caching is applied
|
||||
# Repeat these a few times in a random order to ensure no invalid caching
|
||||
# is applied
|
||||
assert type(return_class_n(1)).__name__ == "DerivedClass1"
|
||||
assert type(return_class_n(2)).__name__ == "DerivedClass2"
|
||||
assert type(return_class_n(0)).__name__ == "BaseClass"
|
||||
@ -53,3 +54,21 @@ def test_isinstance():
|
||||
objects = [tuple(), dict(), Pet("Polly", "parrot")] + [Dog("Molly")] * 4
|
||||
expected = (True, True, True, True, True, False, False)
|
||||
assert test_isinstance(objects) == expected
|
||||
|
||||
|
||||
def test_holder():
|
||||
from pybind11_tests import test_mismatched_holder_type_1, test_mismatched_holder_type_2
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
test_mismatched_holder_type_1()
|
||||
|
||||
assert str(excinfo.value) == ("generic_type: type \"MismatchDerived1\" does not have "
|
||||
"a non-default holder type while its base "
|
||||
"\"MismatchBase1\" does")
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
test_mismatched_holder_type_2()
|
||||
|
||||
assert str(excinfo.value) == ("generic_type: type \"MismatchDerived2\" has a "
|
||||
"non-default holder type while its base "
|
||||
"\"MismatchBase2\" does not")
|
||||
|
Loading…
Reference in New Issue
Block a user