From 2fb3d7cbde264a0b3f921e802f287195387e8263 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Tue, 27 Jun 2023 15:08:32 -0700 Subject: [PATCH] Trivial refactoring to make the capsule API more user friendly. (#4720) * Trivial refactoring to make the capsule API more user friendly. * Use new API in production code. Thanks @Lalaland for pointing this out. --- include/pybind11/pybind11.h | 2 +- include/pybind11/pytypes.h | 51 ++++++++++++++++++++++--------------- tests/test_pytypes.cpp | 9 +++++++ tests/test_pytypes.py | 13 ++++++++++ 4 files changed, 54 insertions(+), 21 deletions(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 28ebc2229..3bce1a01b 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -508,8 +508,8 @@ protected: rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS; capsule rec_capsule(unique_rec.release(), + detail::get_function_record_capsule_name(), [](void *ptr) { destruct((detail::function_record *) ptr); }); - rec_capsule.set_name(detail::get_function_record_capsule_name()); guarded_strdup.release(); object scope_module; diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index f5d3f34f3..c93e3d3b9 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1925,28 +1925,13 @@ public: } } + /// Capsule name is nullptr. capsule(const void *value, void (*destructor)(void *)) { - m_ptr = PyCapsule_New(const_cast(value), nullptr, [](PyObject *o) { - // guard if destructor called while err indicator is set - error_scope error_guard; - auto destructor = reinterpret_cast(PyCapsule_GetContext(o)); - if (destructor == nullptr && PyErr_Occurred()) { - throw error_already_set(); - } - const char *name = get_name_in_error_scope(o); - void *ptr = PyCapsule_GetPointer(o, name); - if (ptr == nullptr) { - throw error_already_set(); - } + initialize_with_void_ptr_destructor(value, nullptr, destructor); + } - if (destructor != nullptr) { - destructor(ptr); - } - }); - - if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast(destructor)) != 0) { - throw error_already_set(); - } + capsule(const void *value, const char *name, void (*destructor)(void *)) { + initialize_with_void_ptr_destructor(value, name, destructor); } explicit capsule(void (*destructor)()) { @@ -2014,6 +1999,32 @@ private: return name; } + + void initialize_with_void_ptr_destructor(const void *value, + const char *name, + void (*destructor)(void *)) { + m_ptr = PyCapsule_New(const_cast(value), name, [](PyObject *o) { + // guard if destructor called while err indicator is set + error_scope error_guard; + auto destructor = reinterpret_cast(PyCapsule_GetContext(o)); + if (destructor == nullptr && PyErr_Occurred()) { + throw error_already_set(); + } + const char *name = get_name_in_error_scope(o); + void *ptr = PyCapsule_GetPointer(o, name); + if (ptr == nullptr) { + throw error_already_set(); + } + + if (destructor != nullptr) { + destructor(ptr); + } + }); + + if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast(destructor)) != 0) { + throw error_already_set(); + } + } }; class tuple : public object { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 1028bb58e..b4ee64289 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -260,6 +260,15 @@ TEST_SUBMODULE(pytypes, m) { }); }); + m.def("return_capsule_with_destructor_3", []() { + py::print("creating capsule"); + auto cap = py::capsule((void *) 1233, "oname", [](void *ptr) { + py::print("destructing capsule: {}"_s.format((size_t) ptr)); + }); + py::print("original name: {}"_s.format(cap.name())); + return cap; + }); + m.def("return_renamed_capsule_with_destructor_2", []() { py::print("creating capsule"); auto cap = py::capsule((void *) 1234, [](void *ptr) { diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index afb7a1ce8..eda7a20a9 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -319,6 +319,19 @@ def test_capsule(capture): """ ) + with capture: + a = m.return_capsule_with_destructor_3() + del a + pytest.gc_collect() + assert ( + capture.unordered + == """ + creating capsule + destructing capsule: 1233 + original name: oname + """ + ) + with capture: a = m.return_renamed_capsule_with_destructor_2() del a