Merge branch 'master' into smart_holder

This commit is contained in:
Ralf W. Grosse-Kunstleve 2021-10-01 10:30:33 -07:00
commit b56b39d3b3
2 changed files with 104 additions and 4 deletions

View File

@ -595,6 +595,23 @@ template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &
); );
} }
template<typename Map>
struct keys_view
{
Map &map;
};
template<typename Map>
struct values_view
{
Map &map;
};
template<typename Map>
struct items_view
{
Map &map;
};
PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(detail)
@ -602,6 +619,9 @@ template <typename Map, typename holder_type = default_holder_type<Map>, typenam
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) { class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
using KeyType = typename Map::key_type; using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type; using MappedType = typename Map::mapped_type;
using KeysView = detail::keys_view<Map>;
using ValuesView = detail::values_view<Map>;
using ItemsView = detail::items_view<Map>;
using Class_ = class_<Map, holder_type>; using Class_ = class_<Map, holder_type>;
// If either type is a non-module-local bound type then make the map binding non-local as well; // If either type is a non-module-local bound type then make the map binding non-local as well;
@ -615,6 +635,12 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
} }
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...); Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
class_<KeysView> keys_view(
scope, ("KeysView[" + name + "]").c_str(), pybind11::module_local(local));
class_<ValuesView> values_view(
scope, ("ValuesView[" + name + "]").c_str(), pybind11::module_local(local));
class_<ItemsView> items_view(
scope, ("ItemsView[" + name + "]").c_str(), pybind11::module_local(local));
cl.def(init<>()); cl.def(init<>());
@ -628,12 +654,22 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
cl.def("__iter__", cl.def("__iter__",
[](Map &m) { return make_key_iterator(m.begin(), m.end()); }, [](Map &m) { return make_key_iterator(m.begin(), m.end()); },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ keep_alive<0, 1>() /* Essential: keep map alive while iterator exists */
);
cl.def("keys",
[](Map &m) { return KeysView{m}; },
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
);
cl.def("values",
[](Map &m) { return ValuesView{m}; },
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
); );
cl.def("items", cl.def("items",
[](Map &m) { return make_iterator(m.begin(), m.end()); }, [](Map &m) { return ItemsView{m}; },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ keep_alive<0, 1>() /* Essential: keep map alive while view exists */
); );
cl.def("__getitem__", cl.def("__getitem__",
@ -654,6 +690,8 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
return true; return true;
} }
); );
// Fallback for when the object is not of the key type
cl.def("__contains__", [](Map &, const object &) -> bool { return false; });
// Assignment provided only if the type is copyable // Assignment provided only if the type is copyable
detail::map_assignment<Map, Class_>(cl); detail::map_assignment<Map, Class_>(cl);
@ -669,6 +707,40 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&.
cl.def("__len__", &Map::size); cl.def("__len__", &Map::size);
keys_view.def("__len__", [](KeysView &view) { return view.map.size(); });
keys_view.def("__iter__",
[](KeysView &view) {
return make_key_iterator(view.map.begin(), view.map.end());
},
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
);
keys_view.def("__contains__",
[](KeysView &view, const KeyType &k) -> bool {
auto it = view.map.find(k);
if (it == view.map.end())
return false;
return true;
}
);
// Fallback for when the object is not of the key type
keys_view.def("__contains__", [](KeysView &, const object &) -> bool { return false; });
values_view.def("__len__", [](ValuesView &view) { return view.map.size(); });
values_view.def("__iter__",
[](ValuesView &view) {
return make_value_iterator(view.map.begin(), view.map.end());
},
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
);
items_view.def("__len__", [](ItemsView &view) { return view.map.size(); });
items_view.def("__iter__",
[](ItemsView &view) {
return make_iterator(view.map.begin(), view.map.end());
},
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
);
return cl; return cl;
} }

View File

@ -160,15 +160,43 @@ def test_map_string_double():
mm["b"] = 2.5 mm["b"] = 2.5
assert list(mm) == ["a", "b"] assert list(mm) == ["a", "b"]
assert list(mm.items()) == [("a", 1), ("b", 2.5)]
assert str(mm) == "MapStringDouble{a: 1, b: 2.5}" assert str(mm) == "MapStringDouble{a: 1, b: 2.5}"
assert "b" in mm
assert "c" not in mm
assert 123 not in mm
# Check that keys, values, items are views, not merely iterable
keys = mm.keys()
values = mm.values()
items = mm.items()
assert list(keys) == ["a", "b"]
assert len(keys) == 2
assert "a" in keys
assert "c" not in keys
assert 123 not in keys
assert list(items) == [("a", 1), ("b", 2.5)]
assert len(items) == 2
assert ("b", 2.5) in items
assert "hello" not in items
assert ("b", 2.5, None) not in items
assert list(values) == [1, 2.5]
assert len(values) == 2
assert 1 in values
assert 2 not in values
# Check that views update when the map is updated
mm["c"] = -1
assert list(keys) == ["a", "b", "c"]
assert list(values) == [1, 2.5, -1]
assert list(items) == [("a", 1), ("b", 2.5), ("c", -1)]
um = m.UnorderedMapStringDouble() um = m.UnorderedMapStringDouble()
um["ua"] = 1.1 um["ua"] = 1.1
um["ub"] = 2.6 um["ub"] = 2.6
assert sorted(list(um)) == ["ua", "ub"] assert sorted(list(um)) == ["ua", "ub"]
assert list(um.keys()) == list(um)
assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)] assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)]
assert list(zip(um.keys(), um.values())) == list(um.items())
assert "UnorderedMapStringDouble" in str(um) assert "UnorderedMapStringDouble" in str(um)