diff --git a/tests/test_classh_wip.cpp b/tests/test_classh_wip.cpp index f9dad5572..58d991bd8 100644 --- a/tests/test_classh_wip.cpp +++ b/tests/test_classh_wip.cpp @@ -42,6 +42,8 @@ std::string pass_mpty_uqcp(std::unique_ptr obj) { return "pass_uqcp: // clang-format on +std::string get_mtxt(mpty const &obj) { return obj.mtxt; } + } // namespace classh_wip } // namespace pybind11_tests @@ -84,8 +86,15 @@ struct type_caster : smart_holder_type_caster_load { return str("cast_mref").release(); } - static handle cast(mpty const * /*src*/, return_value_policy /*policy*/, handle /*parent*/) { - return str("cast_cptr").release(); + static handle cast(mpty const *src, return_value_policy policy, handle parent) { + // type_caster_base BEGIN + // clang-format off + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, policy, parent, st.second, + make_copy_constructor(src), make_move_constructor(src)); + // clang-format on + // type_caster_base END } static handle cast(mpty * /*src*/, return_value_policy /*policy*/, handle /*parent*/) { @@ -116,6 +125,59 @@ struct type_caster : smart_holder_type_caster_load { operator mpty*() { return smhldr_ptr->as_raw_ptr_unowned(); } // clang-format on + + using itype = mpty; + + // type_caster_base BEGIN + // clang-format off + + // Returns a (pointer, type_info) pair taking care of necessary type lookup for a + // polymorphic type (using RTTI by default, but can be overridden by specializing + // polymorphic_type_hook). If the instance isn't derived, returns the base version. + static std::pair src_and_type(const itype *src) { + auto &cast_type = typeid(itype); + const std::type_info *instance_type = nullptr; + const void *vsrc = polymorphic_type_hook::get(src, instance_type); + if (instance_type && !same_type(cast_type, *instance_type)) { + // This is a base pointer to a derived type. If the derived type is registered + // with pybind11, we want to make the full derived object available. + // In the typical case where itype is polymorphic, we get the correct + // derived pointer (which may be != base pointer) by a dynamic_cast to + // most derived type. If itype is not polymorphic, we won't get here + // except via a user-provided specialization of polymorphic_type_hook, + // and the user has promised that no this-pointer adjustment is + // required in that case, so it's OK to use static_cast. + if (const auto *tpi = get_type_info(*instance_type)) + return {vsrc, tpi}; + } + // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so + // don't do a cast + return type_caster_generic::src_and_type(src, cast_type, instance_type); + } + + using Constructor = void *(*)(const void *); + + /* Only enabled when the types are {copy,move}-constructible *and* when the type + does not have a private operator new implementation. */ + template ::value>> + static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { + return [](const void *arg) -> void * { + return new T(*reinterpret_cast(arg)); + }; + } + + template ::value>> + static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { + return [](const void *arg) -> void * { + return new T(std::move(*const_cast(reinterpret_cast(arg)))); + }; + } + + static Constructor make_copy_constructor(...) { return nullptr; } + static Constructor make_move_constructor(...) { return nullptr; } + + // clang-format on + // type_caster_base END }; template <> @@ -221,6 +283,8 @@ TEST_SUBMODULE(classh_wip, m) { m.def("pass_mpty_uqmp", pass_mpty_uqmp); m.def("pass_mpty_uqcp", pass_mpty_uqcp); + + m.def("get_mtxt", get_mtxt); // Requires pass_mpty_cref to work properly. } } // namespace classh_wip diff --git a/tests/test_classh_wip.py b/tests/test_classh_wip.py index 4ecdee8b6..719b303a8 100644 --- a/tests/test_classh_wip.py +++ b/tests/test_classh_wip.py @@ -18,7 +18,7 @@ def test_cast(): assert m.rtrn_mpty_rref() == "cast_rref" assert m.rtrn_mpty_cref() == "cast_cref" assert m.rtrn_mpty_mref() == "cast_mref" - assert m.rtrn_mpty_cptr() == "cast_cptr" + assert m.get_mtxt(m.rtrn_mpty_cptr()) == "rtrn_cptr" assert m.rtrn_mpty_mptr() == "cast_mptr"