diff --git a/docs/advanced.rst b/docs/advanced.rst index 748f91e2e..ff85a4fc2 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -90,10 +90,13 @@ is really just short hand notation for .def("__mul__", [](const Vector2 &a, float b) { return a * b; - }) + }, py::is_operator()) This can be useful for exposing additional operators that don't exist on the -C++ side, or to perform other types of customization. +C++ side, or to perform other types of customization. The ``py::is_operator`` +flag marker is needed to inform pybind11 that this is an operator, which +returns ``NotImplemented`` when invoked with incompatible arguments rather than +throwing a type error. .. note:: diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index 9acb3e3aa..e3434b145 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -17,6 +17,9 @@ NAMESPACE_BEGIN(pybind11) /// Annotation for methods struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; +/// Annotation for operators +struct is_operator { }; + /// Annotation for parent scope struct scope { handle value; scope(const handle &s) : value(s) { } }; @@ -57,6 +60,10 @@ struct argument_record { /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) struct function_record { + function_record() + : is_constructor(false), is_stateless(false), is_operator(false), + has_args(false), has_kwargs(false) { } + /// Function name char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ @@ -87,6 +94,9 @@ struct function_record { /// True if this is a stateless function pointer bool is_stateless : 1; + /// True if this is an operator (__add__), etc. + bool is_operator : 1; + /// True if the function has a '*args' argument bool has_args : 1; @@ -198,6 +208,10 @@ template <> struct process_attribute : process_attribute_default { static void init(const scope &s, function_record *r) { r->scope = s.value; } }; +/// Process an attribute which indicates that this function is an operator +template <> struct process_attribute : process_attribute_default { + static void init(const is_operator &, function_record *r) { r->is_operator = true; } +}; /// Process a keyword argument attribute (*without* a default value) template <> struct process_attribute : process_attribute_default { diff --git a/include/pybind11/operators.h b/include/pybind11/operators.h index eda51a178..22d1859d0 100644 --- a/include/pybind11/operators.h +++ b/include/pybind11/operators.h @@ -54,14 +54,14 @@ template struct op_ { typedef typename std::conditional::value, Base, L>::type L_type; typedef typename std::conditional::value, Base, R>::type R_type; typedef op_impl op; - cl.def(op::name(), &op::execute, extra...); + cl.def(op::name(), &op::execute, is_operator(), extra...); } template void execute_cast(Class &cl, const Extra&... extra) const { typedef typename Class::type Base; typedef typename std::conditional::value, Base, L>::type L_type; typedef typename std::conditional::value, Base, R>::type R_type; typedef op_impl op; - cl.def(op::name(), &op::execute_cast, extra...); + cl.def(op::name(), &op::execute_cast, is_operator(), extra...); } }; diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 8f88d36c9..ea7acb4d5 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -71,6 +71,11 @@ public: object name() const { return attr("__name__"); } protected: + /// Space optimization: don't inline this frequently instantiated fragment + PYBIND11_NOINLINE detail::function_record *make_function_record() { + return new detail::function_record(); + } + /// Special internal constructor for functors, lambda functions, etc. template void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) { @@ -80,7 +85,7 @@ protected: struct capture { typename std::remove_reference::type f; }; /* Store the function including any extra state it might have (e.g. a lambda capture object) */ - auto rec = new detail::function_record(); + auto rec = make_function_record(); /* Store the capture object directly in the function record if there is enough space */ if (sizeof(capture) <= sizeof(rec->data)) { @@ -241,9 +246,6 @@ protected: rec->signature = strdup(signature.c_str()); rec->args.shrink_to_fit(); rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); - rec->is_stateless = false; - rec->has_args = false; - rec->has_kwargs = false; rec->nargs = (uint16_t) args; #if PY_MAJOR_VERSION < 3 @@ -454,6 +456,9 @@ protected: } if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + if (overloads->is_operator) + return handle(Py_NotImplemented).inc_ref().ptr(); + std::string msg = "Incompatible " + std::string(overloads->is_constructor ? "constructor" : "function") + " arguments. The following argument types are supported:\n"; int ctr = 0; diff --git a/tests/test_issues.cpp b/tests/test_issues.cpp index c5314bcdc..843978eff 100644 --- a/tests/test_issues.cpp +++ b/tests/test_issues.cpp @@ -20,6 +20,23 @@ struct NestA : NestABase { int value = 3; NestA& operator+=(int i) { value += i; struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) }; struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) }; +/// #393 +class OpTest1 {}; +class OpTest2 {}; + +OpTest1 operator+(const OpTest1 &, const OpTest1 &) { + py::print("Add OpTest1 with OpTest1"); + return OpTest1(); +} +OpTest2 operator+(const OpTest2 &, const OpTest2 &) { + py::print("Add OpTest2 with OpTest2"); + return OpTest2(); +} +OpTest2 operator+(const OpTest2 &, const OpTest1 &) { + py::print("Add OpTest2 with OpTest1"); + return OpTest2(); +} + void init_issues(py::module &m) { py::module m2 = m.def_submodule("issues"); @@ -230,6 +247,16 @@ void init_issues(py::module &m) { .def("A_value", &OverrideTest::A_value) .def("A_ref", &OverrideTest::A_ref); + /// Issue 393: need to return NotSupported to ensure correct arithmetic operator behavior + py::class_(m2, "OpTest1") + .def(py::init<>()) + .def(py::self + py::self); + + py::class_(m2, "OpTest2") + .def(py::init<>()) + .def(py::self + py::self) + .def("__add__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; }) + .def("__radd__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; }); } // MSVC workaround: trying to use a lambda here crashes MSCV diff --git a/tests/test_issues.py b/tests/test_issues.py index 2af6f1ce1..a28e50902 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -181,3 +181,16 @@ def test_override_ref(): assert a.value == "hi" a.value = "bye" assert a.value == "bye" + +def test_operators_notimplemented(capture): + from pybind11_tests.issues import OpTest1, OpTest2 + with capture: + C1, C2 = OpTest1(), OpTest2() + C1 + C1 + C2 + C2 + C2 + C1 + C1 + C2 + assert capture == """Add OpTest1 with OpTest1 +Add OpTest2 with OpTest2 +Add OpTest2 with OpTest1 +Add OpTest2 with OpTest1"""