diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index a7fe1898a..204aaa43a 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1469,9 +1469,17 @@ struct enum_base { }, \ is_method(m_base)) + #define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](object a_, object b) { \ + int_ a(a_); \ + return expr; \ + }, \ + is_method(m_base)) + if (is_convertible) { - PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b)); - PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b)); + PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b)); + PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b)); if (is_arithmetic) { PYBIND11_ENUM_OP_CONV("__lt__", a < b); @@ -1501,6 +1509,7 @@ struct enum_base { } } + #undef PYBIND11_ENUM_OP_CONV_LHS #undef PYBIND11_ENUM_OP_CONV #undef PYBIND11_ENUM_OP_STRICT diff --git a/tests/test_enum.cpp b/tests/test_enum.cpp index 498a00e16..315308920 100644 --- a/tests/test_enum.cpp +++ b/tests/test_enum.cpp @@ -13,11 +13,13 @@ TEST_SUBMODULE(enums, m) { // test_unscoped_enum enum UnscopedEnum { EOne = 1, - ETwo + ETwo, + EThree }; py::enum_(m, "UnscopedEnum", py::arithmetic(), "An unscoped enumeration") .value("EOne", EOne, "Docstring for EOne") .value("ETwo", ETwo, "Docstring for ETwo") + .value("EThree", EThree, "Docstring for EThree") .export_values(); // test_scoped_enum diff --git a/tests/test_enum.py b/tests/test_enum.py index 2f119a3a9..7fe9b618d 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -21,7 +21,7 @@ def test_unscoped_enum(): # __members__ property assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo} + {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} # __members__ readonly with pytest.raises(AttributeError): m.UnscopedEnum.__members__ = {} @@ -29,23 +29,18 @@ def test_unscoped_enum(): foo = m.UnscopedEnum.__members__ foo["bar"] = "baz" assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo} + {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} - assert m.UnscopedEnum.__doc__ == \ - '''An unscoped enumeration + for docstring_line in '''An unscoped enumeration Members: EOne : Docstring for EOne - ETwo : Docstring for ETwo''' or m.UnscopedEnum.__doc__ == \ - '''An unscoped enumeration - -Members: - ETwo : Docstring for ETwo - EOne : Docstring for EOne''' + EThree : Docstring for EThree'''.split('\n'): + assert docstring_line in m.UnscopedEnum.__doc__ # Unscoped enums will accept ==/!= int comparisons y = m.UnscopedEnum.ETwo @@ -53,6 +48,38 @@ Members: assert 2 == y assert y != 3 assert 3 != y + # Compare with None + assert (y != None) # noqa: E711 + assert not (y == None) # noqa: E711 + # Compare with an object + assert (y != object()) + assert not (y == object()) + # Compare with string + assert y != "2" + assert "2" != y + assert not ("2" == y) + assert not (y == "2") + + with pytest.raises(TypeError): + y < object() + + with pytest.raises(TypeError): + y <= object() + + with pytest.raises(TypeError): + y > object() + + with pytest.raises(TypeError): + y >= object() + + with pytest.raises(TypeError): + y | object() + + with pytest.raises(TypeError): + y & object() + + with pytest.raises(TypeError): + y ^ object() assert int(m.UnscopedEnum.ETwo) == 2 assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo" @@ -71,6 +98,11 @@ Members: assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne) assert not (2 < m.UnscopedEnum.EOne) + # arithmetic + assert m.UnscopedEnum.EOne & m.UnscopedEnum.EThree == m.UnscopedEnum.EOne + assert m.UnscopedEnum.EOne | m.UnscopedEnum.ETwo == m.UnscopedEnum.EThree + assert m.UnscopedEnum.EOne ^ m.UnscopedEnum.EThree == m.UnscopedEnum.ETwo + def test_scoped_enum(): assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three" @@ -82,6 +114,12 @@ def test_scoped_enum(): assert not 3 == z assert z != 3 assert 3 != z + # Compare with None + assert (z != None) # noqa: E711 + assert not (z == None) # noqa: E711 + # Compare with an object + assert (z != object()) + assert not (z == object()) # Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions) with pytest.raises(TypeError): z > 3