Update OVERLOAD macros to support ref/ptr return type overloads

This adds a static local variable (in dead code unless actually needed)
in the overload code that is used for storage if the overload is for
some convert-by-value type (such as numeric values or std::string).

This has limitations (as written up in the advanced doc), but is better
than simply not being able to overload reference or pointer methods.
This commit is contained in:
Jason Rhinelander 2016-09-08 14:49:43 -04:00
parent 116d37c9ba
commit 7dfb932e70
5 changed files with 136 additions and 39 deletions

View File

@ -298,13 +298,11 @@ helper class that is defined as follows:
The macro :func:`PYBIND11_OVERLOAD_PURE` should be used for pure virtual The macro :func:`PYBIND11_OVERLOAD_PURE` should be used for pure virtual
functions, and :func:`PYBIND11_OVERLOAD` should be used for functions which have functions, and :func:`PYBIND11_OVERLOAD` should be used for functions which have
a default implementation. a default implementation. There are also two alternate macros
:func:`PYBIND11_OVERLOAD_PURE_NAME` and :func:`PYBIND11_OVERLOAD_NAME` which
There are also two alternate macros :func:`PYBIND11_OVERLOAD_PURE_NAME` and take a string-valued name argument between the *Parent class* and *Name of the
:func:`PYBIND11_OVERLOAD_NAME` which take a string-valued name argument between function* slots. This is useful when the C++ and Python versions of the
the *Parent class* and *Name of the function* slots. This is useful when the function have different names, e.g. ``operator()`` vs ``__call__``.
C++ and Python versions of the function have different names, e.g.
``operator()`` vs ``__call__``.
The binding code also needs a few minor adaptations (highlighted): The binding code also needs a few minor adaptations (highlighted):
@ -357,6 +355,25 @@ a virtual method call.
Please take a look at the :ref:`macro_notes` before using this feature. Please take a look at the :ref:`macro_notes` before using this feature.
.. note::
When the overridden type returns a reference or pointer to a type that
pybind11 converts from Python (for example, numeric values, std::string,
and other built-in value-converting types), there are some limitations to
be aware of:
- because in these cases there is no C++ variable to reference (the value
is stored in the referenced Python variable), pybind11 provides one in
the PYBIND11_OVERLOAD macros (when needed) with static storage duration.
Note that this means that invoking the overloaded method on *any*
instance will change the referenced value stored in *all* instances of
that type.
- Attempts to modify a non-const reference will not have the desired
effect: it will change only the static cache variable, but this change
will not propagate to underlying Python instance, and the change will be
replaced the next time the overload is invoked.
.. seealso:: .. seealso::
The file :file:`tests/test_virtual_functions.cpp` contains a complete The file :file:`tests/test_virtual_functions.cpp` contains a complete

View File

