Fix #5399: iterator increment operator does not skip first item (#5400)

* Fix #5399: iterator increment operator does not skip first item

* Fix postfix increment operator: init() must be called before copying *this
This commit is contained in:
Boris Dalstein 2024-10-12 05:33:13 +02:00 committed by GitHub
parent af67e87393
commit f2907651fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 5 deletions

View File

@ -1470,11 +1470,17 @@ public:
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
iterator &operator++() { iterator &operator++() {
init();
advance(); advance();
return *this; return *this;
} }
iterator operator++(int) { iterator operator++(int) {
// Note: We must call init() first so that rv.value is
// the same as this->value just before calling advance().
// Otherwise, dereferencing the returned iterator may call
// advance() again and return the 3rd item instead of the 1st.
init();
auto rv = *this; auto rv = *this;
advance(); advance();
return rv; return rv;
@ -1482,15 +1488,12 @@ public:
// NOLINTNEXTLINE(readability-const-return-type) // PR #3263 // NOLINTNEXTLINE(readability-const-return-type) // PR #3263
reference operator*() const { reference operator*() const {
if (m_ptr && !value.ptr()) { init();
auto &self = const_cast<iterator &>(*this);
self.advance();
}
return value; return value;
} }
pointer operator->() const { pointer operator->() const {
operator*(); init();
return &value; return &value;
} }
@ -1513,6 +1516,13 @@ public:
friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); }
private: private:
void init() const {
if (m_ptr && !value.ptr()) {
auto &self = const_cast<iterator &>(*this);
self.advance();
}
}
void advance() { void advance() {
value = reinterpret_steal<object>(PyIter_Next(m_ptr)); value = reinterpret_steal<object>(PyIter_Next(m_ptr));
if (value.ptr() == nullptr && PyErr_Occurred()) { if (value.ptr() == nullptr && PyErr_Occurred()) {

View File

@ -150,6 +150,18 @@ TEST_SUBMODULE(pytypes, m) {
m.def("get_iterator", [] { return py::iterator(); }); m.def("get_iterator", [] { return py::iterator(); });
// test_iterable // test_iterable
m.def("get_iterable", [] { return py::iterable(); }); m.def("get_iterable", [] { return py::iterable(); });
m.def("get_first_item_from_iterable", [](const py::iterable &iter) {
// This tests the postfix increment operator
py::iterator it = iter.begin();
py::iterator it2 = it++;
return *it2;
});
m.def("get_second_item_from_iterable", [](const py::iterable &iter) {
// This tests the prefix increment operator
py::iterator it = iter.begin();
++it;
return *it;
});
m.def("get_frozenset_from_iterable", m.def("get_frozenset_from_iterable",
[](const py::iterable &iter) { return py::frozenset(iter); }); [](const py::iterable &iter) { return py::frozenset(iter); });
m.def("get_list_from_iterable", [](const py::iterable &iter) { return py::list(iter); }); m.def("get_list_from_iterable", [](const py::iterable &iter) { return py::list(iter); });

View File

@ -52,6 +52,11 @@ def test_from_iterable(pytype, from_iter_func):
def test_iterable(doc): def test_iterable(doc):
assert doc(m.get_iterable) == "get_iterable() -> Iterable" assert doc(m.get_iterable) == "get_iterable() -> Iterable"
lins = [1, 2, 3]
i = m.get_first_item_from_iterable(lins)
assert i == 1
i = m.get_second_item_from_iterable(lins)
assert i == 2
def test_float(doc): def test_float(doc):