diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 2d573dfad..f1dd009c2 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1224,8 +1224,9 @@ public: detail::dict_iterator begin() const { return {*this, 0}; } detail::dict_iterator end() const { return {}; } void clear() const { PyDict_Clear(ptr()); } - bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; } - bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } + template bool contains(T &&key) const { + return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward(key)).ptr()) == 1; + } private: /// Call the `dict` Python type -- always returns a new reference @@ -1276,6 +1277,9 @@ public: return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; } void clear() const { PySet_Clear(m_ptr); } + template bool contains(T &&val) const { + return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; + } }; class function : public object { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index e6c955ff9..a8caca45c 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -37,6 +37,12 @@ TEST_SUBMODULE(pytypes, m) { for (auto item : set) py::print("key:", item); }); + m.def("set_contains", [](py::set set, py::object key) { + return set.contains(key); + }); + m.def("set_contains", [](py::set set, const char* key) { + return set.contains(key); + }); // test_dict m.def("get_dict", []() { return py::dict("key"_a="value"); }); @@ -49,6 +55,12 @@ TEST_SUBMODULE(pytypes, m) { auto d2 = py::dict("z"_a=3, **d1); return d2; }); + m.def("dict_contains", [](py::dict dict, py::object val) { + return dict.contains(val); + }); + m.def("dict_contains", [](py::dict dict, const char* val) { + return dict.contains(val); + }); // test_str m.def("str_from_string", []() { return py::str(std::string("baz")); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 0116d4ef2..a0364d6af 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -37,6 +37,10 @@ def test_set(capture, doc): key: key4 """ + assert not m.set_contains(set([]), 42) + assert m.set_contains({42}, 42) + assert m.set_contains({"foo"}, "foo") + assert doc(m.get_list) == "get_list() -> list" assert doc(m.print_list) == "print_list(arg0: list) -> None" @@ -53,6 +57,10 @@ def test_dict(capture, doc): key: key2, value=value2 """ + assert not m.dict_contains({}, 42) + assert m.dict_contains({42: None}, 42) + assert m.dict_contains({"foo": None}, "foo") + assert doc(m.get_dict) == "get_dict() -> dict" assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"