Fix make_key_iterator/make_value_iterator for prvalue iterators (#3348)

* Add a test showing a flaw in make_key_iterator/make_value_iterator

If the iterator dereference operator returns a value rather than a
reference (and that pair also does not *contain* references),
make_key_iterator and make_value_iterator will return a reference to a
temporary, causing a segfault.

* Fix make_key_iterator/make_value_iterator for prvalue iterators

If an iterator returns a pair<T1, T2> rather than a reference to a pair
or a pair of references, make_key_iterator and make_value_iterator would
return a reference to a temporary, typically leading to a segfault. This
is because the value category of member access to a prvalue is an
xvalue, not a prvalue, so decltype produces an rvalue reference type.
Fix the type calculation to handle this case.

I also removed some decltype parentheses that weren't needed, either
because the expression isn't one of the special cases for decltype or
because decltype was only used for SFINAE. Hopefully that makes the code
a bit more readable.

Closes #3347

* Attempt a workaround for nvcc
This commit is contained in:
Bruce Merry 2021-10-11 17:35:39 +02:00 committed by GitHub
parent 750e38dcfd
commit 8a7c266d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 12 deletions

View File

@ -1967,29 +1967,54 @@ struct iterator_state {
};
// Note: these helpers take the iterator by non-const reference because some
// iterators in the wild can't be dereferenced when const. C++ needs the extra parens in decltype
// to enforce an lvalue. The & after Iterator is required for MSVC < 16.9. SFINAE cannot be
// reused for result_type due to bugs in ICC, NVCC, and PGI compilers. See PR #3293.
template <typename Iterator, typename SFINAE = decltype((*std::declval<Iterator &>()))>
// iterators in the wild can't be dereferenced when const. The & after Iterator
// is required for MSVC < 16.9. SFINAE cannot be reused for result_type due to
// bugs in ICC, NVCC, and PGI compilers. See PR #3293.
template <typename Iterator, typename SFINAE = decltype(*std::declval<Iterator &>())>
struct iterator_access {
using result_type = decltype((*std::declval<Iterator &>()));
using result_type = decltype(*std::declval<Iterator &>());
// NOLINTNEXTLINE(readability-const-return-type) // PR #3263
result_type operator()(Iterator &it) const {
return *it;
}
};
template <typename Iterator, typename SFINAE = decltype(((*std::declval<Iterator &>()).first)) >
struct iterator_key_access {
using result_type = decltype(((*std::declval<Iterator &>()).first));
template <typename Iterator, typename SFINAE = decltype((*std::declval<Iterator &>()).first) >
class iterator_key_access {
private:
using pair_type = decltype(*std::declval<Iterator &>());
public:
/* If either the pair itself or the element of the pair is a reference, we
* want to return a reference, otherwise a value. When the decltype
* expression is parenthesized it is based on the value category of the
* expression; otherwise it is the declared type of the pair member.
* The use of declval<pair_type> in the second branch rather than directly
* using *std::declval<Iterator &>() is a workaround for nvcc
* (it's not used in the first branch because going via decltype and back
* through declval does not perfectly preserve references).
*/
using result_type = conditional_t<
std::is_reference<decltype(*std::declval<Iterator &>())>::value,
decltype(((*std::declval<Iterator &>()).first)),
decltype(std::declval<pair_type>().first)
>;
result_type operator()(Iterator &it) const {
return (*it).first;
}
};
template <typename Iterator, typename SFINAE = decltype(((*std::declval<Iterator &>()).second))>
struct iterator_value_access {
using result_type = decltype(((*std::declval<Iterator &>()).second));
template <typename Iterator, typename SFINAE = decltype((*std::declval<Iterator &>()).second)>
class iterator_value_access {
private:
using pair_type = decltype(*std::declval<Iterator &>());
public:
using result_type = conditional_t<
std::is_reference<decltype(*std::declval<Iterator &>())>::value,
decltype(((*std::declval<Iterator &>()).second)),
decltype(std::declval<pair_type>().second)
>;
result_type operator()(Iterator &it) const {
return (*it).second;
}

View File

@ -38,6 +38,17 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
return !(*it).first || !(*it).second;
}
/* Iterator where dereferencing returns prvalues instead of references. */
template<typename T>
class NonRefIterator {
const T* ptr_;
public:
explicit NonRefIterator(const T *ptr) : ptr_(ptr) {}
T operator*() const { return T(*ptr_); }
NonRefIterator& operator++() { ++ptr_; return *this; }
bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; }
};
class NonCopyableInt {
public:
explicit NonCopyableInt(int value) : value_(value) {}
@ -331,7 +342,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) {
py::class_<IntPairs>(m, "IntPairs")
.def(py::init<std::vector<std::pair<int, int>>>())
.def("nonzero", [](const IntPairs& s) {
return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
}, py::keep_alive<0, 1>())
.def("nonzero_keys", [](const IntPairs& s) {
return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
@ -340,6 +351,20 @@ TEST_SUBMODULE(sequences_and_iterators, m) {
return py::make_value_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
}, py::keep_alive<0, 1>())
// test iterator that returns values instead of references
.def("nonref", [](const IntPairs& s) {
return py::make_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
.def("nonref_keys", [](const IntPairs& s) {
return py::make_key_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
.def("nonref_values", [](const IntPairs& s) {
return py::make_value_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
// test single-argument make_iterator
.def("simple_iterator", [](IntPairs& self) {
return py::make_iterator(self);

View File

@ -52,6 +52,13 @@ def test_generalized_iterators():
next(it)
def test_nonref_iterators():
pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)])
assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)]
assert list(pairs.nonref_keys()) == [1, 3, 0]
assert list(pairs.nonref_values()) == [2, 4, 5]
def test_generalized_iterators_simple():
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [
(1, 2),