diff --git a/include/pybind11/detail/classh_type_casters.h b/include/pybind11/detail/classh_type_casters.h index 267d46c01..033a5ead1 100644 --- a/include/pybind11/detail/classh_type_casters.h +++ b/include/pybind11/detail/classh_type_casters.h @@ -231,7 +231,8 @@ struct smart_holder_type_caster_load { return true; } - T *as_raw_ptr_unowned() { + T *loaded_as_raw_ptr_unowned() { + // BYPASSES smart_holder type checking completely. if (load_impl.loaded_v_h_cpptype != nullptr) { if (load_impl.reinterpret_cast_deemed_ok) { return static_cast(loaded_smhldr_ptr->vptr.get()); @@ -241,7 +242,13 @@ struct smart_holder_type_caster_load { return static_cast(implicit_casted); } } - return loaded_smhldr_ptr->as_raw_ptr_unowned(); + return static_cast(loaded_smhldr_ptr->vptr.get()); + } + + std::shared_ptr loaded_as_shared_ptr() { + T *raw_ptr = loaded_as_raw_ptr_unowned(); + // BYPASSES smart_holder shared_ptr tracking completely. + return std::shared_ptr(loaded_smhldr_ptr->vptr, raw_ptr); } std::unique_ptr loaded_as_unique_ptr() { @@ -350,8 +357,8 @@ struct classh_type_caster : smart_holder_type_caster_load { operator T&&() && { return this->loaded_smhldr_ptr->template rvalue_ref(); } operator T const&() { return this->loaded_smhldr_ptr->template lvalue_ref(); } operator T&() { return this->loaded_smhldr_ptr->template lvalue_ref(); } - operator T const*() { return this->as_raw_ptr_unowned(); } - operator T*() { return this->as_raw_ptr_unowned(); } + operator T const*() { return this->loaded_as_raw_ptr_unowned(); } + operator T*() { return this->loaded_as_raw_ptr_unowned(); } // clang-format on @@ -487,7 +494,7 @@ struct classh_type_caster> : smart_holder_type_caster_load template using cast_op_type = std::shared_ptr; - operator std::shared_ptr() { return this->loaded_smhldr_ptr->template as_shared_ptr(); } + operator std::shared_ptr() { return this->loaded_as_shared_ptr(); } }; template @@ -505,9 +512,7 @@ struct classh_type_caster> : smart_holder_type_caster_l template using cast_op_type = std::shared_ptr; - operator std::shared_ptr() { - return this->loaded_smhldr_ptr->template as_shared_ptr(); - } + operator std::shared_ptr() { return this->loaded_as_shared_ptr(); } // Mutbl2Const }; template diff --git a/tests/test_classh_inheritance.cpp b/tests/test_classh_inheritance.cpp index ed944e521..2694d051f 100644 --- a/tests/test_classh_inheritance.cpp +++ b/tests/test_classh_inheritance.cpp @@ -2,6 +2,8 @@ #include +#include + namespace pybind11_tests { namespace classh_inheritance { @@ -25,6 +27,12 @@ inline base *rtrn_mptr_drvd_up_cast() { return new drvd; } inline int pass_cptr_base(base const *b) { return b->id() + 11; } inline int pass_cptr_drvd(drvd const *d) { return d->id() + 12; } + +inline std::shared_ptr rtrn_shmp_drvd() { return std::shared_ptr(new drvd); } +inline std::shared_ptr rtrn_shmp_drvd_up_cast() { return std::shared_ptr(new drvd); } + +inline int pass_shcp_base(std::shared_ptr b) { return b->id() + 21; } +inline int pass_shcp_drvd(std::shared_ptr d) { return d->id() + 22; } // clang-format on using base1 = base_template<110>; @@ -69,6 +77,11 @@ TEST_SUBMODULE(classh_inheritance, m) { m.def("pass_cptr_base", pass_cptr_base); m.def("pass_cptr_drvd", pass_cptr_drvd); + m.def("rtrn_shmp_drvd", rtrn_shmp_drvd); + m.def("rtrn_shmp_drvd_up_cast", rtrn_shmp_drvd_up_cast); + m.def("pass_shcp_base", pass_shcp_base); + m.def("pass_shcp_drvd", pass_shcp_drvd); + py::classh(m, "base1").def(py::init<>()); // __init__ needed for Python inheritance. py::classh(m, "base2").def(py::init<>()); py::classh(m, "drvd2"); diff --git a/tests/test_classh_inheritance.py b/tests/test_classh_inheritance.py index 8421bc304..cf72c5bc4 100644 --- a/tests/test_classh_inheritance.py +++ b/tests/test_classh_inheritance.py @@ -9,6 +9,12 @@ def test_rtrn_mptr_drvd_pass_cptr_base(): assert i == 2 * 100 + 11 +def test_rtrn_shmp_drvd_pass_shcp_base(): + d = m.rtrn_shmp_drvd() + i = m.pass_shcp_base(d) # load_impl Case 2a + assert i == 2 * 100 + 21 + + def test_rtrn_mptr_drvd_up_cast_pass_cptr_drvd(): b = m.rtrn_mptr_drvd_up_cast() # the base return is down-cast immediately. @@ -17,6 +23,14 @@ def test_rtrn_mptr_drvd_up_cast_pass_cptr_drvd(): assert i == 2 * 100 + 12 +def test_rtrn_shmp_drvd_up_cast_pass_shcp_drvd(): + b = m.rtrn_shmp_drvd_up_cast() + # the base return is down-cast immediately. + assert b.__class__.__name__ == "drvd" + i = m.pass_shcp_drvd(b) + assert i == 2 * 100 + 22 + + def test_rtrn_mptr_drvd2_pass_cptr_bases(): d = m.rtrn_mptr_drvd2() i1 = m.pass_cptr_base1(d) # load_impl Case 2c