From 90d27805b9d1be3f1a8ad4d55ed3b06bc7b3b976 Mon Sep 17 00:00:00 2001 From: Pim Schellart Date: Wed, 16 Nov 2016 11:28:11 -0500 Subject: [PATCH] Extended enum support (#503) * Allow enums to be ordered * Support binary operators --- include/pybind11/pybind11.h | 18 ++++++++++++++ tests/test_enum.cpp | 13 ++++++++++ tests/test_enum.py | 48 +++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 4c5c830a9..ac33ab285 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1205,12 +1205,30 @@ public: def("__int__", [](Type value) { return (UnderlyingType) value; }); def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; }); def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; }); + def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; }); + def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; }); + def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; }); + def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; }); if (std::is_convertible::value) { // Don't provide comparison with the underlying type if the enum isn't convertible, // i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly // convert Type to UnderlyingType below anyway because this needs to compile). def("__eq__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value == value2; }); def("__ne__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value != value2; }); + def("__lt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value < value2; }); + def("__gt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value > value2; }); + def("__le__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value <= value2; }); + def("__ge__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value >= value2; }); + def("__invert__", [](const Type &value) { return ~((UnderlyingType) value); }); + def("__and__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; }); + def("__or__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; }); + def("__xor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; }); + def("__rand__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; }); + def("__ror__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; }); + def("__rxor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; }); + def("__and__", [](const Type &value, const Type &value2) { return (UnderlyingType) value & (UnderlyingType) value2; }); + def("__or__", [](const Type &value, const Type &value2) { return (UnderlyingType) value | (UnderlyingType) value2; }); + def("__xor__", [](const Type &value, const Type &value2) { return (UnderlyingType) value ^ (UnderlyingType) value2; }); } def("__hash__", [](const Type &value) { return (UnderlyingType) value; }); // Pickling and unpickling -- needed for use with the 'multiprocessing' module diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 87cb7d0d4..70a694a09 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -19,6 +19,12 @@ enum class ScopedEnum { Three }; +enum Flags { + Read = 4, + Write = 2, + Execute = 1 +}; + class ClassWithUnscopedEnum { public: enum EMode { @@ -48,6 +54,13 @@ test_initializer enums([](py::module &m) { .value("Three", ScopedEnum::Three) ; + py::enum_(m, "Flags") + .value("Read", Flags::Read) + .value("Write", Flags::Write) + .value("Execute", Flags::Execute) + .export_values(); + ; + py::class_ exenum_class(m, "ClassWithUnscopedEnum"); exenum_class.def_static("test_function", &ClassWithUnscopedEnum::test_function); py::enum_(exenum_class, "EMode") diff --git a/tests/test_enum.py b/tests/test_enum.py index efabae7f7..5dfdbbeb4 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -16,6 +16,24 @@ def test_unscoped_enum(): assert int(UnscopedEnum.ETwo) == 2 assert str(UnscopedEnum(2)) == "UnscopedEnum.ETwo" + # order + assert UnscopedEnum.EOne < UnscopedEnum.ETwo + assert UnscopedEnum.EOne < 2 + assert UnscopedEnum.ETwo > UnscopedEnum.EOne + assert UnscopedEnum.ETwo > 1 + assert UnscopedEnum.ETwo <= 2 + assert UnscopedEnum.ETwo >= 2 + assert UnscopedEnum.EOne <= UnscopedEnum.ETwo + assert UnscopedEnum.EOne <= 2 + assert UnscopedEnum.ETwo >= UnscopedEnum.EOne + assert UnscopedEnum.ETwo >= 1 + assert not (UnscopedEnum.ETwo < UnscopedEnum.EOne) + assert not (2 < UnscopedEnum.EOne) + +def test_scoped_enum(): + from pybind11_tests import ScopedEnum, test_scoped_enum + + assert test_scoped_enum(ScopedEnum.Three) == "ScopedEnum::Three" def test_scoped_enum(): from pybind11_tests import ScopedEnum, test_scoped_enum @@ -30,6 +48,13 @@ def test_scoped_enum(): with pytest.raises(TypeError): assert z != 3 + # order + assert ScopedEnum.Two < ScopedEnum.Three + assert ScopedEnum.Three > ScopedEnum.Two + assert ScopedEnum.Two <= ScopedEnum.Three + assert ScopedEnum.Two <= ScopedEnum.Two + assert ScopedEnum.Two >= ScopedEnum.Two + assert ScopedEnum.Three >= ScopedEnum.Two def test_implicit_conversion(): from pybind11_tests import ClassWithUnscopedEnum @@ -61,3 +86,26 @@ def test_implicit_conversion(): x[f(second)] = 4 # Hashing test assert str(x) == "{EMode.EFirstMode: 3, EMode.ESecondMode: 4}" + +def test_binary_operators(): + from pybind11_tests import Flags + + assert int(Flags.Read) == 4 + assert int(Flags.Write) == 2 + assert int(Flags.Execute) == 1 + assert int(Flags.Read | Flags.Write | Flags.Execute) == 7 + assert int(Flags.Read | Flags.Write) == 6 + assert int(Flags.Read | Flags.Execute) == 5 + assert int(Flags.Write | Flags.Execute) == 3 + assert int(Flags.Write | 1) == 3 + + state = Flags.Read | Flags.Write + assert (state & Flags.Read) != 0 + assert (state & Flags.Write) != 0 + assert (state & Flags.Execute) == 0 + assert (state & 1) == 0 + + state2 = ~state + assert state2 == -7 + assert int(state ^ state2) == -1 +