From b3573ac9615b705d6e5b9bf598a3378edf079f06 Mon Sep 17 00:00:00 2001 From: Bruce Merry <1963944+bmerry@users.noreply.github.com> Date: Fri, 1 Oct 2021 15:24:36 +0200 Subject: [PATCH] feat: add `.keys` and `.values` to bind_map (#3310) * Add `.keys` and `.values` to bind_map Both of these implement views (rather than just iterators), and `.items` is also upgraded to a view. In practical terms, this allows a view to be iterated multiple times and have its size taken, neither of which works with an iterator. The views implement `__len__`, `__iter__`, and the keys view implements `__contains__`. Testing membership also works in item and value views because Python falls back to iteration. This won't be optimal for item values since it's linear rather than O(log n) or O(1), but I didn't fancy trying to get all the corner cases to match Python behaviour (tuple of wrong types, wrong length tuple, not a tuple etc). Missing relative to Python dictionary views is `__reversed__` (only added to Python in 3.8). Implementing that could break code that binds custom map classes which don't provide `rbegin`/`rend` (at least without doing clever things with SFINAE), so I've not tried. The size increase on my system is 131072 bytes, which is rather large (5%) but also suspiciously round (2^17) and makes me suspect some quantisation effect. * bind_map: support any object in __contains__ Add extra overload of `__contains__` (for both the map itself and KeysView) which takes an arbitrary object and returns false. * Take py::object by const reference in __contains__ To keep clang-tidy happy. * Removing stray `py::` (detected via interactive testing in Google environment). Co-authored-by: Ralf W. Grosse-Kunstleve --- include/pybind11/stl_bind.h | 78 +++++++++++++++++++++++++++++++++++-- tests/test_stl_binders.py | 30 +++++++++++++- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index 82317b379..050be83cc 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -595,6 +595,23 @@ template auto map_if_insertion_operator(Class_ & ); } +template +struct keys_view +{ + Map ↦ +}; + +template +struct values_view +{ + Map ↦ +}; + +template +struct items_view +{ + Map ↦ +}; PYBIND11_NAMESPACE_END(detail) @@ -602,6 +619,9 @@ template , typename... class_ bind_map(handle scope, const std::string &name, Args&&... args) { using KeyType = typename Map::key_type; using MappedType = typename Map::mapped_type; + using KeysView = detail::keys_view; + using ValuesView = detail::values_view; + using ItemsView = detail::items_view; using Class_ = class_; // If either type is a non-module-local bound type then make the map binding non-local as well; @@ -615,6 +635,12 @@ class_ bind_map(handle scope, const std::string &name, Args&&. } Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + class_ keys_view( + scope, ("KeysView[" + name + "]").c_str(), pybind11::module_local(local)); + class_ values_view( + scope, ("ValuesView[" + name + "]").c_str(), pybind11::module_local(local)); + class_ items_view( + scope, ("ItemsView[" + name + "]").c_str(), pybind11::module_local(local)); cl.def(init<>()); @@ -628,12 +654,22 @@ class_ bind_map(handle scope, const std::string &name, Args&&. cl.def("__iter__", [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + keep_alive<0, 1>() /* Essential: keep map alive while iterator exists */ + ); + + cl.def("keys", + [](Map &m) { return KeysView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ + ); + + cl.def("values", + [](Map &m) { return ValuesView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def("items", - [](Map &m) { return make_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + [](Map &m) { return ItemsView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def("__getitem__", @@ -654,6 +690,8 @@ class_ bind_map(handle scope, const std::string &name, Args&&. return true; } ); + // Fallback for when the object is not of the key type + cl.def("__contains__", [](Map &, const object &) -> bool { return false; }); // Assignment provided only if the type is copyable detail::map_assignment(cl); @@ -669,6 +707,40 @@ class_ bind_map(handle scope, const std::string &name, Args&&. cl.def("__len__", &Map::size); + keys_view.def("__len__", [](KeysView &view) { return view.map.size(); }); + keys_view.def("__iter__", + [](KeysView &view) { + return make_key_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + keys_view.def("__contains__", + [](KeysView &view, const KeyType &k) -> bool { + auto it = view.map.find(k); + if (it == view.map.end()) + return false; + return true; + } + ); + // Fallback for when the object is not of the key type + keys_view.def("__contains__", [](KeysView &, const object &) -> bool { return false; }); + + values_view.def("__len__", [](ValuesView &view) { return view.map.size(); }); + values_view.def("__iter__", + [](ValuesView &view) { + return make_value_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + + items_view.def("__len__", [](ItemsView &view) { return view.map.size(); }); + items_view.def("__iter__", + [](ItemsView &view) { + return make_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + return cl; } diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index 475a9ec40..a68dcd31d 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -160,15 +160,43 @@ def test_map_string_double(): mm["b"] = 2.5 assert list(mm) == ["a", "b"] - assert list(mm.items()) == [("a", 1), ("b", 2.5)] assert str(mm) == "MapStringDouble{a: 1, b: 2.5}" + assert "b" in mm + assert "c" not in mm + assert 123 not in mm + + # Check that keys, values, items are views, not merely iterable + keys = mm.keys() + values = mm.values() + items = mm.items() + assert list(keys) == ["a", "b"] + assert len(keys) == 2 + assert "a" in keys + assert "c" not in keys + assert 123 not in keys + assert list(items) == [("a", 1), ("b", 2.5)] + assert len(items) == 2 + assert ("b", 2.5) in items + assert "hello" not in items + assert ("b", 2.5, None) not in items + assert list(values) == [1, 2.5] + assert len(values) == 2 + assert 1 in values + assert 2 not in values + # Check that views update when the map is updated + mm["c"] = -1 + assert list(keys) == ["a", "b", "c"] + assert list(values) == [1, 2.5, -1] + assert list(items) == [("a", 1), ("b", 2.5), ("c", -1)] um = m.UnorderedMapStringDouble() um["ua"] = 1.1 um["ub"] = 2.6 assert sorted(list(um)) == ["ua", "ub"] + assert list(um.keys()) == list(um) assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)] + assert list(zip(um.keys(), um.values())) == list(um.items()) assert "UnorderedMapStringDouble" in str(um)