Added set::contains and generalized dict::contains (#1884)

Dynamically resolving __contains__ on each call is wasteful since set
has a public PySet_Contains function.
This commit is contained in:
Sergei Lebedev 2019-08-16 12:32:27 -07:00 committed by Wenzel Jakob
parent 5b0ea77c62
commit 08b0bda4bc
3 changed files with 26 additions and 2 deletions

View File

@ -1224,8 +1224,9 @@ public:
detail::dict_iterator begin() const { return {*this, 0}; }
detail::dict_iterator end() const { return {}; }
void clear() const { PyDict_Clear(ptr()); }
bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; }
bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
template <typename T> bool contains(T &&key) const {
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
}
private:
/// Call the `dict` Python type -- always returns a new reference
@ -1276,6 +1277,9 @@ public:
return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
}
void clear() const { PySet_Clear(m_ptr); }
template <typename T> bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
}
};
class function : public object {

View File

@ -37,6 +37,12 @@ TEST_SUBMODULE(pytypes, m) {
for (auto item : set)
py::print("key:", item);
});
m.def("set_contains", [](py::set set, py::object key) {
return set.contains(key);
});
m.def("set_contains", [](py::set set, const char* key) {
return set.contains(key);
});
// test_dict
m.def("get_dict", []() { return py::dict("key"_a="value"); });
@ -49,6 +55,12 @@ TEST_SUBMODULE(pytypes, m) {
auto d2 = py::dict("z"_a=3, **d1);
return d2;
});
m.def("dict_contains", [](py::dict dict, py::object val) {
return dict.contains(val);
});
m.def("dict_contains", [](py::dict dict, const char* val) {
return dict.contains(val);
});
// test_str
m.def("str_from_string", []() { return py::str(std::string("baz")); });

View File

@ -37,6 +37,10 @@ def test_set(capture, doc):
key: key4
"""
assert not m.set_contains(set([]), 42)
assert m.set_contains({42}, 42)
assert m.set_contains({"foo"}, "foo")
assert doc(m.get_list) == "get_list() -> list"
assert doc(m.print_list) == "print_list(arg0: list) -> None"
@ -53,6 +57,10 @@ def test_dict(capture, doc):
key: key2, value=value2
"""
assert not m.dict_contains({}, 42)
assert m.dict_contains({42: None}, 42)
assert m.dict_contains({"foo": None}, "foo")
assert doc(m.get_dict) == "get_dict() -> dict"
assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"