From a01b6b805c97a66a796744e589b248b559086706 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Mon, 24 Apr 2017 12:29:42 -0400 Subject: [PATCH] functional: support bound methods If a bound std::function is invoked with a bound method, the implicit bound self is lost because we use `detail::get_function` to unbox the function. This commit amends the code to use py::function and only unboxes in the special is-really-a-c-function case. This makes bound methods stay bound rather than unbinding them by forcing extraction of the c function. --- include/pybind11/functional.h | 18 +++++++++--------- include/pybind11/pytypes.h | 8 ++++++-- tests/test_callbacks.cpp | 5 +++++ tests/test_callbacks.py | 15 +++++++++++++++ 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index a99ee737f..ab9e1c3c5 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -22,14 +22,15 @@ struct type_caster> { using function_type = Return (*) (Args...); public: - bool load(handle src_, bool) { - if (src_.is_none()) + bool load(handle src, bool) { + if (src.is_none()) return true; - src_ = detail::get_function(src_); - if (!src_ || !PyCallable_Check(src_.ptr())) + if (!isinstance(src)) return false; + auto func = reinterpret_borrow(src); + /* When passing a C++ function as an argument to another C++ function via Python, every function call would normally involve @@ -38,8 +39,8 @@ public: stateless (i.e. function pointer or lambda function without captured variables), in which case the roundtrip can be avoided. */ - if (PyCFunction_Check(src_.ptr())) { - auto c = reinterpret_borrow(PyCFunction_GET_SELF(src_.ptr())); + if (auto cfunc = func.cpp_function()) { + auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); auto rec = (function_record *) c; if (rec && rec->is_stateless && rec->data[1] == &typeid(function_type)) { @@ -49,10 +50,9 @@ public: } } - auto src = reinterpret_borrow(src_); - value = [src](Args... args) -> Return { + value = [func](Args... args) -> Return { gil_scoped_acquire acq; - object retval(src(std::forward(args)...)); + object retval(func(std::forward(args)...)); /* Visual studio 2015 parser issue: need parentheses around this expression */ return (retval.template cast()); }; diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index b7317df96..26f5ae409 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -355,6 +355,7 @@ inline handle get_function(handle value) { #if PY_MAJOR_VERSION >= 3 if (PyInstanceMethod_Check(value.ptr())) value = PyInstanceMethod_GET_FUNCTION(value.ptr()); + else #endif if (PyMethod_Check(value.ptr())) value = PyMethod_GET_FUNCTION(value.ptr()); @@ -1133,10 +1134,13 @@ public: class function : public object { public: PYBIND11_OBJECT_DEFAULT(function, object, PyCallable_Check) - bool is_cpp_function() const { + handle cpp_function() const { handle fun = detail::get_function(m_ptr); - return fun && PyCFunction_Check(fun.ptr()); + if (fun && PyCFunction_Check(fun.ptr())) + return fun; + return handle(); } + bool is_cpp_function() const { return (bool) cpp_function(); } }; class buffer : public object { diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index f89cc1c79..41110087c 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -179,4 +179,9 @@ test_initializer callbacks([](py::module &m) { f(x); // lvalue reference shouldn't move out object return x.valid; // must still return `true` }); + + struct CppBoundMethodTest {}; + py::class_(m, "CppBoundMethodTest") + .def(py::init<>()) + .def("triple", [](CppBoundMethodTest &, int val) { return 3 * val; }); }); diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index f94e7b64c..a5109d03c 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -27,6 +27,21 @@ def test_callbacks(): assert f(number=43) == 44 +def test_bound_method_callback(): + from pybind11_tests import test_callback3, CppBoundMethodTest + + # Bound Python method: + class MyClass: + def double(self, val): + return 2 * val + + z = MyClass() + assert test_callback3(z.double) == "func(43) = 86" + + z = CppBoundMethodTest() + assert test_callback3(z.triple) == "func(43) = 129" + + def test_keyword_args_and_generalized_unpacking(): from pybind11_tests import (test_tuple_unpacking, test_dict_unpacking, test_keyword_args, test_unpacking_and_keywords1, test_unpacking_and_keywords2,