@ -867,14 +867,8 @@ template <typename type> using cast_is_temporary_value_reference = bool_constant
!std::is_base_of<type_caster_generic, make_caster<type>>::value !std::is_base_of<type_caster_generic, make_caster<type>>::value
>; >;
template <typename T> make_caster<T> load_type(const handle &handle) {
NAMESPACE_END(detail) make_caster<T> conv;
template <typename T> T cast(const handle &handle) {
using type_caster = detail::make_caster<T>;
static_assert(!detail::cast_is_temporary_value_reference<T>::value,
"Unable to cast type to reference: value is local to type caster");
type_caster conv;
if (!conv.load(handle, true)) { if (!conv.load(handle, true)) {
#if defined(NDEBUG) #if defined(NDEBUG)
throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)");
@ -883,7 +877,16 @@ template <typename T> T cast(const handle &handle) {
(std::string) handle.get_type().str() + " to C++ type '" + type_id<T>() + "''"); (std::string) handle.get_type().str() + " to C++ type '" + type_id<T>() + "''");
#endif #endif
} }
return conv.operator typename type_caster::template cast_op_type<T>(); return conv;
}
NAMESPACE_END(detail)
template <typename T> T cast(const handle &handle) {
static_assert(!detail::cast_is_temporary_value_reference<T>::value,
"Unable to cast type to reference: value is local to type caster");
using type_caster = detail::make_caster<T>;
return detail::load_type<T>(handle).operator typename type_caster::template cast_op_type<T>();
} }
template <typename T> object cast(const T &value, template <typename T> object cast(const T &value,
@ -900,7 +903,7 @@ template <typename T> T handle::cast() const { return pybind11::cast<T>(*this);
template <> inline void handle::cast() const { return; } template <> inline void handle::cast() const { return; }
template <typename T> template <typename T>
typename std::enable_if<detail::move_always<T>::value || detail::move_if_unreferenced<T>::value, T>::type move(object &&obj) { detail::enable_if_t<detail::move_always<T>::value || detail::move_if_unreferenced<T>::value, T> move(object &&obj) {
if (obj.ref_count() > 1) if (obj.ref_count() > 1)
#if defined(NDEBUG) #if defined(NDEBUG)
throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references" throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references"
@ -910,18 +913,8 @@ typename std::enable_if<detail::move_always<T>::value || detail::move_if_unrefer
" instance to C++ " + type_id<T>() + " instance: instance has multiple references"); " instance to C++ " + type_id<T>() + " instance: instance has multiple references");
#endif #endif
typedef detail::type_caster<T> type_caster;
type_caster conv;
if (!conv.load(obj, true))
#if defined(NDEBUG)
throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)");
#else
throw cast_error("Unable to cast Python instance of type " +
(std::string) obj.get_type().str() + " to C++ type '" + type_id<T>() + "''");
#endif
// Move into a temporary and return that, because the reference may be a local value of `conv` // Move into a temporary and return that, because the reference may be a local value of `conv`
T ret = std::move(conv.operator T&()); T ret = std::move(detail::load_type<T>(obj).operator T&());
return ret; return ret;
} }
@ -930,24 +923,57 @@ typename std::enable_if<detail::move_always<T>::value || detail::move_if_unrefer
// object has multiple references, but trying to copy will fail to compile. // object has multiple references, but trying to copy will fail to compile.
// - If both movable and copyable, check ref count: if 1, move; otherwise copy // - If both movable and copyable, check ref count: if 1, move; otherwise copy
// - Otherwise (not movable), copy. // - Otherwise (not movable), copy.
template <typename T> typename std::enable_if<detail::move_always<T>::value, T>::type cast(object &&object) { template <typename T> detail::enable_if_t<detail::move_always<T>::value, T> cast(object &&object) {
return move<T>(std::move(object)); return move<T>(std::move(object));
} }
template <typename T> typename std::enable_if<detail::move_if_unreferenced<T>::value, T>::type cast(object &&object) { template <typename T> detail::enable_if_t<detail::move_if_unreferenced<T>::value, T> cast(object &&object) {
if (object.ref_count() > 1) if (object.ref_count() > 1)
return cast<T>(object); return cast<T>(object);
else else
return move<T>(std::move(object)); return move<T>(std::move(object));
} }
template <typename T> typename std::enable_if<detail::move_never<T>::value, T>::type cast(object &&object) { template <typename T> detail::enable_if_t<detail::move_never<T>::value, T> cast(object &&object) {
return cast<T>(object); return cast<T>(object);
} }
// Provide a ref_cast() with move support for objects (only participates for moveable types)
template <typename T> detail::enable_if_t<detail::move_is_plain_type<T>::value, T>
ref_cast(object &&object) { return cast<T>(std::move(object)); }
template <typename T> T object::cast() const & { return pybind11::cast<T>(*this); } template <typename T> T object::cast() const & { return pybind11::cast<T>(*this); }
template <typename T> T object::cast() && { return pybind11::cast<T>(std::move(*this)); } template <typename T> T object::cast() && { return pybind11::cast<T>(std::move(*this)); }
template <> inline void object::cast() const & { return; } template <> inline void object::cast() const & { return; }
template <> inline void object::cast() && { return; } template <> inline void object::cast() && { return; }
NAMESPACE_BEGIN(detail)
struct overload_nothing {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro
template <typename ret_type> using overload_local_t = conditional_t<
cast_is_temporary_value_reference<ret_type>::value, intrinsic_t<ret_type>, overload_nothing>;
template <typename T> enable_if_t<std::is_lvalue_reference<T>::value, T> storage_cast(intrinsic_t<T> &v) { return v; }
template <typename T> enable_if_t<std::is_pointer<T>::value, T> storage_cast(intrinsic_t<T> &v) { return &v; }
// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then
// store the result in the given variable. For other types, this is a no-op.
template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&o, intrinsic_t<T> &storage) {
using type_caster = make_caster<T>;
using itype = intrinsic_t<T>;
storage = std::move(load_type<T>(o).operator typename type_caster::template cast_op_type<itype>());
return storage_cast<T>(storage);
}
template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&, overload_nothing &) {
pybind11_fail("Internal error: cast_ref fallback invoked"); }
// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even
// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in
// cases where pybind11::cast is valid.
template <typename T> enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&o) {
return pybind11::cast<T>(std::move(o)); }
template <typename T> enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_safe(object &&) {
pybind11_fail("Internal error: cast_safe fallback invoked"); }
template <> inline void cast_safe<void>(object &&) {}
NAMESPACE_END(detail)
template <return_value_policy policy = return_value_policy::automatic_reference, template <return_value_policy policy = return_value_policy::automatic_reference,

View File

