From 8a7c266d26e5c9de3beb478bedddcee26f026b04 Mon Sep 17 00:00:00 2001 From: Bruce Merry <1963944+bmerry@users.noreply.github.com> Date: Mon, 11 Oct 2021 17:35:39 +0200 Subject: [PATCH] Fix make_key_iterator/make_value_iterator for prvalue iterators (#3348) * Add a test showing a flaw in make_key_iterator/make_value_iterator If the iterator dereference operator returns a value rather than a reference (and that pair also does not *contain* references), make_key_iterator and make_value_iterator will return a reference to a temporary, causing a segfault. * Fix make_key_iterator/make_value_iterator for prvalue iterators If an iterator returns a pair rather than a reference to a pair or a pair of references, make_key_iterator and make_value_iterator would return a reference to a temporary, typically leading to a segfault. This is because the value category of member access to a prvalue is an xvalue, not a prvalue, so decltype produces an rvalue reference type. Fix the type calculation to handle this case. I also removed some decltype parentheses that weren't needed, either because the expression isn't one of the special cases for decltype or because decltype was only used for SFINAE. Hopefully that makes the code a bit more readable. Closes #3347 * Attempt a workaround for nvcc --- include/pybind11/pybind11.h | 47 ++++++++++++++++++++------ tests/test_sequences_and_iterators.cpp | 27 ++++++++++++++- tests/test_sequences_and_iterators.py | 7 ++++ 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 370e52cff..285f3b18c 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1967,29 +1967,54 @@ struct iterator_state { }; // Note: these helpers take the iterator by non-const reference because some -// iterators in the wild can't be dereferenced when const. C++ needs the extra parens in decltype -// to enforce an lvalue. The & after Iterator is required for MSVC < 16.9. SFINAE cannot be -// reused for result_type due to bugs in ICC, NVCC, and PGI compilers. See PR #3293. -template ()))> +// iterators in the wild can't be dereferenced when const. The & after Iterator +// is required for MSVC < 16.9. SFINAE cannot be reused for result_type due to +// bugs in ICC, NVCC, and PGI compilers. See PR #3293. +template ())> struct iterator_access { - using result_type = decltype((*std::declval())); + using result_type = decltype(*std::declval()); // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 result_type operator()(Iterator &it) const { return *it; } }; -template ()).first)) > -struct iterator_key_access { - using result_type = decltype(((*std::declval()).first)); +template ()).first) > +class iterator_key_access { +private: + using pair_type = decltype(*std::declval()); + +public: + /* If either the pair itself or the element of the pair is a reference, we + * want to return a reference, otherwise a value. When the decltype + * expression is parenthesized it is based on the value category of the + * expression; otherwise it is the declared type of the pair member. + * The use of declval in the second branch rather than directly + * using *std::declval() is a workaround for nvcc + * (it's not used in the first branch because going via decltype and back + * through declval does not perfectly preserve references). + */ + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).first)), + decltype(std::declval().first) + >; result_type operator()(Iterator &it) const { return (*it).first; } }; -template ()).second))> -struct iterator_value_access { - using result_type = decltype(((*std::declval()).second)); +template ()).second)> +class iterator_value_access { +private: + using pair_type = decltype(*std::declval()); + +public: + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).second)), + decltype(std::declval().second) + >; result_type operator()(Iterator &it) const { return (*it).second; } diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index 9de69338b..a378128ae 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -38,6 +38,17 @@ bool operator==(const NonZeroIterator>& it, const NonZeroSentine return !(*it).first || !(*it).second; } +/* Iterator where dereferencing returns prvalues instead of references. */ +template +class NonRefIterator { + const T* ptr_; +public: + explicit NonRefIterator(const T *ptr) : ptr_(ptr) {} + T operator*() const { return T(*ptr_); } + NonRefIterator& operator++() { ++ptr_; return *this; } + bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; } +}; + class NonCopyableInt { public: explicit NonCopyableInt(int value) : value_(value) {} @@ -331,7 +342,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) { py::class_(m, "IntPairs") .def(py::init>>()) .def("nonzero", [](const IntPairs& s) { - return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) .def("nonzero_keys", [](const IntPairs& s) { return py::make_key_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); @@ -340,6 +351,20 @@ TEST_SUBMODULE(sequences_and_iterators, m) { return py::make_value_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) + // test iterator that returns values instead of references + .def("nonref", [](const IntPairs& s) { + return py::make_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_keys", [](const IntPairs& s) { + return py::make_key_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_values", [](const IntPairs& s) { + return py::make_value_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + // test single-argument make_iterator .def("simple_iterator", [](IntPairs& self) { return py::make_iterator(self); diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index 38e2ab5b7..6985918a1 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -52,6 +52,13 @@ def test_generalized_iterators(): next(it) +def test_nonref_iterators(): + pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)]) + assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)] + assert list(pairs.nonref_keys()) == [1, 3, 0] + assert list(pairs.nonref_values()) == [2, 4, 5] + + def test_generalized_iterators_simple(): assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [ (1, 2),