From a2f6fde0dcea0a4c66429e4bb262fc51079900c5 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Thu, 1 Oct 2015 16:46:03 +0200 Subject: [PATCH] support for overriding virtual functions --- CMakeLists.txt | 1 + README.md | 1 + example/example.cpp | 2 + example/example12.cpp | 82 +++++++++++++++++++++++++++++++ example/example12.py | 31 ++++++++++++ example/example5.cpp | 35 +------------ example/example5.py | 18 ++----- include/pybind/cast.h | 3 ++ include/pybind/common.h | 34 ++++++++++--- include/pybind/functional.h | 2 - include/pybind/pybind.h | 98 +++++++++++++++++++++++++------------ include/pybind/pytypes.h | 6 ++- 12 files changed, 222 insertions(+), 91 deletions(-) create mode 100644 example/example12.cpp create mode 100644 example/example12.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e80bd014..96bb79598 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(example SHARED example/example9.cpp example/example10.cpp example/example11.cpp + example/example12.cpp ) set_target_properties(example PROPERTIES PREFIX "") diff --git a/README.md b/README.md index ead5b37d7..f802d7dea 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ The following core C++ features can be mapped to Python - STL data structures - Smart pointers with reference counting like `std::shared_ptr` - Internal references with correct reference counting +- C++ classes with virtual (and pure virtual) methods can be extended in Python ## Goodies In addition to the core functionality, pybind11 provides some extra goodies: diff --git a/example/example.cpp b/example/example.cpp index 2cc7e5056..9eea762e9 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -20,6 +20,7 @@ void init_ex8(py::module &); void init_ex9(py::module &); void init_ex10(py::module &); void init_ex11(py::module &); +void init_ex12(py::module &); PYTHON_PLUGIN(example) { py::module m("example", "pybind example plugin"); @@ -35,6 +36,7 @@ PYTHON_PLUGIN(example) { init_ex9(m); init_ex10(m); init_ex11(m); + init_ex12(m); return m.ptr(); } diff --git a/example/example12.cpp b/example/example12.cpp new file mode 100644 index 000000000..274edf832 --- /dev/null +++ b/example/example12.cpp @@ -0,0 +1,82 @@ +/* + example/example12.cpp -- overriding virtual functions from Python + + Copyright (c) 2015 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include "example.h" +#include + +/* This is an example class that we'll want to be able to extend from Python */ +class Example12 { +public: + Example12(int state) : state(state) { + cout << "Constructing Example12.." << endl; + } + + ~Example12() { + cout << "Destructing Example12.." << endl; + } + + virtual int run(int value) { + std::cout << "Original implementation of Example12::run(state=" << state + << ", value=" << value << ")" << std::endl; + return state + value; + } + + virtual void pure_virtual() = 0; +private: + int state; +}; + +/* This is a wrapper class that must be generated */ +class PyExample12 : public Example12 { +public: + using Example12::Example12; /* Inherit constructors */ + + virtual int run(int value) { + /* Generate wrapping code that enables native function overloading */ + PYBIND_OVERLOAD( + int, /* Return type */ + Example12, /* Parent class */ + run, /* Name of function */ + value /* Argument(s) */ + ); + } + + virtual void pure_virtual() { + PYBIND_OVERLOAD_PURE( + void, /* Return type */ + Example12, /* Parent class */ + pure_virtual /* Name of function */ + /* This function has no arguments */ + ); + } +}; + +int runExample12(Example12 *ex, int value) { + return ex->run(value); +} + +void runExample12Virtual(Example12 *ex) { + ex->pure_virtual(); +} + +void init_ex12(py::module &m) { + /* Important: use the wrapper type as a template + argument to class_<>, but use the original name + to denote the type */ + py::class_(m, "Example12") + /* Declare that 'PyExample12' is really an alias for the original type 'Example12' */ + .alias() + .def(py::init()) + /* Reference original class in function definitions */ + .def("run", &Example12::run) + .def("pure_virtual", &Example12::pure_virtual); + + m.def("runExample12", &runExample12); + m.def("runExample12Virtual", &runExample12Virtual); +} diff --git a/example/example12.py b/example/example12.py new file mode 100644 index 000000000..4f785750b --- /dev/null +++ b/example/example12.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +from __future__ import print_function +import sys +sys.path.append('.') + +from example import Example12, runExample12, runExample12Virtual + + +class ExtendedExample12(Example12): + def __init__(self, state): + super(ExtendedExample12, self).__init__(state + 1) + self.data = "Hello world" + + def run(self, value): + print('ExtendedExample12::run(%i), calling parent..' % value) + return super(ExtendedExample12, self).run(value + 1) + + def pure_virtual(self): + print('ExtendedExample12::pure_virtual(): %s' % self.data) + + +ex12 = Example12(10) +print(runExample12(ex12, 20)) +try: + runExample12Virtual(ex12) +except Exception as e: + print("Caught expected exception: " + str(e)) + +ex12p = ExtendedExample12(10) +print(runExample12(ex12p, 20)) +runExample12Virtual(ex12p) diff --git a/example/example5.cpp b/example/example5.cpp index f6de5ba20..91e0b4fde 100644 --- a/example/example5.cpp +++ b/example/example5.cpp @@ -37,28 +37,6 @@ void dog_bark(const Dog &dog) { dog.bark(); } -class Example5 { -public: - Example5(py::handle self, int state) - : self(self), state(state) { - cout << "Constructing Example5.." << endl; - } - - ~Example5() { - cout << "Destructing Example5.." << endl; - } - - void callback(int value) { - py::gil_scoped_acquire gil; - cout << "In Example5::callback() " << endl; - py::object method = self.attr("callback"); - method.call(state, value); - } -private: - py::handle self; - int state; -}; - bool test_callback1(py::object func) { func.call(); return false; @@ -69,16 +47,11 @@ int test_callback2(py::object func) { return result.cast(); } -void test_callback3(Example5 *ex, int value) { - py::gil_scoped_release gil; - ex->callback(value); -} - -void test_callback4(const std::function &func) { +void test_callback3(const std::function &func) { cout << "func(43) = " << func(43)<< std::endl; } -std::function test_callback5() { +std::function test_callback4() { return [](int i) { return i+1; }; } @@ -99,8 +72,4 @@ void init_ex5(py::module &m) { m.def("test_callback2", &test_callback2); m.def("test_callback3", &test_callback3); m.def("test_callback4", &test_callback4); - m.def("test_callback5", &test_callback5); - - py::class_(m, "Example5") - .def(py::init()); } diff --git a/example/example5.py b/example/example5.py index 4e75e1714..5aaaae7f4 100755 --- a/example/example5.py +++ b/example/example5.py @@ -24,29 +24,17 @@ from example import test_callback1 from example import test_callback2 from example import test_callback3 from example import test_callback4 -from example import test_callback5 -from example import Example5 def func1(): print('Callback function 1 called!') def func2(a, b, c, d): print('Callback function 2 called : ' + str(a) + ", " + str(b) + ", " + str(c) + ", "+ str(d)) - return c - -class MyCallback(Example5): - def __init__(self, value): - Example5.__init__(self, self, value) - - def callback(self, value1, value2): - print('got callback: %i %i' % (value1, value2)) + return d print(test_callback1(func1)) print(test_callback2(func2)) -callback = MyCallback(3) -test_callback3(callback, 4) - -test_callback4(lambda i: i+1) -f = test_callback5() +test_callback3(lambda i: i + 1) +f = test_callback4() print("func(43) = %i" % f(43)) diff --git a/include/pybind/cast.h b/include/pybind/cast.h index 2feaf84ee..b8eee92c4 100644 --- a/include/pybind/cast.h +++ b/include/pybind/cast.h @@ -601,6 +601,7 @@ template inline object cast(const T &value, return_value_policy pol } template inline T handle::cast() { return pybind::cast(m_ptr); } +template <> inline void handle::cast() { return; } template inline object handle::call(Args&&... args_) { const size_t size = sizeof...(Args); @@ -624,6 +625,8 @@ template inline object handle::call(Args&&... args_) { PyTuple_SetItem(tuple, counter++, result); PyObject *result = PyObject_CallObject(m_ptr, tuple); Py_DECREF(tuple); + if (result == nullptr && PyErr_Occurred()) + throw error_already_set(); return object(result, false); } diff --git a/include/pybind/common.h b/include/pybind/common.h index 3a95210b5..05750c653 100644 --- a/include/pybind/common.h +++ b/include/pybind/common.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -114,13 +115,6 @@ struct buffer_info { } }; -// C++ bindings of core Python exceptions -struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} }; -struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} }; -struct error_already_set : public std::exception { public: error_already_set() {} }; -/// Thrown when pybind::cast or handle::call fail due to a type casting error -struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} }; - NAMESPACE_BEGIN(detail) inline std::string error_string(); @@ -145,10 +139,19 @@ struct type_info { void *get_buffer_data = nullptr; }; +struct overload_hash { + inline std::size_t operator()(const std::pair& v) const { + size_t value = std::hash()(v.first); + value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); + return value; + } +}; + /// Internal data struture used to track registered instances and types struct internals { std::unordered_map registered_types; - std::unordered_map registered_instances; + std::unordered_map registered_instances; + std::unordered_set, overload_hash> inactive_overload_cache; }; /// Return a reference to the current 'internals' information @@ -176,5 +179,20 @@ template struct decay { typedef typename deca /// Helper type to replace 'void' in some expressions struct void_type { }; +/// to_string variant which also accepts strings +template inline typename std::enable_if::value, std::string>::type +to_string(const T &value) { return std::to_string(value); } +template <> inline std::string to_string(const std::string &value) { return value; } +template inline typename std::enable_if::value, std::string>::type +to_string(T value) { return std::to_string((int) value); } + NAMESPACE_END(detail) + +// C++ bindings of core Python exceptions +struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} }; +struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} }; +struct error_already_set : public std::runtime_error { public: error_already_set() : std::runtime_error(detail::error_string()) {} }; +/// Thrown when pybind::cast or handle::call fail due to a type casting error +struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} }; + NAMESPACE_END(pybind) diff --git a/include/pybind/functional.h b/include/pybind/functional.h index f300d4d4d..7c3c9a006 100644 --- a/include/pybind/functional.h +++ b/include/pybind/functional.h @@ -25,8 +25,6 @@ public: object src(src_, true); value = [src](Args... args) -> Return { object retval(pybind::handle(src).call(std::move(args)...)); - if (retval.ptr() == nullptr && PyErr_Occurred()) - throw error_already_set(); /* Visual studio 2015 parser issue: need parentheses around this expression */ return (retval.template cast()); }; diff --git a/include/pybind/pybind.h b/include/pybind/pybind.h index 156e5705c..699489c32 100644 --- a/include/pybind/pybind.h +++ b/include/pybind/pybind.h @@ -24,7 +24,6 @@ #endif #include -#include NAMESPACE_BEGIN(pybind) @@ -46,12 +45,8 @@ template inline arg_t arg::operator=(const T &value) { return ar /// Annotation for methods struct is_method { -#if PY_MAJOR_VERSION < 3 PyObject *class_; is_method(object *o) : class_(o->ptr()) { } -#else - is_method(object *) { } -#endif }; /// Annotation for documentation @@ -76,9 +71,7 @@ private: short keywords = 0; return_value_policy policy = return_value_policy::automatic; std::string signature; -#if PY_MAJOR_VERSION < 3 PyObject *class_ = nullptr; -#endif PyObject *sibling = nullptr; const char *doc = nullptr; function_entry *next = nullptr; @@ -126,21 +119,18 @@ private: kw[entry->keywords++] = "self"; kw[entry->keywords++] = a.name; } + template static void process_extra(const pybind::arg_t &a, function_entry *entry, const char **kw, const char **def) { if (entry->is_method && entry->keywords == 0) kw[entry->keywords++] = "self"; kw[entry->keywords] = a.name; - def[entry->keywords++] = strdup(std::to_string(a.value).c_str()); + def[entry->keywords++] = strdup(detail::to_string(a.value).c_str()); } static void process_extra(const pybind::is_method &m, function_entry *entry, const char **, const char **) { entry->is_method = true; -#if PY_MAJOR_VERSION < 3 entry->class_ = m.class_; -#else - (void) m; -#endif } static void process_extra(const pybind::return_value_policy p, function_entry *entry, const char **, const char **) { entry->policy = p; } static void process_extra(pybind::sibling s, function_entry *entry, const char **, const char **) { entry->sibling = s.value; } @@ -366,35 +356,38 @@ private: m_entry->sibling = PyMethod_GET_FUNCTION(m_entry->sibling); #endif - function_entry *entry = m_entry; - bool overloaded = false; - if (!entry->sibling || !PyCFunction_Check(entry->sibling)) { - entry->def = new PyMethodDef(); - memset(entry->def, 0, sizeof(PyMethodDef)); - entry->def->ml_name = entry->name; - entry->def->ml_meth = reinterpret_cast(*dispatcher); - entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS; - capsule entry_capsule(entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); }); - m_ptr = PyCFunction_New(entry->def, entry_capsule.ptr()); + function_entry *s_entry = nullptr, *entry = m_entry; + if (m_entry->sibling && PyCFunction_Check(m_entry->sibling)) { + capsule entry_capsule(PyCFunction_GetSelf(m_entry->sibling), true); + s_entry = (function_entry *) entry_capsule; + if (s_entry->class_ != m_entry->class_) + s_entry = nullptr; /* Method override */ + } + + if (!s_entry) { + m_entry->def = new PyMethodDef(); + memset(m_entry->def, 0, sizeof(PyMethodDef)); + m_entry->def->ml_name = m_entry->name; + m_entry->def->ml_meth = reinterpret_cast(*dispatcher); + m_entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS; + capsule entry_capsule(m_entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); }); + m_ptr = PyCFunction_New(m_entry->def, entry_capsule.ptr()); if (!m_ptr) throw std::runtime_error("cpp_function::cpp_function(): Could not allocate function object"); } else { - m_ptr = entry->sibling; + m_ptr = m_entry->sibling; inc_ref(); - capsule entry_capsule(PyCFunction_GetSelf(m_ptr), true); - function_entry *parent = (function_entry *) entry_capsule, *backup = parent; - while (parent->next) - parent = parent->next; - parent->next = entry; - entry = backup; - overloaded = true; + entry = s_entry; + while (s_entry->next) + s_entry = s_entry->next; + s_entry->next = m_entry; } std::string signatures; int index = 0; function_entry *it = entry; while (it) { /* Create pydoc it */ - if (overloaded) + if (s_entry) signatures += std::to_string(++index) + ". "; signatures += "Signature : " + std::string(it->signature) + "\n"; if (it->doc && strlen(it->doc) > 0) @@ -783,6 +776,12 @@ public: metaclass().attr(name) = property; return *this; } + + template class_ alias() { + auto &instances = pybind::detail::get_internals().registered_types; + instances[&typeid(target)] = instances[&typeid(type)]; + return *this; + } private: static void init_holder(PyObject *inst_) { instance_type *inst = (instance_type *) inst_; @@ -882,6 +881,43 @@ public: inline ~gil_scoped_release() { PyEval_RestoreThread(state); } }; +inline function get_overload(const void *this_ptr, const char *name) { + handle py_object = detail::get_object_handle(this_ptr); + handle type = py_object.get_type(); + auto key = std::make_pair(type.ptr(), name); + + /* Cache functions that aren't overloaded in python to avoid + many costly dictionary lookups in Python */ + auto &cache = detail::get_internals().inactive_overload_cache; + if (cache.find(key) != cache.end()) + return function(); + + function overload = (function) py_object.attr(name); + if (overload.is_cpp_function()) { + cache.insert(key); + return function(); + } + PyFrameObject *frame = PyThreadState_Get()->frame; + pybind::str caller = pybind::handle(frame->f_code->co_name).str(); + if (strcmp((const char *) caller, name) == 0) + return function(); + return overload; +} + +#define PYBIND_OVERLOAD_INT(ret_type, class_name, name, ...) { \ + pybind::gil_scoped_acquire gil; \ + pybind::function overload = pybind::get_overload(this, #name); \ + if (overload) \ + return overload.call(__VA_ARGS__).cast(); } + +#define PYBIND_OVERLOAD(ret_type, class_name, name, ...) \ + PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \ + return class_name::name(__VA_ARGS__) + +#define PYBIND_OVERLOAD_PURE(ret_type, class_name, name, ...) \ + PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \ + throw std::runtime_error("Tried to call pure virtual function \"" #name "\""); + NAMESPACE_END(pybind) #if defined(_MSC_VER) diff --git a/include/pybind/pytypes.h b/include/pybind/pytypes.h index 4559ffbcb..8afc08877 100644 --- a/include/pybind/pytypes.h +++ b/include/pybind/pytypes.h @@ -331,10 +331,12 @@ public: PyObject *ptr = m_ptr; if (ptr == nullptr) return false; -#if PY_MAJOR_VERSION < 3 +#if PY_MAJOR_VERSION >= 3 + if (PyInstanceMethod_Check(ptr)) + ptr = PyInstanceMethod_GET_FUNCTION(ptr); +#endif if (PyMethod_Check(ptr)) ptr = PyMethod_GET_FUNCTION(ptr); -#endif return PyCFunction_Check(ptr); } };