Allow function pointer extraction from overloaded functions (#2944)

* Add a failure test for overloaded functions

* Allow function pointer extraction from overloaded functions
This commit is contained in:
Tamaki Nishino 2021-04-14 08:53:56 +09:00 committed by GitHub
parent e0c1dadb75
commit 6709abba93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 5 deletions

View File

@ -46,11 +46,17 @@ public:
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr())); auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *) c; auto rec = (function_record *) c;
if (rec && rec->is_stateless && while (rec != nullptr) {
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) { if (rec->is_stateless
struct capture { function_type f; }; && same_type(typeid(function_type),
value = ((capture *) &rec->data)->f; *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
return true; struct capture {
function_type f;
};
value = ((capture *) &rec->data)->f;
return true;
}
rec = rec->next;
} }
} }

View File

@ -97,6 +97,8 @@ TEST_SUBMODULE(callbacks, m) {
// test_cpp_function_roundtrip // test_cpp_function_roundtrip
/* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */ /* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */
m.def("dummy_function", &dummy_function); m.def("dummy_function", &dummy_function);
m.def("dummy_function_overloaded", [](int i, int j) { return i + j; });
m.def("dummy_function_overloaded", &dummy_function);
m.def("dummy_function2", [](int i, int j) { return i + j; }); m.def("dummy_function2", [](int i, int j) { return i + j; });
m.def("roundtrip", [](std::function<int(int)> f, bool expect_none = false) { m.def("roundtrip", [](std::function<int(int)> f, bool expect_none = false) {
if (expect_none && f) if (expect_none && f)

View File

@ -93,6 +93,10 @@ def test_cpp_function_roundtrip():
m.test_dummy_function(m.roundtrip(m.dummy_function)) m.test_dummy_function(m.roundtrip(m.dummy_function))
== "matches dummy_function: eval(1) = 2" == "matches dummy_function: eval(1) = 2"
) )
assert (
m.test_dummy_function(m.dummy_function_overloaded)
== "matches dummy_function: eval(1) = 2"
)
assert m.roundtrip(None, expect_none=True) is None assert m.roundtrip(None, expect_none=True) is None
assert ( assert (
m.test_dummy_function(lambda x: x + 2) m.test_dummy_function(lambda x: x + 2)