Fix scoped enums comparison for equal/not equal cases (#1339) (#1571)

This commit is contained in:
Tarcísio Fischer 2018-10-24 06:18:58 -03:00 committed by Wenzel Jakob
parent 1377fbf73c
commit 54eb8193e5
2 changed files with 25 additions and 12 deletions

View File

@ -1426,11 +1426,11 @@ struct enum_base {
}), none(), none(), "" }), none(), none(), ""
); );
#define PYBIND11_ENUM_OP_STRICT(op, expr) \ #define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \
m_base.attr(op) = cpp_function( \ m_base.attr(op) = cpp_function( \
[](object a, object b) { \ [](object a, object b) { \
if (!a.get_type().is(b.get_type())) \ if (!a.get_type().is(b.get_type())) \
throw type_error("Expected an enumeration of matching type!"); \ strict_behavior; \
return expr; \ return expr; \
}, \ }, \
is_method(m_base)) is_method(m_base))
@ -1460,14 +1460,16 @@ struct enum_base {
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
} }
} else { } else {
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b))); PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b))); PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
if (is_arithmetic) { if (is_arithmetic) {
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b)); #define THROW throw type_error("Expected an enumeration of matching type!");
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b)); PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b)); PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b)); PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), THROW);
#undef THROW
} }
} }

View File

@ -47,10 +47,12 @@ Members:
EOne : Docstring for EOne''' EOne : Docstring for EOne'''
# no TypeError exception for unscoped enum ==/!= int comparisons # Unscoped enums will accept ==/!= int comparisons
y = m.UnscopedEnum.ETwo y = m.UnscopedEnum.ETwo
assert y == 2 assert y == 2
assert 2 == y
assert y != 3 assert y != 3
assert 3 != y
assert int(m.UnscopedEnum.ETwo) == 2 assert int(m.UnscopedEnum.ETwo) == 2
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo" assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
@ -75,11 +77,20 @@ def test_scoped_enum():
z = m.ScopedEnum.Two z = m.ScopedEnum.Two
assert m.test_scoped_enum(z) == "ScopedEnum::Two" assert m.test_scoped_enum(z) == "ScopedEnum::Two"
# expected TypeError exceptions for scoped enum ==/!= int comparisons # Scoped enums will *NOT* accept ==/!= int comparisons (Will always return False)
with pytest.raises(TypeError): assert not z == 3
assert z == 2 assert not 3 == z
with pytest.raises(TypeError):
assert z != 3 assert z != 3
assert 3 != z
# Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions)
with pytest.raises(TypeError):
z > 3
with pytest.raises(TypeError):
z < 3
with pytest.raises(TypeError):
z >= 3
with pytest.raises(TypeError):
z <= 3
# order # order
assert m.ScopedEnum.Two < m.ScopedEnum.Three assert m.ScopedEnum.Two < m.ScopedEnum.Three