From 7a6d30ca58154ef8c571cdec77d7303d5a523220 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Tue, 2 Jul 2024 18:04:57 -0700 Subject: [PATCH] Fix `rtrn_shmp`, `rtrn_shmp` by transferring `smart_holder_from_shared_ptr()` functionality from smart_holder branch. --- include/pybind11/cast.h | 6 +- .../detail/smart_holder_type_caster_support.h | 65 +++++++++++++++++++ tests/test_class_sh_basic.py | 4 +- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 296056c89..aca6ce8b1 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -860,9 +860,9 @@ public: explicit operator std::shared_ptr *() { return std::addressof(holder); } explicit operator std::shared_ptr &() { return holder; } - static handle cast(const std::shared_ptr &src, return_value_policy, handle) { - const auto *ptr = holder_helper>::get(src); - return type_caster_base::cast_holder(ptr, &src); + static handle + cast(const std::shared_ptr &src, return_value_policy policy, handle parent) { + return smart_holder_type_caster_support::shared_ptr_to_python(src, policy, parent); } protected: diff --git a/include/pybind11/detail/smart_holder_type_caster_support.h b/include/pybind11/detail/smart_holder_type_caster_support.h index 4381a32e3..4bb16a7c0 100644 --- a/include/pybind11/detail/smart_holder_type_caster_support.h +++ b/include/pybind11/detail/smart_holder_type_caster_support.h @@ -101,6 +101,71 @@ unique_ptr_to_python(std::unique_ptr &&unq_ptr, return_value_policy policy nullptr, std::addressof(unq_ptr)); } +template +handle smart_holder_from_shared_ptr(const std::shared_ptr &src, + return_value_policy policy, + handle parent, + const std::pair &st) { + switch (policy) { + case return_value_policy::automatic: + case return_value_policy::automatic_reference: + break; + case return_value_policy::take_ownership: + throw cast_error("Invalid return_value_policy for shared_ptr (take_ownership)."); + case return_value_policy::copy: + case return_value_policy::move: + break; + case return_value_policy::reference: + throw cast_error("Invalid return_value_policy for shared_ptr (reference)."); + case return_value_policy::reference_internal: + break; + } + if (!src) { + return none().release(); + } + + auto src_raw_ptr = src.get(); + assert(st.second != nullptr); + // BAKEIN_WIP: Better Const2Mutbl + void *src_raw_void_ptr = const_cast(static_cast(src_raw_ptr)); + const detail::type_info *tinfo = st.second; + if (handle existing_inst = find_registered_python_instance(src_raw_void_ptr, tinfo)) { + // SMART_HOLDER_WIP: MISSING: Enforcement of consistency with existing smart_holder. + // SMART_HOLDER_WIP: MISSING: keep_alive. + return existing_inst; + } + + auto inst = reinterpret_steal(make_new_instance(tinfo->type)); + auto *inst_raw_ptr = reinterpret_cast(inst.ptr()); + inst_raw_ptr->owned = true; + void *&valueptr = values_and_holders(inst_raw_ptr).begin()->value_ptr(); + valueptr = src_raw_void_ptr; + + auto smhldr = pybindit::memory::smart_holder::from_shared_ptr( + std::shared_ptr(src, const_cast(st.first))); + tinfo->init_instance(inst_raw_ptr, static_cast(&smhldr)); + + if (policy == return_value_policy::reference_internal) { + keep_alive_impl(inst, parent); + } + + return inst.release(); +} + +template +handle shared_ptr_to_python(const std::shared_ptr &shd_ptr, + return_value_policy policy, + handle parent) { + const auto *ptr = shd_ptr.get(); + auto st = type_caster_base::src_and_type(ptr); + if (st.second == nullptr) { + return handle(); // no type info: error will be set already + } + if (st.second->default_holder) { + return smart_holder_from_shared_ptr(shd_ptr, policy, parent, st); + } + return type_caster_base::cast_holder(ptr, &shd_ptr); +} PYBIND11_NAMESPACE_END(smart_holder_type_caster_support) PYBIND11_NAMESPACE_END(detail) diff --git a/tests/test_class_sh_basic.py b/tests/test_class_sh_basic.py index 2610a8d94..1ad958952 100644 --- a/tests/test_class_sh_basic.py +++ b/tests/test_class_sh_basic.py @@ -26,8 +26,8 @@ def test_atyp_constructors(): (m.rtrn_mref, "rtrn_mref(_MvCtor)*_CpCtor"), (m.rtrn_cptr, "rtrn_cptr"), (m.rtrn_mptr, "rtrn_mptr"), - # BAKEIN_BREAK (m.rtrn_shmp, "rtrn_shmp"), - # BAKEIN_BREAK (m.rtrn_shcp, "rtrn_shcp"), + (m.rtrn_shmp, "rtrn_shmp"), + (m.rtrn_shcp, "rtrn_shcp"), (m.rtrn_uqmp, "rtrn_uqmp"), # BAKEIN_BREAK (m.rtrn_uqcp, "rtrn_uqcp"), (m.rtrn_udmp, "rtrn_udmp"),