diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 786e36f6a..6d25fa57d 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1353,7 +1353,7 @@ template (handle(), "iterator") .def("__iter__", [](state &s) -> state& { return s; }) .def("__next__", [](state &s) -> ValueType { - if (!s.first) + if (!s.first_or_done) ++s.it; else - s.first = false; - if (s.it == s.end) + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; throw stop_iteration(); + } return *s.it; }, std::forward(extra)..., Policy); } - return (iterator) cast(state { first, last, true }); + return cast(state{first, last, true}); } /// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a @@ -1401,17 +1403,19 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { class_(handle(), "iterator") .def("__iter__", [](state &s) -> state& { return s; }) .def("__next__", [](state &s) -> KeyType { - if (!s.first) + if (!s.first_or_done) ++s.it; else - s.first = false; - if (s.it == s.end) + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; throw stop_iteration(); + } return (*s.it).first; }, std::forward(extra)..., Policy); } - return (iterator) cast(state { first, last, true }); + return cast(state{first, last, true}); } /// Makes an iterator over values of an stl container or other container supporting diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index 30b6aaf4b..e04c579dd 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -21,6 +21,17 @@ def test_generalized_iterators(): assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1] assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == [] + # __next__ must continue to raise StopIteration + it = IntPairs([(0, 0)]).nonzero() + for _ in range(3): + with pytest.raises(StopIteration): + next(it) + + it = IntPairs([(0, 0)]).nonzero_keys() + for _ in range(3): + with pytest.raises(StopIteration): + next(it) + def test_sequence(): from pybind11_tests import ConstructorStats @@ -45,6 +56,12 @@ def test_sequence(): rev2 = s[::-1] assert cstats.values() == ['of size', '5'] + it = iter(Sequence(0)) + for _ in range(3): # __next__ must continue to raise StopIteration + with pytest.raises(StopIteration): + next(it) + assert cstats.values() == ['of size', '0'] + expected = [0, 56.78, 0, 0, 12.34] assert allclose(rev, expected) assert allclose(rev2, expected) @@ -55,6 +72,8 @@ def test_sequence(): assert allclose(rev, [2, 56.78, 2, 0, 2]) + assert cstats.alive() == 4 + del it assert cstats.alive() == 3 del s assert cstats.alive() == 2 @@ -90,6 +109,11 @@ def test_map_iterator(): for k, v in m.items(): assert v == expected[k] + it = iter(StringMap({})) + for _ in range(3): # __next__ must continue to raise StopIteration + with pytest.raises(StopIteration): + next(it) + def test_python_iterator_in_cpp(): import pybind11_tests.sequences_and_iterators as m