From 90d27805b9d1be3f1a8ad4d55ed3b06bc7b3b976 Mon Sep 17 00:00:00 2001
From: Pim Schellart
Date: Wed, 16 Nov 2016 11:28:11 -0500
Subject: [PATCH] Extended enum support (#503)
* Allow enums to be ordered
* Support binary operators
---
include/pybind11/pybind11.h | 18 ++++++++++++++
tests/test_enum.cpp | 13 ++++++++++
tests/test_enum.py | 48 +++++++++++++++++++++++++++++++++++++
3 files changed, 79 insertions(+)
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 4c5c830a9..ac33ab285 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1205,12 +1205,30 @@ public:
def("__int__", [](Type value) { return (UnderlyingType) value; });
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
+ def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
+ def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
+ def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
+ def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
if (std::is_convertible::value) {
// Don't provide comparison with the underlying type if the enum isn't convertible,
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
// convert Type to UnderlyingType below anyway because this needs to compile).
def("__eq__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value == value2; });
def("__ne__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value != value2; });
+ def("__lt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value < value2; });
+ def("__gt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value > value2; });
+ def("__le__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value <= value2; });
+ def("__ge__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value >= value2; });
+ def("__invert__", [](const Type &value) { return ~((UnderlyingType) value); });
+ def("__and__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; });
+ def("__or__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; });
+ def("__xor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; });
+ def("__rand__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; });
+ def("__ror__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; });
+ def("__rxor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; });
+ def("__and__", [](const Type &value, const Type &value2) { return (UnderlyingType) value & (UnderlyingType) value2; });
+ def("__or__", [](const Type &value, const Type &value2) { return (UnderlyingType) value | (UnderlyingType) value2; });
+ def("__xor__", [](const Type &value, const Type &value2) { return (UnderlyingType) value ^ (UnderlyingType) value2; });
}
def("__hash__", [](const Type &value) { return (UnderlyingType) value; });
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp
index 87cb7d0d4..70a694a09 100644
--- a/tests/test_enum.cpp
+++ b/tests/test_enum.cpp
@@ -19,6 +19,12 @@ enum class ScopedEnum {
Three
};
+enum Flags {
+ Read = 4,
+ Write = 2,
+ Execute = 1
+};
+
class ClassWithUnscopedEnum {
public:
enum EMode {
@@ -48,6 +54,13 @@ test_initializer enums([](py::module &m) {
.value("Three", ScopedEnum::Three)
;
+ py::enum_(m, "Flags")
+ .value("Read", Flags::Read)
+ .value("Write", Flags::Write)
+ .value("Execute", Flags::Execute)
+ .export_values();
+ ;
+
py::class_ exenum_class(m, "ClassWithUnscopedEnum");
exenum_class.def_static("test_function", &ClassWithUnscopedEnum::test_function);
py::enum_(exenum_class, "EMode")
diff --git a/tests/test_enum.py b/tests/test_enum.py
index efabae7f7..5dfdbbeb4 100644
--- a/tests/test_enum.py
+++ b/tests/test_enum.py
@@ -16,6 +16,24 @@ def test_unscoped_enum():
assert int(UnscopedEnum.ETwo) == 2
assert str(UnscopedEnum(2)) == "UnscopedEnum.ETwo"
+ # order
+ assert UnscopedEnum.EOne < UnscopedEnum.ETwo
+ assert UnscopedEnum.EOne < 2
+ assert UnscopedEnum.ETwo > UnscopedEnum.EOne
+ assert UnscopedEnum.ETwo > 1
+ assert UnscopedEnum.ETwo <= 2
+ assert UnscopedEnum.ETwo >= 2
+ assert UnscopedEnum.EOne <= UnscopedEnum.ETwo
+ assert UnscopedEnum.EOne <= 2
+ assert UnscopedEnum.ETwo >= UnscopedEnum.EOne
+ assert UnscopedEnum.ETwo >= 1
+ assert not (UnscopedEnum.ETwo < UnscopedEnum.EOne)
+ assert not (2 < UnscopedEnum.EOne)
+
+def test_scoped_enum():
+ from pybind11_tests import ScopedEnum, test_scoped_enum
+
+ assert test_scoped_enum(ScopedEnum.Three) == "ScopedEnum::Three"
def test_scoped_enum():
from pybind11_tests import ScopedEnum, test_scoped_enum
@@ -30,6 +48,13 @@ def test_scoped_enum():
with pytest.raises(TypeError):
assert z != 3
+ # order
+ assert ScopedEnum.Two < ScopedEnum.Three
+ assert ScopedEnum.Three > ScopedEnum.Two
+ assert ScopedEnum.Two <= ScopedEnum.Three
+ assert ScopedEnum.Two <= ScopedEnum.Two
+ assert ScopedEnum.Two >= ScopedEnum.Two
+ assert ScopedEnum.Three >= ScopedEnum.Two
def test_implicit_conversion():
from pybind11_tests import ClassWithUnscopedEnum
@@ -61,3 +86,26 @@ def test_implicit_conversion():
x[f(second)] = 4
# Hashing test
assert str(x) == "{EMode.EFirstMode: 3, EMode.ESecondMode: 4}"
+
+def test_binary_operators():
+ from pybind11_tests import Flags
+
+ assert int(Flags.Read) == 4
+ assert int(Flags.Write) == 2
+ assert int(Flags.Execute) == 1
+ assert int(Flags.Read | Flags.Write | Flags.Execute) == 7
+ assert int(Flags.Read | Flags.Write) == 6
+ assert int(Flags.Read | Flags.Execute) == 5
+ assert int(Flags.Write | Flags.Execute) == 3
+ assert int(Flags.Write | 1) == 3
+
+ state = Flags.Read | Flags.Write
+ assert (state & Flags.Read) != 0
+ assert (state & Flags.Write) != 0
+ assert (state & Flags.Execute) == 0
+ assert (state & 1) == 0
+
+ state2 = ~state
+ assert state2 == -7
+ assert int(state ^ state2) == -1
+