mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
Handle all py::iterator errors
Before this, `py::iterator` didn't do any error handling, so code like: ```c++ for (auto item : py::int_(1)) { // ... } ``` would just silently skip the loop. The above now throws `TypeError` as expected. This is a breaking behavior change, but any code which relied on the silent skip was probably broken anyway. Also, errors returned by `PyIter_Next()` are now properly handled.
This commit is contained in:
parent
cecb577a19
commit
f7685826e2
@ -591,47 +591,66 @@ NAMESPACE_END(detail)
|
||||
|
||||
/// \addtogroup pytypes
|
||||
/// @{
|
||||
|
||||
/** \rst
|
||||
Wraps a Python iterator so that it can also be used as a C++ input iterator
|
||||
|
||||
Caveat: copying an iterator does not (and cannot) clone the internal
|
||||
state of the Python iterable. This also applies to the post-increment
|
||||
operator. This iterator should only be used to retrieve the current
|
||||
value using ``operator*()``.
|
||||
\endrst */
|
||||
class iterator : public object {
|
||||
public:
|
||||
/** Caveat: copying an iterator does not (and cannot) clone the internal
|
||||
state of the Python iterable */
|
||||
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
|
||||
|
||||
iterator& operator++() {
|
||||
if (m_ptr)
|
||||
advance();
|
||||
advance();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** Caveat: this postincrement operator does not (and cannot) clone the
|
||||
internal state of the Python iterable. It should only be used to
|
||||
retrieve the current iterate using <tt>operator*()</tt> */
|
||||
iterator operator++(int) {
|
||||
iterator rv(*this);
|
||||
rv.value = value;
|
||||
if (m_ptr)
|
||||
advance();
|
||||
auto rv = *this;
|
||||
advance();
|
||||
return rv;
|
||||
}
|
||||
|
||||
bool operator==(const iterator &it) const { return *it == **this; }
|
||||
bool operator!=(const iterator &it) const { return *it != **this; }
|
||||
|
||||
handle operator*() const {
|
||||
if (!ready && m_ptr) {
|
||||
if (m_ptr && !value.ptr()) {
|
||||
auto& self = const_cast<iterator &>(*this);
|
||||
self.advance();
|
||||
self.ready = true;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
const handle *operator->() const { operator*(); return &value; }
|
||||
|
||||
/** \rst
|
||||
The value which marks the end of the iteration. ``it == iterator::sentinel()``
|
||||
is equivalent to catching ``StopIteration`` in Python.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
void foo(py::iterator it) {
|
||||
while (it != py::iterator::sentinel()) {
|
||||
// use `*it`
|
||||
++it;
|
||||
}
|
||||
}
|
||||
\endrst */
|
||||
static iterator sentinel() { return {}; }
|
||||
|
||||
friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); }
|
||||
friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); }
|
||||
|
||||
private:
|
||||
void advance() { value = reinterpret_steal<object>(PyIter_Next(m_ptr)); }
|
||||
void advance() {
|
||||
value = reinterpret_steal<object>(PyIter_Next(m_ptr));
|
||||
if (PyErr_Occurred()) { throw error_already_set(); }
|
||||
}
|
||||
|
||||
private:
|
||||
object value = {};
|
||||
bool ready = false;
|
||||
};
|
||||
|
||||
class iterable : public object {
|
||||
@ -1032,15 +1051,17 @@ inline str repr(handle h) {
|
||||
#endif
|
||||
return reinterpret_steal<str>(str_value);
|
||||
}
|
||||
|
||||
inline iterator iter(handle obj) {
|
||||
PyObject *result = PyObject_GetIter(obj.ptr());
|
||||
if (!result) { throw error_already_set(); }
|
||||
return reinterpret_steal<iterator>(result);
|
||||
}
|
||||
/// @} python_builtins
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename D> iterator object_api<D>::begin() const {
|
||||
return reinterpret_steal<iterator>(PyObject_GetIter(derived().ptr()));
|
||||
}
|
||||
template <typename D> iterator object_api<D>::end() const {
|
||||
return {};
|
||||
}
|
||||
template <typename D> iterator object_api<D>::begin() const { return iter(derived()); }
|
||||
template <typename D> iterator object_api<D>::end() const { return iterator::sentinel(); }
|
||||
template <typename D> item_accessor object_api<D>::operator[](handle key) const {
|
||||
return {derived(), reinterpret_borrow<object>(key)};
|
||||
}
|
||||
|
@ -169,7 +169,8 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
|
||||
return !(*it).first || !(*it).second;
|
||||
}
|
||||
|
||||
test_initializer sequences_and_iterators([](py::module &m) {
|
||||
test_initializer sequences_and_iterators([](py::module &pm) {
|
||||
auto m = pm.def_submodule("sequences_and_iterators");
|
||||
|
||||
py::class_<Sequence> seq(m, "Sequence");
|
||||
|
||||
@ -272,4 +273,21 @@ test_initializer sequences_and_iterators([](py::module &m) {
|
||||
On the actual Sequence object, the iterator would be constructed as follows:
|
||||
.def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
|
||||
#endif
|
||||
|
||||
m.def("object_to_list", [](py::object o) {
|
||||
auto l = py::list();
|
||||
for (auto item : o) {
|
||||
l.append(item);
|
||||
}
|
||||
return l;
|
||||
});
|
||||
|
||||
m.def("iterator_to_list", [](py::iterator it) {
|
||||
auto l = py::list();
|
||||
while (it != py::iterator::sentinel()) {
|
||||
l.append(*it);
|
||||
++it;
|
||||
}
|
||||
return l;
|
||||
});
|
||||
});
|
||||
|
@ -11,7 +11,7 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
|
||||
|
||||
|
||||
def test_generalized_iterators():
|
||||
from pybind11_tests import IntPairs
|
||||
from pybind11_tests.sequences_and_iterators import IntPairs
|
||||
|
||||
assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
|
||||
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
|
||||
@ -23,7 +23,8 @@ def test_generalized_iterators():
|
||||
|
||||
|
||||
def test_sequence():
|
||||
from pybind11_tests import Sequence, ConstructorStats
|
||||
from pybind11_tests import ConstructorStats
|
||||
from pybind11_tests.sequences_and_iterators import Sequence
|
||||
|
||||
cstats = ConstructorStats.get(Sequence)
|
||||
|
||||
@ -71,7 +72,7 @@ def test_sequence():
|
||||
|
||||
|
||||
def test_map_iterator():
|
||||
from pybind11_tests import StringMap
|
||||
from pybind11_tests.sequences_and_iterators import StringMap
|
||||
|
||||
m = StringMap({'hi': 'bye', 'black': 'white'})
|
||||
assert m['hi'] == 'bye'
|
||||
@ -88,3 +89,27 @@ def test_map_iterator():
|
||||
assert m[k] == expected[k]
|
||||
for k, v in m.items():
|
||||
assert v == expected[k]
|
||||
|
||||
|
||||
def test_python_iterator_in_cpp():
|
||||
import pybind11_tests.sequences_and_iterators as m
|
||||
|
||||
t = (1, 2, 3)
|
||||
assert m.object_to_list(t) == [1, 2, 3]
|
||||
assert m.object_to_list(iter(t)) == [1, 2, 3]
|
||||
assert m.iterator_to_list(iter(t)) == [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
m.object_to_list(1)
|
||||
assert "object is not iterable" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
m.iterator_to_list(1)
|
||||
assert "incompatible function arguments" in str(excinfo.value)
|
||||
|
||||
def bad_next_call():
|
||||
raise RuntimeError("py::iterator::advance() should propagate errors")
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
m.iterator_to_list(iter(bad_next_call, None))
|
||||
assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
|
||||
|
Loading…
Reference in New Issue
Block a user