Handle all py::iterator errors

Before this, `py::iterator` didn't do any error handling, so code like:
```c++
for (auto item : py::int_(1)) {
    // ...
}
```
would just silently skip the loop. The above now throws `TypeError` as
expected. This is a breaking behavior change, but any code which relied
on the silent skip was probably broken anyway.

Also, errors returned by `PyIter_Next()` are now properly handled.
This commit is contained in:
Dean Moldovan 2017-02-08 14:31:49 +01:00 committed by Wenzel Jakob
parent cecb577a19
commit f7685826e2
3 changed files with 92 additions and 28 deletions

View File

@ -591,47 +591,66 @@ NAMESPACE_END(detail)
/// \addtogroup pytypes /// \addtogroup pytypes
/// @{ /// @{
/** \rst
Wraps a Python iterator so that it can also be used as a C++ input iterator
Caveat: copying an iterator does not (and cannot) clone the internal
state of the Python iterable. This also applies to the post-increment
operator. This iterator should only be used to retrieve the current
value using ``operator*()``.
\endrst */
class iterator : public object { class iterator : public object {
public: public:
/** Caveat: copying an iterator does not (and cannot) clone the internal
state of the Python iterable */
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
iterator& operator++() { iterator& operator++() {
if (m_ptr) advance();
advance();
return *this; return *this;
} }
/** Caveat: this postincrement operator does not (and cannot) clone the
internal state of the Python iterable. It should only be used to
retrieve the current iterate using <tt>operator*()</tt> */
iterator operator++(int) { iterator operator++(int) {
iterator rv(*this); auto rv = *this;
rv.value = value; advance();
if (m_ptr)
advance();
return rv; return rv;
} }
bool operator==(const iterator &it) const { return *it == **this; }
bool operator!=(const iterator &it) const { return *it != **this; }
handle operator*() const { handle operator*() const {
if (!ready && m_ptr) { if (m_ptr && !value.ptr()) {
auto& self = const_cast<iterator &>(*this); auto& self = const_cast<iterator &>(*this);
self.advance(); self.advance();
self.ready = true;
} }
return value; return value;
} }
const handle *operator->() const { operator*(); return &value; }
/** \rst
The value which marks the end of the iteration. ``it == iterator::sentinel()``
is equivalent to catching ``StopIteration`` in Python.
.. code-block:: cpp
void foo(py::iterator it) {
while (it != py::iterator::sentinel()) {
// use `*it`
++it;
}
}
\endrst */
static iterator sentinel() { return {}; }
friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); }
friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); }
private: private:
void advance() { value = reinterpret_steal<object>(PyIter_Next(m_ptr)); } void advance() {
value = reinterpret_steal<object>(PyIter_Next(m_ptr));
if (PyErr_Occurred()) { throw error_already_set(); }
}
private: private:
object value = {}; object value = {};
bool ready = false;
}; };
class iterable : public object { class iterable : public object {
@ -1032,15 +1051,17 @@ inline str repr(handle h) {
#endif #endif
return reinterpret_steal<str>(str_value); return reinterpret_steal<str>(str_value);
} }
inline iterator iter(handle obj) {
PyObject *result = PyObject_GetIter(obj.ptr());
if (!result) { throw error_already_set(); }
return reinterpret_steal<iterator>(result);
}
/// @} python_builtins /// @} python_builtins
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
template <typename D> iterator object_api<D>::begin() const { template <typename D> iterator object_api<D>::begin() const { return iter(derived()); }
return reinterpret_steal<iterator>(PyObject_GetIter(derived().ptr())); template <typename D> iterator object_api<D>::end() const { return iterator::sentinel(); }
}
template <typename D> iterator object_api<D>::end() const {
return {};
}
template <typename D> item_accessor object_api<D>::operator[](handle key) const { template <typename D> item_accessor object_api<D>::operator[](handle key) const {
return {derived(), reinterpret_borrow<object>(key)}; return {derived(), reinterpret_borrow<object>(key)};
} }

View File

@ -169,7 +169,8 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
return !(*it).first || !(*it).second; return !(*it).first || !(*it).second;
} }
test_initializer sequences_and_iterators([](py::module &m) { test_initializer sequences_and_iterators([](py::module &pm) {
auto m = pm.def_submodule("sequences_and_iterators");
py::class_<Sequence> seq(m, "Sequence"); py::class_<Sequence> seq(m, "Sequence");
@ -272,4 +273,21 @@ test_initializer sequences_and_iterators([](py::module &m) {
On the actual Sequence object, the iterator would be constructed as follows: On the actual Sequence object, the iterator would be constructed as follows:
.def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); }) .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
#endif #endif
m.def("object_to_list", [](py::object o) {
auto l = py::list();
for (auto item : o) {
l.append(item);
}
return l;
});
m.def("iterator_to_list", [](py::iterator it) {
auto l = py::list();
while (it != py::iterator::sentinel()) {
l.append(*it);
++it;
}
return l;
});
}); });

View File

@ -11,7 +11,7 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
def test_generalized_iterators(): def test_generalized_iterators():
from pybind11_tests import IntPairs from pybind11_tests.sequences_and_iterators import IntPairs
assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)] assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)] assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
@ -23,7 +23,8 @@ def test_generalized_iterators():
def test_sequence(): def test_sequence():
from pybind11_tests import Sequence, ConstructorStats from pybind11_tests import ConstructorStats
from pybind11_tests.sequences_and_iterators import Sequence
cstats = ConstructorStats.get(Sequence) cstats = ConstructorStats.get(Sequence)
@ -71,7 +72,7 @@ def test_sequence():
def test_map_iterator(): def test_map_iterator():
from pybind11_tests import StringMap from pybind11_tests.sequences_and_iterators import StringMap
m = StringMap({'hi': 'bye', 'black': 'white'}) m = StringMap({'hi': 'bye', 'black': 'white'})
assert m['hi'] == 'bye' assert m['hi'] == 'bye'
@ -88,3 +89,27 @@ def test_map_iterator():
assert m[k] == expected[k] assert m[k] == expected[k]
for k, v in m.items(): for k, v in m.items():
assert v == expected[k] assert v == expected[k]
def test_python_iterator_in_cpp():
import pybind11_tests.sequences_and_iterators as m
t = (1, 2, 3)
assert m.object_to_list(t) == [1, 2, 3]
assert m.object_to_list(iter(t)) == [1, 2, 3]
assert m.iterator_to_list(iter(t)) == [1, 2, 3]
with pytest.raises(TypeError) as excinfo:
m.object_to_list(1)
assert "object is not iterable" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
m.iterator_to_list(1)
assert "incompatible function arguments" in str(excinfo.value)
def bad_next_call():
raise RuntimeError("py::iterator::advance() should propagate errors")
with pytest.raises(RuntimeError) as excinfo:
m.iterator_to_list(iter(bad_next_call, None))
assert str(excinfo.value) == "py::iterator::advance() should propagate errors"