mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Fix functional.h bug + introduce test to verify that it is fixed (#4254)
* Illustrate bug in functional.h * style: pre-commit fixes * Make functional casting more robust / add workaround * Make function_record* casting even more robust * See if this fixes PyPy issue * It still fails on PyPy sadly * Do not make new CTOR just yet * Fix test * Add name to ensure correctness * style: pre-commit fixes * Clean up tests + remove ifdef guards * Add comments * Improve comments, error handling, and safety * Fix compile error * Fix magic logic * Extract helper function * Fix func signature * move to local internals * style: pre-commit fixes * Switch to simpler design * style: pre-commit fixes * Move to function_record * style: pre-commit fixes * Switch to internals, update tests and docs * Fix lint * Oops, forgot to resolve last comment * Fix typo * Update in response to comments * Implement suggestion to improve test * Update comment * Simple fixes Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
parent
0176632e8c
commit
ee2b522629
@ -43,6 +43,8 @@ using ExceptionTranslator = void (*)(std::exception_ptr);
|
|||||||
|
|
||||||
PYBIND11_NAMESPACE_BEGIN(detail)
|
PYBIND11_NAMESPACE_BEGIN(detail)
|
||||||
|
|
||||||
|
constexpr const char *internals_function_record_capsule_name = "pybind11_function_record_capsule";
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
inline PyTypeObject *make_static_property_type();
|
inline PyTypeObject *make_static_property_type();
|
||||||
inline PyTypeObject *make_default_metaclass();
|
inline PyTypeObject *make_default_metaclass();
|
||||||
@ -182,6 +184,16 @@ struct internals {
|
|||||||
# endif // PYBIND11_INTERNALS_VERSION > 4
|
# endif // PYBIND11_INTERNALS_VERSION > 4
|
||||||
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
|
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
|
||||||
PyInterpreterState *istate = nullptr;
|
PyInterpreterState *istate = nullptr;
|
||||||
|
|
||||||
|
# if PYBIND11_INTERNALS_VERSION > 4
|
||||||
|
// Note that we have to use a std::string to allocate memory to ensure a unique address
|
||||||
|
// We want unique addresses since we use pointer equality to compare function records
|
||||||
|
std::string function_record_capsule_name = internals_function_record_capsule_name;
|
||||||
|
# endif
|
||||||
|
|
||||||
|
internals() = default;
|
||||||
|
internals(const internals &other) = delete;
|
||||||
|
internals &operator=(const internals &other) = delete;
|
||||||
~internals() {
|
~internals() {
|
||||||
# if PYBIND11_INTERNALS_VERSION > 4
|
# if PYBIND11_INTERNALS_VERSION > 4
|
||||||
PYBIND11_TLS_FREE(loader_life_support_tls_key);
|
PYBIND11_TLS_FREE(loader_life_support_tls_key);
|
||||||
@ -548,6 +560,25 @@ const char *c_str(Args &&...args) {
|
|||||||
return strings.front().c_str();
|
return strings.front().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline const char *get_function_record_capsule_name() {
|
||||||
|
#if PYBIND11_INTERNALS_VERSION > 4
|
||||||
|
return get_internals().function_record_capsule_name.c_str();
|
||||||
|
#else
|
||||||
|
return nullptr;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine whether or not the following capsule contains a pybind11 function record.
|
||||||
|
// Note that we use `internals` to make sure that only ABI compatible records are touched.
|
||||||
|
//
|
||||||
|
// This check is currently used in two places:
|
||||||
|
// - An important optimization in functional.h to avoid overhead in C++ -> Python -> C++
|
||||||
|
// - The sibling feature of cpp_function to allow overloads
|
||||||
|
inline bool is_function_record_capsule(const capsule &cap) {
|
||||||
|
// Pointer equality as we rely on internals() to ensure unique pointers
|
||||||
|
return cap.name() == get_function_record_capsule_name();
|
||||||
|
}
|
||||||
|
|
||||||
PYBIND11_NAMESPACE_END(detail)
|
PYBIND11_NAMESPACE_END(detail)
|
||||||
|
|
||||||
/// Returns a named pointer that is shared among all extension modules (using the same
|
/// Returns a named pointer that is shared among all extension modules (using the same
|
||||||
|
@ -48,9 +48,16 @@ public:
|
|||||||
*/
|
*/
|
||||||
if (auto cfunc = func.cpp_function()) {
|
if (auto cfunc = func.cpp_function()) {
|
||||||
auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
|
auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
|
||||||
if (isinstance<capsule>(cfunc_self)) {
|
if (cfunc_self == nullptr) {
|
||||||
|
PyErr_Clear();
|
||||||
|
} else if (isinstance<capsule>(cfunc_self)) {
|
||||||
auto c = reinterpret_borrow<capsule>(cfunc_self);
|
auto c = reinterpret_borrow<capsule>(cfunc_self);
|
||||||
auto *rec = (function_record *) c;
|
|
||||||
|
function_record *rec = nullptr;
|
||||||
|
// Check that we can safely reinterpret the capsule into a function_record
|
||||||
|
if (detail::is_function_record_capsule(c)) {
|
||||||
|
rec = c.get_pointer<function_record>();
|
||||||
|
}
|
||||||
|
|
||||||
while (rec != nullptr) {
|
while (rec != nullptr) {
|
||||||
if (rec->is_stateless
|
if (rec->is_stateless
|
||||||
|
@ -468,14 +468,21 @@ protected:
|
|||||||
if (rec->sibling) {
|
if (rec->sibling) {
|
||||||
if (PyCFunction_Check(rec->sibling.ptr())) {
|
if (PyCFunction_Check(rec->sibling.ptr())) {
|
||||||
auto *self = PyCFunction_GET_SELF(rec->sibling.ptr());
|
auto *self = PyCFunction_GET_SELF(rec->sibling.ptr());
|
||||||
capsule rec_capsule = isinstance<capsule>(self) ? reinterpret_borrow<capsule>(self)
|
if (!isinstance<capsule>(self)) {
|
||||||
: capsule(self);
|
chain = nullptr;
|
||||||
chain = (detail::function_record *) rec_capsule;
|
} else {
|
||||||
|
auto rec_capsule = reinterpret_borrow<capsule>(self);
|
||||||
|
if (detail::is_function_record_capsule(rec_capsule)) {
|
||||||
|
chain = rec_capsule.get_pointer<detail::function_record>();
|
||||||
/* Never append a method to an overload chain of a parent class;
|
/* Never append a method to an overload chain of a parent class;
|
||||||
instead, hide the parent's overloads in this case */
|
instead, hide the parent's overloads in this case */
|
||||||
if (!chain->scope.is(rec->scope)) {
|
if (!chain->scope.is(rec->scope)) {
|
||||||
chain = nullptr;
|
chain = nullptr;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
chain = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Don't trigger for things like the default __init__, which are wrapper_descriptors
|
// Don't trigger for things like the default __init__, which are wrapper_descriptors
|
||||||
// that we are intentionally replacing
|
// that we are intentionally replacing
|
||||||
@ -496,6 +503,7 @@ protected:
|
|||||||
|
|
||||||
capsule rec_capsule(unique_rec.release(),
|
capsule rec_capsule(unique_rec.release(),
|
||||||
[](void *ptr) { destruct((detail::function_record *) ptr); });
|
[](void *ptr) { destruct((detail::function_record *) ptr); });
|
||||||
|
rec_capsule.set_name(detail::get_function_record_capsule_name());
|
||||||
guarded_strdup.release();
|
guarded_strdup.release();
|
||||||
|
|
||||||
object scope_module;
|
object scope_module;
|
||||||
@ -661,10 +669,13 @@ protected:
|
|||||||
/// Main dispatch logic for calls to functions bound using pybind11
|
/// Main dispatch logic for calls to functions bound using pybind11
|
||||||
static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
|
static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
|
||||||
using namespace detail;
|
using namespace detail;
|
||||||
|
assert(isinstance<capsule>(self));
|
||||||
|
|
||||||
/* Iterator over the list of potentially admissible overloads */
|
/* Iterator over the list of potentially admissible overloads */
|
||||||
const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
|
const function_record *overloads = reinterpret_cast<function_record *>(
|
||||||
|
PyCapsule_GetPointer(self, get_function_record_capsule_name())),
|
||||||
*it = overloads;
|
*it = overloads;
|
||||||
|
assert(overloads != nullptr);
|
||||||
|
|
||||||
/* Need to know how many arguments + keyword arguments there are to pick the right
|
/* Need to know how many arguments + keyword arguments there are to pick the right
|
||||||
overload */
|
overload */
|
||||||
@ -1871,9 +1882,22 @@ private:
|
|||||||
|
|
||||||
static detail::function_record *get_function_record(handle h) {
|
static detail::function_record *get_function_record(handle h) {
|
||||||
h = detail::get_function(h);
|
h = detail::get_function(h);
|
||||||
return h ? (detail::function_record *) reinterpret_borrow<capsule>(
|
if (!h) {
|
||||||
PyCFunction_GET_SELF(h.ptr()))
|
return nullptr;
|
||||||
: nullptr;
|
}
|
||||||
|
|
||||||
|
handle func_self = PyCFunction_GET_SELF(h.ptr());
|
||||||
|
if (!func_self) {
|
||||||
|
throw error_already_set();
|
||||||
|
}
|
||||||
|
if (!isinstance<capsule>(func_self)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto cap = reinterpret_borrow<capsule>(func_self);
|
||||||
|
if (!detail::is_function_record_capsule(cap)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return cap.get_pointer<detail::function_record>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -240,4 +240,41 @@ TEST_SUBMODULE(callbacks, m) {
|
|||||||
f();
|
f();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
auto *custom_def = []() {
|
||||||
|
static PyMethodDef def;
|
||||||
|
def.ml_name = "example_name";
|
||||||
|
def.ml_doc = "Example doc";
|
||||||
|
def.ml_meth = [](PyObject *, PyObject *args) -> PyObject * {
|
||||||
|
if (PyTuple_Size(args) != 1) {
|
||||||
|
throw std::runtime_error("Invalid number of arguments for example_name");
|
||||||
|
}
|
||||||
|
PyObject *first = PyTuple_GetItem(args, 0);
|
||||||
|
if (!PyLong_Check(first)) {
|
||||||
|
throw std::runtime_error("Invalid argument to example_name");
|
||||||
|
}
|
||||||
|
auto result = py::cast(PyLong_AsLong(first) * 9);
|
||||||
|
return result.release().ptr();
|
||||||
|
};
|
||||||
|
def.ml_flags = METH_VARARGS;
|
||||||
|
return &def;
|
||||||
|
}();
|
||||||
|
|
||||||
|
// rec_capsule with name that has the same value (but not pointer) as our internal one
|
||||||
|
// This capsule should be detected by our code as foreign and not inspected as the pointers
|
||||||
|
// shouldn't match
|
||||||
|
constexpr const char *rec_capsule_name
|
||||||
|
= pybind11::detail::internals_function_record_capsule_name;
|
||||||
|
py::capsule rec_capsule(std::malloc(1), [](void *data) { std::free(data); });
|
||||||
|
rec_capsule.set_name(rec_capsule_name);
|
||||||
|
m.add_object("custom_function", PyCFunction_New(custom_def, rec_capsule.ptr()));
|
||||||
|
|
||||||
|
// This test requires a new ABI version to pass
|
||||||
|
#if PYBIND11_INTERNALS_VERSION > 4
|
||||||
|
// rec_capsule with nullptr name
|
||||||
|
py::capsule rec_capsule2(std::malloc(1), [](void *data) { std::free(data); });
|
||||||
|
m.add_object("custom_function2", PyCFunction_New(custom_def, rec_capsule2.ptr()));
|
||||||
|
#else
|
||||||
|
m.add_object("custom_function2", py::none());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -193,3 +193,16 @@ def test_callback_num_times():
|
|||||||
if len(rates) > 1:
|
if len(rates) > 1:
|
||||||
print("Min Mean Max")
|
print("Min Mean Max")
|
||||||
print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}")
|
print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_func():
|
||||||
|
assert m.custom_function(4) == 36
|
||||||
|
assert m.roundtrip(m.custom_function)(4) == 36
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
m.custom_function2 is None, reason="Current PYBIND11_INTERNALS_VERSION too low"
|
||||||
|
)
|
||||||
|
def test_custom_func2():
|
||||||
|
assert m.custom_function2(3) == 27
|
||||||
|
assert m.roundtrip(m.custom_function2)(3) == 27
|
||||||
|
Loading…
Reference in New Issue
Block a user