mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 13:47:53 +00:00
Avoid conversion to int_
rhs argument of enum eq/ne (#1912)
* fix: Avoid conversion to `int_` rhs argument of enum eq/ne * test: compare unscoped enum with strings * suppress comparison to None warning * test unscoped enum arithmetic and comparision with unsupported type
This commit is contained in:
parent
f6c4c1047a
commit
09f0829401
@ -1469,9 +1469,17 @@ struct enum_base {
|
|||||||
}, \
|
}, \
|
||||||
is_method(m_base))
|
is_method(m_base))
|
||||||
|
|
||||||
|
#define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \
|
||||||
|
m_base.attr(op) = cpp_function( \
|
||||||
|
[](object a_, object b) { \
|
||||||
|
int_ a(a_); \
|
||||||
|
return expr; \
|
||||||
|
}, \
|
||||||
|
is_method(m_base))
|
||||||
|
|
||||||
if (is_convertible) {
|
if (is_convertible) {
|
||||||
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
|
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
|
||||||
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
|
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));
|
||||||
|
|
||||||
if (is_arithmetic) {
|
if (is_arithmetic) {
|
||||||
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
|
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
|
||||||
@ -1501,6 +1509,7 @@ struct enum_base {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef PYBIND11_ENUM_OP_CONV_LHS
|
||||||
#undef PYBIND11_ENUM_OP_CONV
|
#undef PYBIND11_ENUM_OP_CONV
|
||||||
#undef PYBIND11_ENUM_OP_STRICT
|
#undef PYBIND11_ENUM_OP_STRICT
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ TEST_SUBMODULE(enums, m) {
|
|||||||
// test_unscoped_enum
|
// test_unscoped_enum
|
||||||
enum UnscopedEnum {
|
enum UnscopedEnum {
|
||||||
EOne = 1,
|
EOne = 1,
|
||||||
ETwo
|
ETwo,
|
||||||
|
EThree
|
||||||
};
|
};
|
||||||
py::enum_<UnscopedEnum>(m, "UnscopedEnum", py::arithmetic(), "An unscoped enumeration")
|
py::enum_<UnscopedEnum>(m, "UnscopedEnum", py::arithmetic(), "An unscoped enumeration")
|
||||||
.value("EOne", EOne, "Docstring for EOne")
|
.value("EOne", EOne, "Docstring for EOne")
|
||||||
.value("ETwo", ETwo, "Docstring for ETwo")
|
.value("ETwo", ETwo, "Docstring for ETwo")
|
||||||
|
.value("EThree", EThree, "Docstring for EThree")
|
||||||
.export_values();
|
.export_values();
|
||||||
|
|
||||||
// test_scoped_enum
|
// test_scoped_enum
|
||||||
|
@ -21,7 +21,7 @@ def test_unscoped_enum():
|
|||||||
|
|
||||||
# __members__ property
|
# __members__ property
|
||||||
assert m.UnscopedEnum.__members__ == \
|
assert m.UnscopedEnum.__members__ == \
|
||||||
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo}
|
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree}
|
||||||
# __members__ readonly
|
# __members__ readonly
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
m.UnscopedEnum.__members__ = {}
|
m.UnscopedEnum.__members__ = {}
|
||||||
@ -29,23 +29,18 @@ def test_unscoped_enum():
|
|||||||
foo = m.UnscopedEnum.__members__
|
foo = m.UnscopedEnum.__members__
|
||||||
foo["bar"] = "baz"
|
foo["bar"] = "baz"
|
||||||
assert m.UnscopedEnum.__members__ == \
|
assert m.UnscopedEnum.__members__ == \
|
||||||
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo}
|
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree}
|
||||||
|
|
||||||
assert m.UnscopedEnum.__doc__ == \
|
for docstring_line in '''An unscoped enumeration
|
||||||
'''An unscoped enumeration
|
|
||||||
|
|
||||||
Members:
|
Members:
|
||||||
|
|
||||||
EOne : Docstring for EOne
|
EOne : Docstring for EOne
|
||||||
|
|
||||||
ETwo : Docstring for ETwo''' or m.UnscopedEnum.__doc__ == \
|
|
||||||
'''An unscoped enumeration
|
|
||||||
|
|
||||||
Members:
|
|
||||||
|
|
||||||
ETwo : Docstring for ETwo
|
ETwo : Docstring for ETwo
|
||||||
|
|
||||||
EOne : Docstring for EOne'''
|
EThree : Docstring for EThree'''.split('\n'):
|
||||||
|
assert docstring_line in m.UnscopedEnum.__doc__
|
||||||
|
|
||||||
# Unscoped enums will accept ==/!= int comparisons
|
# Unscoped enums will accept ==/!= int comparisons
|
||||||
y = m.UnscopedEnum.ETwo
|
y = m.UnscopedEnum.ETwo
|
||||||
@ -53,6 +48,38 @@ Members:
|
|||||||
assert 2 == y
|
assert 2 == y
|
||||||
assert y != 3
|
assert y != 3
|
||||||
assert 3 != y
|
assert 3 != y
|
||||||
|
# Compare with None
|
||||||
|
assert (y != None) # noqa: E711
|
||||||
|
assert not (y == None) # noqa: E711
|
||||||
|
# Compare with an object
|
||||||
|
assert (y != object())
|
||||||
|
assert not (y == object())
|
||||||
|
# Compare with string
|
||||||
|
assert y != "2"
|
||||||
|
assert "2" != y
|
||||||
|
assert not ("2" == y)
|
||||||
|
assert not (y == "2")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y < object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y <= object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y > object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y >= object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y | object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y & object()
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
y ^ object()
|
||||||
|
|
||||||
assert int(m.UnscopedEnum.ETwo) == 2
|
assert int(m.UnscopedEnum.ETwo) == 2
|
||||||
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
|
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
|
||||||
@ -71,6 +98,11 @@ Members:
|
|||||||
assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne)
|
assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne)
|
||||||
assert not (2 < m.UnscopedEnum.EOne)
|
assert not (2 < m.UnscopedEnum.EOne)
|
||||||
|
|
||||||
|
# arithmetic
|
||||||
|
assert m.UnscopedEnum.EOne & m.UnscopedEnum.EThree == m.UnscopedEnum.EOne
|
||||||
|
assert m.UnscopedEnum.EOne | m.UnscopedEnum.ETwo == m.UnscopedEnum.EThree
|
||||||
|
assert m.UnscopedEnum.EOne ^ m.UnscopedEnum.EThree == m.UnscopedEnum.ETwo
|
||||||
|
|
||||||
|
|
||||||
def test_scoped_enum():
|
def test_scoped_enum():
|
||||||
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
|
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
|
||||||
@ -82,6 +114,12 @@ def test_scoped_enum():
|
|||||||
assert not 3 == z
|
assert not 3 == z
|
||||||
assert z != 3
|
assert z != 3
|
||||||
assert 3 != z
|
assert 3 != z
|
||||||
|
# Compare with None
|
||||||
|
assert (z != None) # noqa: E711
|
||||||
|
assert not (z == None) # noqa: E711
|
||||||
|
# Compare with an object
|
||||||
|
assert (z != object())
|
||||||
|
assert not (z == object())
|
||||||
# Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions)
|
# Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
z > 3
|
z > 3
|
||||||
|
Loading…
Reference in New Issue
Block a user