From e0be5dbd48e414f9a2bfcd4a9e4729edc7af3bd9 Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Tue, 20 Aug 2024 17:09:24 +0200 Subject: [PATCH] test fix --- include/pybind11/functional.h | 24 ++++++++++++------------ tests/test_callbacks.py | 3 ++- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 8a8c32c0e..4baeaa57a 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -14,6 +14,7 @@ #include "pybind11.h" #include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -129,24 +130,23 @@ public: // See PR #1413 for full details } else { // Check number of arguments of Python function - auto argCountFromFuncCode = [&](handle &obj) { - // This is faster then doing import inspect and - // inspect.signature(obj).parameters - - object argCount = obj.attr("co_argcount"); - return argCount.template cast(); + auto get_argument_count = [](const handle &obj) -> size_t { + // Faster then `import inspect` and `inspect.signature(obj).parameters` + return obj.attr("co_argcount").cast(); }; size_t argCount = 0; - handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + handle empty; + object codeAttr = getattr(src, "__code__", empty); + if (codeAttr) { - argCount = argCountFromFuncCode(codeAttr); + argCount = get_argument_count(codeAttr); } else { - handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + object callAttr = getattr(src, "__call__", empty); + if (callAttr) { - handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); - argCount = argCountFromFuncCode(codeAttr2) - - 1; // we have to remove the self argument + object codeAttr2 = getattr(callAttr, "__code__"); + argCount = get_argument_count(codeAttr2) - 1; // removing the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d2afbc2ca..c81aee667 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -112,7 +112,8 @@ def test_cpp_correct_overload_resolution(): return a assert m.dummy_function_overloaded_std_func_arg(f) == 9 - assert m.dummy_function_overloaded_std_func_arg(A()) == 9 + a = A() + assert m.dummy_function_overloaded_std_func_arg(a) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 def f2(a, b):