diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 0cba8a128..8ecc20f25 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1067,8 +1067,14 @@ public: + clean_type_id(typeinfo->cpptype->name()) + ")"); } - template - using cast_op_type = std::unique_ptr; + template + using cast_op_type + = conditional_t::type, + const std::unique_ptr &>::value + || std::is_same::type, + const std::unique_ptr &>::value, + const std::unique_ptr &, + std::unique_ptr>; explicit operator std::unique_ptr() { if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) { @@ -1077,6 +1083,28 @@ public: pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__)); } + explicit operator const std::unique_ptr &() { + if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) { + // Get shared_ptr to ensure that the Python object is not disowned elsewhere. + shared_ptr_storage = sh_load_helper.load_as_shared_ptr(value); + // Build a temporary unique_ptr that is meant to never expire. + unique_ptr_storage = std::shared_ptr>( + new std::unique_ptr{ + sh_load_helper.template load_as_const_unique_ptr( + shared_ptr_storage.get())}, + [](std::unique_ptr *ptr) { + if (!ptr) { + pybind11_fail("FATAL: `const std::unique_ptr &` was disowned " + "(EXPECT UNDEFINED BEHAVIOR)."); + } + (void) ptr->release(); + delete ptr; + }); + return *unique_ptr_storage; + } + pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__)); + } + bool try_implicit_casts(handle src, bool convert) { for (auto &cast : typeinfo->implicit_casts) { move_only_holder_caster sub_caster(*cast.first); @@ -1097,6 +1125,8 @@ public: static bool try_direct_conversions(handle) { return false; } smart_holder_type_caster_support::load_helper> sh_load_helper; // Const2Mutbl + std::shared_ptr shared_ptr_storage; // Serves as a pseudo lock. + std::shared_ptr> unique_ptr_storage; }; #endif // PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT diff --git a/include/pybind11/detail/struct_smart_holder.h b/include/pybind11/detail/struct_smart_holder.h index 401a6fd1b..b1e24d7bb 100644 --- a/include/pybind11/detail/struct_smart_holder.h +++ b/include/pybind11/detail/struct_smart_holder.h @@ -234,7 +234,7 @@ struct smart_holder { // Caller is responsible for precondition: ensure_compatible_rtti_uqp_del() must succeed. template std::unique_ptr extract_deleter(const char *context) const { - auto *gd = std::get_deleter(vptr); + const auto *gd = std::get_deleter(vptr); if (gd && gd->use_del_fun) { const auto &custom_deleter_ptr = gd->del_fun.template target>(); if (custom_deleter_ptr == nullptr) { @@ -242,7 +242,9 @@ struct smart_holder { std::string("smart_holder::extract_deleter() precondition failure (") + context + ")."); } - return std::unique_ptr(new D(std::move(custom_deleter_ptr->deleter))); + static_assert(std::is_copy_constructible::value, + "Required for compatibility with smart_holder functionality."); + return std::unique_ptr(new D(custom_deleter_ptr->deleter)); } return nullptr; } diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index 8fb0b9e3c..f1a5d0d8d 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -814,6 +814,19 @@ struct load_helper : value_and_holder_helper { return result; } + + // This assumes load_as_shared_ptr succeeded(), and the returned shared_ptr is still alive. + // The returned unique_ptr is meant to never expire (the behavior is undefined otherwise). + template + std::unique_ptr + load_as_const_unique_ptr(T *raw_type_ptr, const char *context = "load_as_const_unique_ptr") { + if (!have_holder()) { + return unique_with_deleter(nullptr, std::unique_ptr()); + } + holder().template ensure_compatible_rtti_uqp_del(context); + return unique_with_deleter( + raw_type_ptr, std::move(holder().template extract_deleter(context))); + } }; PYBIND11_NAMESPACE_END(smart_holder_type_caster_support) diff --git a/tests/test_class_sh_basic.cpp b/tests/test_class_sh_basic.cpp index 9602387b3..ac19a750b 100644 --- a/tests/test_class_sh_basic.cpp +++ b/tests/test_class_sh_basic.cpp @@ -120,6 +120,17 @@ std::string get_mtxt(atyp const &obj) { return obj.mtxt; } std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast(&obj); } std::unique_ptr unique_ptr_roundtrip(std::unique_ptr obj) { return obj; } + +std::string pass_unique_ptr_cref(const std::unique_ptr &obj) { return obj->mtxt; } + +const std::unique_ptr &rtrn_unique_ptr_cref(const std::string &mtxt) { + static std::unique_ptr obj{new atyp{"static_ctor_arg"}}; + if (!mtxt.empty()) { + obj->mtxt = mtxt; + } + return obj; +} + const std::unique_ptr &unique_ptr_cref_roundtrip(const std::unique_ptr &obj) { return obj; } @@ -217,6 +228,9 @@ TEST_SUBMODULE(class_sh_basic, m) { m.def("get_ptr", get_ptr); // pass_cref m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp + + m.def("pass_unique_ptr_cref", pass_unique_ptr_cref); + m.def("rtrn_unique_ptr_cref", rtrn_unique_ptr_cref); m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip); py::classh(m, "SharedPtrStash") diff --git a/tests/test_class_sh_basic.py b/tests/test_class_sh_basic.py index 87f1f8f09..7db7d31b7 100644 --- a/tests/test_class_sh_basic.py +++ b/tests/test_class_sh_basic.py @@ -151,19 +151,31 @@ def test_unique_ptr_roundtrip(num_round_trips=1000): id_orig = id_rtrn -# This currently fails, because a unique_ptr is always loaded by value -# due to pybind11/detail/smart_holder_type_casters.h:689 -# I think, we need to provide more cast operators. -@pytest.mark.skip() -def test_unique_ptr_cref_roundtrip(): - orig = m.atyp("passenger") - id_orig = id(orig) - mtxt_orig = m.get_mtxt(orig) +def test_pass_unique_ptr_cref(): + obj = m.atyp("ctor_arg") + assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj)) + assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.pass_unique_ptr_cref(obj)) + assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj)) - recycled = m.unique_ptr_cref_roundtrip(orig) - assert m.get_mtxt(orig) == mtxt_orig - assert m.get_mtxt(recycled) == mtxt_orig - assert id(recycled) == id_orig + +def test_rtrn_unique_ptr_cref(): + obj0 = m.rtrn_unique_ptr_cref("") + assert m.get_mtxt(obj0) == "static_ctor_arg" + obj1 = m.rtrn_unique_ptr_cref("passed_mtxt_1") + assert m.get_mtxt(obj1) == "passed_mtxt_1" + assert m.get_mtxt(obj0) == "passed_mtxt_1" + assert obj0 is obj1 + + +def test_unique_ptr_cref_roundtrip(num_round_trips=1000): + # Multiple roundtrips to stress-test implementation. + orig = m.atyp("passenger") + mtxt_orig = m.get_mtxt(orig) + recycled = orig + for _ in range(num_round_trips): + recycled = m.unique_ptr_cref_roundtrip(recycled) + assert recycled is orig + assert m.get_mtxt(recycled) == mtxt_orig @pytest.mark.parametrize( diff --git a/tests/test_class_sh_trampoline_shared_from_this.cpp b/tests/test_class_sh_trampoline_shared_from_this.cpp index f664cac51..9c2e4ec76 100644 --- a/tests/test_class_sh_trampoline_shared_from_this.cpp +++ b/tests/test_class_sh_trampoline_shared_from_this.cpp @@ -87,7 +87,10 @@ long pass_shared_ptr(const std::shared_ptr &obj) { return sft.use_count(); } -void pass_unique_ptr_cref(const std::unique_ptr &) { +std::string pass_unique_ptr_cref(const std::unique_ptr &obj) { + return obj ? obj->history : ""; +} +void pass_unique_ptr_rref(std::unique_ptr &&) { throw std::runtime_error("Expected to not be reached."); } @@ -138,6 +141,7 @@ TEST_SUBMODULE(class_sh_trampoline_shared_from_this, m) { m.def("use_count", use_count); m.def("pass_shared_ptr", pass_shared_ptr); m.def("pass_unique_ptr_cref", pass_unique_ptr_cref); + m.def("pass_unique_ptr_rref", pass_unique_ptr_rref); m.def("make_pure_cpp_sft_raw_ptr", make_pure_cpp_sft_raw_ptr); m.def("make_pure_cpp_sft_unq_ptr", make_pure_cpp_sft_unq_ptr); m.def("make_pure_cpp_sft_shd_ptr", make_pure_cpp_sft_shd_ptr); diff --git a/tests/test_class_sh_trampoline_shared_from_this.py b/tests/test_class_sh_trampoline_shared_from_this.py index b112bb5d9..7c5ee0e8e 100644 --- a/tests/test_class_sh_trampoline_shared_from_this.py +++ b/tests/test_class_sh_trampoline_shared_from_this.py @@ -137,8 +137,10 @@ def test_pass_released_shared_ptr_as_unique_ptr(): obj = PySft("PySft") stash1 = m.SftSharedPtrStash(1) stash1.Add(obj) # Releases shared_ptr to C++. + assert m.pass_unique_ptr_cref(obj) == "PySft_Stash1Add" + assert obj.history == "PySft_Stash1Add" with pytest.raises(ValueError) as exc_info: - m.pass_unique_ptr_cref(obj) + m.pass_unique_ptr_rref(obj) assert str(exc_info.value) == ( "Python instance is currently owned by a std::shared_ptr." )