Extended enum support (#503)

* Allow enums to be ordered
* Support binary operators
This commit is contained in:
Pim Schellart 2016-11-16 11:28:11 -05:00 committed by Wenzel Jakob
parent 2e76daa53f
commit 90d27805b9
3 changed files with 79 additions and 0 deletions

View File

@ -1205,12 +1205,30 @@ public:
def("__int__", [](Type value) { return (UnderlyingType) value; });
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
if (std::is_convertible<Type, UnderlyingType>::value) {
// Don't provide comparison with the underlying type if the enum isn't convertible,
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
// convert Type to UnderlyingType below anyway because this needs to compile).
def("__eq__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value == value2; });
def("__ne__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value != value2; });
def("__lt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value < value2; });
def("__gt__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value > value2; });
def("__le__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value <= value2; });
def("__ge__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value >= value2; });
def("__invert__", [](const Type &value) { return ~((UnderlyingType) value); });
def("__and__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; });
def("__or__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; });
def("__xor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; });
def("__rand__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value & value2; });
def("__ror__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value | value2; });
def("__rxor__", [](const Type &value, UnderlyingType value2) { return (UnderlyingType) value ^ value2; });
def("__and__", [](const Type &value, const Type &value2) { return (UnderlyingType) value & (UnderlyingType) value2; });
def("__or__", [](const Type &value, const Type &value2) { return (UnderlyingType) value | (UnderlyingType) value2; });
def("__xor__", [](const Type &value, const Type &value2) { return (UnderlyingType) value ^ (UnderlyingType) value2; });
}
def("__hash__", [](const Type &value) { return (UnderlyingType) value; });
// Pickling and unpickling -- needed for use with the 'multiprocessing' module

View File

@ -19,6 +19,12 @@ enum class ScopedEnum {
Three
};
enum Flags {
Read = 4,
Write = 2,
Execute = 1
};
class ClassWithUnscopedEnum {
public:
enum EMode {
@ -48,6 +54,13 @@ test_initializer enums([](py::module &m) {
.value("Three", ScopedEnum::Three)
;
py::enum_<Flags>(m, "Flags")
.value("Read", Flags::Read)
.value("Write", Flags::Write)
.value("Execute", Flags::Execute)
.export_values();
;
py::class_<ClassWithUnscopedEnum> exenum_class(m, "ClassWithUnscopedEnum");
exenum_class.def_static("test_function", &ClassWithUnscopedEnum::test_function);
py::enum_<ClassWithUnscopedEnum::EMode>(exenum_class, "EMode")

View File

@ -16,6 +16,24 @@ def test_unscoped_enum():
assert int(UnscopedEnum.ETwo) == 2
assert str(UnscopedEnum(2)) == "UnscopedEnum.ETwo"
# order
assert UnscopedEnum.EOne < UnscopedEnum.ETwo
assert UnscopedEnum.EOne < 2
assert UnscopedEnum.ETwo > UnscopedEnum.EOne
assert UnscopedEnum.ETwo > 1
assert UnscopedEnum.ETwo <= 2
assert UnscopedEnum.ETwo >= 2
assert UnscopedEnum.EOne <= UnscopedEnum.ETwo
assert UnscopedEnum.EOne <= 2
assert UnscopedEnum.ETwo >= UnscopedEnum.EOne
assert UnscopedEnum.ETwo >= 1
assert not (UnscopedEnum.ETwo < UnscopedEnum.EOne)
assert not (2 < UnscopedEnum.EOne)
def test_scoped_enum():
from pybind11_tests import ScopedEnum, test_scoped_enum
assert test_scoped_enum(ScopedEnum.Three) == "ScopedEnum::Three"
def test_scoped_enum():
from pybind11_tests import ScopedEnum, test_scoped_enum
@ -30,6 +48,13 @@ def test_scoped_enum():
with pytest.raises(TypeError):
assert z != 3
# order
assert ScopedEnum.Two < ScopedEnum.Three
assert ScopedEnum.Three > ScopedEnum.Two
assert ScopedEnum.Two <= ScopedEnum.Three
assert ScopedEnum.Two <= ScopedEnum.Two
assert ScopedEnum.Two >= ScopedEnum.Two
assert ScopedEnum.Three >= ScopedEnum.Two
def test_implicit_conversion():
from pybind11_tests import ClassWithUnscopedEnum
@ -61,3 +86,26 @@ def test_implicit_conversion():
x[f(second)] = 4
# Hashing test
assert str(x) == "{EMode.EFirstMode: 3, EMode.ESecondMode: 4}"
def test_binary_operators():
from pybind11_tests import Flags
assert int(Flags.Read) == 4
assert int(Flags.Write) == 2
assert int(Flags.Execute) == 1
assert int(Flags.Read | Flags.Write | Flags.Execute) == 7
assert int(Flags.Read | Flags.Write) == 6
assert int(Flags.Read | Flags.Execute) == 5
assert int(Flags.Write | Flags.Execute) == 3
assert int(Flags.Write | 1) == 3
state = Flags.Read | Flags.Write
assert (state & Flags.Read) != 0
assert (state & Flags.Write) != 0
assert (state & Flags.Execute) == 0
assert (state & 1) == 0
state2 = ~state
assert state2 == -7
assert int(state ^ state2) == -1