From 6709abba934e968d31c24f20c132c67a6f424a6a Mon Sep 17 00:00:00 2001 From: Tamaki Nishino Date: Wed, 14 Apr 2021 08:53:56 +0900 Subject: [PATCH] Allow function pointer extraction from overloaded functions (#2944) * Add a failure test for overloaded functions * Allow function pointer extraction from overloaded functions --- include/pybind11/functional.h | 16 +++++++++++----- tests/test_callbacks.cpp | 2 ++ tests/test_callbacks.py | 4 ++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 92c17dc22..aee9be4e4 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -46,11 +46,17 @@ public: auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); auto rec = (function_record *) c; - if (rec && rec->is_stateless && - same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { - struct capture { function_type f; }; - value = ((capture *) &rec->data)->f; - return true; + while (rec != nullptr) { + if (rec->is_stateless + && same_type(typeid(function_type), + *reinterpret_cast(rec->data[1]))) { + struct capture { + function_type f; + }; + value = ((capture *) &rec->data)->f; + return true; + } + rec = rec->next; } } diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 61bc3a8f0..33927f5c3 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -97,6 +97,8 @@ TEST_SUBMODULE(callbacks, m) { // test_cpp_function_roundtrip /* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */ m.def("dummy_function", &dummy_function); + m.def("dummy_function_overloaded", [](int i, int j) { return i + j; }); + m.def("dummy_function_overloaded", &dummy_function); m.def("dummy_function2", [](int i, int j) { return i + j; }); m.def("roundtrip", [](std::function f, bool expect_none = false) { if (expect_none && f) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index cec68bda5..352661430 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -93,6 +93,10 @@ def test_cpp_function_roundtrip(): m.test_dummy_function(m.roundtrip(m.dummy_function)) == "matches dummy_function: eval(1) = 2" ) + assert ( + m.test_dummy_function(m.dummy_function_overloaded) + == "matches dummy_function: eval(1) = 2" + ) assert m.roundtrip(None, expect_none=True) is None assert ( m.test_dummy_function(lambda x: x + 2)