diff --git a/include/pybind11/detail/smart_holder_type_casters.h b/include/pybind11/detail/smart_holder_type_casters.h index c263013ea..f048aaf29 100644 --- a/include/pybind11/detail/smart_holder_type_casters.h +++ b/include/pybind11/detail/smart_holder_type_casters.h @@ -716,6 +716,15 @@ struct smart_holder_type_caster> : smart_holder_type_caste return inst.release(); } + static handle cast(const std::unique_ptr &src, return_value_policy policy, handle parent) { + if (!src) + return none().release(); + if (policy == return_value_policy::automatic) + policy = return_value_policy::reference_internal; + if (policy != return_value_policy::reference_internal) + throw cast_error("Invalid return_value_policy for unique_ptr&"); + return smart_holder_type_caster::cast(src.get(), policy, parent); + } template using cast_op_type = std::unique_ptr; diff --git a/tests/test_class_sh_basic.cpp b/tests/test_class_sh_basic.cpp index 9ec947baf..9bfbe081a 100644 --- a/tests/test_class_sh_basic.cpp +++ b/tests/test_class_sh_basic.cpp @@ -17,6 +17,17 @@ struct atyp { // Short for "any type". atyp(atyp &&other) { mtxt = other.mtxt + "_MvCtor"; } }; +struct uconsumer { // unique_ptr consumer + std::unique_ptr held; + bool valid() const { return static_cast(held); } + + void pass_valu(std::unique_ptr obj) { held = std::move(obj); } + void pass_rref(std::unique_ptr &&obj) { held = std::move(obj); } + std::unique_ptr rtrn_valu() { return std::move(held); } + std::unique_ptr& rtrn_lref() { return held; } + const std::unique_ptr &rtrn_cref() { return held; } +}; + // clang-format off atyp rtrn_valu() { atyp obj{"rtrn_valu"}; return obj; } @@ -57,7 +68,11 @@ std::string pass_udcp(std::unique_ptr obj) { return "pass_udcp // Helpers for testing. std::string get_mtxt(atyp const &obj) { return obj.mtxt; } +std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast(&obj); } + std::unique_ptr unique_ptr_roundtrip(std::unique_ptr obj) { return obj; } +const std::unique_ptr& unique_ptr_cref_roundtrip(const std::unique_ptr& obj) { return obj; } + struct SharedPtrStash { std::vector> stash; void Add(std::shared_ptr obj) { stash.push_back(obj); } @@ -67,6 +82,7 @@ struct SharedPtrStash { } // namespace pybind11_tests PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::atyp) +PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::uconsumer) PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::SharedPtrStash) namespace pybind11_tests { @@ -112,10 +128,23 @@ TEST_SUBMODULE(class_sh_basic, m) { m.def("pass_udmp", pass_udmp); m.def("pass_udcp", pass_udcp); + py::classh(m, "uconsumer") + .def(py::init<>()) + .def("valid", &uconsumer::valid) + .def("pass_valu", &uconsumer::pass_valu) + .def("pass_rref", &uconsumer::pass_rref) + .def("rtrn_valu", &uconsumer::rtrn_valu) + .def("rtrn_lref", &uconsumer::rtrn_lref) + .def("rtrn_cref", &uconsumer::rtrn_cref); + // Helpers for testing. // These require selected functions above to work first, as indicated: m.def("get_mtxt", get_mtxt); // pass_cref + m.def("get_ptr", get_ptr); // pass_cref + m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp + m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip); + py::classh(m, "SharedPtrStash") .def(py::init<>()) .def("Add", &SharedPtrStash::Add, py::arg("obj")); diff --git a/tests/test_class_sh_basic.py b/tests/test_class_sh_basic.py index 9e2813a96..3727dcdb5 100644 --- a/tests/test_class_sh_basic.py +++ b/tests/test_class_sh_basic.py @@ -118,6 +118,48 @@ def test_unique_ptr_roundtrip(num_round_trips=1000): id_orig = id_rtrn +# This currently fails, because a unique_ptr is always loaded by value +# due to pybind11/detail/smart_holder_type_casters.h:689 +# I think, we need to provide more cast operators. +@pytest.mark.skip +def test_unique_ptr_cref_roundtrip(num_round_trips=1000): + orig = m.atyp("passenger") + id_orig = id(orig) + mtxt_orig = m.get_mtxt(orig) + + recycled = m.unique_ptr_cref_roundtrip(orig) + assert m.get_mtxt(orig) == mtxt_orig + assert m.get_mtxt(recycled) == mtxt_orig + assert id(recycled) == id_orig + + +@pytest.mark.parametrize( + "pass_f, rtrn_f, moved_out, moved_in", + [ + (m.uconsumer.pass_valu, m.uconsumer.rtrn_valu, True, True), + (m.uconsumer.pass_rref, m.uconsumer.rtrn_valu, True, True), + (m.uconsumer.pass_valu, m.uconsumer.rtrn_lref, True, False), + (m.uconsumer.pass_valu, m.uconsumer.rtrn_cref, True, False), + ], +) +def test_unique_ptr_consumer_roundtrip(pass_f, rtrn_f, moved_out, moved_in): + c = m.uconsumer() + assert not c.valid() + recycled = m.atyp("passenger") + mtxt_orig = m.get_mtxt(recycled) + assert re.match("passenger_(MvCtor){1,2}", mtxt_orig) + + pass_f(c, recycled) + if moved_out: + with pytest.raises(ValueError) as excinfo: + m.get_mtxt(recycled) + assert "Python instance was disowned" in str(excinfo.value) + + recycled = rtrn_f(c) + assert c.valid() != moved_in + assert m.get_mtxt(recycled) == mtxt_orig + + def test_py_type_handle_of_atyp(): obj = m.py_type_handle_of_atyp() assert obj.__class__.__name__ == "pybind11_type"