From 0235533fdaace87738c782278b4c64cd5fe9e128 Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Sat, 3 Aug 2024 10:06:13 +0200 Subject: [PATCH] add argument number dispatch mechanism for std::function casting --- include/pybind11/functional.h | 41 ++++++++++++++++++++++++++- tests/test_callbacks.cpp | 6 ++++ tests/test_callbacks.py | 19 ++++++++++++- tests/test_embed/test_interpreter.cpp | 16 +++++++++++ 4 files changed, 80 insertions(+), 2 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 4b3610117..fde99a4ae 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -101,8 +101,17 @@ public: if (detail::is_function_record_capsule(c)) { rec = c.get_pointer(); } - while (rec != nullptr) { + const int correctingSelfArgument = rec->is_method ? 1 : 0; + if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { + rec = rec->next; + // if the overload is not feasible in terms of number of arguments, we + // continue to the next one. If there is no next one, we return false. + if (rec == nullptr) { + return false; + } + continue; + } if (rec->is_stateless && same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { @@ -118,6 +127,36 @@ public: // PYPY segfaults here when passing builtin function like sum. // Raising an fail exception here works to prevent the segfault, but only on gcc. // See PR #1413 for full details + } else { + // Check number of arguments of Python function + auto getArgCount = [&](PyObject *obj) { + // This is faster then doing import inspect and inspect.signature(obj).parameters + auto *t = PyObject_GetAttrString(obj, "__code__"); + auto *argCount = PyObject_GetAttrString(t, "co_argcount"); + return PyLong_AsLong(argCount); + }; + long argCount = -1; + + if (static_cast(PyObject_HasAttrString(src.ptr(), "__code__"))) { + argCount = getArgCount(src.ptr()); + } else { + if (static_cast(PyObject_HasAttrString(src.ptr(), "__call__"))) { + auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); + argCount = getArgCount(t2) - 1; // we have to remove the self argument + } else { + // No __code__ or __call__ attribute, this is not a proper Python function + return false; + } + } + // if we are a method, we have to correct the argument count since we are not counting + // the self argument + const int correctingSelfArgument + = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + + argCount -= correctingSelfArgument; + if (argCount != sizeof...(Args)) { + return false; + } } value = type_caster_std_function_specializations::func_wrapper( diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 2fd05dec7..ed55ad7b7 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) { return "argument does NOT match dummy_function. This should never happen!"; }); + // test_cpp_correct_overload_resolution + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 3 * f(3); }); + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 2 * f(3, 4); }); + class AbstractBase { public: // [workaround(intel)] = default does not work here diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index db6d8dece..82b03fac1 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -103,6 +103,20 @@ def test_cpp_callable_cleanup(): assert alive_counts == [0, 1, 2, 1, 2, 1, 0] +def test_cpp_correct_overload_resolution(): + def f(a): + return a + + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 + + def f2(a, b): + return a + b + + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14 + + def test_cpp_function_roundtrip(): """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer""" @@ -131,7 +145,10 @@ def test_cpp_function_roundtrip(): m.test_dummy_function(lambda x, y: x + y) assert any( s in str(excinfo.value) - for s in ("missing 1 required positional argument", "takes exactly 2 arguments") + for s in ( + "incompatible function arguments. The following argument types are", + "function test_cpp_function_roundtrip..", + ) ) diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index c6c8a22d9..98df9b19e 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -1,4 +1,5 @@ #include +#include // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to // catch 2.0.1; this should be fixed in the next catch release after 2.0.1). @@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { d["missing"].cast(); } +PYBIND11_EMBEDDED_MODULE(func_module, m) { + m.def("funcOverload", [](const std::function &f) { + return f(2, 3); + }).def("funcOverload", [](const std::function &f) { return f(2); }); +} + TEST_CASE("PYTHONPATH is used to update sys.path") { // The setup for this TEST_CASE is in catch.cpp! auto sys_path = py::str(py::module_::import("sys").attr("path")).cast(); @@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") { py::initialize_interpreter(); } +TEST_CASE("Check the overload resolution from cpp_function objects to std::function") { + auto m = py::module_::import("func_module"); + auto f = std::function([](int x) { return 2 * x; }); + REQUIRE(m.attr("funcOverload")(f).template cast() == 4); + + auto f2 = std::function([](int x, int y) { return 2 * x * y; }); + REQUIRE(m.attr("funcOverload")(f2).template cast() == 12); +} + #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX TEST_CASE("Custom PyConfig") { py::finalize_interpreter();