diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index f18217210..d2f577ad2 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1376,13 +1376,24 @@ enable_if_t::value, T> cast_ref(object &&, // static_assert, even though if it's in dead code, so we provide a "trampoline" to pybind11::cast // that only does anything in cases where pybind11::cast is valid. template -enable_if_t::value, T> cast_safe(object &&) { +enable_if_t::value + && !detail::is_same_ignoring_cvref::value, + T> +cast_safe(object &&) { pybind11_fail("Internal error: cast_safe fallback invoked"); } template enable_if_t::value, void> cast_safe(object &&) {} template -enable_if_t, std::is_void>::value, T> +enable_if_t::value, PyObject *> +cast_safe(object &&o) { + return o.release().ptr(); +} +template +enable_if_t, + detail::is_same_ignoring_cvref, + std::is_void>::value, + T> cast_safe(object &&o) { return pybind11::cast(std::move(o)); } diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 99606f8d1..11c69d547 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -3128,10 +3128,14 @@ function get_override(const T *this_ptr, const char *name) { = pybind11::get_override(static_cast(this), name); \ if (override) { \ auto o = override(__VA_ARGS__); \ - if (pybind11::detail::cast_is_temporary_value_reference::value) { \ + PYBIND11_WARNING_PUSH \ + PYBIND11_WARNING_DISABLE_MSVC(4127) \ + if (pybind11::detail::cast_is_temporary_value_reference::value \ + && !pybind11::detail::is_same_ignoring_cvref::value) { \ static pybind11::detail::override_caster_t caster; \ return pybind11::detail::cast_ref(std::move(o), caster); \ } \ + PYBIND11_WARNING_POP \ return pybind11::detail::cast_safe(std::move(o)); \ } \ } while (false) diff --git a/tests/test_type_caster_pyobject_ptr.cpp b/tests/test_type_caster_pyobject_ptr.cpp index 8069f7dcd..a45c08b64 100644 --- a/tests/test_type_caster_pyobject_ptr.cpp +++ b/tests/test_type_caster_pyobject_ptr.cpp @@ -5,9 +5,10 @@ #include "pybind11_tests.h" #include +#include #include -namespace { +namespace test_type_caster_pyobject_ptr { std::vector make_vector_pyobject_ptr(const py::object &ValueHolder) { std::vector vec_obj; @@ -18,9 +19,39 @@ std::vector make_vector_pyobject_ptr(const py::object &ValueHolder) return vec_obj; } -} // namespace +struct WithPyObjectPtrReturn { +#if defined(__clang_major__) && __clang_major__ < 4 + WithPyObjectPtrReturn() = default; + WithPyObjectPtrReturn(const WithPyObjectPtrReturn &) = default; +#endif + virtual ~WithPyObjectPtrReturn() = default; + virtual PyObject *return_pyobject_ptr() const = 0; +}; + +struct WithPyObjectPtrReturnTrampoline : WithPyObjectPtrReturn { + PyObject *return_pyobject_ptr() const override { + PYBIND11_OVERRIDE_PURE(PyObject *, WithPyObjectPtrReturn, return_pyobject_ptr, + /* no arguments */); + } +}; + +std::string call_return_pyobject_ptr(const WithPyObjectPtrReturn *base_class_ptr) { + PyObject *returned_obj = base_class_ptr->return_pyobject_ptr(); +#if !defined(PYPY_VERSION) // It is not worth the trouble doing something special for PyPy. + if (Py_REFCNT(returned_obj) != 1) { + py::pybind11_fail(__FILE__ ":" PYBIND11_TOSTRING(__LINE__)); + } +#endif + auto ret_val = py::repr(returned_obj).cast(); + Py_DECREF(returned_obj); + return ret_val; +} + +} // namespace test_type_caster_pyobject_ptr TEST_SUBMODULE(type_caster_pyobject_ptr, m) { + using namespace test_type_caster_pyobject_ptr; + m.def("cast_from_pyobject_ptr", []() { PyObject *ptr = PyLong_FromLongLong(6758L); return py::cast(ptr, py::return_value_policy::take_ownership); @@ -127,4 +158,10 @@ TEST_SUBMODULE(type_caster_pyobject_ptr, m) { (void) py::cast(*ptr); } #endif + + py::class_(m, "WithPyObjectPtrReturn") + .def(py::init<>()) + .def("return_pyobject_ptr", &WithPyObjectPtrReturn::return_pyobject_ptr); + + m.def("call_return_pyobject_ptr", call_return_pyobject_ptr); } diff --git a/tests/test_type_caster_pyobject_ptr.py b/tests/test_type_caster_pyobject_ptr.py index 1f1ece2ba..f6358d011 100644 --- a/tests/test_type_caster_pyobject_ptr.py +++ b/tests/test_type_caster_pyobject_ptr.py @@ -102,3 +102,19 @@ def test_return_list_pyobject_ptr_reference(): def test_type_caster_name_via_incompatible_function_arguments_type_error(): with pytest.raises(TypeError, match=r"1\. \(arg0: object, arg1: int\) -> None"): m.pass_pyobject_ptr_and_int(ValueHolder(101), ValueHolder(202)) + + +def test_trampoline_with_pyobject_ptr_return(): + class Drvd(m.WithPyObjectPtrReturn): + def return_pyobject_ptr(self): + return ["11", "22", "33"] + + # Basic health check: First make sure this works as expected. + d = Drvd() + assert d.return_pyobject_ptr() == ["11", "22", "33"] + + while True: + # This failed before PR #5156: AddressSanitizer: heap-use-after-free ... in Py_DECREF + d_repr = m.call_return_pyobject_ptr(d) + assert d_repr == repr(["11", "22", "33"]) + break # Comment out for manual leak checking.