Modified Vector STL bind initialization from a buffer type with optimization for simple arrays (#2298)

* Modified Vector STL bind initialization from a buffer type with optimization for simple arrays

* Add subtests to demonstrate processing Python buffer protocol objects with step > 1

* Fixed memoryview step test to only run on Python 3+

* Modified Vector constructor from buffer to return by value for readability
This commit is contained in:
marc-chiesa 2020-08-13 16:47:23 -04:00 committed by GitHub
parent 1534e17e44
commit 830adda850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View File

@ -397,14 +397,19 @@ vector_buffer(Class_& cl) {
if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize) if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")"); throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
auto vec = std::unique_ptr<Vector>(new Vector());
vec->reserve((size_t) info.shape[0]);
T *p = static_cast<T*>(info.ptr); T *p = static_cast<T*>(info.ptr);
ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T)); ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
T *end = p + info.shape[0] * step; T *end = p + info.shape[0] * step;
for (; p != end; p += step) if (step == 1) {
vec->push_back(*p); return Vector(p, end);
return vec.release(); }
else {
Vector vec;
vec.reserve((size_t) info.shape[0]);
for (; p != end; p += step)
vec.push_back(*p);
return vec;
}
})); }));
return; return;

View File

@ -85,6 +85,11 @@ def test_vector_buffer():
mv[2] = '\x06' mv[2] = '\x06'
assert v[2] == 6 assert v[2] == 6
if sys.version_info.major > 2:
mv = memoryview(b)
v = m.VectorUChar(mv[::2])
assert v[1] == 3
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
m.create_undeclstruct() # Undeclared struct contents, no buffer interface m.create_undeclstruct() # Undeclared struct contents, no buffer interface
assert "NumPy type info missing for " in str(excinfo.value) assert "NumPy type info missing for " in str(excinfo.value)
@ -119,6 +124,10 @@ def test_vector_buffer_numpy():
('y', 'float64'), ('z', 'bool')], align=True))) ('y', 'float64'), ('z', 'bool')], align=True)))
assert len(v) == 3 assert len(v) == 3
b = np.array([1, 2, 3, 4], dtype=np.uint8)
v = m.VectorUChar(b[::2])
assert v[1] == 3
def test_vector_bool(): def test_vector_bool():
import pybind11_cross_module_tests as cm import pybind11_cross_module_tests as cm