Fix scoped enums comparison for equal/not equal cases (#1339) (#1571)

This commit is contained in:
Tarcísio Fischer 2018-10-24 06:18:58 -03:00 committed by Wenzel Jakob
parent 1377fbf73c
commit 54eb8193e5
2 changed files with 25 additions and 12 deletions

View File

@ -1426,11 +1426,11 @@ struct enum_base {
}), none(), none(), ""
);
#define PYBIND11_ENUM_OP_STRICT(op, expr) \
#define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \
m_base.attr(op) = cpp_function( \
[](object a, object b) { \
if (!a.get_type().is(b.get_type())) \
throw type_error("Expected an enumeration of matching type!"); \
strict_behavior; \
return expr; \
}, \
is_method(m_base))
@ -1460,14 +1460,16 @@ struct enum_base {
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
}
} else {
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)));
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)));
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
if (is_arithmetic) {
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b));
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b));
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b));
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b));
#define THROW throw type_error("Expected an enumeration of matching type!");
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), THROW);
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), THROW);
#undef THROW
}
}

View File

@ -47,10 +47,12 @@ Members:
EOne : Docstring for EOne'''
# no TypeError exception for unscoped enum ==/!= int comparisons
# Unscoped enums will accept ==/!= int comparisons
y = m.UnscopedEnum.ETwo
assert y == 2
assert 2 == y
assert y != 3
assert 3 != y
assert int(m.UnscopedEnum.ETwo) == 2
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
@ -75,11 +77,20 @@ def test_scoped_enum():
z = m.ScopedEnum.Two
assert m.test_scoped_enum(z) == "ScopedEnum::Two"
# expected TypeError exceptions for scoped enum ==/!= int comparisons
with pytest.raises(TypeError):
assert z == 2
with pytest.raises(TypeError):
# Scoped enums will *NOT* accept ==/!= int comparisons (Will always return False)
assert not z == 3
assert not 3 == z
assert z != 3
assert 3 != z
# Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions)
with pytest.raises(TypeError):
z > 3
with pytest.raises(TypeError):
z < 3
with pytest.raises(TypeError):
z >= 3
with pytest.raises(TypeError):
z <= 3
# order
assert m.ScopedEnum.Two < m.ScopedEnum.Three