diff --git a/example/example-constants-and-functions.py b/example/example-constants-and-functions.py index 607450f92..f9292ee9e 100755 --- a/example/example-constants-and-functions.py +++ b/example/example-constants-and-functions.py @@ -46,6 +46,22 @@ print("Inequality test 2: " + str( ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) != ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode))) +print("Equality test 3: " + str( + ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) == + int(ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode)))) + +print("Inequality test 3: " + str( + ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) != + int(ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode)))) + +print("Equality test 4: " + str( + ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) == + int(ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode)))) + +print("Inequality test 4: " + str( + ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) != + int(ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode)))) + x = { ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode): 1, ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode): 2 diff --git a/example/example-constants-and-functions.ref b/example/example-constants-and-functions.ref index d2e1731aa..1d08223f8 100644 --- a/example/example-constants-and-functions.ref +++ b/example/example-constants-and-functions.ref @@ -30,6 +30,18 @@ ExampleWithEnum::test_function(enum=1) ExampleWithEnum::test_function(enum=2) Inequality test 2: True ExampleWithEnum::test_function(enum=1) +ExampleWithEnum::test_function(enum=1) +Equality test 3: True +ExampleWithEnum::test_function(enum=1) +ExampleWithEnum::test_function(enum=1) +Inequality test 3: False +ExampleWithEnum::test_function(enum=1) +ExampleWithEnum::test_function(enum=2) +Equality test 4: False +ExampleWithEnum::test_function(enum=1) +ExampleWithEnum::test_function(enum=2) +Inequality test 4: True +ExampleWithEnum::test_function(enum=1) ExampleWithEnum::test_function(enum=2) ExampleWithEnum::test_function(enum=1) ExampleWithEnum::test_function(enum=2) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index a6255d8d7..975cf1440 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1004,22 +1004,25 @@ private: /// Binds C++ enumerations and enumeration classes to Python template class enum_ : public class_ { public: + using UnderlyingType = typename std::underlying_type::type; template enum_(const handle &scope, const char *name, const Extra&... extra) : class_(scope, name, extra...), m_parent(scope) { - auto entries = new std::unordered_map(); + auto entries = new std::unordered_map(); this->def("__repr__", [name, entries](Type value) -> std::string { - auto it = entries->find((int) value); + auto it = entries->find((UnderlyingType) value); return std::string(name) + "." + ((it == entries->end()) ? std::string("???") : std::string(it->second)); }); - this->def("__init__", [](Type& value, int i) { value = (Type)i; }); - this->def("__init__", [](Type& value, int i) { new (&value) Type((Type) i); }); - this->def("__int__", [](Type value) { return (int) value; }); + this->def("__init__", [](Type& value, UnderlyingType i) { value = (Type)i; }); + this->def("__init__", [](Type& value, UnderlyingType i) { new (&value) Type((Type) i); }); + this->def("__int__", [](Type value) { return (UnderlyingType) value; }); this->def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; }); + this->def("__eq__", [](const Type &value, UnderlyingType value2) { return value2 && value == value2; }); this->def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; }); - this->def("__hash__", [](const Type &value) { return (int) value; }); + this->def("__ne__", [](const Type &value, UnderlyingType value2) { return value != value2; }); + this->def("__hash__", [](const Type &value) { return (UnderlyingType) value; }); m_entries = entries; } @@ -1036,11 +1039,11 @@ public: /// Add an enumeration entry enum_& value(char const* name, Type value) { this->attr(name) = pybind11::cast(value, return_value_policy::copy); - (*m_entries)[(int) value] = name; + (*m_entries)[(UnderlyingType) value] = name; return *this; } private: - std::unordered_map *m_entries; + std::unordered_map *m_entries; handle m_parent; };