diff --git a/include/pybind11/detail/smart_holder_type_casters.h b/include/pybind11/detail/smart_holder_type_casters.h index b61d0a999..d59e679c6 100644 --- a/include/pybind11/detail/smart_holder_type_casters.h +++ b/include/pybind11/detail/smart_holder_type_casters.h @@ -270,6 +270,30 @@ struct smart_holder_type_caster_class_hooks : smart_holder_type_caster_base_tag return &modified_type_caster_generic_load_impl::local_load; } + using holder_type = pybindit::memory::smart_holder; + + template + static void from_raw_pointer_take_ownership_or_shared_from_this( + holder_type *uninitialized_location, WrappedType *value_ptr, ...) { + new (uninitialized_location) + holder_type(holder_type::from_raw_ptr_take_ownership(value_ptr)); + } + + template + static void from_raw_pointer_take_ownership_or_shared_from_this( + holder_type *uninitialized_location, + WrappedType *value_ptr, + const std::enable_shared_from_this *) { + auto shd_ptr + = std::dynamic_pointer_cast(detail::try_get_shared_from_this(value_ptr)); + if (shd_ptr) { + new (uninitialized_location) holder_type(holder_type::from_shared_ptr(shd_ptr)); + } else { + new (uninitialized_location) + holder_type(holder_type::from_shared_ptr(std::shared_ptr(value_ptr))); + } + } + template static void init_instance_for_type(detail::instance *inst, const void *holder_const_void_ptr) { // Need for const_cast is a consequence of the type_info::init_instance type: @@ -281,14 +305,15 @@ struct smart_holder_type_caster_class_hooks : smart_holder_type_caster_base_tag register_instance(inst, v_h.value_ptr(), v_h.type); v_h.set_instance_registered(); } - using holder_type = pybindit::memory::smart_holder; if (holder_void_ptr) { // Note: inst->owned ignored. auto holder_ptr = static_cast(holder_void_ptr); new (std::addressof(v_h.holder())) holder_type(std::move(*holder_ptr)); } else if (inst->owned) { - new (std::addressof(v_h.holder())) holder_type( - holder_type::from_raw_ptr_take_ownership(v_h.value_ptr())); + from_raw_pointer_take_ownership_or_shared_from_this( + std::addressof(v_h.holder()), + v_h.value_ptr(), + v_h.value_ptr()); } else { new (std::addressof(v_h.holder())) holder_type(holder_type::from_raw_ptr_unowned(v_h.value_ptr())); diff --git a/tests/test_class_sh_shared_from_this.cpp b/tests/test_class_sh_shared_from_this.cpp index 84cbe1215..b81d1fae5 100644 --- a/tests/test_class_sh_shared_from_this.cpp +++ b/tests/test_class_sh_shared_from_this.cpp @@ -45,8 +45,8 @@ PYBIND11_SMART_HOLDER_TYPE_CASTERS(SharedFromThisRef) PYBIND11_SMART_HOLDER_TYPE_CASTERS(SharedFromThisVirt) TEST_SUBMODULE(class_sh_shared_from_this, m) { - // py::classh(m, "MyObject3") - // .def(py::init()); + py::classh(m, "MyObject3") + .def(py::init()); m.def("make_myobject3_1", []() { return new MyObject3(8); }); m.def("make_myobject3_2", []() { return std::make_shared(9); }); m.def("print_myobject3_1", [](const MyObject3 *obj) { py::print(obj->toString()); }); @@ -55,7 +55,7 @@ TEST_SUBMODULE(class_sh_shared_from_this, m) { // m.def("print_myobject3_4", [](const std::shared_ptr *obj) { py::print((*obj)->toString()); }); using B = SharedFromThisRef::B; - // py::classh(m, "B"); + py::classh(m, "B"); py::classh(m, "SharedFromThisRef") .def(py::init<>()) .def_readonly("bad_wp", &SharedFromThisRef::value) @@ -69,6 +69,6 @@ TEST_SUBMODULE(class_sh_shared_from_this, m) { .def("set_holder", [](SharedFromThisRef &, std::shared_ptr) { return true; }); static std::shared_ptr sft(new SharedFromThisVirt()); - // py::classh(m, "SharedFromThisVirt") - // .def_static("get", []() { return sft.get(); }, py::return_value_policy::reference); + py::classh(m, "SharedFromThisVirt") + .def_static("get", []() { return sft.get(); }, py::return_value_policy::reference); } diff --git a/tests/test_class_sh_shared_from_this.py b/tests/test_class_sh_shared_from_this.py index 8fec1e480..28badf750 100644 --- a/tests/test_class_sh_shared_from_this.py +++ b/tests/test_class_sh_shared_from_this.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -import pytest +# import pytest from pybind11_tests import class_sh_shared_from_this as m from pybind11_tests import ConstructorStats def test_smart_ptr(capture): - pytest.skip("WIP") # Object3 for i, o in zip( [9, 8, 9], [m.MyObject3(9), m.make_myobject3_1(), m.make_myobject3_2()] @@ -32,7 +31,6 @@ def test_smart_ptr(capture): def test_shared_from_this_ref(): - pytest.skip("WIP") s = m.SharedFromThisRef() stats = ConstructorStats.get(m.B) assert stats.alive() == 2 @@ -48,7 +46,6 @@ def test_shared_from_this_ref(): def test_shared_from_this_bad_wp(): - pytest.skip("WIP") s = m.SharedFromThisRef() stats = ConstructorStats.get(m.B) assert stats.alive() == 2 @@ -57,7 +54,7 @@ def test_shared_from_this_bad_wp(): assert stats.alive() == 2 assert s.set_ref(bad_wp) # with pytest.raises(RuntimeError) as excinfo: - if 1: + if 1: # XXX XXX XXX assert s.set_holder(bad_wp) # assert "Unable to cast from non-held to held instance" in str(excinfo.value) del bad_wp, s @@ -65,7 +62,6 @@ def test_shared_from_this_bad_wp(): def test_shared_from_this_copy(): - pytest.skip("WIP") s = m.SharedFromThisRef() stats = ConstructorStats.get(m.B) assert stats.alive() == 2 @@ -80,7 +76,6 @@ def test_shared_from_this_copy(): def test_shared_from_this_holder_ref(): - pytest.skip("WIP") s = m.SharedFromThisRef() stats = ConstructorStats.get(m.B) assert stats.alive() == 2 @@ -96,7 +91,6 @@ def test_shared_from_this_holder_ref(): def test_shared_from_this_holder_copy(): - pytest.skip("WIP") s = m.SharedFromThisRef() stats = ConstructorStats.get(m.B) assert stats.alive() == 2