mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
operators should return NotImplemented given unsupported input (fixes #393)
This commit is contained in:
parent
8d38ebed91
commit
382484ae56
@ -90,10 +90,13 @@ is really just short hand notation for
|
||||
|
||||
.def("__mul__", [](const Vector2 &a, float b) {
|
||||
return a * b;
|
||||
})
|
||||
}, py::is_operator())
|
||||
|
||||
This can be useful for exposing additional operators that don't exist on the
|
||||
C++ side, or to perform other types of customization.
|
||||
C++ side, or to perform other types of customization. The ``py::is_operator``
|
||||
flag marker is needed to inform pybind11 that this is an operator, which
|
||||
returns ``NotImplemented`` when invoked with incompatible arguments rather than
|
||||
throwing a type error.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -17,6 +17,9 @@ NAMESPACE_BEGIN(pybind11)
|
||||
/// Annotation for methods
|
||||
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
|
||||
|
||||
/// Annotation for operators
|
||||
struct is_operator { };
|
||||
|
||||
/// Annotation for parent scope
|
||||
struct scope { handle value; scope(const handle &s) : value(s) { } };
|
||||
|
||||
@ -57,6 +60,10 @@ struct argument_record {
|
||||
|
||||
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
|
||||
struct function_record {
|
||||
function_record()
|
||||
: is_constructor(false), is_stateless(false), is_operator(false),
|
||||
has_args(false), has_kwargs(false) { }
|
||||
|
||||
/// Function name
|
||||
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
|
||||
|
||||
@ -87,6 +94,9 @@ struct function_record {
|
||||
/// True if this is a stateless function pointer
|
||||
bool is_stateless : 1;
|
||||
|
||||
/// True if this is an operator (__add__), etc.
|
||||
bool is_operator : 1;
|
||||
|
||||
/// True if the function has a '*args' argument
|
||||
bool has_args : 1;
|
||||
|
||||
@ -198,6 +208,10 @@ template <> struct process_attribute<scope> : process_attribute_default<scope> {
|
||||
static void init(const scope &s, function_record *r) { r->scope = s.value; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this function is an operator
|
||||
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
|
||||
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
|
||||
};
|
||||
|
||||
/// Process a keyword argument attribute (*without* a default value)
|
||||
template <> struct process_attribute<arg> : process_attribute_default<arg> {
|
||||
|
@ -54,14 +54,14 @@ template <op_id id, op_type ot, typename L, typename R> struct op_ {
|
||||
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
|
||||
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
|
||||
typedef op_impl<id, ot, Base, L_type, R_type> op;
|
||||
cl.def(op::name(), &op::execute, extra...);
|
||||
cl.def(op::name(), &op::execute, is_operator(), extra...);
|
||||
}
|
||||
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
|
||||
typedef typename Class::type Base;
|
||||
typedef typename std::conditional<std::is_same<L, self_t>::value, Base, L>::type L_type;
|
||||
typedef typename std::conditional<std::is_same<R, self_t>::value, Base, R>::type R_type;
|
||||
typedef op_impl<id, ot, Base, L_type, R_type> op;
|
||||
cl.def(op::name(), &op::execute_cast, extra...);
|
||||
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -71,6 +71,11 @@ public:
|
||||
object name() const { return attr("__name__"); }
|
||||
|
||||
protected:
|
||||
/// Space optimization: don't inline this frequently instantiated fragment
|
||||
PYBIND11_NOINLINE detail::function_record *make_function_record() {
|
||||
return new detail::function_record();
|
||||
}
|
||||
|
||||
/// Special internal constructor for functors, lambda functions, etc.
|
||||
template <typename Func, typename Return, typename... Args, typename... Extra>
|
||||
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
|
||||
@ -80,7 +85,7 @@ protected:
|
||||
struct capture { typename std::remove_reference<Func>::type f; };
|
||||
|
||||
/* Store the function including any extra state it might have (e.g. a lambda capture object) */
|
||||
auto rec = new detail::function_record();
|
||||
auto rec = make_function_record();
|
||||
|
||||
/* Store the capture object directly in the function record if there is enough space */
|
||||
if (sizeof(capture) <= sizeof(rec->data)) {
|
||||
@ -241,9 +246,6 @@ protected:
|
||||
rec->signature = strdup(signature.c_str());
|
||||
rec->args.shrink_to_fit();
|
||||
rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
|
||||
rec->is_stateless = false;
|
||||
rec->has_args = false;
|
||||
rec->has_kwargs = false;
|
||||
rec->nargs = (uint16_t) args;
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
@ -454,6 +456,9 @@ protected:
|
||||
}
|
||||
|
||||
if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
|
||||
if (overloads->is_operator)
|
||||
return handle(Py_NotImplemented).inc_ref().ptr();
|
||||
|
||||
std::string msg = "Incompatible " + std::string(overloads->is_constructor ? "constructor" : "function") +
|
||||
" arguments. The following argument types are supported:\n";
|
||||
int ctr = 0;
|
||||
|
@ -20,6 +20,23 @@ struct NestA : NestABase { int value = 3; NestA& operator+=(int i) { value += i;
|
||||
struct NestB { NestA a; int value = 4; NestB& operator-=(int i) { value -= i; return *this; } TRACKERS(NestB) };
|
||||
struct NestC { NestB b; int value = 5; NestC& operator*=(int i) { value *= i; return *this; } TRACKERS(NestC) };
|
||||
|
||||
/// #393
|
||||
class OpTest1 {};
|
||||
class OpTest2 {};
|
||||
|
||||
OpTest1 operator+(const OpTest1 &, const OpTest1 &) {
|
||||
py::print("Add OpTest1 with OpTest1");
|
||||
return OpTest1();
|
||||
}
|
||||
OpTest2 operator+(const OpTest2 &, const OpTest2 &) {
|
||||
py::print("Add OpTest2 with OpTest2");
|
||||
return OpTest2();
|
||||
}
|
||||
OpTest2 operator+(const OpTest2 &, const OpTest1 &) {
|
||||
py::print("Add OpTest2 with OpTest1");
|
||||
return OpTest2();
|
||||
}
|
||||
|
||||
void init_issues(py::module &m) {
|
||||
py::module m2 = m.def_submodule("issues");
|
||||
|
||||
@ -230,6 +247,16 @@ void init_issues(py::module &m) {
|
||||
.def("A_value", &OverrideTest::A_value)
|
||||
.def("A_ref", &OverrideTest::A_ref);
|
||||
|
||||
/// Issue 393: need to return NotSupported to ensure correct arithmetic operator behavior
|
||||
py::class_<OpTest1>(m2, "OpTest1")
|
||||
.def(py::init<>())
|
||||
.def(py::self + py::self);
|
||||
|
||||
py::class_<OpTest2>(m2, "OpTest2")
|
||||
.def(py::init<>())
|
||||
.def(py::self + py::self)
|
||||
.def("__add__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; })
|
||||
.def("__radd__", [](const OpTest2& c2, const OpTest1& c1) { return c2 + c1; });
|
||||
}
|
||||
|
||||
// MSVC workaround: trying to use a lambda here crashes MSCV
|
||||
|
@ -181,3 +181,16 @@ def test_override_ref():
|
||||
assert a.value == "hi"
|
||||
a.value = "bye"
|
||||
assert a.value == "bye"
|
||||
|
||||
def test_operators_notimplemented(capture):
|
||||
from pybind11_tests.issues import OpTest1, OpTest2
|
||||
with capture:
|
||||
C1, C2 = OpTest1(), OpTest2()
|
||||
C1 + C1
|
||||
C2 + C2
|
||||
C2 + C1
|
||||
C1 + C2
|
||||
assert capture == """Add OpTest1 with OpTest1
|
||||
Add OpTest2 with OpTest2
|
||||
Add OpTest2 with OpTest1
|
||||
Add OpTest2 with OpTest1"""
|
||||
|
Loading…
Reference in New Issue
Block a user