diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index f9625e77e..339b0961e 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -190,6 +190,11 @@ private: bool rich_compare(object_api const &other, int value) const; }; +template +using is_pyobj_ptr_or_nullptr_t = detail::any_of, + std::is_same, + std::is_same>; + PYBIND11_NAMESPACE_END(detail) #if !defined(PYBIND11_HANDLE_REF_DEBUG) && !defined(NDEBUG) @@ -211,9 +216,23 @@ class handle : public detail::object_api { public: /// The default constructor creates a handle with a ``nullptr``-valued pointer handle() = default; - /// Creates a ``handle`` from the given raw Python object pointer + + /// Enable implicit conversion from ``PyObject *`` and ``nullptr``. + /// Not using ``handle(PyObject *ptr)`` to avoid implicit conversion from ``0``. + template ::value, int> = 0> // NOLINTNEXTLINE(google-explicit-constructor) - handle(PyObject *ptr) : m_ptr(ptr) {} // Allow implicit conversion from PyObject* + handle(T ptr) : m_ptr(ptr) {} + + /// Enable implicit conversion through ``T::operator PyObject *()``. + template < + typename T, + detail::enable_if_t, + detail::is_pyobj_ptr_or_nullptr_t>, + std::is_convertible>::value, + int> = 0> + // NOLINTNEXTLINE(google-explicit-constructor) + handle(T &obj) : m_ptr(obj) {} /// Return the underlying ``PyObject *`` pointer PyObject *ptr() const { return m_ptr; } diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index cb81007c3..f532e2608 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -39,7 +39,68 @@ class float_ : public py::object { }; } // namespace external +namespace implicit_conversion_from_0_to_handle { +// Uncomment to trigger compiler error. Note: Before PR #4008 this used to compile successfully. +// void expected_to_trigger_compiler_error() { py::handle(0); } +} // namespace implicit_conversion_from_0_to_handle + +// Used to validate systematically that PR #4008 does/did NOT change the behavior. +void pure_compile_tests_for_handle_from_PyObject_pointers() { + { + PyObject *ptr = Py_None; + py::handle{ptr}; + } + { + PyObject *const ptr = Py_None; + py::handle{ptr}; + } + // Uncomment to trigger compiler errors. + // PyObject const * ptr = Py_None; py::handle{ptr}; + // PyObject const *const ptr = Py_None; py::handle{ptr}; + // PyObject volatile * ptr = Py_None; py::handle{ptr}; + // PyObject volatile *const ptr = Py_None; py::handle{ptr}; + // PyObject const volatile * ptr = Py_None; py::handle{ptr}; + // PyObject const volatile *const ptr = Py_None; py::handle{ptr}; +} + +namespace handle_from_move_only_type_with_operator_PyObject { + +// Reduced from +// https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/csrc/utils/object_ptr.h +struct operator_ncnst { + operator_ncnst() = default; + operator_ncnst(operator_ncnst &&) = default; + operator PyObject *() /* */ { return Py_None; } // NOLINT(google-explicit-constructor) +}; + +struct operator_const { + operator_const() = default; + operator_const(operator_const &&) = default; + operator PyObject *() const { return Py_None; } // NOLINT(google-explicit-constructor) +}; + +bool from_ncnst() { + operator_ncnst obj; + auto h = py::handle(obj); // Critical part of test: does this compile? + return h.ptr() == Py_None; // Just something. +} + +bool from_const() { + operator_const obj; + auto h = py::handle(obj); // Critical part of test: does this compile? + return h.ptr() == Py_None; // Just something. +} + +void m_defs(py::module_ &m) { + m.def("handle_from_move_only_type_with_operator_PyObject_ncnst", from_ncnst); + m.def("handle_from_move_only_type_with_operator_PyObject_const", from_const); +} + +} // namespace handle_from_move_only_type_with_operator_PyObject + TEST_SUBMODULE(pytypes, m) { + handle_from_move_only_type_with_operator_PyObject::m_defs(m); + // test_bool m.def("get_bool", [] { return py::bool_(false); }); // test_int diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 3e9d51a27..7a0a8b4ab 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -9,6 +9,11 @@ from pybind11_tests import detailed_error_messages_enabled from pybind11_tests import pytypes as m +def test_handle_from_move_only_type_with_operator_PyObject(): # noqa: N802 + assert m.handle_from_move_only_type_with_operator_PyObject_ncnst() + assert m.handle_from_move_only_type_with_operator_PyObject_const() + + def test_bool(doc): assert doc(m.get_bool) == "get_bool() -> bool"