This commit is contained in:
AlexanderMueller 2024-08-20 17:09:24 +02:00
parent d21cee39e8
commit e0be5dbd48
2 changed files with 14 additions and 13 deletions

View File

@ -14,6 +14,7 @@
#include "pybind11.h" #include "pybind11.h"
#include <functional> #include <functional>
#include <iostream>
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail) PYBIND11_NAMESPACE_BEGIN(detail)
@ -129,24 +130,23 @@ public:
// See PR #1413 for full details // See PR #1413 for full details
} else { } else {
// Check number of arguments of Python function // Check number of arguments of Python function
auto argCountFromFuncCode = [&](handle &obj) { auto get_argument_count = [](const handle &obj) -> size_t {
// This is faster then doing import inspect and // Faster then `import inspect` and `inspect.signature(obj).parameters`
// inspect.signature(obj).parameters return obj.attr("co_argcount").cast<size_t>();
object argCount = obj.attr("co_argcount");
return argCount.template cast<size_t>();
}; };
size_t argCount = 0; size_t argCount = 0;
handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); handle empty;
object codeAttr = getattr(src, "__code__", empty);
if (codeAttr) { if (codeAttr) {
argCount = argCountFromFuncCode(codeAttr); argCount = get_argument_count(codeAttr);
} else { } else {
handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); object callAttr = getattr(src, "__call__", empty);
if (callAttr) { if (callAttr) {
handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); object codeAttr2 = getattr(callAttr, "__code__");
argCount = argCountFromFuncCode(codeAttr2) argCount = get_argument_count(codeAttr2) - 1; // removing the self argument
- 1; // we have to remove the self argument
} else { } else {
// No __code__ or __call__ attribute, this is not a proper Python function // No __code__ or __call__ attribute, this is not a proper Python function
return false; return false;

View File

@ -112,7 +112,8 @@ def test_cpp_correct_overload_resolution():
return a return a
assert m.dummy_function_overloaded_std_func_arg(f) == 9 assert m.dummy_function_overloaded_std_func_arg(f) == 9
assert m.dummy_function_overloaded_std_func_arg(A()) == 9 a = A()
assert m.dummy_function_overloaded_std_func_arg(a) == 9
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
def f2(a, b): def f2(a, b):