Fix py::make_iterator's __next__() for past-the-end calls

Fixes #896.

From Python docs: "Once an iterator’s `__next__()` method raises
`StopIteration`, it must continue to do so on subsequent calls.
Implementations that do not obey this property are deemed broken."
This commit is contained in:
Dean Moldovan 2017-06-09 16:49:04 +02:00
parent 17cc39c818
commit caedf74a89
2 changed files with 37 additions and 9 deletions

View File

@ -1353,7 +1353,7 @@ template <typename Iterator, typename Sentinel, bool KeyIterator, return_value_p
struct iterator_state {
Iterator it;
Sentinel end;
bool first;
bool first_or_done;
};
NAMESPACE_END(detail)
@ -1374,17 +1374,19 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> ValueType {
if (!s.first)
if (!s.first_or_done)
++s.it;
else
s.first = false;
if (s.it == s.end)
s.first_or_done = false;
if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}
return *s.it;
}, std::forward<Extra>(extra)..., Policy);
}
return (iterator) cast(state { first, last, true });
return cast(state{first, last, true});
}
/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a
@ -1401,17 +1403,19 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> KeyType {
if (!s.first)
if (!s.first_or_done)
++s.it;
else
s.first = false;
if (s.it == s.end)
s.first_or_done = false;
if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}
return (*s.it).first;
}, std::forward<Extra>(extra)..., Policy);
}
return (iterator) cast(state { first, last, true });
return cast(state{first, last, true});
}
/// Makes an iterator over values of an stl container or other container supporting

View File

@ -21,6 +21,17 @@ def test_generalized_iterators():
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
# __next__ must continue to raise StopIteration
it = IntPairs([(0, 0)]).nonzero()
for _ in range(3):
with pytest.raises(StopIteration):
next(it)
it = IntPairs([(0, 0)]).nonzero_keys()
for _ in range(3):
with pytest.raises(StopIteration):
next(it)
def test_sequence():
from pybind11_tests import ConstructorStats
@ -45,6 +56,12 @@ def test_sequence():
rev2 = s[::-1]
assert cstats.values() == ['of size', '5']
it = iter(Sequence(0))
for _ in range(3): # __next__ must continue to raise StopIteration
with pytest.raises(StopIteration):
next(it)
assert cstats.values() == ['of size', '0']
expected = [0, 56.78, 0, 0, 12.34]
assert allclose(rev, expected)
assert allclose(rev2, expected)
@ -55,6 +72,8 @@ def test_sequence():
assert allclose(rev, [2, 56.78, 2, 0, 2])
assert cstats.alive() == 4
del it
assert cstats.alive() == 3
del s
assert cstats.alive() == 2
@ -90,6 +109,11 @@ def test_map_iterator():
for k, v in m.items():
assert v == expected[k]
it = iter(StringMap({}))
for _ in range(3): # __next__ must continue to raise StopIteration
with pytest.raises(StopIteration):
next(it)
def test_python_iterator_in_cpp():
import pybind11_tests.sequences_and_iterators as m