diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 370e52cff..285f3b18c 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1967,29 +1967,54 @@ struct iterator_state { }; // Note: these helpers take the iterator by non-const reference because some -// iterators in the wild can't be dereferenced when const. C++ needs the extra parens in decltype -// to enforce an lvalue. The & after Iterator is required for MSVC < 16.9. SFINAE cannot be -// reused for result_type due to bugs in ICC, NVCC, and PGI compilers. See PR #3293. -template ()))> +// iterators in the wild can't be dereferenced when const. The & after Iterator +// is required for MSVC < 16.9. SFINAE cannot be reused for result_type due to +// bugs in ICC, NVCC, and PGI compilers. See PR #3293. +template ())> struct iterator_access { - using result_type = decltype((*std::declval())); + using result_type = decltype(*std::declval()); // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 result_type operator()(Iterator &it) const { return *it; } }; -template ()).first)) > -struct iterator_key_access { - using result_type = decltype(((*std::declval()).first)); +template ()).first) > +class iterator_key_access { +private: + using pair_type = decltype(*std::declval()); + +public: + /* If either the pair itself or the element of the pair is a reference, we + * want to return a reference, otherwise a value. When the decltype + * expression is parenthesized it is based on the value category of the + * expression; otherwise it is the declared type of the pair member. + * The use of declval in the second branch rather than directly + * using *std::declval() is a workaround for nvcc + * (it's not used in the first branch because going via decltype and back + * through declval does not perfectly preserve references). + */ + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).first)), + decltype(std::declval().first) + >; result_type operator()(Iterator &it) const { return (*it).first; } }; -template ()).second))> -struct iterator_value_access { - using result_type = decltype(((*std::declval()).second)); +template ()).second)> +class iterator_value_access { +private: + using pair_type = decltype(*std::declval()); + +public: + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).second)), + decltype(std::declval().second) + >; result_type operator()(Iterator &it) const { return (*it).second; } diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index 9de69338b..a378128ae 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -38,6 +38,17 @@ bool operator==(const NonZeroIterator>& it, const NonZeroSentine return !(*it).first || !(*it).second; } +/* Iterator where dereferencing returns prvalues instead of references. */ +template +class NonRefIterator { + const T* ptr_; +public: + explicit NonRefIterator(const T *ptr) : ptr_(ptr) {} + T operator*() const { return T(*ptr_); } + NonRefIterator& operator++() { ++ptr_; return *this; } + bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; } +}; + class NonCopyableInt { public: explicit NonCopyableInt(int value) : value_(value) {} @@ -331,7 +342,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) { py::class_(m, "IntPairs") .def(py::init>>()) .def("nonzero", [](const IntPairs& s) { - return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) .def("nonzero_keys", [](const IntPairs& s) { return py::make_key_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); @@ -340,6 +351,20 @@ TEST_SUBMODULE(sequences_and_iterators, m) { return py::make_value_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) + // test iterator that returns values instead of references + .def("nonref", [](const IntPairs& s) { + return py::make_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_keys", [](const IntPairs& s) { + return py::make_key_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_values", [](const IntPairs& s) { + return py::make_value_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + // test single-argument make_iterator .def("simple_iterator", [](IntPairs& self) { return py::make_iterator(self); diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index 38e2ab5b7..6985918a1 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -52,6 +52,13 @@ def test_generalized_iterators(): next(it) +def test_nonref_iterators(): + pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)]) + assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)] + assert list(pairs.nonref_keys()) == [1, 3, 0] + assert list(pairs.nonref_values()) == [2, 4, 5] + + def test_generalized_iterators_simple(): assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [ (1, 2),