mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
8a7c266d26
* Add a test showing a flaw in make_key_iterator/make_value_iterator If the iterator dereference operator returns a value rather than a reference (and that pair also does not *contain* references), make_key_iterator and make_value_iterator will return a reference to a temporary, causing a segfault. * Fix make_key_iterator/make_value_iterator for prvalue iterators If an iterator returns a pair<T1, T2> rather than a reference to a pair or a pair of references, make_key_iterator and make_value_iterator would return a reference to a temporary, typically leading to a segfault. This is because the value category of member access to a prvalue is an xvalue, not a prvalue, so decltype produces an rvalue reference type. Fix the type calculation to handle this case. I also removed some decltype parentheses that weren't needed, either because the expression isn't one of the special cases for decltype or because decltype was only used for SFINAE. Hopefully that makes the code a bit more readable. Closes #3347 * Attempt a workaround for nvcc
254 lines
7.9 KiB
Python
254 lines
7.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
import pytest
|
|
|
|
from pybind11_tests import ConstructorStats
|
|
from pybind11_tests import sequences_and_iterators as m
|
|
|
|
|
|
def isclose(a, b, rel_tol=1e-05, abs_tol=0.0):
|
|
"""Like math.isclose() from Python 3.5"""
|
|
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
|
|
|
|
|
|
def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
|
|
return all(
|
|
isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list)
|
|
)
|
|
|
|
|
|
def test_slice_constructors():
|
|
assert m.make_forward_slice_size_t() == slice(0, -1, 1)
|
|
assert m.make_reversed_slice_object() == slice(None, None, -1)
|
|
|
|
|
|
@pytest.mark.skipif(not m.has_optional, reason="no <optional>")
|
|
def test_slice_constructors_explicit_optional():
|
|
assert m.make_reversed_slice_size_t_optional() == slice(None, None, -1)
|
|
assert m.make_reversed_slice_size_t_optional_verbose() == slice(None, None, -1)
|
|
|
|
|
|
def test_generalized_iterators():
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
|
|
assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
|
|
assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero()) == []
|
|
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_keys()) == [1, 3]
|
|
assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
|
|
assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
|
|
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_values()) == [2, 4]
|
|
assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_values()) == [2]
|
|
assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_values()) == []
|
|
|
|
# __next__ must continue to raise StopIteration
|
|
it = m.IntPairs([(0, 0)]).nonzero()
|
|
for _ in range(3):
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
it = m.IntPairs([(0, 0)]).nonzero_keys()
|
|
for _ in range(3):
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
|
|
def test_nonref_iterators():
|
|
pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)])
|
|
assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)]
|
|
assert list(pairs.nonref_keys()) == [1, 3, 0]
|
|
assert list(pairs.nonref_values()) == [2, 4, 5]
|
|
|
|
|
|
def test_generalized_iterators_simple():
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [
|
|
(1, 2),
|
|
(3, 4),
|
|
(0, 5),
|
|
]
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_keys()) == [1, 3, 0]
|
|
assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_values()) == [2, 4, 5]
|
|
|
|
|
|
def test_iterator_referencing():
|
|
"""Test that iterators reference rather than copy their referents."""
|
|
vec = m.VectorNonCopyableInt()
|
|
vec.append(3)
|
|
vec.append(5)
|
|
assert [int(x) for x in vec] == [3, 5]
|
|
# Increment everything to make sure the referents can be mutated
|
|
for x in vec:
|
|
x.set(int(x) + 1)
|
|
assert [int(x) for x in vec] == [4, 6]
|
|
|
|
vec = m.VectorNonCopyableIntPair()
|
|
vec.append([3, 4])
|
|
vec.append([5, 7])
|
|
assert [int(x) for x in vec.keys()] == [3, 5]
|
|
assert [int(x) for x in vec.values()] == [4, 7]
|
|
for x in vec.keys():
|
|
x.set(int(x) + 1)
|
|
for x in vec.values():
|
|
x.set(int(x) + 10)
|
|
assert [int(x) for x in vec.keys()] == [4, 6]
|
|
assert [int(x) for x in vec.values()] == [14, 17]
|
|
|
|
|
|
def test_sliceable():
|
|
sliceable = m.Sliceable(100)
|
|
assert sliceable[::] == (0, 100, 1)
|
|
assert sliceable[10::] == (10, 100, 1)
|
|
assert sliceable[:10:] == (0, 10, 1)
|
|
assert sliceable[::10] == (0, 100, 10)
|
|
assert sliceable[-10::] == (90, 100, 1)
|
|
assert sliceable[:-10:] == (0, 90, 1)
|
|
assert sliceable[::-10] == (99, -1, -10)
|
|
assert sliceable[50:60:1] == (50, 60, 1)
|
|
assert sliceable[50:60:-1] == (50, 60, -1)
|
|
|
|
|
|
def test_sequence():
|
|
cstats = ConstructorStats.get(m.Sequence)
|
|
|
|
s = m.Sequence(5)
|
|
assert cstats.values() == ["of size", "5"]
|
|
|
|
assert "Sequence" in repr(s)
|
|
assert len(s) == 5
|
|
assert s[0] == 0 and s[3] == 0
|
|
assert 12.34 not in s
|
|
s[0], s[3] = 12.34, 56.78
|
|
assert 12.34 in s
|
|
assert isclose(s[0], 12.34) and isclose(s[3], 56.78)
|
|
|
|
rev = reversed(s)
|
|
assert cstats.values() == ["of size", "5"]
|
|
|
|
rev2 = s[::-1]
|
|
assert cstats.values() == ["of size", "5"]
|
|
|
|
it = iter(m.Sequence(0))
|
|
for _ in range(3): # __next__ must continue to raise StopIteration
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
assert cstats.values() == ["of size", "0"]
|
|
|
|
expected = [0, 56.78, 0, 0, 12.34]
|
|
assert allclose(rev, expected)
|
|
assert allclose(rev2, expected)
|
|
assert rev == rev2
|
|
|
|
rev[0::2] = m.Sequence([2.0, 2.0, 2.0])
|
|
assert cstats.values() == ["of size", "3", "from std::vector"]
|
|
|
|
assert allclose(rev, [2, 56.78, 2, 0, 2])
|
|
|
|
assert cstats.alive() == 4
|
|
del it
|
|
assert cstats.alive() == 3
|
|
del s
|
|
assert cstats.alive() == 2
|
|
del rev
|
|
assert cstats.alive() == 1
|
|
del rev2
|
|
assert cstats.alive() == 0
|
|
|
|
assert cstats.values() == []
|
|
assert cstats.default_constructions == 0
|
|
assert cstats.copy_constructions == 0
|
|
assert cstats.move_constructions >= 1
|
|
assert cstats.copy_assignments == 0
|
|
assert cstats.move_assignments == 0
|
|
|
|
|
|
def test_sequence_length():
|
|
"""#2076: Exception raised by len(arg) should be propagated"""
|
|
|
|
class BadLen(RuntimeError):
|
|
pass
|
|
|
|
class SequenceLike:
|
|
def __getitem__(self, i):
|
|
return None
|
|
|
|
def __len__(self):
|
|
raise BadLen()
|
|
|
|
with pytest.raises(BadLen):
|
|
m.sequence_length(SequenceLike())
|
|
|
|
assert m.sequence_length([1, 2, 3]) == 3
|
|
assert m.sequence_length("hello") == 5
|
|
|
|
|
|
def test_map_iterator():
|
|
sm = m.StringMap({"hi": "bye", "black": "white"})
|
|
assert sm["hi"] == "bye"
|
|
assert len(sm) == 2
|
|
assert sm["black"] == "white"
|
|
|
|
with pytest.raises(KeyError):
|
|
assert sm["orange"]
|
|
sm["orange"] = "banana"
|
|
assert sm["orange"] == "banana"
|
|
|
|
expected = {"hi": "bye", "black": "white", "orange": "banana"}
|
|
for k in sm:
|
|
assert sm[k] == expected[k]
|
|
for k, v in sm.items():
|
|
assert v == expected[k]
|
|
assert list(sm.values()) == [expected[k] for k in sm]
|
|
|
|
it = iter(m.StringMap({}))
|
|
for _ in range(3): # __next__ must continue to raise StopIteration
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
|
|
def test_python_iterator_in_cpp():
|
|
t = (1, 2, 3)
|
|
assert m.object_to_list(t) == [1, 2, 3]
|
|
assert m.object_to_list(iter(t)) == [1, 2, 3]
|
|
assert m.iterator_to_list(iter(t)) == [1, 2, 3]
|
|
|
|
with pytest.raises(TypeError) as excinfo:
|
|
m.object_to_list(1)
|
|
assert "object is not iterable" in str(excinfo.value)
|
|
|
|
with pytest.raises(TypeError) as excinfo:
|
|
m.iterator_to_list(1)
|
|
assert "incompatible function arguments" in str(excinfo.value)
|
|
|
|
def bad_next_call():
|
|
raise RuntimeError("py::iterator::advance() should propagate errors")
|
|
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
m.iterator_to_list(iter(bad_next_call, None))
|
|
assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
|
|
|
|
lst = [1, None, 0, None]
|
|
assert m.count_none(lst) == 2
|
|
assert m.find_none(lst) is True
|
|
assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2
|
|
|
|
r = range(5)
|
|
assert all(m.tuple_iterator(tuple(r)))
|
|
assert all(m.list_iterator(list(r)))
|
|
assert all(m.sequence_iterator(r))
|
|
|
|
|
|
def test_iterator_passthrough():
|
|
"""#181: iterator passthrough did not compile"""
|
|
from pybind11_tests.sequences_and_iterators import iterator_passthrough
|
|
|
|
values = [3, 5, 7, 9, 11, 13, 15]
|
|
assert list(iterator_passthrough(iter(values))) == values
|
|
|
|
|
|
def test_iterator_rvp():
|
|
"""#388: Can't make iterators via make_iterator() with different r/v policies"""
|
|
import pybind11_tests.sequences_and_iterators as m
|
|
|
|
assert list(m.make_iterator_1()) == [1, 2, 3]
|
|
assert list(m.make_iterator_2()) == [1, 2, 3]
|
|
assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2()))
|