Fix py::make_iterator's __next__() for past-the-end calls

Fixes #896.

From Python docs: "Once an iterator’s `__next__()` method raises
`StopIteration`, it must continue to do so on subsequent calls.
Implementations that do not obey this property are deemed broken."
This commit is contained in:
Dean Moldovan 2017-06-09 16:49:04 +02:00
parent 17cc39c818
commit caedf74a89
2 changed files with 37 additions and 9 deletions

View File

@ -1353,7 +1353,7 @@ template <typename Iterator, typename Sentinel, bool KeyIterator, return_value_p
struct iterator_state { struct iterator_state {
Iterator it; Iterator it;
Sentinel end; Sentinel end;
bool first; bool first_or_done;
}; };
NAMESPACE_END(detail) NAMESPACE_END(detail)
@ -1374,17 +1374,19 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
class_<state>(handle(), "iterator") class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; }) .def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> ValueType { .def("__next__", [](state &s) -> ValueType {
if (!s.first) if (!s.first_or_done)
++s.it; ++s.it;
else else
s.first = false; s.first_or_done = false;
if (s.it == s.end) if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration(); throw stop_iteration();
}
return *s.it; return *s.it;
}, std::forward<Extra>(extra)..., Policy); }, std::forward<Extra>(extra)..., Policy);
} }
return (iterator) cast(state { first, last, true }); return cast(state{first, last, true});
} }
/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a /// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a
@ -1401,17 +1403,19 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
class_<state>(handle(), "iterator") class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; }) .def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> KeyType { .def("__next__", [](state &s) -> KeyType {
if (!s.first) if (!s.first_or_done)
++s.it; ++s.it;
else else
s.first = false; s.first_or_done = false;
if (s.it == s.end) if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration(); throw stop_iteration();
}
return (*s.it).first; return (*s.it).first;
}, std::forward<Extra>(extra)..., Policy); }, std::forward<Extra>(extra)..., Policy);
} }
return (iterator) cast(state { first, last, true }); return cast(state{first, last, true});
} }
/// Makes an iterator over values of an stl container or other container supporting /// Makes an iterator over values of an stl container or other container supporting

View File

@ -21,6 +21,17 @@ def test_generalized_iterators():
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1] assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == [] assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
# __next__ must continue to raise StopIteration
it = IntPairs([(0, 0)]).nonzero()
for _ in range(3):
with pytest.raises(StopIteration):
next(it)
it = IntPairs([(0, 0)]).nonzero_keys()
for _ in range(3):
with pytest.raises(StopIteration):
next(it)
def test_sequence(): def test_sequence():
from pybind11_tests import ConstructorStats from pybind11_tests import ConstructorStats
@ -45,6 +56,12 @@ def test_sequence():
rev2 = s[::-1] rev2 = s[::-1]
assert cstats.values() == ['of size', '5'] assert cstats.values() == ['of size', '5']
it = iter(Sequence(0))
for _ in range(3): # __next__ must continue to raise StopIteration
with pytest.raises(StopIteration):
next(it)
assert cstats.values() == ['of size', '0']
expected = [0, 56.78, 0, 0, 12.34] expected = [0, 56.78, 0, 0, 12.34]
assert allclose(rev, expected) assert allclose(rev, expected)
assert allclose(rev2, expected) assert allclose(rev2, expected)
@ -55,6 +72,8 @@ def test_sequence():
assert allclose(rev, [2, 56.78, 2, 0, 2]) assert allclose(rev, [2, 56.78, 2, 0, 2])
assert cstats.alive() == 4
del it
assert cstats.alive() == 3 assert cstats.alive() == 3
del s del s
assert cstats.alive() == 2 assert cstats.alive() == 2
@ -90,6 +109,11 @@ def test_map_iterator():
for k, v in m.items(): for k, v in m.items():
assert v == expected[k] assert v == expected[k]
it = iter(StringMap({}))
for _ in range(3): # __next__ must continue to raise StopIteration
with pytest.raises(StopIteration):
next(it)
def test_python_iterator_in_cpp(): def test_python_iterator_in_cpp():
import pybind11_tests.sequences_and_iterators as m import pybind11_tests.sequences_and_iterators as m