diff --git a/docs/changelog.rst b/docs/changelog.rst index 875b585dc..0485e925a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,7 +25,8 @@ Breaking changes queued for v2.0.0 (Not yet released) * ``std::enable_shared_from_this<>`` now also works for ``const`` values * A return value policy can now be passed to ``handle::operator()`` * ``make_iterator()`` improvements for better compatibility with various types - (now uses prefix increment operator) + (now uses prefix increment operator); it now also accepts iterators with + different begin/end types as long as they are equality comparable. * ``arg()`` now accepts a wider range of argument types for default values * Added ``repr()`` method to the ``handle`` class. * Added support for registering structured dtypes via ``PYBIND11_NUMPY_DTYPE()`` macro. diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 675827de2..f9e840f16 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1119,8 +1119,10 @@ PYBIND11_NOINLINE inline void keep_alive_impl(int Nurse, int Patient, handle arg keep_alive_impl(nurse, patient); } -template struct iterator_state { - Iterator it, end; +template +struct iterator_state { + Iterator it; + Sentinel end; bool first; }; @@ -1129,10 +1131,11 @@ NAMESPACE_END(detail) template detail::init init() { return detail::init(); } template ()), typename... Extra> -iterator make_iterator(Iterator first, Iterator last, Extra &&... extra) { - typedef detail::iterator_state state; +iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; if (!detail::get_type_info(typeid(state))) { class_(handle(), "") @@ -1151,10 +1154,11 @@ iterator make_iterator(Iterator first, Iterator last, Extra &&... extra) { return (iterator) cast(state { first, last, true }); } template ()->first), + typename Sentinel, + typename KeyType = decltype((*std::declval()).first), typename... Extra> -iterator make_key_iterator(Iterator first, Iterator last, Extra &&... extra) { - typedef detail::iterator_state state; +iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; if (!detail::get_type_info(typeid(state))) { class_(handle(), "") @@ -1166,7 +1170,7 @@ iterator make_key_iterator(Iterator first, Iterator last, Extra &&... extra) { s.first = false; if (s.it == s.end) throw stop_iteration(); - return s.it->first; + return (*s.it).first; }, return_value_policy::reference_internal, std::forward(extra)...); } diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index a67f3ca0e..9434510e0 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -245,7 +245,7 @@ pybind11::class_, holder_type> bind_vector(pybind11::m cl.def("__iter__", [](Vector &v) { - return pybind11::make_iterator(v.begin(), v.end()); + return pybind11::make_iterator(v.begin(), v.end()); }, pybind11::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ ); diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index a92c6bf62..39e342ba6 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -116,6 +116,15 @@ private: float *m_data; }; +class IntPairs { +public: + IntPairs(std::vector> data) : data_(std::move(data)) {} + const std::pair* begin() const { return data_.data(); } + +private: + std::vector> data_; +}; + // Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic // map-like functionality. class StringMap { @@ -143,8 +152,24 @@ public: decltype(map.cend()) end() const { return map.cend(); } }; +template +class NonZeroIterator { + const T* ptr_; +public: + NonZeroIterator(const T* ptr) : ptr_(ptr) {} + const T& operator*() const { return *ptr_; } + NonZeroIterator& operator++() { ++ptr_; return *this; } +}; + +class NonZeroSentinel {}; + +template +bool operator==(const NonZeroIterator>& it, const NonZeroSentinel&) { + return !(*it).first || !(*it).second; +} void init_ex_sequences_and_iterators(py::module &m) { + py::class_ seq(m, "Sequence"); seq.def(py::init()) @@ -210,6 +235,15 @@ void init_ex_sequences_and_iterators(py::module &m) { py::keep_alive<0, 1>()) ; + py::class_(m, "IntPairs") + .def(py::init>>()) + .def("nonzero", [](const IntPairs& s) { + return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + }, py::keep_alive<0, 1>()) + .def("nonzero_keys", [](const IntPairs& s) { + return py::make_key_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + }, py::keep_alive<0, 1>()); + #if 0 // Obsolete: special data structure for exposing custom iterator types to python diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index a35dc584b..c83c4e57c 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -10,6 +10,18 @@ def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0): return all(isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list)) +def test_generalized_iterators(): + from pybind11_tests import IntPairs + + assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)] + assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)] + assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero()) == [] + + assert list(IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_keys()) == [1, 3] + assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1] + assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == [] + + def test_sequence(): from pybind11_tests import Sequence, ConstructorStats