changes from review

This commit is contained in:
AlexanderMueller 2024-08-20 11:49:16 +02:00
parent 0235533fda
commit d21cee39e8
2 changed files with 30 additions and 17 deletions

View File

@ -102,8 +102,8 @@ public:
rec = c.get_pointer<function_record>(); rec = c.get_pointer<function_record>();
} }
while (rec != nullptr) { while (rec != nullptr) {
const int correctingSelfArgument = rec->is_method ? 1 : 0; const size_t self_offset = rec->is_method ? 1 : 0;
if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { if (rec->nargs != sizeof...(Args) + self_offset) {
rec = rec->next; rec = rec->next;
// if the overload is not feasible in terms of number of arguments, we // if the overload is not feasible in terms of number of arguments, we
// continue to the next one. If there is no next one, we return false. // continue to the next one. If there is no next one, we return false.
@ -129,20 +129,24 @@ public:
// See PR #1413 for full details // See PR #1413 for full details
} else { } else {
// Check number of arguments of Python function // Check number of arguments of Python function
auto getArgCount = [&](PyObject *obj) { auto argCountFromFuncCode = [&](handle &obj) {
// This is faster then doing import inspect and inspect.signature(obj).parameters // This is faster then doing import inspect and
auto *t = PyObject_GetAttrString(obj, "__code__"); // inspect.signature(obj).parameters
auto *argCount = PyObject_GetAttrString(t, "co_argcount");
return PyLong_AsLong(argCount);
};
long argCount = -1;
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) { object argCount = obj.attr("co_argcount");
argCount = getArgCount(src.ptr()); return argCount.template cast<size_t>();
};
size_t argCount = 0;
handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__");
if (codeAttr) {
argCount = argCountFromFuncCode(codeAttr);
} else { } else {
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__call__"))) { handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__");
auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); if (callAttr) {
argCount = getArgCount(t2) - 1; // we have to remove the self argument handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__");
argCount = argCountFromFuncCode(codeAttr2)
- 1; // we have to remove the self argument
} else { } else {
// No __code__ or __call__ attribute, this is not a proper Python function // No __code__ or __call__ attribute, this is not a proper Python function
return false; return false;
@ -150,10 +154,9 @@ public:
} }
// if we are a method, we have to correct the argument count since we are not counting // if we are a method, we have to correct the argument count since we are not counting
// the self argument // the self argument
const int correctingSelfArgument const size_t self_offset = static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
= static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
argCount -= correctingSelfArgument; argCount -= self_offset;
if (argCount != sizeof...(Args)) { if (argCount != sizeof...(Args)) {
return false; return false;
} }

View File

@ -107,13 +107,23 @@ def test_cpp_correct_overload_resolution():
def f(a): def f(a):
return a return a
class A:
def __call__(self, a):
return a
assert m.dummy_function_overloaded_std_func_arg(f) == 9 assert m.dummy_function_overloaded_std_func_arg(f) == 9
assert m.dummy_function_overloaded_std_func_arg(A()) == 9
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
def f2(a, b): def f2(a, b):
return a + b return a + b
class B:
def __call__(self, a, b):
return a + b
assert m.dummy_function_overloaded_std_func_arg(f2) == 14 assert m.dummy_function_overloaded_std_func_arg(f2) == 14
assert m.dummy_function_overloaded_std_func_arg(B()) == 14
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14 assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14