From ee2b5226295d67b690faddd446a329bb2840a1a8 Mon Sep 17 00:00:00 2001 From: Ethan Steinberg Date: Wed, 2 Nov 2022 11:32:53 -0700 Subject: [PATCH] Fix functional.h bug + introduce test to verify that it is fixed (#4254) * Illustrate bug in functional.h * style: pre-commit fixes * Make functional casting more robust / add workaround * Make function_record* casting even more robust * See if this fixes PyPy issue * It still fails on PyPy sadly * Do not make new CTOR just yet * Fix test * Add name to ensure correctness * style: pre-commit fixes * Clean up tests + remove ifdef guards * Add comments * Improve comments, error handling, and safety * Fix compile error * Fix magic logic * Extract helper function * Fix func signature * move to local internals * style: pre-commit fixes * Switch to simpler design * style: pre-commit fixes * Move to function_record * style: pre-commit fixes * Switch to internals, update tests and docs * Fix lint * Oops, forgot to resolve last comment * Fix typo * Update in response to comments * Implement suggestion to improve test * Update comment * Simple fixes Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Gokaslan --- include/pybind11/detail/internals.h | 31 ++++++++++++++++++++ include/pybind11/functional.h | 11 ++++++-- include/pybind11/pybind11.h | 44 ++++++++++++++++++++++------- tests/test_callbacks.cpp | 37 ++++++++++++++++++++++++ tests/test_callbacks.py | 13 +++++++++ 5 files changed, 124 insertions(+), 12 deletions(-) diff --git a/include/pybind11/detail/internals.h b/include/pybind11/detail/internals.h index 7de779434..6fd61098c 100644 --- a/include/pybind11/detail/internals.h +++ b/include/pybind11/detail/internals.h @@ -43,6 +43,8 @@ using ExceptionTranslator = void (*)(std::exception_ptr); PYBIND11_NAMESPACE_BEGIN(detail) +constexpr const char *internals_function_record_capsule_name = "pybind11_function_record_capsule"; + // Forward declarations inline PyTypeObject *make_static_property_type(); inline PyTypeObject *make_default_metaclass(); @@ -182,6 +184,16 @@ struct internals { # endif // PYBIND11_INTERNALS_VERSION > 4 // Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined: PyInterpreterState *istate = nullptr; + +# if PYBIND11_INTERNALS_VERSION > 4 + // Note that we have to use a std::string to allocate memory to ensure a unique address + // We want unique addresses since we use pointer equality to compare function records + std::string function_record_capsule_name = internals_function_record_capsule_name; +# endif + + internals() = default; + internals(const internals &other) = delete; + internals &operator=(const internals &other) = delete; ~internals() { # if PYBIND11_INTERNALS_VERSION > 4 PYBIND11_TLS_FREE(loader_life_support_tls_key); @@ -548,6 +560,25 @@ const char *c_str(Args &&...args) { return strings.front().c_str(); } +inline const char *get_function_record_capsule_name() { +#if PYBIND11_INTERNALS_VERSION > 4 + return get_internals().function_record_capsule_name.c_str(); +#else + return nullptr; +#endif +} + +// Determine whether or not the following capsule contains a pybind11 function record. +// Note that we use `internals` to make sure that only ABI compatible records are touched. +// +// This check is currently used in two places: +// - An important optimization in functional.h to avoid overhead in C++ -> Python -> C++ +// - The sibling feature of cpp_function to allow overloads +inline bool is_function_record_capsule(const capsule &cap) { + // Pointer equality as we rely on internals() to ensure unique pointers + return cap.name() == get_function_record_capsule_name(); +} + PYBIND11_NAMESPACE_END(detail) /// Returns a named pointer that is shared among all extension modules (using the same diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 102d1a938..87ec4d10c 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -48,9 +48,16 @@ public: */ if (auto cfunc = func.cpp_function()) { auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr()); - if (isinstance(cfunc_self)) { + if (cfunc_self == nullptr) { + PyErr_Clear(); + } else if (isinstance(cfunc_self)) { auto c = reinterpret_borrow(cfunc_self); - auto *rec = (function_record *) c; + + function_record *rec = nullptr; + // Check that we can safely reinterpret the capsule into a function_record + if (detail::is_function_record_capsule(c)) { + rec = c.get_pointer(); + } while (rec != nullptr) { if (rec->is_stateless diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index e662236d0..76d6eadca 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -468,13 +468,20 @@ protected: if (rec->sibling) { if (PyCFunction_Check(rec->sibling.ptr())) { auto *self = PyCFunction_GET_SELF(rec->sibling.ptr()); - capsule rec_capsule = isinstance(self) ? reinterpret_borrow(self) - : capsule(self); - chain = (detail::function_record *) rec_capsule; - /* Never append a method to an overload chain of a parent class; - instead, hide the parent's overloads in this case */ - if (!chain->scope.is(rec->scope)) { + if (!isinstance(self)) { chain = nullptr; + } else { + auto rec_capsule = reinterpret_borrow(self); + if (detail::is_function_record_capsule(rec_capsule)) { + chain = rec_capsule.get_pointer(); + /* Never append a method to an overload chain of a parent class; + instead, hide the parent's overloads in this case */ + if (!chain->scope.is(rec->scope)) { + chain = nullptr; + } + } else { + chain = nullptr; + } } } // Don't trigger for things like the default __init__, which are wrapper_descriptors @@ -496,6 +503,7 @@ protected: capsule rec_capsule(unique_rec.release(), [](void *ptr) { destruct((detail::function_record *) ptr); }); + rec_capsule.set_name(detail::get_function_record_capsule_name()); guarded_strdup.release(); object scope_module; @@ -661,10 +669,13 @@ protected: /// Main dispatch logic for calls to functions bound using pybind11 static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) { using namespace detail; + assert(isinstance(self)); /* Iterator over the list of potentially admissible overloads */ - const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr), + const function_record *overloads = reinterpret_cast( + PyCapsule_GetPointer(self, get_function_record_capsule_name())), *it = overloads; + assert(overloads != nullptr); /* Need to know how many arguments + keyword arguments there are to pick the right overload */ @@ -1871,9 +1882,22 @@ private: static detail::function_record *get_function_record(handle h) { h = detail::get_function(h); - return h ? (detail::function_record *) reinterpret_borrow( - PyCFunction_GET_SELF(h.ptr())) - : nullptr; + if (!h) { + return nullptr; + } + + handle func_self = PyCFunction_GET_SELF(h.ptr()); + if (!func_self) { + throw error_already_set(); + } + if (!isinstance(func_self)) { + return nullptr; + } + auto cap = reinterpret_borrow(func_self); + if (!detail::is_function_record_capsule(cap)) { + return nullptr; + } + return cap.get_pointer(); } }; diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 92b8053de..2fd05dec7 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -240,4 +240,41 @@ TEST_SUBMODULE(callbacks, m) { f(); } }); + + auto *custom_def = []() { + static PyMethodDef def; + def.ml_name = "example_name"; + def.ml_doc = "Example doc"; + def.ml_meth = [](PyObject *, PyObject *args) -> PyObject * { + if (PyTuple_Size(args) != 1) { + throw std::runtime_error("Invalid number of arguments for example_name"); + } + PyObject *first = PyTuple_GetItem(args, 0); + if (!PyLong_Check(first)) { + throw std::runtime_error("Invalid argument to example_name"); + } + auto result = py::cast(PyLong_AsLong(first) * 9); + return result.release().ptr(); + }; + def.ml_flags = METH_VARARGS; + return &def; + }(); + + // rec_capsule with name that has the same value (but not pointer) as our internal one + // This capsule should be detected by our code as foreign and not inspected as the pointers + // shouldn't match + constexpr const char *rec_capsule_name + = pybind11::detail::internals_function_record_capsule_name; + py::capsule rec_capsule(std::malloc(1), [](void *data) { std::free(data); }); + rec_capsule.set_name(rec_capsule_name); + m.add_object("custom_function", PyCFunction_New(custom_def, rec_capsule.ptr())); + + // This test requires a new ABI version to pass +#if PYBIND11_INTERNALS_VERSION > 4 + // rec_capsule with nullptr name + py::capsule rec_capsule2(std::malloc(1), [](void *data) { std::free(data); }); + m.add_object("custom_function2", PyCFunction_New(custom_def, rec_capsule2.ptr())); +#else + m.add_object("custom_function2", py::none()); +#endif } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 0b1047bbf..57b659988 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -193,3 +193,16 @@ def test_callback_num_times(): if len(rates) > 1: print("Min Mean Max") print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}") + + +def test_custom_func(): + assert m.custom_function(4) == 36 + assert m.roundtrip(m.custom_function)(4) == 36 + + +@pytest.mark.skipif( + m.custom_function2 is None, reason="Current PYBIND11_INTERNALS_VERSION too low" +) +def test_custom_func2(): + assert m.custom_function2(3) == 27 + assert m.roundtrip(m.custom_function2)(3) == 27