diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index e44f8ac5b..3579da155 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -591,47 +591,66 @@ NAMESPACE_END(detail) /// \addtogroup pytypes /// @{ + +/** \rst + Wraps a Python iterator so that it can also be used as a C++ input iterator + + Caveat: copying an iterator does not (and cannot) clone the internal + state of the Python iterable. This also applies to the post-increment + operator. This iterator should only be used to retrieve the current + value using ``operator*()``. +\endrst */ class iterator : public object { public: - /** Caveat: copying an iterator does not (and cannot) clone the internal - state of the Python iterable */ PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) iterator& operator++() { - if (m_ptr) - advance(); + advance(); return *this; } - /** Caveat: this postincrement operator does not (and cannot) clone the - internal state of the Python iterable. It should only be used to - retrieve the current iterate using operator*() */ iterator operator++(int) { - iterator rv(*this); - rv.value = value; - if (m_ptr) - advance(); + auto rv = *this; + advance(); return rv; } - bool operator==(const iterator &it) const { return *it == **this; } - bool operator!=(const iterator &it) const { return *it != **this; } - handle operator*() const { - if (!ready && m_ptr) { + if (m_ptr && !value.ptr()) { auto& self = const_cast(*this); self.advance(); - self.ready = true; } return value; } + const handle *operator->() const { operator*(); return &value; } + + /** \rst + The value which marks the end of the iteration. ``it == iterator::sentinel()`` + is equivalent to catching ``StopIteration`` in Python. + + .. code-block:: cpp + + void foo(py::iterator it) { + while (it != py::iterator::sentinel()) { + // use `*it` + ++it; + } + } + \endrst */ + static iterator sentinel() { return {}; } + + friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); } + friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } + private: - void advance() { value = reinterpret_steal(PyIter_Next(m_ptr)); } + void advance() { + value = reinterpret_steal(PyIter_Next(m_ptr)); + if (PyErr_Occurred()) { throw error_already_set(); } + } private: object value = {}; - bool ready = false; }; class iterable : public object { @@ -1032,15 +1051,17 @@ inline str repr(handle h) { #endif return reinterpret_steal(str_value); } + +inline iterator iter(handle obj) { + PyObject *result = PyObject_GetIter(obj.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} /// @} python_builtins NAMESPACE_BEGIN(detail) -template iterator object_api::begin() const { - return reinterpret_steal(PyObject_GetIter(derived().ptr())); -} -template iterator object_api::end() const { - return {}; -} +template iterator object_api::begin() const { return iter(derived()); } +template iterator object_api::end() const { return iterator::sentinel(); } template item_accessor object_api::operator[](handle key) const { return {derived(), reinterpret_borrow(key)}; } diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index 323b4bf00..640188240 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -169,7 +169,8 @@ bool operator==(const NonZeroIterator>& it, const NonZeroSentine return !(*it).first || !(*it).second; } -test_initializer sequences_and_iterators([](py::module &m) { +test_initializer sequences_and_iterators([](py::module &pm) { + auto m = pm.def_submodule("sequences_and_iterators"); py::class_ seq(m, "Sequence"); @@ -272,4 +273,21 @@ test_initializer sequences_and_iterators([](py::module &m) { On the actual Sequence object, the iterator would be constructed as follows: .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast(), s); }) #endif + + m.def("object_to_list", [](py::object o) { + auto l = py::list(); + for (auto item : o) { + l.append(item); + } + return l; + }); + + m.def("iterator_to_list", [](py::iterator it) { + auto l = py::list(); + while (it != py::iterator::sentinel()) { + l.append(*it); + ++it; + } + return l; + }); }); diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index 76b9f43f6..b340451d7 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -11,7 +11,7 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0): def test_generalized_iterators(): - from pybind11_tests import IntPairs + from pybind11_tests.sequences_and_iterators import IntPairs assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)] assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)] @@ -23,7 +23,8 @@ def test_generalized_iterators(): def test_sequence(): - from pybind11_tests import Sequence, ConstructorStats + from pybind11_tests import ConstructorStats + from pybind11_tests.sequences_and_iterators import Sequence cstats = ConstructorStats.get(Sequence) @@ -71,7 +72,7 @@ def test_sequence(): def test_map_iterator(): - from pybind11_tests import StringMap + from pybind11_tests.sequences_and_iterators import StringMap m = StringMap({'hi': 'bye', 'black': 'white'}) assert m['hi'] == 'bye' @@ -88,3 +89,27 @@ def test_map_iterator(): assert m[k] == expected[k] for k, v in m.items(): assert v == expected[k] + + +def test_python_iterator_in_cpp(): + import pybind11_tests.sequences_and_iterators as m + + t = (1, 2, 3) + assert m.object_to_list(t) == [1, 2, 3] + assert m.object_to_list(iter(t)) == [1, 2, 3] + assert m.iterator_to_list(iter(t)) == [1, 2, 3] + + with pytest.raises(TypeError) as excinfo: + m.object_to_list(1) + assert "object is not iterable" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + m.iterator_to_list(1) + assert "incompatible function arguments" in str(excinfo.value) + + def bad_next_call(): + raise RuntimeError("py::iterator::advance() should propagate errors") + + with pytest.raises(RuntimeError) as excinfo: + m.iterator_to_list(iter(bad_next_call, None)) + assert str(excinfo.value) == "py::iterator::advance() should propagate errors"