diff --git a/docs/advanced/cast/overview.rst b/docs/advanced/cast/overview.rst index e9f43be3b..84085725c 100644 --- a/docs/advanced/cast/overview.rst +++ b/docs/advanced/cast/overview.rst @@ -77,7 +77,7 @@ as arguments and return values, refer to the section on binding :ref:`classes`. +------------------------------------+---------------------------+-------------------------------+ | Data type | Description | Header file | -+=---================================+===========================+===============================+ ++====================================+===========================+===============================+ | ``int8_t``, ``uint8_t`` | 8-bit integers | :file:`pybind11/pybind11.h` | +------------------------------------+---------------------------+-------------------------------+ | ``int16_t``, ``uint16_t`` | 16-bit integers | :file:`pybind11/pybind11.h` | diff --git a/docs/classes.rst b/docs/classes.rst index 300816d41..3e8f2ee97 100644 --- a/docs/classes.rst +++ b/docs/classes.rst @@ -393,4 +393,18 @@ typed enums. 1L +.. note:: + + When the special tag ``py::arithmetic()`` is specified to the ``enum_`` + constructor, pybind11 creates an enumeration that also supports rudimentary + arithmetic and bit-level operations like comparisons, and, or, xor, negation, + etc. + + .. code-block:: cpp + + py::enum_(pet, "Kind", py::arithmetic()) + ... + + By default, these are omitted to conserve space. + .. [#f1] Stateless closures are those with an empty pair of brackets ``[]`` as the capture object. diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index d728210e0..2e6dec104 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -47,6 +47,9 @@ struct multiple_inheritance { }; /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class struct dynamic_attr { }; +/// Annotation to mark enums as an arithmetic type +struct arithmetic { }; + NAMESPACE_BEGIN(detail) /* Forward declarations */ enum op_id : int; @@ -306,6 +309,11 @@ struct process_attribute : process_attribute_default static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } }; + +/// Process an 'arithmetic' attribute for enums (does nothing here) +template <> +struct process_attribute : process_attribute_default {}; + /*** * Process a keep_alive call policy -- invokes keep_alive_impl during the * pre-call handler if both Nurse, Patient != 0 and use the post-call handler diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 6804fa91a..0079052f7 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1187,51 +1187,62 @@ private: template class enum_ : public class_ { public: using class_::def; - using UnderlyingType = typename std::underlying_type::type; + using Scalar = typename std::underlying_type::type; + template using arithmetic_tag = std::is_same; + template enum_(const handle &scope, const char *name, const Extra&... extra) : class_(scope, name, extra...), m_parent(scope) { - auto entries = new std::unordered_map(); + + constexpr bool is_arithmetic = + !std::is_same, + void>::value; + + auto entries = new std::unordered_map(); def("__repr__", [name, entries](Type value) -> std::string { - auto it = entries->find((UnderlyingType) value); + auto it = entries->find((Scalar) value); return std::string(name) + "." + ((it == entries->end()) ? std::string("???") : std::string(it->second)); }); - def("__init__", [](Type& value, UnderlyingType i) { value = (Type)i; }); - def("__init__", [](Type& value, UnderlyingType i) { new (&value) Type((Type) i); }); - def("__int__", [](Type value) { return (UnderlyingType) value; }); + def("__init__", [](Type& value, Scalar i) { value = (Type)i; }); + def("__init__", [](Type& value, Scalar i) { new (&value) Type((Type) i); }); + def("__int__", [](Type value) { return (Scalar) value; }); def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; }); def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; }); - def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; }); - def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; }); - def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; }); - def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; }); - if (std::is_convertible::value) { + if (is_arithmetic) { + def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; }); + def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; }); + def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; }); + def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; }); + } + if (std::is_convertible::value) { // Don't provide comparison with the underlying type if the enum isn't convertible, // i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly - // convert Type to UnderlyingType below anyway because this needs to compile). - def("__eq__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value == value2; }); - def("__ne__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value != value2; }); - def("__lt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value < value2; }); - def("__gt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value > value2; }); - def("__le__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value <= value2; }); - def("__ge__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value >= value2; }); - def("__invert__", [](const Type &value) { return ~((UnderlyingType) value); }); - def("__and__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; }); - def("__or__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; }); - def("__xor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; }); - def("__rand__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; }); - def("__ror__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; }); - def("__rxor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; }); - def("__and__", [](const Type &value, const Type &value2) { return (UnderlyingType) value & (UnderlyingType) value2; }); - def("__or__", [](const Type &value, const Type &value2) { return (UnderlyingType) value | (UnderlyingType) value2; }); - def("__xor__", [](const Type &value, const Type &value2) { return (UnderlyingType) value ^ (UnderlyingType) value2; }); + // convert Type to Scalar below anyway because this needs to compile). + def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; }); + def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; }); + if (is_arithmetic) { + def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; }); + def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; }); + def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; }); + def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; }); + def("__invert__", [](const Type &value) { return ~((Scalar) value); }); + def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); + def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); + def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); + def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; }); + def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; }); + def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; }); + def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; }); + def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; }); + def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; }); + } } - def("__hash__", [](const Type &value) { return (UnderlyingType) value; }); + def("__hash__", [](const Type &value) { return (Scalar) value; }); // Pickling and unpickling -- needed for use with the 'multiprocessing' module - def("__getstate__", [](const Type &value) { return pybind11::make_tuple((UnderlyingType) value); }); - def("__setstate__", [](Type &p, tuple t) { new (&p) Type((Type) t[0].cast()); }); + def("__getstate__", [](const Type &value) { return pybind11::make_tuple((Scalar) value); }); + def("__setstate__", [](Type &p, tuple t) { new (&p) Type((Type) t[0].cast()); }); m_entries = entries; } @@ -1249,11 +1260,11 @@ public: /// Add an enumeration entry enum_& value(char const* name, Type value) { this->attr(name) = pybind11::cast(value, return_value_policy::copy); - (*m_entries)[(UnderlyingType) value] = name; + (*m_entries)[(Scalar) value] = name; return *this; } private: - std::unordered_map *m_entries; + std::unordered_map *m_entries; handle m_parent; }; diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 70a694a09..09f334cdb 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -44,22 +44,20 @@ std::string test_scoped_enum(ScopedEnum z) { test_initializer enums([](py::module &m) { m.def("test_scoped_enum", &test_scoped_enum); - py::enum_(m, "UnscopedEnum") + py::enum_(m, "UnscopedEnum", py::arithmetic()) .value("EOne", EOne) .value("ETwo", ETwo) .export_values(); - py::enum_(m, "ScopedEnum") + py::enum_(m, "ScopedEnum", py::arithmetic()) .value("Two", ScopedEnum::Two) - .value("Three", ScopedEnum::Three) - ; + .value("Three", ScopedEnum::Three); - py::enum_(m, "Flags") + py::enum_(m, "Flags", py::arithmetic()) .value("Read", Flags::Read) .value("Write", Flags::Write) .value("Execute", Flags::Execute) .export_values(); - ; py::class_ exenum_class(m, "ClassWithUnscopedEnum"); exenum_class.def_static("test_function", &ClassWithUnscopedEnum::test_function);