From d21cee39e8d4870fd4f9048145bec41f37b49fdd Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Tue, 20 Aug 2024 11:49:16 +0200 Subject: [PATCH] changes from review --- include/pybind11/functional.h | 37 +++++++++++++++++++---------------- tests/test_callbacks.py | 10 ++++++++++ 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index fde99a4ae..8a8c32c0e 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -102,8 +102,8 @@ public: rec = c.get_pointer(); } while (rec != nullptr) { - const int correctingSelfArgument = rec->is_method ? 1 : 0; - if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { + const size_t self_offset = rec->is_method ? 1 : 0; + if (rec->nargs != sizeof...(Args) + self_offset) { rec = rec->next; // if the overload is not feasible in terms of number of arguments, we // continue to the next one. If there is no next one, we return false. @@ -129,20 +129,24 @@ public: // See PR #1413 for full details } else { // Check number of arguments of Python function - auto getArgCount = [&](PyObject *obj) { - // This is faster then doing import inspect and inspect.signature(obj).parameters - auto *t = PyObject_GetAttrString(obj, "__code__"); - auto *argCount = PyObject_GetAttrString(t, "co_argcount"); - return PyLong_AsLong(argCount); - }; - long argCount = -1; + auto argCountFromFuncCode = [&](handle &obj) { + // This is faster then doing import inspect and + // inspect.signature(obj).parameters - if (static_cast(PyObject_HasAttrString(src.ptr(), "__code__"))) { - argCount = getArgCount(src.ptr()); + object argCount = obj.attr("co_argcount"); + return argCount.template cast(); + }; + size_t argCount = 0; + + handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + if (codeAttr) { + argCount = argCountFromFuncCode(codeAttr); } else { - if (static_cast(PyObject_HasAttrString(src.ptr(), "__call__"))) { - auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); - argCount = getArgCount(t2) - 1; // we have to remove the self argument + handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + if (callAttr) { + handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); + argCount = argCountFromFuncCode(codeAttr2) + - 1; // we have to remove the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; @@ -150,10 +154,9 @@ public: } // if we are a method, we have to correct the argument count since we are not counting // the self argument - const int correctingSelfArgument - = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + const size_t self_offset = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; - argCount -= correctingSelfArgument; + argCount -= self_offset; if (argCount != sizeof...(Args)) { return false; } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 82b03fac1..d2afbc2ca 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -107,13 +107,23 @@ def test_cpp_correct_overload_resolution(): def f(a): return a + class A: + def __call__(self, a): + return a + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + assert m.dummy_function_overloaded_std_func_arg(A()) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 def f2(a, b): return a + b + class B: + def __call__(self, a, b): + return a + b + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(B()) == 14 assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14