diff --git a/docs/changelog.rst b/docs/changelog.rst index 9162048ac..eb25578c3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,11 @@ v2.3.0 (Not yet released) * Added support for write only properties. `#1144 `_. +* Python type wrappers (``py::handle``, ``py::object``, etc.) + now support map Python's number protocol onto C++ arithmetic + operators such as ``operator+``, ``operator/=``, etc. + `#1511 `_. + * A number of improvements related to enumerations: 1. The ``enum_`` implementation was rewritten from scratch to reduce diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 8e40c9f8e..26a033473 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1360,6 +1360,146 @@ detail::initimpl::pickle_factory pickle(GetState &&g, SetSta return {std::forward(g), std::forward(s)}; } +NAMESPACE_BEGIN(detail) +struct enum_base { + enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { } + + PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) { + m_base.attr("__entries") = dict(); + auto property = handle((PyObject *) &PyProperty_Type); + auto static_property = handle((PyObject *) get_internals().static_property_type); + + m_base.attr("__repr__") = cpp_function( + [](handle arg) -> str { + handle type = arg.get_type(); + object type_name = type.attr("__name__"); + dict entries = type.attr("__entries"); + for (const auto &kv : entries) { + object other = kv.second[int_(0)]; + if (other.equal(arg)) + return pybind11::str("{}.{}").format(type_name, kv.first); + } + return pybind11::str("{}.???").format(type_name); + }, is_method(m_base) + ); + + m_base.attr("name") = property(cpp_function( + [](handle arg) -> str { + dict entries = arg.get_type().attr("__entries"); + for (const auto &kv : entries) { + if (handle(kv.second[int_(0)]).equal(arg)) + return pybind11::str(kv.first); + } + return "???"; + }, is_method(m_base) + )); + + m_base.attr("__doc__") = static_property(cpp_function( + [](handle arg) -> std::string { + std::string docstring; + dict entries = arg.attr("__entries"); + if (((PyTypeObject *) arg.ptr())->tp_doc) + docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n"; + docstring += "Members:"; + for (const auto &kv : entries) { + auto key = std::string(pybind11::str(kv.first)); + auto comment = kv.second[int_(1)]; + docstring += "\n\n " + key; + if (!comment.is_none()) + docstring += " : " + (std::string) pybind11::str(comment); + } + return docstring; + } + ), none(), none(), ""); + + m_base.attr("__members__") = static_property(cpp_function( + [](handle arg) -> dict { + dict entries = arg.attr("__entries"), m; + for (const auto &kv : entries) + m[kv.first] = kv.second[int_(0)]; + return m; + }), none(), none(), "" + ); + + #define PYBIND11_ENUM_OP_STRICT(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](object a, object b) { \ + if (!a.get_type().is(b.get_type())) \ + throw type_error("Expected an enumeration of matching type!"); \ + return expr; \ + }, \ + is_method(m_base)) + + #define PYBIND11_ENUM_OP_CONV(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](object a_, object b_) { \ + int_ a(a_), b(b_); \ + return expr; \ + }, \ + is_method(m_base)) + + if (is_convertible) { + PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b)); + PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b)); + + if (is_arithmetic) { + PYBIND11_ENUM_OP_CONV("__lt__", a < b); + PYBIND11_ENUM_OP_CONV("__gt__", a > b); + PYBIND11_ENUM_OP_CONV("__le__", a <= b); + PYBIND11_ENUM_OP_CONV("__ge__", a >= b); + PYBIND11_ENUM_OP_CONV("__and__", a & b); + PYBIND11_ENUM_OP_CONV("__rand__", a & b); + PYBIND11_ENUM_OP_CONV("__or__", a | b); + PYBIND11_ENUM_OP_CONV("__ror__", a | b); + PYBIND11_ENUM_OP_CONV("__xor__", a ^ b); + PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); + } + } else { + PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b))); + PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b))); + + if (is_arithmetic) { + PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b)); + PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b)); + PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b)); + PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b)); + } + } + + #undef PYBIND11_ENUM_OP_CONV + #undef PYBIND11_ENUM_OP_STRICT + + object getstate = cpp_function( + [](object arg) { return int_(arg); }, is_method(m_base)); + + m_base.attr("__getstate__") = getstate; + m_base.attr("__hash__") = getstate; + } + + PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) { + dict entries = m_base.attr("__entries"); + str name(name_); + if (entries.contains(name)) { + std::string type_name = (std::string) str(m_base.attr("__name__")); + throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!"); + } + + entries[name] = std::make_pair(value, doc); + m_base.attr(name) = value; + } + + PYBIND11_NOINLINE void export_values() { + dict entries = m_base.attr("__entries"); + for (const auto &kv : entries) + m_parent.attr(kv.first) = kv.second[int_(0)]; + } + + handle m_base; + handle m_parent; +}; + +NAMESPACE_END(detail) + /// Binds C++ enumerations and enumeration classes to Python template class enum_ : public class_ { public: @@ -1370,109 +1510,33 @@ public: template enum_(const handle &scope, const char *name, const Extra&... extra) - : class_(scope, name, extra...), m_entries(), m_parent(scope) { - + : class_(scope, name, extra...), m_base(*this, scope) { constexpr bool is_arithmetic = detail::any_of...>::value; + constexpr bool is_convertible = std::is_convertible::value; + m_base.init(is_arithmetic, is_convertible); - auto m_entries_ptr = m_entries.inc_ref().ptr(); - def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str { - for (const auto &kv : reinterpret_borrow(m_entries_ptr)) { - if (pybind11::cast(kv.second[int_(0)]) == value) - return pybind11::str("{}.{}").format(name, kv.first); - } - return pybind11::str("{}.???").format(name); - }); - def_property_readonly("name", [m_entries_ptr](Type value) -> pybind11::str { - for (const auto &kv : reinterpret_borrow(m_entries_ptr)) { - if (pybind11::cast(kv.second[int_(0)]) == value) - return pybind11::str(kv.first); - } - return pybind11::str("???"); - }); - def_property_readonly_static("__doc__", [m_entries_ptr](handle self_) { - std::string docstring; - const char *tp_doc = ((PyTypeObject *) self_.ptr())->tp_doc; - if (tp_doc) - docstring += std::string(tp_doc) + "\n\n"; - docstring += "Members:"; - for (const auto &kv : reinterpret_borrow(m_entries_ptr)) { - auto key = std::string(pybind11::str(kv.first)); - auto comment = kv.second[int_(1)]; - docstring += "\n\n " + key; - if (!comment.is_none()) - docstring += " : " + (std::string) pybind11::str(comment); - } - return docstring; - }); - def_property_readonly_static("__members__", [m_entries_ptr](handle /* self_ */) { - dict m; - for (const auto &kv : reinterpret_borrow(m_entries_ptr)) - m[kv.first] = kv.second[int_(0)]; - return m; - }, return_value_policy::copy); def(init([](Scalar i) { return static_cast(i); })); def("__int__", [](Type value) { return (Scalar) value; }); #if PY_MAJOR_VERSION < 3 def("__long__", [](Type value) { return (Scalar) value; }); #endif - def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; }); - def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; }); - if (is_arithmetic) { - def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; }); - def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; }); - def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; }); - def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; }); - } - if (std::is_convertible::value) { - // Don't provide comparison with the underlying type if the enum isn't convertible, - // i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly - // convert Type to Scalar below anyway because this needs to compile). - def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; }); - def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; }); - if (is_arithmetic) { - def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; }); - def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; }); - def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; }); - def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; }); - def("__invert__", [](const Type &value) { return ~((Scalar) value); }); - def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); - def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); - def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); - def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); - def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); - def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); - def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; }); - def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; }); - def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; }); - } - } - def("__hash__", [](const Type &value) { return (Scalar) value; }); - // Pickling and unpickling -- needed for use with the 'multiprocessing' module - def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); }, - [](tuple t) { return static_cast(t[0].cast()); })); + def("__setstate__", [](Type &value, Scalar arg) { value = static_cast(arg); }); } /// Export enumeration entries into the parent scope enum_& export_values() { - for (const auto &kv : m_entries) - m_parent.attr(kv.first) = kv.second[int_(0)]; + m_base.export_values(); return *this; } /// Add an enumeration entry enum_& value(char const* name, Type value, const char *doc = nullptr) { - auto v = pybind11::cast(value, return_value_policy::copy); - this->attr(name) = v; - auto name_converted = pybind11::str(name); - if (m_entries.contains(name_converted)) - throw value_error("Enum error - element with name: " + std::string(name) + " already exists"); - m_entries[name_converted] = std::make_pair(v, doc); + m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc); return *this; } private: - dict m_entries; - handle m_parent; + detail::enum_base m_base; }; NAMESPACE_BEGIN(detail) diff --git a/tests/test_enum.py b/tests/test_enum.py index a031d9570..b1a5089ea 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -153,4 +153,4 @@ def test_enum_to_int(): def test_duplicate_enum_name(): with pytest.raises(ValueError) as excinfo: m.register_bad_enum() - assert str(excinfo.value) == "Enum error - element with name: ONE already exists" + assert str(excinfo.value) == 'SimpleEnum: element "ONE" already exists!' diff --git a/tests/test_pickling.py b/tests/test_pickling.py index 707d34786..5ae05aaa0 100644 --- a/tests/test_pickling.py +++ b/tests/test_pickling.py @@ -34,3 +34,9 @@ def test_roundtrip_with_dict(cls_name): assert p2.value == p.value assert p2.extra == p.extra assert p2.dynamic == p.dynamic + + +def test_enum_pickle(): + from pybind11_tests import enums as e + data = pickle.dumps(e.EOne, 2) + assert e.EOne == pickle.loads(data)