From f2907651fadc54c89492d4da03a120faa7b9260d Mon Sep 17 00:00:00 2001 From: Boris Dalstein Date: Sat, 12 Oct 2024 05:33:13 +0200 Subject: [PATCH] Fix #5399: iterator increment operator does not skip first item (#5400) * Fix #5399: iterator increment operator does not skip first item * Fix postfix increment operator: init() must be called before copying *this --- include/pybind11/pytypes.h | 20 +++++++++++++++----- tests/test_pytypes.cpp | 12 ++++++++++++ tests/test_pytypes.py | 5 +++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 7aafab6dc..027e36098 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1470,11 +1470,17 @@ public: PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) iterator &operator++() { + init(); advance(); return *this; } iterator operator++(int) { + // Note: We must call init() first so that rv.value is + // the same as this->value just before calling advance(). + // Otherwise, dereferencing the returned iterator may call + // advance() again and return the 3rd item instead of the 1st. + init(); auto rv = *this; advance(); return rv; @@ -1482,15 +1488,12 @@ public: // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference operator*() const { - if (m_ptr && !value.ptr()) { - auto &self = const_cast(*this); - self.advance(); - } + init(); return value; } pointer operator->() const { - operator*(); + init(); return &value; } @@ -1513,6 +1516,13 @@ public: friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } private: + void init() const { + if (m_ptr && !value.ptr()) { + auto &self = const_cast(*this); + self.advance(); + } + } + void advance() { value = reinterpret_steal(PyIter_Next(m_ptr)); if (value.ptr() == nullptr && PyErr_Occurred()) { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 19f65ce7e..8df4cdd3f 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -150,6 +150,18 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_iterator", [] { return py::iterator(); }); // test_iterable m.def("get_iterable", [] { return py::iterable(); }); + m.def("get_first_item_from_iterable", [](const py::iterable &iter) { + // This tests the postfix increment operator + py::iterator it = iter.begin(); + py::iterator it2 = it++; + return *it2; + }); + m.def("get_second_item_from_iterable", [](const py::iterable &iter) { + // This tests the prefix increment operator + py::iterator it = iter.begin(); + ++it; + return *it; + }); m.def("get_frozenset_from_iterable", [](const py::iterable &iter) { return py::frozenset(iter); }); m.def("get_list_from_iterable", [](const py::iterable &iter) { return py::list(iter); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 39d0b619b..9fd24b34f 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -52,6 +52,11 @@ def test_from_iterable(pytype, from_iter_func): def test_iterable(doc): assert doc(m.get_iterable) == "get_iterable() -> Iterable" + lins = [1, 2, 3] + i = m.get_first_item_from_iterable(lins) + assert i == 1 + i = m.get_second_item_from_iterable(lins) + assert i == 2 def test_float(doc):