From 6601ec4ea761e33babda78f5623807a43300d19b Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 13 Jan 2021 10:48:08 -0800 Subject: [PATCH] static handle cast implementations for rtrn_shmp, rtrn_shcp. --- include/pybind11/classh.h | 24 +++++------------ tests/test_classh_wip.cpp | 55 ++++++++++++++++++++++++++++++++------- tests/test_classh_wip.py | 4 +-- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/include/pybind11/classh.h b/include/pybind11/classh.h index 61bc50d4b..77a9a9328 100644 --- a/include/pybind11/classh.h +++ b/include/pybind11/classh.h @@ -268,30 +268,18 @@ public: private: /// Initialize holder object, variant 1: object derives from enable_shared_from_this template - static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + static void init_holder(detail::value_and_holder &/*v_h*/, const holder_type * /* unused */, const std::enable_shared_from_this * /* dummy */) { - try { - auto sh = std::dynamic_pointer_cast( // Was: typename holder_type::element_type - v_h.value_ptr()->shared_from_this()); - if (sh) { - new (std::addressof(v_h.holder())) holder_type(std::move(sh)); - v_h.set_holder_constructed(); - } - } catch (const std::bad_weak_ptr &) {} - - if (!v_h.holder_constructed() && inst->owned) { - new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); - v_h.set_holder_constructed(); - } + throw std::runtime_error("Not implemented: classh::init_holder enable_shared_from_this."); } /// Initialize holder object, variant 2: try to construct from existing holder object, if possible - static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + static void init_holder(detail::value_and_holder &v_h, const holder_type *holder_ptr, const void * /* dummy -- not enable_shared_from_this) */) { if (holder_ptr) { - new (std::addressof(v_h.holder())) holder_type(*reinterpret_cast(holder_ptr)); + new (std::addressof(v_h.holder())) holder_type(std::move(*holder_ptr)); v_h.set_holder_constructed(); - } else if (inst->owned || detail::always_construct_holder::value) { + } else { // Was: if (inst->owned || detail::always_construct_holder::value) new (std::addressof(v_h.holder())) holder_type( std::move(holder_type::from_raw_ptr_take_ownership(v_h.value_ptr()))); v_h.set_holder_constructed(); @@ -308,7 +296,7 @@ private: register_instance(inst, v_h.value_ptr(), v_h.type); v_h.set_instance_registered(); } - init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr()); + init_holder(v_h, static_cast(holder_ptr), v_h.value_ptr()); } /// Deallocates an instance; via holder, if constructed; otherwise via operator delete. diff --git a/tests/test_classh_wip.cpp b/tests/test_classh_wip.cpp index 5c9fcdb9d..66d0ab3a4 100644 --- a/tests/test_classh_wip.cpp +++ b/tests/test_classh_wip.cpp @@ -93,7 +93,7 @@ struct type_caster : smart_holder_type_caster_load { } static handle cast(mpty &src, return_value_policy policy, handle parent) { - return cast(const_cast(src), policy, parent); // Mtbl2Const + return cast(const_cast(src), policy, parent); // Mutbl2Const } static handle cast(mpty const *src, return_value_policy policy, handle parent) { @@ -108,7 +108,7 @@ struct type_caster : smart_holder_type_caster_load { } static handle cast(mpty *src, return_value_policy policy, handle parent) { - return cast(const_cast(src), policy, parent); // Mtbl2Const + return cast(const_cast(src), policy, parent); // Mutbl2Const } template @@ -287,10 +287,43 @@ template <> struct type_caster> : smart_holder_type_caster_load { static constexpr auto name = _>(); - static handle cast(const std::shared_ptr & /*src*/, - return_value_policy /*policy*/, - handle /*parent*/) { - return str("cast_shmp").release(); + static handle + cast(const std::shared_ptr &src, return_value_policy policy, handle parent) { + if (policy != return_value_policy::automatic + && policy != return_value_policy::reference_internal) { + // IMPROVEABLE: Error message. + throw cast_error("Invalid return_value_policy for shared_ptr."); + } + + auto src_raw_ptr = src.get(); + auto st = type_caster::src_and_type(src_raw_ptr); + if (st.first == nullptr) + return none().release(); // PyErr was set already. + + void *src_raw_void_ptr = static_cast(src_raw_ptr); + const detail::type_info *tinfo = st.second; + auto it_instances = get_internals().registered_instances.equal_range(src_raw_void_ptr); + // Loop copied from type_caster_generic::cast. + for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { + for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { + if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) + // MISSING: Enforcement of consistency with existing smart_holder. + // MISSING: keep_alive. + return handle((PyObject *) it_i->second).inc_ref(); + } + } + + object inst = reinterpret_steal(make_new_instance(tinfo->type)); + instance *inst_raw_ptr = reinterpret_cast(inst.ptr()); + inst_raw_ptr->owned = false; // Not actually used. + + auto smhldr = pybindit::memory::smart_holder::from_shared_ptr(src); + tinfo->init_instance(inst_raw_ptr, static_cast(&smhldr)); + + if (policy == return_value_policy::reference_internal) + keep_alive_impl(inst, parent); + + return inst.release(); } template @@ -303,10 +336,12 @@ template <> struct type_caster> : smart_holder_type_caster_load { static constexpr auto name = _>(); - static handle cast(const std::shared_ptr & /*src*/, - return_value_policy /*policy*/, - handle /*parent*/) { - return str("cast_shcp").release(); + static handle + cast(const std::shared_ptr &src, return_value_policy policy, handle parent) { + return type_caster>::cast( + std::const_pointer_cast(src), // Const2Mutbl + policy, + parent); } template diff --git a/tests/test_classh_wip.py b/tests/test_classh_wip.py index 93aaae01d..148270d12 100644 --- a/tests/test_classh_wip.py +++ b/tests/test_classh_wip.py @@ -39,8 +39,8 @@ def test_load(): def test_cast_shared_ptr(): - assert m.rtrn_mpty_shmp() == "cast_shmp" - assert m.rtrn_mpty_shcp() == "cast_shcp" + assert m.get_mtxt(m.rtrn_mpty_shmp()) == "rtrn_shmp" + assert m.get_mtxt(m.rtrn_mpty_shcp()) == "rtrn_shcp" def test_load_shared_ptr():