mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-18 17:05:53 +00:00
support for overriding virtual functions
This commit is contained in:
parent
04358b02ed
commit
a2f6fde0dc
@ -61,6 +61,7 @@ add_library(example SHARED
|
||||
example/example9.cpp
|
||||
example/example10.cpp
|
||||
example/example11.cpp
|
||||
example/example12.cpp
|
||||
)
|
||||
|
||||
set_target_properties(example PROPERTIES PREFIX "")
|
||||
|
@ -38,6 +38,7 @@ The following core C++ features can be mapped to Python
|
||||
- STL data structures
|
||||
- Smart pointers with reference counting like `std::shared_ptr`
|
||||
- Internal references with correct reference counting
|
||||
- C++ classes with virtual (and pure virtual) methods can be extended in Python
|
||||
|
||||
## Goodies
|
||||
In addition to the core functionality, pybind11 provides some extra goodies:
|
||||
|
@ -20,6 +20,7 @@ void init_ex8(py::module &);
|
||||
void init_ex9(py::module &);
|
||||
void init_ex10(py::module &);
|
||||
void init_ex11(py::module &);
|
||||
void init_ex12(py::module &);
|
||||
|
||||
PYTHON_PLUGIN(example) {
|
||||
py::module m("example", "pybind example plugin");
|
||||
@ -35,6 +36,7 @@ PYTHON_PLUGIN(example) {
|
||||
init_ex9(m);
|
||||
init_ex10(m);
|
||||
init_ex11(m);
|
||||
init_ex12(m);
|
||||
|
||||
return m.ptr();
|
||||
}
|
||||
|
82
example/example12.cpp
Normal file
82
example/example12.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
example/example12.cpp -- overriding virtual functions from Python
|
||||
|
||||
Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include "example.h"
|
||||
#include <pybind/functional.h>
|
||||
|
||||
/* This is an example class that we'll want to be able to extend from Python */
|
||||
class Example12 {
|
||||
public:
|
||||
Example12(int state) : state(state) {
|
||||
cout << "Constructing Example12.." << endl;
|
||||
}
|
||||
|
||||
~Example12() {
|
||||
cout << "Destructing Example12.." << endl;
|
||||
}
|
||||
|
||||
virtual int run(int value) {
|
||||
std::cout << "Original implementation of Example12::run(state=" << state
|
||||
<< ", value=" << value << ")" << std::endl;
|
||||
return state + value;
|
||||
}
|
||||
|
||||
virtual void pure_virtual() = 0;
|
||||
private:
|
||||
int state;
|
||||
};
|
||||
|
||||
/* This is a wrapper class that must be generated */
|
||||
class PyExample12 : public Example12 {
|
||||
public:
|
||||
using Example12::Example12; /* Inherit constructors */
|
||||
|
||||
virtual int run(int value) {
|
||||
/* Generate wrapping code that enables native function overloading */
|
||||
PYBIND_OVERLOAD(
|
||||
int, /* Return type */
|
||||
Example12, /* Parent class */
|
||||
run, /* Name of function */
|
||||
value /* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
virtual void pure_virtual() {
|
||||
PYBIND_OVERLOAD_PURE(
|
||||
void, /* Return type */
|
||||
Example12, /* Parent class */
|
||||
pure_virtual /* Name of function */
|
||||
/* This function has no arguments */
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
int runExample12(Example12 *ex, int value) {
|
||||
return ex->run(value);
|
||||
}
|
||||
|
||||
void runExample12Virtual(Example12 *ex) {
|
||||
ex->pure_virtual();
|
||||
}
|
||||
|
||||
void init_ex12(py::module &m) {
|
||||
/* Important: use the wrapper type as a template
|
||||
argument to class_<>, but use the original name
|
||||
to denote the type */
|
||||
py::class_<PyExample12>(m, "Example12")
|
||||
/* Declare that 'PyExample12' is really an alias for the original type 'Example12' */
|
||||
.alias<Example12>()
|
||||
.def(py::init<int>())
|
||||
/* Reference original class in function definitions */
|
||||
.def("run", &Example12::run)
|
||||
.def("pure_virtual", &Example12::pure_virtual);
|
||||
|
||||
m.def("runExample12", &runExample12);
|
||||
m.def("runExample12Virtual", &runExample12Virtual);
|
||||
}
|
31
example/example12.py
Normal file
31
example/example12.py
Normal file
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
|
||||
from example import Example12, runExample12, runExample12Virtual
|
||||
|
||||
|
||||
class ExtendedExample12(Example12):
|
||||
def __init__(self, state):
|
||||
super(ExtendedExample12, self).__init__(state + 1)
|
||||
self.data = "Hello world"
|
||||
|
||||
def run(self, value):
|
||||
print('ExtendedExample12::run(%i), calling parent..' % value)
|
||||
return super(ExtendedExample12, self).run(value + 1)
|
||||
|
||||
def pure_virtual(self):
|
||||
print('ExtendedExample12::pure_virtual(): %s' % self.data)
|
||||
|
||||
|
||||
ex12 = Example12(10)
|
||||
print(runExample12(ex12, 20))
|
||||
try:
|
||||
runExample12Virtual(ex12)
|
||||
except Exception as e:
|
||||
print("Caught expected exception: " + str(e))
|
||||
|
||||
ex12p = ExtendedExample12(10)
|
||||
print(runExample12(ex12p, 20))
|
||||
runExample12Virtual(ex12p)
|
@ -37,28 +37,6 @@ void dog_bark(const Dog &dog) {
|
||||
dog.bark();
|
||||
}
|
||||
|
||||
class Example5 {
|
||||
public:
|
||||
Example5(py::handle self, int state)
|
||||
: self(self), state(state) {
|
||||
cout << "Constructing Example5.." << endl;
|
||||
}
|
||||
|
||||
~Example5() {
|
||||
cout << "Destructing Example5.." << endl;
|
||||
}
|
||||
|
||||
void callback(int value) {
|
||||
py::gil_scoped_acquire gil;
|
||||
cout << "In Example5::callback() " << endl;
|
||||
py::object method = self.attr("callback");
|
||||
method.call(state, value);
|
||||
}
|
||||
private:
|
||||
py::handle self;
|
||||
int state;
|
||||
};
|
||||
|
||||
bool test_callback1(py::object func) {
|
||||
func.call();
|
||||
return false;
|
||||
@ -69,16 +47,11 @@ int test_callback2(py::object func) {
|
||||
return result.cast<int>();
|
||||
}
|
||||
|
||||
void test_callback3(Example5 *ex, int value) {
|
||||
py::gil_scoped_release gil;
|
||||
ex->callback(value);
|
||||
}
|
||||
|
||||
void test_callback4(const std::function<int(int)> &func) {
|
||||
void test_callback3(const std::function<int(int)> &func) {
|
||||
cout << "func(43) = " << func(43)<< std::endl;
|
||||
}
|
||||
|
||||
std::function<int(int)> test_callback5() {
|
||||
std::function<int(int)> test_callback4() {
|
||||
return [](int i) { return i+1; };
|
||||
}
|
||||
|
||||
@ -99,8 +72,4 @@ void init_ex5(py::module &m) {
|
||||
m.def("test_callback2", &test_callback2);
|
||||
m.def("test_callback3", &test_callback3);
|
||||
m.def("test_callback4", &test_callback4);
|
||||
m.def("test_callback5", &test_callback5);
|
||||
|
||||
py::class_<Example5>(m, "Example5")
|
||||
.def(py::init<py::object, int>());
|
||||
}
|
||||
|
@ -24,29 +24,17 @@ from example import test_callback1
|
||||
from example import test_callback2
|
||||
from example import test_callback3
|
||||
from example import test_callback4
|
||||
from example import test_callback5
|
||||
from example import Example5
|
||||
|
||||
def func1():
|
||||
print('Callback function 1 called!')
|
||||
|
||||
def func2(a, b, c, d):
|
||||
print('Callback function 2 called : ' + str(a) + ", " + str(b) + ", " + str(c) + ", "+ str(d))
|
||||
return c
|
||||
|
||||
class MyCallback(Example5):
|
||||
def __init__(self, value):
|
||||
Example5.__init__(self, self, value)
|
||||
|
||||
def callback(self, value1, value2):
|
||||
print('got callback: %i %i' % (value1, value2))
|
||||
return d
|
||||
|
||||
print(test_callback1(func1))
|
||||
print(test_callback2(func2))
|
||||
|
||||
callback = MyCallback(3)
|
||||
test_callback3(callback, 4)
|
||||
|
||||
test_callback4(lambda i: i+1)
|
||||
f = test_callback5()
|
||||
test_callback3(lambda i: i + 1)
|
||||
f = test_callback4()
|
||||
print("func(43) = %i" % f(43))
|
||||
|
@ -601,6 +601,7 @@ template <typename T> inline object cast(const T &value, return_value_policy pol
|
||||
}
|
||||
|
||||
template <typename T> inline T handle::cast() { return pybind::cast<T>(m_ptr); }
|
||||
template <> inline void handle::cast() { return; }
|
||||
|
||||
template <typename... Args> inline object handle::call(Args&&... args_) {
|
||||
const size_t size = sizeof...(Args);
|
||||
@ -624,6 +625,8 @@ template <typename... Args> inline object handle::call(Args&&... args_) {
|
||||
PyTuple_SetItem(tuple, counter++, result);
|
||||
PyObject *result = PyObject_CallObject(m_ptr, tuple);
|
||||
Py_DECREF(tuple);
|
||||
if (result == nullptr && PyErr_Occurred())
|
||||
throw error_already_set();
|
||||
return object(result, false);
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
@ -114,13 +115,6 @@ struct buffer_info {
|
||||
}
|
||||
};
|
||||
|
||||
// C++ bindings of core Python exceptions
|
||||
struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
|
||||
struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
|
||||
struct error_already_set : public std::exception { public: error_already_set() {} };
|
||||
/// Thrown when pybind::cast or handle::call fail due to a type casting error
|
||||
struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
inline std::string error_string();
|
||||
@ -145,10 +139,19 @@ struct type_info {
|
||||
void *get_buffer_data = nullptr;
|
||||
};
|
||||
|
||||
struct overload_hash {
|
||||
inline std::size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
|
||||
size_t value = std::hash<const void *>()(v.first);
|
||||
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Internal data struture used to track registered instances and types
|
||||
struct internals {
|
||||
std::unordered_map<const std::type_info *, type_info> registered_types;
|
||||
std::unordered_map<void *, PyObject *> registered_instances;
|
||||
std::unordered_map<const void *, PyObject *> registered_instances;
|
||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
||||
};
|
||||
|
||||
/// Return a reference to the current 'internals' information
|
||||
@ -176,5 +179,20 @@ template <typename T, size_t N> struct decay<T[N]> { typedef typename deca
|
||||
/// Helper type to replace 'void' in some expressions
|
||||
struct void_type { };
|
||||
|
||||
/// to_string variant which also accepts strings
|
||||
template <typename T> inline typename std::enable_if<!std::is_enum<T>::value, std::string>::type
|
||||
to_string(const T &value) { return std::to_string(value); }
|
||||
template <> inline std::string to_string(const std::string &value) { return value; }
|
||||
template <typename T> inline typename std::enable_if<std::is_enum<T>::value, std::string>::type
|
||||
to_string(T value) { return std::to_string((int) value); }
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
// C++ bindings of core Python exceptions
|
||||
struct stop_iteration : public std::runtime_error { public: stop_iteration(const std::string &w="") : std::runtime_error(w) {} };
|
||||
struct index_error : public std::runtime_error { public: index_error(const std::string &w="") : std::runtime_error(w) {} };
|
||||
struct error_already_set : public std::runtime_error { public: error_already_set() : std::runtime_error(detail::error_string()) {} };
|
||||
/// Thrown when pybind::cast or handle::call fail due to a type casting error
|
||||
struct cast_error : public std::runtime_error { public: cast_error(const std::string &w = "") : std::runtime_error(w) {} };
|
||||
|
||||
NAMESPACE_END(pybind)
|
||||
|
@ -25,8 +25,6 @@ public:
|
||||
object src(src_, true);
|
||||
value = [src](Args... args) -> Return {
|
||||
object retval(pybind::handle(src).call(std::move(args)...));
|
||||
if (retval.ptr() == nullptr && PyErr_Occurred())
|
||||
throw error_already_set();
|
||||
/* Visual studio 2015 parser issue: need parentheses around this expression */
|
||||
return (retval.template cast<Return>());
|
||||
};
|
||||
|
@ -24,7 +24,6 @@
|
||||
#endif
|
||||
|
||||
#include <pybind/cast.h>
|
||||
#include <iostream>
|
||||
|
||||
NAMESPACE_BEGIN(pybind)
|
||||
|
||||
@ -46,12 +45,8 @@ template <typename T> inline arg_t<T> arg::operator=(const T &value) { return ar
|
||||
|
||||
/// Annotation for methods
|
||||
struct is_method {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyObject *class_;
|
||||
is_method(object *o) : class_(o->ptr()) { }
|
||||
#else
|
||||
is_method(object *) { }
|
||||
#endif
|
||||
};
|
||||
|
||||
/// Annotation for documentation
|
||||
@ -76,9 +71,7 @@ private:
|
||||
short keywords = 0;
|
||||
return_value_policy policy = return_value_policy::automatic;
|
||||
std::string signature;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyObject *class_ = nullptr;
|
||||
#endif
|
||||
PyObject *sibling = nullptr;
|
||||
const char *doc = nullptr;
|
||||
function_entry *next = nullptr;
|
||||
@ -126,21 +119,18 @@ private:
|
||||
kw[entry->keywords++] = "self";
|
||||
kw[entry->keywords++] = a.name;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void process_extra(const pybind::arg_t<T> &a, function_entry *entry, const char **kw, const char **def) {
|
||||
if (entry->is_method && entry->keywords == 0)
|
||||
kw[entry->keywords++] = "self";
|
||||
kw[entry->keywords] = a.name;
|
||||
def[entry->keywords++] = strdup(std::to_string(a.value).c_str());
|
||||
def[entry->keywords++] = strdup(detail::to_string(a.value).c_str());
|
||||
}
|
||||
|
||||
static void process_extra(const pybind::is_method &m, function_entry *entry, const char **, const char **) {
|
||||
entry->is_method = true;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
entry->class_ = m.class_;
|
||||
#else
|
||||
(void) m;
|
||||
#endif
|
||||
}
|
||||
static void process_extra(const pybind::return_value_policy p, function_entry *entry, const char **, const char **) { entry->policy = p; }
|
||||
static void process_extra(pybind::sibling s, function_entry *entry, const char **, const char **) { entry->sibling = s.value; }
|
||||
@ -366,35 +356,38 @@ private:
|
||||
m_entry->sibling = PyMethod_GET_FUNCTION(m_entry->sibling);
|
||||
#endif
|
||||
|
||||
function_entry *entry = m_entry;
|
||||
bool overloaded = false;
|
||||
if (!entry->sibling || !PyCFunction_Check(entry->sibling)) {
|
||||
entry->def = new PyMethodDef();
|
||||
memset(entry->def, 0, sizeof(PyMethodDef));
|
||||
entry->def->ml_name = entry->name;
|
||||
entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
|
||||
entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
|
||||
capsule entry_capsule(entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
|
||||
m_ptr = PyCFunction_New(entry->def, entry_capsule.ptr());
|
||||
function_entry *s_entry = nullptr, *entry = m_entry;
|
||||
if (m_entry->sibling && PyCFunction_Check(m_entry->sibling)) {
|
||||
capsule entry_capsule(PyCFunction_GetSelf(m_entry->sibling), true);
|
||||
s_entry = (function_entry *) entry_capsule;
|
||||
if (s_entry->class_ != m_entry->class_)
|
||||
s_entry = nullptr; /* Method override */
|
||||
}
|
||||
|
||||
if (!s_entry) {
|
||||
m_entry->def = new PyMethodDef();
|
||||
memset(m_entry->def, 0, sizeof(PyMethodDef));
|
||||
m_entry->def->ml_name = m_entry->name;
|
||||
m_entry->def->ml_meth = reinterpret_cast<PyCFunction>(*dispatcher);
|
||||
m_entry->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
|
||||
capsule entry_capsule(m_entry, [](PyObject *o) { destruct((function_entry *) PyCapsule_GetPointer(o, nullptr)); });
|
||||
m_ptr = PyCFunction_New(m_entry->def, entry_capsule.ptr());
|
||||
if (!m_ptr)
|
||||
throw std::runtime_error("cpp_function::cpp_function(): Could not allocate function object");
|
||||
} else {
|
||||
m_ptr = entry->sibling;
|
||||
m_ptr = m_entry->sibling;
|
||||
inc_ref();
|
||||
capsule entry_capsule(PyCFunction_GetSelf(m_ptr), true);
|
||||
function_entry *parent = (function_entry *) entry_capsule, *backup = parent;
|
||||
while (parent->next)
|
||||
parent = parent->next;
|
||||
parent->next = entry;
|
||||
entry = backup;
|
||||
overloaded = true;
|
||||
entry = s_entry;
|
||||
while (s_entry->next)
|
||||
s_entry = s_entry->next;
|
||||
s_entry->next = m_entry;
|
||||
}
|
||||
|
||||
std::string signatures;
|
||||
int index = 0;
|
||||
function_entry *it = entry;
|
||||
while (it) { /* Create pydoc it */
|
||||
if (overloaded)
|
||||
if (s_entry)
|
||||
signatures += std::to_string(++index) + ". ";
|
||||
signatures += "Signature : " + std::string(it->signature) + "\n";
|
||||
if (it->doc && strlen(it->doc) > 0)
|
||||
@ -783,6 +776,12 @@ public:
|
||||
metaclass().attr(name) = property;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename target> class_ alias() {
|
||||
auto &instances = pybind::detail::get_internals().registered_types;
|
||||
instances[&typeid(target)] = instances[&typeid(type)];
|
||||
return *this;
|
||||
}
|
||||
private:
|
||||
static void init_holder(PyObject *inst_) {
|
||||
instance_type *inst = (instance_type *) inst_;
|
||||
@ -882,6 +881,43 @@ public:
|
||||
inline ~gil_scoped_release() { PyEval_RestoreThread(state); }
|
||||
};
|
||||
|
||||
inline function get_overload(const void *this_ptr, const char *name) {
|
||||
handle py_object = detail::get_object_handle(this_ptr);
|
||||
handle type = py_object.get_type();
|
||||
auto key = std::make_pair(type.ptr(), name);
|
||||
|
||||
/* Cache functions that aren't overloaded in python to avoid
|
||||
many costly dictionary lookups in Python */
|
||||
auto &cache = detail::get_internals().inactive_overload_cache;
|
||||
if (cache.find(key) != cache.end())
|
||||
return function();
|
||||
|
||||
function overload = (function) py_object.attr(name);
|
||||
if (overload.is_cpp_function()) {
|
||||
cache.insert(key);
|
||||
return function();
|
||||
}
|
||||
PyFrameObject *frame = PyThreadState_Get()->frame;
|
||||
pybind::str caller = pybind::handle(frame->f_code->co_name).str();
|
||||
if (strcmp((const char *) caller, name) == 0)
|
||||
return function();
|
||||
return overload;
|
||||
}
|
||||
|
||||
#define PYBIND_OVERLOAD_INT(ret_type, class_name, name, ...) { \
|
||||
pybind::gil_scoped_acquire gil; \
|
||||
pybind::function overload = pybind::get_overload(this, #name); \
|
||||
if (overload) \
|
||||
return overload.call(__VA_ARGS__).cast<ret_type>(); }
|
||||
|
||||
#define PYBIND_OVERLOAD(ret_type, class_name, name, ...) \
|
||||
PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
|
||||
return class_name::name(__VA_ARGS__)
|
||||
|
||||
#define PYBIND_OVERLOAD_PURE(ret_type, class_name, name, ...) \
|
||||
PYBIND_OVERLOAD_INT(ret_type, class_name, name, __VA_ARGS__) \
|
||||
throw std::runtime_error("Tried to call pure virtual function \"" #name "\"");
|
||||
|
||||
NAMESPACE_END(pybind)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
@ -331,10 +331,12 @@ public:
|
||||
PyObject *ptr = m_ptr;
|
||||
if (ptr == nullptr)
|
||||
return false;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyInstanceMethod_Check(ptr))
|
||||
ptr = PyInstanceMethod_GET_FUNCTION(ptr);
|
||||
#endif
|
||||
if (PyMethod_Check(ptr))
|
||||
ptr = PyMethod_GET_FUNCTION(ptr);
|
||||
#endif
|
||||
return PyCFunction_Check(ptr);
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user