diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 3579da155..f09b5febb 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -602,6 +602,12 @@ NAMESPACE_END(detail) \endrst */ class iterator : public object { public: + using iterator_category = std::input_iterator_tag; + using difference_type = ssize_t; + using value_type = handle; + using reference = const handle; + using pointer = const handle *; + PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) iterator& operator++() { @@ -615,7 +621,7 @@ public: return rv; } - handle operator*() const { + reference operator*() const { if (m_ptr && !value.ptr()) { auto& self = const_cast(*this); self.advance(); @@ -623,7 +629,7 @@ public: return value; } - const handle *operator->() const { operator*(); return &value; } + pointer operator->() const { operator*(); return &value; } /** \rst The value which marks the end of the iteration. ``it == iterator::sentinel()`` diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index 640188240..cda0af479 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -290,4 +290,14 @@ test_initializer sequences_and_iterators([](py::module &pm) { } return l; }); + + // Make sure that py::iterator works with std algorithms + m.def("count_none", [](py::object o) { + return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); + }); + + m.def("find_none", [](py::object o) { + auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); + return it->is_none(); + }); }); diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index b340451d7..306664751 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -113,3 +113,7 @@ def test_python_iterator_in_cpp(): with pytest.raises(RuntimeError) as excinfo: m.iterator_to_list(iter(bad_next_call, None)) assert str(excinfo.value) == "py::iterator::advance() should propagate errors" + + l = [1, None, 0, None] + assert m.count_none(l) == 2 + assert m.find_none(l) is True