mirror of
https://github.com/pybind/pybind11.git
synced 2024-12-01 17:37:15 +00:00
Handle all py::iterator errors
Before this, `py::iterator` didn't do any error handling, so code like: ```c++ for (auto item : py::int_(1)) { // ... } ``` would just silently skip the loop. The above now throws `TypeError` as expected. This is a breaking behavior change, but any code which relied on the silent skip was probably broken anyway. Also, errors returned by `PyIter_Next()` are now properly handled.
This commit is contained in:
parent
cecb577a19
commit
f7685826e2
@ -591,47 +591,66 @@ NAMESPACE_END(detail)
|
|||||||
|
|
||||||
/// \addtogroup pytypes
|
/// \addtogroup pytypes
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/** \rst
|
||||||
|
Wraps a Python iterator so that it can also be used as a C++ input iterator
|
||||||
|
|
||||||
|
Caveat: copying an iterator does not (and cannot) clone the internal
|
||||||
|
state of the Python iterable. This also applies to the post-increment
|
||||||
|
operator. This iterator should only be used to retrieve the current
|
||||||
|
value using ``operator*()``.
|
||||||
|
\endrst */
|
||||||
class iterator : public object {
|
class iterator : public object {
|
||||||
public:
|
public:
|
||||||
/** Caveat: copying an iterator does not (and cannot) clone the internal
|
|
||||||
state of the Python iterable */
|
|
||||||
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
|
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
|
||||||
|
|
||||||
iterator& operator++() {
|
iterator& operator++() {
|
||||||
if (m_ptr)
|
|
||||||
advance();
|
advance();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Caveat: this postincrement operator does not (and cannot) clone the
|
|
||||||
internal state of the Python iterable. It should only be used to
|
|
||||||
retrieve the current iterate using <tt>operator*()</tt> */
|
|
||||||
iterator operator++(int) {
|
iterator operator++(int) {
|
||||||
iterator rv(*this);
|
auto rv = *this;
|
||||||
rv.value = value;
|
|
||||||
if (m_ptr)
|
|
||||||
advance();
|
advance();
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const iterator &it) const { return *it == **this; }
|
|
||||||
bool operator!=(const iterator &it) const { return *it != **this; }
|
|
||||||
|
|
||||||
handle operator*() const {
|
handle operator*() const {
|
||||||
if (!ready && m_ptr) {
|
if (m_ptr && !value.ptr()) {
|
||||||
auto& self = const_cast<iterator &>(*this);
|
auto& self = const_cast<iterator &>(*this);
|
||||||
self.advance();
|
self.advance();
|
||||||
self.ready = true;
|
|
||||||
}
|
}
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handle *operator->() const { operator*(); return &value; }
|
||||||
|
|
||||||
|
/** \rst
|
||||||
|
The value which marks the end of the iteration. ``it == iterator::sentinel()``
|
||||||
|
is equivalent to catching ``StopIteration`` in Python.
|
||||||
|
|
||||||
|
.. code-block:: cpp
|
||||||
|
|
||||||
|
void foo(py::iterator it) {
|
||||||
|
while (it != py::iterator::sentinel()) {
|
||||||
|
// use `*it`
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
\endrst */
|
||||||
|
static iterator sentinel() { return {}; }
|
||||||
|
|
||||||
|
friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); }
|
||||||
|
friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void advance() { value = reinterpret_steal<object>(PyIter_Next(m_ptr)); }
|
void advance() {
|
||||||
|
value = reinterpret_steal<object>(PyIter_Next(m_ptr));
|
||||||
|
if (PyErr_Occurred()) { throw error_already_set(); }
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
object value = {};
|
object value = {};
|
||||||
bool ready = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class iterable : public object {
|
class iterable : public object {
|
||||||
@ -1032,15 +1051,17 @@ inline str repr(handle h) {
|
|||||||
#endif
|
#endif
|
||||||
return reinterpret_steal<str>(str_value);
|
return reinterpret_steal<str>(str_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline iterator iter(handle obj) {
|
||||||
|
PyObject *result = PyObject_GetIter(obj.ptr());
|
||||||
|
if (!result) { throw error_already_set(); }
|
||||||
|
return reinterpret_steal<iterator>(result);
|
||||||
|
}
|
||||||
/// @} python_builtins
|
/// @} python_builtins
|
||||||
|
|
||||||
NAMESPACE_BEGIN(detail)
|
NAMESPACE_BEGIN(detail)
|
||||||
template <typename D> iterator object_api<D>::begin() const {
|
template <typename D> iterator object_api<D>::begin() const { return iter(derived()); }
|
||||||
return reinterpret_steal<iterator>(PyObject_GetIter(derived().ptr()));
|
template <typename D> iterator object_api<D>::end() const { return iterator::sentinel(); }
|
||||||
}
|
|
||||||
template <typename D> iterator object_api<D>::end() const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
template <typename D> item_accessor object_api<D>::operator[](handle key) const {
|
template <typename D> item_accessor object_api<D>::operator[](handle key) const {
|
||||||
return {derived(), reinterpret_borrow<object>(key)};
|
return {derived(), reinterpret_borrow<object>(key)};
|
||||||
}
|
}
|
||||||
|
@ -169,7 +169,8 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
|
|||||||
return !(*it).first || !(*it).second;
|
return !(*it).first || !(*it).second;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_initializer sequences_and_iterators([](py::module &m) {
|
test_initializer sequences_and_iterators([](py::module &pm) {
|
||||||
|
auto m = pm.def_submodule("sequences_and_iterators");
|
||||||
|
|
||||||
py::class_<Sequence> seq(m, "Sequence");
|
py::class_<Sequence> seq(m, "Sequence");
|
||||||
|
|
||||||
@ -272,4 +273,21 @@ test_initializer sequences_and_iterators([](py::module &m) {
|
|||||||
On the actual Sequence object, the iterator would be constructed as follows:
|
On the actual Sequence object, the iterator would be constructed as follows:
|
||||||
.def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
|
.def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
m.def("object_to_list", [](py::object o) {
|
||||||
|
auto l = py::list();
|
||||||
|
for (auto item : o) {
|
||||||
|
l.append(item);
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("iterator_to_list", [](py::iterator it) {
|
||||||
|
auto l = py::list();
|
||||||
|
while (it != py::iterator::sentinel()) {
|
||||||
|
l.append(*it);
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
@ -11,7 +11,7 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
|
|||||||
|
|
||||||
|
|
||||||
def test_generalized_iterators():
|
def test_generalized_iterators():
|
||||||
from pybind11_tests import IntPairs
|
from pybind11_tests.sequences_and_iterators import IntPairs
|
||||||
|
|
||||||
assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
|
assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
|
||||||
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
|
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
|
||||||
@ -23,7 +23,8 @@ def test_generalized_iterators():
|
|||||||
|
|
||||||
|
|
||||||
def test_sequence():
|
def test_sequence():
|
||||||
from pybind11_tests import Sequence, ConstructorStats
|
from pybind11_tests import ConstructorStats
|
||||||
|
from pybind11_tests.sequences_and_iterators import Sequence
|
||||||
|
|
||||||
cstats = ConstructorStats.get(Sequence)
|
cstats = ConstructorStats.get(Sequence)
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ def test_sequence():
|
|||||||
|
|
||||||
|
|
||||||
def test_map_iterator():
|
def test_map_iterator():
|
||||||
from pybind11_tests import StringMap
|
from pybind11_tests.sequences_and_iterators import StringMap
|
||||||
|
|
||||||
m = StringMap({'hi': 'bye', 'black': 'white'})
|
m = StringMap({'hi': 'bye', 'black': 'white'})
|
||||||
assert m['hi'] == 'bye'
|
assert m['hi'] == 'bye'
|
||||||
@ -88,3 +89,27 @@ def test_map_iterator():
|
|||||||
assert m[k] == expected[k]
|
assert m[k] == expected[k]
|
||||||
for k, v in m.items():
|
for k, v in m.items():
|
||||||
assert v == expected[k]
|
assert v == expected[k]
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_iterator_in_cpp():
|
||||||
|
import pybind11_tests.sequences_and_iterators as m
|
||||||
|
|
||||||
|
t = (1, 2, 3)
|
||||||
|
assert m.object_to_list(t) == [1, 2, 3]
|
||||||
|
assert m.object_to_list(iter(t)) == [1, 2, 3]
|
||||||
|
assert m.iterator_to_list(iter(t)) == [1, 2, 3]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as excinfo:
|
||||||
|
m.object_to_list(1)
|
||||||
|
assert "object is not iterable" in str(excinfo.value)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as excinfo:
|
||||||
|
m.iterator_to_list(1)
|
||||||
|
assert "incompatible function arguments" in str(excinfo.value)
|
||||||
|
|
||||||
|
def bad_next_call():
|
||||||
|
raise RuntimeError("py::iterator::advance() should propagate errors")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
m.iterator_to_list(iter(bad_next_call, None))
|
||||||
|
assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
|
||||||
|
Loading…
Reference in New Issue
Block a user