operators should return NotImplemented given unsupported input (fixes #393)

This commit is contained in:
Wenzel Jakob 2016-09-10 15:28:37 +09:00
parent 8d38ebed91
commit 382484ae56
6 changed files with 70 additions and 8 deletions

View File

@ -90,10 +90,13 @@ is really just short hand notation for
.def("__mul__", [](const Vector2 &a, float b) { .def("__mul__", [](const Vector2 &a, float b) {
return a * b; return a * b;
}) }, py::is_operator())
This can be useful for exposing additional operators that don't exist on the This can be useful for exposing additional operators that don't exist on the
C++ side, or to perform other types of customization. C++ side, or to perform other types of customization. The ``py::is_operator``
flag marker is needed to inform pybind11 that this is an operator, which
returns ``NotImplemented`` when invoked with incompatible arguments rather than
throwing a type error.
.. note:: .. note::

View File

@ -17,6 +17,9 @@ NAMESPACE_BEGIN(pybind11)
/// Annotation for methods /// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
/// Annotation for operators
struct is_operator { };
/// Annotation for parent scope /// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } }; struct scope { handle value; scope(const handle &s) : value(s) { } };
@ -57,6 +60,10 @@ struct argument_record {
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
struct function_record { struct function_record {
function_record()
: is_constructor(false), is_stateless(false), is_operator(false),
has_args(false), has_kwargs(false) { }
/// Function name /// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
@ -87,6 +94,9 @@ struct function_record {
/// True if this is a stateless function pointer /// True if this is a stateless function pointer
bool is_stateless : 1; bool is_stateless : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if the function has a '*args' argument /// True if the function has a '*args' argument
bool has_args : 1; bool has_args : 1;
@ -198,6 +208,10 @@ template <> struct process_attribute<scope> : process_attribute_default<scope> {
static void init(const scope &s, function_record *r) { r->scope = s.value; } static void init(const scope &s, function_record *r) { r->scope = s.value; }
}; };
/// Process an attribute which indicates that this function is an operator
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
};
/// Process a keyword argument attribute (*without* a default value) /// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> { template <> struct process_attribute<arg> : process_attribute_default<arg> {

View File

@ -54,14 +54,14 @@ template <op_id id, op_type ot, typename L, typename R> struct op_ {
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type; typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type; typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
typedef op_impl<id, ot, Base, L_type, R_type> op; typedef op_impl<id, ot, Base, L_type, R_type> op;
cl.def(op::name(), &op::execute, extra...); cl.def(op::name(), &op::execute, is_operator(), extra...);
} }
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const { template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
typedef typename Class::type Base; typedef typename Class::type Base;
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type; typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type; typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
typedef op_impl<id, ot, Base, L_type, R_type> op; typedef op_impl<id, ot, Base, L_type, R_type> op;
cl.def(op::name(), &op::execute_cast, extra...); cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
} }
}; };

View File

@ -71,6 +71,11 @@ public:
object name() const { return attr("__name__"); } object name() const { return attr("__name__"); }
protected: protected:
/// Space optimization: don't inline this frequently instantiated fragment
PYBIND11_NOINLINE detail::function_record *make_function_record() {
return new detail::function_record();
}
/// Special internal constructor for functors, lambda functions, etc. /// Special internal constructor for functors, lambda functions, etc.
template <typename Func, typename Return, typename... Args, typename... Extra> template <typename Func, typename Return, typename... Args, typename... Extra>
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) { void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
@ -80,7 +85,7 @@ protected:
struct capture { typename std::remove_reference<Func>::type f; }; struct capture { typename std::remove_reference<Func>::type f; };
/* Store the function including any extra state it might have (e.g. a lambda capture object) */ /* Store the function including any extra state it might have (e.g. a lambda capture object) */
auto rec = new detail::function_record(); auto rec = make_function_record();
/* Store the capture object directly in the function record if there is enough space */ /* Store the capture object directly in the function record if there is enough space */
if (sizeof(capture) <= sizeof(rec->data)) { if (sizeof(capture) <= sizeof(rec->data)) {
@ -241,9 +246,6 @@ protected:
rec->signature = strdup(signature.c_str()); rec->signature = strdup(signature.c_str());
rec->args.shrink_to_fit(); rec->args.shrink_to_fit();
rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
rec->is_stateless = false;
rec->has_args = false;
rec->has_kwargs = false;
rec->nargs = (uint16_t) args; rec->nargs = (uint16_t) args;
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
@ -454,6 +456,9 @@ protected:
} }
if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
if (overloads->is_operator)
return handle(Py_NotImplemented).inc_ref().ptr();
std::string msg = "Incompatible " + std::string(overloads->is_constructor ? "constructor" : "function") + std::string msg = "Incompatible " + std::string(overloads->is_constructor ? "constructor" : "function") +
" arguments. The following argument types are supported:\n"; " arguments. The following argument types are supported:\n";
int ctr = 0; int ctr = 0;

View File

@ -20,6 +20,23 @@ struct NestA : NestABase { int value = 3; NestA& operator+=(int i) { value += i;
struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) }; struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) };
struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) }; struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) };
/// #393
class OpTest1 {};
class OpTest2 {};
OpTest1 operator+(const OpTest1 &, const OpTest1 &) {
py::print("Add OpTest1 with OpTest1");
return OpTest1();
}
OpTest2 operator+(const OpTest2 &, const OpTest2 &) {
py::print("Add OpTest2 with OpTest2");
return OpTest2();
}
OpTest2 operator+(const OpTest2 &, const OpTest1 &) {
py::print("Add OpTest2 with OpTest1");
return OpTest2();
}
void init_issues(py::module &m) { void init_issues(py::module &m) {
py::module m2 = m.def_submodule("issues"); py::module m2 = m.def_submodule("issues");
@ -230,6 +247,16 @@ void init_issues(py::module &m) {
.def("A_value", &OverrideTest::A_value) .def("A_value", &OverrideTest::A_value)
.def("A_ref", &OverrideTest::A_ref); .def("A_ref", &OverrideTest::A_ref);
/// Issue 393: need to return NotSupported to ensure correct arithmetic operator behavior
py::class_<OpTest1>(m2, "OpTest1")
.def(py::init<>())
.def(py::self + py::self);
py::class_<OpTest2>(m2, "OpTest2")
.def(py::init<>())
.def(py::self + py::self)
.def("__add__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; })
.def("__radd__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; });
} }
// MSVC workaround: trying to use a lambda here crashes MSCV // MSVC workaround: trying to use a lambda here crashes MSCV

View File

@ -181,3 +181,16 @@ def test_override_ref():
assert a.value == "hi" assert a.value == "hi"
a.value = "bye" a.value = "bye"
assert a.value == "bye" assert a.value == "bye"
def test_operators_notimplemented(capture):
from pybind11_tests.issues import OpTest1, OpTest2
with capture:
C1, C2 = OpTest1(), OpTest2()
C1 + C1
C2 + C2
C2 + C1
C1 + C2
assert capture == """Add OpTest1 with OpTest1
Add OpTest2 with OpTest2
Add OpTest2 with OpTest1
Add OpTest2 with OpTest1"""