Throw exception on returning a unique_ptr or shared_ptr nullptr (or any other holder type) from py::init, rather than crashing (#2430)

This commit is contained in:
Yannick Jadoul 2020-08-25 18:51:06 +02:00 committed by GitHub
parent 5b59b7b263
commit a2bb297b32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 3 deletions

View File

@ -132,6 +132,7 @@ void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
no_nullptr(ptr);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "

View File

@ -154,6 +154,8 @@ TEST_SUBMODULE(factory_constructors, m) {
MAKE_TAG_TYPE(TF4);
MAKE_TAG_TYPE(TF5);
MAKE_TAG_TYPE(null_ptr);
MAKE_TAG_TYPE(null_unique_ptr);
MAKE_TAG_TYPE(null_shared_ptr);
MAKE_TAG_TYPE(base);
MAKE_TAG_TYPE(invalid_base);
MAKE_TAG_TYPE(alias);
@ -194,6 +196,8 @@ TEST_SUBMODULE(factory_constructors, m) {
// Returns nullptr:
.def(py::init([](null_ptr_tag) { return (TestFactory3 *) nullptr; }))
.def(py::init([](null_unique_ptr_tag) { return std::unique_ptr<TestFactory3>(); }))
.def(py::init([](null_shared_ptr_tag) { return std::shared_ptr<TestFactory3>(); }))
.def_readwrite("value", &TestFactory3::value)
;

View File

@ -41,9 +41,12 @@ def test_init_factory_basic():
z3 = m.TestFactory3("bye")
assert z3.value == "bye"
with pytest.raises(TypeError) as excinfo:
m.TestFactory3(tag.null_ptr)
assert str(excinfo.value) == "pybind11::init(): factory function returned nullptr"
for null_ptr_kind in [tag.null_ptr,
tag.null_unique_ptr,
tag.null_shared_ptr]:
with pytest.raises(TypeError) as excinfo:
m.TestFactory3(null_ptr_kind)
assert str(excinfo.value) == "pybind11::init(): factory function returned nullptr"
assert [i.alive() for i in cstats] == [3, 3, 3]
assert ConstructorStats.detail_reg_inst() == n_inst + 9