@ -1485,8 +1485,15 @@ template <class T> function get_overload(const T *this_ptr, const char *name) {
#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ #define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
pybind11::gil_scoped_acquire gil; \ pybind11::gil_scoped_acquire gil; \
pybind11::function overload = pybind11::get_overload(static_cast<const cname *>(this), name); \ pybind11::function overload = pybind11::get_overload(static_cast<const cname *>(this), name); \
if (overload) \ if (overload) { \
return overload(__VA_ARGS__).template cast<ret_type>(); } pybind11::object o = overload(__VA_ARGS__); \
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
static pybind11::detail::overload_local_t<ret_type> local_value; \
return pybind11::detail::cast_ref<ret_type>(std::move(o), local_value); \
} \
else return pybind11::detail::cast_safe<ret_type>(std::move(o)); \
} \
}
#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ #define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \

View File

@ -21,14 +21,22 @@ public:
virtual int run(int value) { virtual int run(int value) {
py::print("Original implementation of " py::print("Original implementation of "
"ExampleVirt::run(state={}, value={})"_s.format(state, value)); "ExampleVirt::run(state={}, value={}, str1={}, str2={})"_s.format(state, value, get_string1(), *get_string2()));
return state + value; return state + value;
} }
virtual bool run_bool() = 0; virtual bool run_bool() = 0;
virtual void pure_virtual() = 0; virtual void pure_virtual() = 0;
// Returning a reference/pointer to a type converted from python (numbers, strings, etc.) is a
// bit trickier, because the actual int& or std::string& or whatever only exists temporarily, so
// we have to handle it specially in the trampoline class (see below).
virtual const std::string &get_string1() { return str1; }
virtual const std::string *get_string2() { return &str2; }
private: private:
int state; int state;
const std::string str1{"default1"}, str2{"default2"};
}; };
/* This is a wrapper class that must be generated */ /* This is a wrapper class that must be generated */
@ -65,6 +73,27 @@ public:
in the previous line is needed for some compilers */ in the previous line is needed for some compilers */
); );
} }
// We can return reference types for compatibility with C++ virtual interfaces that do so, but
// note they have some significant limitations (see the documentation).
const std::string &get_string1() override {
PYBIND11_OVERLOAD(
const std::string &, /* Return type */
ExampleVirt, /* Parent class */
get_string1, /* Name of function */
/* (no arguments) */
);
}
const std::string *get_string2() override {
PYBIND11_OVERLOAD(
const std::string *, /* Return type */
ExampleVirt, /* Parent class */
get_string2, /* Name of function */
/* (no arguments) */
);
}
}; };
class NonCopyable { class NonCopyable {

View File

@ -20,13 +20,23 @@ def test_override(capture, msg):
print('ExtendedExampleVirt::run_bool()') print('ExtendedExampleVirt::run_bool()')
return False return False
def get_string1(self):
return "override1"
def pure_virtual(self): def pure_virtual(self):
print('ExtendedExampleVirt::pure_virtual(): %s' % self.data) print('ExtendedExampleVirt::pure_virtual(): %s' % self.data)
class ExtendedExampleVirt2(ExtendedExampleVirt):
def __init__(self, state):
super(ExtendedExampleVirt2, self).__init__(state + 1)
def get_string2(self):
return "override2"
ex12 = ExampleVirt(10) ex12 = ExampleVirt(10)
with capture: with capture:
assert runExampleVirt(ex12, 20) == 30 assert runExampleVirt(ex12, 20) == 30
assert capture == "Original implementation of ExampleVirt::run(state=10, value=20)" assert capture == "Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)"
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
runExampleVirtVirtual(ex12) runExampleVirtVirtual(ex12)
@ -37,7 +47,7 @@ def test_override(capture, msg):
assert runExampleVirt(ex12p, 20) == 32 assert runExampleVirt(ex12p, 20) == 32
assert capture == """ assert capture == """
ExtendedExampleVirt::run(20), calling parent.. ExtendedExampleVirt::run(20), calling parent..
Original implementation of ExampleVirt::run(state=11, value=21) Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
""" """
with capture: with capture:
assert runExampleVirtBool(ex12p) is False assert runExampleVirtBool(ex12p) is False
@ -46,11 +56,19 @@ def test_override(capture, msg):
runExampleVirtVirtual(ex12p) runExampleVirtVirtual(ex12p)
assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world" assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
ex12p2 = ExtendedExampleVirt2(15)
with capture:
assert runExampleVirt(ex12p2, 50) == 68
assert capture == """
ExtendedExampleVirt::run(50), calling parent..
Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
"""
cstats = ConstructorStats.get(ExampleVirt) cstats = ConstructorStats.get(ExampleVirt)
assert cstats.alive() == 2 assert cstats.alive() == 3
del ex12, ex12p del ex12, ex12p, ex12p2
assert cstats.alive() == 0 assert cstats.alive() == 0
assert cstats.values() == ['10', '11'] assert cstats.values() == ['10', '11', '17']
assert cstats.copy_constructions == 0 assert cstats.copy_constructions == 0
assert cstats.move_constructions >= 0 assert cstats.move_constructions >= 0