diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index d1d45e2c0..300e8af9a 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -326,6 +326,49 @@ template auto vector_if_insertion_operator(Cl ); } +// Provide the buffer interface for vectors if we have data() and we have a format for it +// GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer +template +struct vector_has_data_and_format : std::false_type {}; +template +struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; + +// Add the buffer interface to a vector +template +enable_if_t...>::value> +vector_buffer(Class_& cl) { + using T = typename Vector::value_type; + + static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); + + // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here + py::format_descriptor::format(); + + cl.def_buffer([](Vector& v) -> py::buffer_info { + return py::buffer_info(v.data(), sizeof(T), py::format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); + }); + + cl.def("__init__", [](Vector& vec, py::buffer buf) { + auto info = buf.request(); + if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T)) + throw pybind11::type_error("Only valid 1D buffers can be copied to a vector"); + if (!detail::compare_buffer_info::compare(info) || sizeof(T) != info.itemsize) + throw pybind11::type_error("Format mismatch (Python: " + info.format + " C++: " + py::format_descriptor::format() + ")"); + new (&vec) Vector(); + vec.reserve(info.shape[0]); + T *p = static_cast(info.ptr); + auto step = info.strides[0] / sizeof(T); + T *end = p + info.shape[0] * step; + for (; p < end; p += step) + vec.push_back(*p); + }); + + return; +} + +template +enable_if_t...>::value> vector_buffer(Class_&) {} + NAMESPACE_END(detail) // @@ -337,6 +380,9 @@ pybind11::class_ bind_vector(pybind11::module &m, std::stri Class_ cl(m, name.c_str(), std::forward(args)...); + // Declare the buffer interface if a py::buffer_protocol() is passed in + detail::vector_buffer(cl); + cl.def(pybind11::init<>()); // Register copy constructor (if possible) diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index ce0b33257..f636c0b55 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -10,6 +10,7 @@ #include "pybind11_tests.h" #include +#include #include #include #include @@ -58,17 +59,45 @@ template Map *times_ten(int n) { return m; } +struct VStruct { + bool w; + uint32_t x; + double y; + bool z; +}; + +struct VUndeclStruct { //dtype not declared for this version + bool w; + uint32_t x; + double y; + bool z; +}; + test_initializer stl_binder_vector([](py::module &m) { py::class_(m, "El") .def(py::init()); - py::bind_vector>(m, "VectorInt"); + py::bind_vector>(m, "VectorUChar", py::buffer_protocol()); + py::bind_vector>(m, "VectorInt", py::buffer_protocol()); py::bind_vector>(m, "VectorBool"); py::bind_vector>(m, "VectorEl"); py::bind_vector>>(m, "VectorVectorEl"); + m.def("create_undeclstruct", [m] () mutable { + py::bind_vector>(m, "VectorUndeclStruct", py::buffer_protocol()); + }); + + try { + py::module::import("numpy"); + } catch (...) { + return; + } + PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z); + py::class_(m, "VStruct").def_readwrite("x", &VStruct::x); + py::bind_vector>(m, "VectorStruct", py::buffer_protocol()); + m.def("get_vectorstruct", [] {return std::vector {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};}); }); test_initializer stl_binder_map([](py::module &m) { @@ -97,4 +126,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) { py::bind_map>(m, "UmapENC"); m.def("get_umnc", ×_ten>, py::return_value_policy::reference); }); - diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index c9bcc7935..f8f817e33 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -1,3 +1,10 @@ +import pytest +import sys + +with pytest.suppress(ImportError): + import numpy as np + + def test_vector_int(): from pybind11_tests import VectorInt @@ -26,6 +33,53 @@ def test_vector_int(): assert v_int2 == VectorInt([0, 99, 2, 3]) +@pytest.unsupported_on_pypy +def test_vector_buffer(): + from pybind11_tests import VectorUChar, create_undeclstruct + b = bytearray([1, 2, 3, 4]) + v = VectorUChar(b) + assert v[1] == 2 + v[2] = 5 + m = memoryview(v) # We expose the buffer interface + if sys.version_info.major > 2: + assert m[2] == 5 + m[2] = 6 + else: + assert m[2] == '\x05' + m[2] = '\x06' + assert v[2] == 6 + + with pytest.raises(RuntimeError): + create_undeclstruct() # Undeclared struct contents, no buffer interface + + +@pytest.requires_numpy +def test_vector_buffer_numpy(): + from pybind11_tests import VectorInt, get_vectorstruct + + a = np.array([1, 2, 3, 4], dtype=np.int32) + with pytest.raises(TypeError): + VectorInt(a) + + a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc) + v = VectorInt(a[0, :]) + assert len(v) == 4 + assert v[2] == 3 + m = np.asarray(v) + m[2] = 5 + assert v[2] == 5 + + v = VectorInt(a[:, 1]) + assert len(v) == 3 + assert v[2] == 10 + + v = get_vectorstruct() + assert v[0].x == 5 + m = np.asarray(v) + m[1]['x'] = 99 + assert v[1].x == 99 + + def test_vector_custom(): from pybind11_tests import El, VectorEl, VectorVectorEl