Add check for matching holder_type when inheriting (#588)

This commit is contained in:
Pim Schellart 2017-01-31 10:52:11 -05:00 committed by Wenzel Jakob
parent 7830e8509f
commit cc88aaecc8
5 changed files with 53 additions and 1 deletions

View File

@ -185,6 +185,9 @@ struct type_record {
/// Does the class require its own metaclass?
bool metaclass : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(*base, false);
if (!base_info) {
@ -194,6 +197,15 @@ struct type_record {
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base->name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname + "\" " +
(base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)

View File

@ -32,6 +32,8 @@ struct type_info {
/** A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type = true;
/* for base vs derived holder_type checks */
bool default_holder = true;
};
PYBIND11_NOINLINE inline internals &get_internals() {

View File

@ -741,6 +741,7 @@ protected:
tinfo->type_size = rec->type_size;
tinfo->init_holder = rec->init_holder;
tinfo->direct_conversions = &internals.direct_conversions[tindex];
tinfo->default_holder = rec->default_holder;
internals.registered_types_cpp[tindex] = tinfo;
internals.registered_types_py[type] = tinfo;
@ -1006,6 +1007,7 @@ public:
record.instance_size = sizeof(instance_type);
record.init_holder = init_holder;
record.dealloc = dealloc;
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;
/* Register base classes specified via template arguments to class_, if any */
bool unused[] = { (add_base<options>(record), false)..., false };

View File

@ -49,6 +49,12 @@ struct BaseClass { virtual ~BaseClass() {} };
struct DerivedClass1 : BaseClass { };
struct DerivedClass2 : BaseClass { };
struct MismatchBase1 { };
struct MismatchDerived1 : MismatchBase1 { };
struct MismatchBase2 { };
struct MismatchDerived2 : MismatchBase2 { };
test_initializer inheritance([](py::module &m) {
py::class_<Pet> pet_class(m, "Pet");
pet_class
@ -97,4 +103,15 @@ test_initializer inheritance([](py::module &m) {
py::isinstance<Unregistered>(l[6])
);
});
m.def("test_mismatched_holder_type_1", []() {
auto m = py::module::import("__main__");
py::class_<MismatchBase1, std::shared_ptr<MismatchBase1>>(m, "MismatchBase1");
py::class_<MismatchDerived1, MismatchBase1>(m, "MismatchDerived1");
});
m.def("test_mismatched_holder_type_2", []() {
auto m = py::module::import("__main__");
py::class_<MismatchBase2>(m, "MismatchBase2");
py::class_<MismatchDerived2, std::shared_ptr<MismatchDerived2>, MismatchBase2>(m, "MismatchDerived2");
});
});

View File

@ -37,7 +37,8 @@ def test_automatic_upcasting():
assert type(return_class_1()).__name__ == "DerivedClass1"
assert type(return_class_2()).__name__ == "DerivedClass2"
assert type(return_none()).__name__ == "NoneType"
# Repeat these a few times in a random order to ensure no invalid caching is applied
# Repeat these a few times in a random order to ensure no invalid caching
# is applied
assert type(return_class_n(1)).__name__ == "DerivedClass1"
assert type(return_class_n(2)).__name__ == "DerivedClass2"
assert type(return_class_n(0)).__name__ == "BaseClass"
@ -53,3 +54,21 @@ def test_isinstance():
objects = [tuple(), dict(), Pet("Polly", "parrot")] + [Dog("Molly")] * 4
expected = (True, True, True, True, True, False, False)
assert test_isinstance(objects) == expected
def test_holder():
from pybind11_tests import test_mismatched_holder_type_1, test_mismatched_holder_type_2
with pytest.raises(RuntimeError) as excinfo:
test_mismatched_holder_type_1()
assert str(excinfo.value) == ("generic_type: type \"MismatchDerived1\" does not have "
"a non-default holder type while its base "
"\"MismatchBase1\" does")
with pytest.raises(RuntimeError) as excinfo:
test_mismatched_holder_type_2()
assert str(excinfo.value) == ("generic_type: type \"MismatchDerived2\" has a "
"non-default holder type while its base "
"\"MismatchBase2\" does not")