diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 1a415a405..87f4645e6 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1047,6 +1047,13 @@ inline void call_operator_delete(void *p, size_t s, size_t a) { #endif } +inline void add_class_method(object& cls, const char *name_, const cpp_function &cf) { + cls.attr(cf.name()) = cf; + if (strcmp(name_, "__eq__") == 0 && !cls.attr("__dict__").contains("__hash__")) { + cls.attr("__hash__") = none(); + } +} + PYBIND11_NAMESPACE_END(detail) /// Given a pointer to a member function, cast it to its `Derived` version. @@ -1144,7 +1151,7 @@ public: class_ &def(const char *name_, Func&& f, const Extra&... extra) { cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), sibling(getattr(*this, name_, none())), extra...); - attr(cf.name()) = cf; + add_class_method(*this, name_, cf); return *this; } diff --git a/tests/test_operator_overloading.cpp b/tests/test_operator_overloading.cpp index 52fcd3383..f3c2eaafa 100644 --- a/tests/test_operator_overloading.cpp +++ b/tests/test_operator_overloading.cpp @@ -187,6 +187,38 @@ TEST_SUBMODULE(operators, m) { .def(py::self *= int()) .def_readwrite("b", &NestC::b); m.def("get_NestC", [](const NestC &c) { return c.value; }); + + + // test_overriding_eq_reset_hash + // #2191 Overriding __eq__ should set __hash__ to None + struct Comparable { + int value; + bool operator==(const Comparable& rhs) const {return value == rhs.value;} + }; + + struct Hashable : Comparable { + explicit Hashable(int value): Comparable{value}{}; + size_t hash() const { return static_cast(value); } + }; + + struct Hashable2 : Hashable { + using Hashable::Hashable; + }; + + py::class_(m, "Comparable") + .def(py::init()) + .def(py::self == py::self); + + py::class_(m, "Hashable") + .def(py::init()) + .def(py::self == py::self) + .def("__hash__", &Hashable::hash); + + // define __hash__ before __eq__ + py::class_(m, "Hashable2") + .def("__hash__", &Hashable::hash) + .def(py::init()) + .def(py::self == py::self); } #ifndef _MSC_VER diff --git a/tests/test_operator_overloading.py b/tests/test_operator_overloading.py index 6c927228c..39e3aee27 100644 --- a/tests/test_operator_overloading.py +++ b/tests/test_operator_overloading.py @@ -127,3 +127,19 @@ def test_nested(): assert abase.value == 42 del abase, b pytest.gc_collect() + + +def test_overriding_eq_reset_hash(): + + assert m.Comparable(15) is not m.Comparable(15) + assert m.Comparable(15) == m.Comparable(15) + + with pytest.raises(TypeError): + hash(m.Comparable(15)) # TypeError: unhashable type: 'm.Comparable' + + for hashable in (m.Hashable, m.Hashable2): + assert hashable(15) is not hashable(15) + assert hashable(15) == hashable(15) + + assert hash(hashable(15)) == 15 + assert hash(hashable(15)) == hash(hashable(15))