From 91b3d681ad96951bd51a719897b89487c3e99d37 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 29 Aug 2016 02:41:05 +0100 Subject: [PATCH 1/3] Expose some dtype/array attributes via NumPy C API --- include/pybind11/eigen.h | 62 +++++++-------- include/pybind11/numpy.h | 151 ++++++++++++++++++++++++++++++++----- tests/CMakeLists.txt | 1 + tests/test_numpy_array.cpp | 45 +++++++++++ tests/test_numpy_array.py | 43 +++++++++++ 5 files changed, 246 insertions(+), 56 deletions(-) create mode 100644 tests/test_numpy_array.cpp create mode 100644 tests/test_numpy_array.py diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 63a48088a..7c6869930 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -83,12 +83,11 @@ struct type_caster::value && static constexpr bool isVector = Type::IsVectorAtCompileTime; bool load(handle src, bool) { - array_t buffer(src, true); - if (!buffer.check()) - return false; + array_t buf(src, true); + if (!buf.check()) + return false; - auto info = buffer.request(); - if (info.ndim == 1) { + if (buf.ndim() == 1) { typedef Eigen::InnerStride<> Strides; if (!isVector && !(Type::RowsAtCompileTime == Eigen::Dynamic && @@ -96,31 +95,32 @@ struct type_caster::value && return false; if (Type::SizeAtCompileTime != Eigen::Dynamic && - info.shape[0] != (size_t) Type::SizeAtCompileTime) + buf.shape(0) != (size_t) Type::SizeAtCompileTime) return false; - auto strides = Strides(info.strides[0] / sizeof(Scalar)); - - Strides::Index n_elts = (Strides::Index) info.shape[0]; + Strides::Index n_elts = (Strides::Index) buf.shape(0); Strides::Index unity = 1; value = Eigen::Map( - (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); - } else if (info.ndim == 2) { + buf.mutable_data(), + rowMajor ? unity : n_elts, + rowMajor ? n_elts : unity, + Strides(buf.strides(0) / sizeof(Scalar)) + ); + } else if (buf.ndim() == 2) { typedef Eigen::Stride Strides; - if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) || - (Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime)) + if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || + (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) return false; - auto strides = Strides( - info.strides[rowMajor ? 0 : 1] / sizeof(Scalar), - info.strides[rowMajor ? 1 : 0] / sizeof(Scalar)); - value = Eigen::Map( - (Scalar *) info.ptr, - typename Strides::Index(info.shape[0]), - typename Strides::Index(info.shape[1]), strides); + buf.mutable_data(), + typename Strides::Index(buf.shape(0)), + typename Strides::Index(buf.shape(1)), + Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar), + buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar)) + ); } else { return false; } @@ -222,28 +222,18 @@ struct type_caster::value>:: } } - auto valuesArray = array_t((object) obj.attr("data")); - auto innerIndicesArray = array_t((object) obj.attr("indices")); - auto outerIndicesArray = array_t((object) obj.attr("indptr")); + auto values = array_t((object) obj.attr("data")); + auto innerIndices = array_t((object) obj.attr("indices")); + auto outerIndices = array_t((object) obj.attr("indptr")); auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); auto nnz = obj.attr("nnz").cast(); - if (!valuesArray.check() || !innerIndicesArray.check() || - !outerIndicesArray.check()) + if (!values.check() || !innerIndices.check() || !outerIndices.check()) return false; - auto outerIndices = outerIndicesArray.request(); - auto innerIndices = innerIndicesArray.request(); - auto values = valuesArray.request(); - value = Eigen::MappedSparseMatrix( - shape[0].cast(), - shape[1].cast(), - nnz, - static_cast(outerIndices.ptr), - static_cast(innerIndices.ptr), - static_cast(values.ptr) - ); + shape[0].cast(), shape[1].cast(), nnz, + outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); return true; } diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 51c68ad96..6e0785e84 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(push) @@ -30,12 +31,41 @@ namespace detail { template struct npy_format_descriptor { }; template struct is_pod_struct; +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject *typeobj; + char kind; + char type; + char byteorder; + char flags; + int type_num; + int elsize; + int alignment; + char *subarray; + PyObject *fields; + PyObject *names; +}; + +struct PyArray_Proxy { + PyObject_HEAD + char *data; + int nd; + ssize_t *dimensions; + ssize_t *strides; + PyObject *base; + PyObject *descr; + int flags; +}; + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, NPY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_OWNDATA_ = 0x0004, NPY_ARRAY_FORCECAST_ = 0x0010, NPY_ENSURE_ARRAY_ = 0x0040, + NPY_ARRAY_ALIGNED_ = 0x0100, + NPY_ARRAY_WRITEABLE_ = 0x0400, NPY_BOOL_ = 0, NPY_BYTE_, NPY_UBYTE_, NPY_SHORT_, NPY_USHORT_, @@ -113,6 +143,11 @@ private: }; } +#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) +#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_CHKFLAGS_(ptr, flag) \ + (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) + class dtype : public object { public: PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); @@ -150,15 +185,15 @@ public: } size_t itemsize() const { - return attr("itemsize").cast(); + return (size_t) PyArrayDescr_GET_(m_ptr, elsize); } bool has_fields() const { - return attr("fields").cast().ptr() != Py_None; + return PyArrayDescr_GET_(m_ptr, names) != nullptr; } - std::string kind() const { - return (std::string) attr("kind").cast(); + char kind() const { + return PyArrayDescr_GET_(m_ptr, kind); } private: @@ -171,20 +206,20 @@ private: dtype strip_padding() { // Recursively strip all void fields with empty names that are generated for // padding fields (as of NumPy v1.11). - auto fields = attr("fields").cast(); - if (fields.ptr() == Py_None) + if (!has_fields()) return *this; struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; std::vector field_descriptors; + auto fields = attr("fields").cast(); auto items = fields.attr("items").cast(); for (auto field : items()) { auto spec = object(field, true).cast(); auto name = spec[0].cast(); auto format = spec[1].cast()[0].cast(); auto offset = spec[1].cast()[1].cast(); - if (!len(name) && format.kind() == "V") + if (!len(name) && format.kind() == 'V') continue; field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset}); } @@ -244,14 +279,83 @@ public: template array(const std::vector& shape, T* ptr) : array(shape, default_strides(shape, sizeof(T)), ptr) { } - template array(size_t size, T* ptr) - : array(std::vector { size }, ptr) { } + template array(size_t count, T* ptr) + : array(std::vector { count }, ptr) { } array(const buffer_info &info) : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } - pybind11::dtype dtype() { - return attr("dtype").cast(); + /// Array descriptor (dtype) + pybind11::dtype dtype() const { + return object(PyArray_GET_(m_ptr, descr), true); + } + + /// Total number of elements + size_t size() const { + return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies()); + } + + /// Byte size of a single element + size_t itemsize() const { + return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize); + } + + /// Total number of bytes + size_t nbytes() const { + return size() * itemsize(); + } + + /// Number of dimensions + size_t ndim() const { + return (size_t) PyArray_GET_(m_ptr, nd); + } + + /// Dimensions of the array + const size_t* shape() const { + static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); + } + + /// Dimension along a given axis + size_t shape(size_t dim) const { + if (dim >= ndim()) + pybind11_fail("NumPy: attempted to index shape beyond ndim"); + return shape()[dim]; + } + + /// Strides of the array + const size_t* strides() const { + static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + return reinterpret_cast(PyArray_GET_(m_ptr, strides)); + } + + /// Stride along a given axis + size_t strides(size_t dim) const { + if (dim >= ndim()) + pybind11_fail("NumPy: attempted to index strides beyond ndim"); + return strides()[dim]; + } + + /// If set, the array is writeable (otherwise the buffer is read-only) + bool writeable() const { + return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); + } + + /// If set, the array owns the data (will be freed when the array is deleted) + bool owndata() const { + return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); + } + + /// Direct pointer to contained buffer + const void* data() const { + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } + + /// Direct mutable pointer to contained buffer (checks writeable flag) + void* mutable_data() { + if (!writeable()) + pybind11_fail("NumPy: cannot get mutable data of a read-only array"); + return reinterpret_cast(PyArray_GET_(m_ptr, data)); } protected: @@ -284,8 +388,18 @@ public: array_t(const std::vector& shape, T* ptr = nullptr) : array(shape, ptr) { } - array_t(size_t size, T* ptr = nullptr) - : array(size, ptr) { } + array_t(size_t count, T* ptr = nullptr) + : array(count, ptr) { } + + const T* data() const { + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } + + T* mutable_data() { + if (!writeable()) + pybind11_fail("NumPy: cannot get mutable data of a read-only array"); + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } @@ -678,16 +792,13 @@ struct vectorize_helper { if (size == 1) return cast(f(*((Args *) buffers[Index].ptr)...)); - array result(buffer_info(nullptr, sizeof(Return), - format_descriptor::format(), - ndim, shape, strides)); - - buffer_info buf = result.request(); - Return *output = (Return *) buf.ptr; + array_t result(shape, strides); + auto buf = result.request(); + auto output = (Return *) buf.ptr; if (trivial_broadcast) { /* Call the function */ - for (size_t i=0; i + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include "pybind11_tests.h" +#include +#include + +test_initializer numpy_array([](py::module &m) { + m.def("get_arr_ndim", [](const py::array& arr) { + return arr.ndim(); + }); + m.def("get_arr_shape", [](const py::array& arr) { + return std::vector(arr.shape(), arr.shape() + arr.ndim()); + }); + m.def("get_arr_shape", [](const py::array& arr, size_t dim) { + return arr.shape(dim); + }); + m.def("get_arr_strides", [](const py::array& arr) { + return std::vector(arr.strides(), arr.strides() + arr.ndim()); + }); + m.def("get_arr_strides", [](const py::array& arr, size_t dim) { + return arr.strides(dim); + }); + m.def("get_arr_writeable", [](const py::array& arr) { + return arr.writeable(); + }); + m.def("get_arr_size", [](const py::array& arr) { + return arr.size(); + }); + m.def("get_arr_itemsize", [](const py::array& arr) { + return arr.itemsize(); + }); + m.def("get_arr_nbytes", [](const py::array& arr) { + return arr.nbytes(); + }); + m.def("get_arr_owndata", [](const py::array& arr) { + return arr.owndata(); + }); +}); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py new file mode 100644 index 000000000..929cdc4d9 --- /dev/null +++ b/tests/test_numpy_array.py @@ -0,0 +1,43 @@ +import pytest + +with pytest.suppress(ImportError): + import numpy as np + + +@pytest.requires_numpy +def test_array_attributes(): + from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable, + get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata) + + a = np.array(0, 'f8') + assert get_arr_ndim(a) == 0 + assert get_arr_shape(a) == [] + assert get_arr_strides(a) == [] + with pytest.raises(RuntimeError): + get_arr_shape(a, 1) + with pytest.raises(RuntimeError): + get_arr_strides(a, 0) + assert get_arr_writeable(a) + assert get_arr_size(a) == 1 + assert get_arr_itemsize(a) == 8 + assert get_arr_nbytes(a) == 8 + assert get_arr_owndata(a) + + a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() + a.flags.writeable = False + assert get_arr_ndim(a) == 2 + assert get_arr_shape(a) == [2, 3] + assert get_arr_shape(a, 0) == 2 + assert get_arr_shape(a, 1) == 3 + assert get_arr_strides(a) == [6, 2] + assert get_arr_strides(a, 0) == 6 + assert get_arr_strides(a, 1) == 2 + with pytest.raises(RuntimeError): + get_arr_shape(a, 2) + with pytest.raises(RuntimeError): + get_arr_strides(a, 2) + assert not get_arr_writeable(a) + assert get_arr_size(a) == 6 + assert get_arr_itemsize(a) == 2 + assert get_arr_nbytes(a) == 12 + assert not get_arr_owndata(a) From f2a0ad5855d233fe4052923c2231c45cffc9b13e Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Thu, 8 Sep 2016 21:48:14 +0100 Subject: [PATCH 2/3] array: add direct data access and indexing methods --- include/pybind11/numpy.h | 135 +++++++++++++++++++++++++++++---------- 1 file changed, 102 insertions(+), 33 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 6e0785e84..0768343fd 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -26,8 +26,14 @@ #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif +/* This will be true on all flat address space platforms and allows us to reduce the + whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size + and dimension types (e.g. shape, strides, indexing), instead of inflicting this + upon the library user. */ +static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + NAMESPACE_BEGIN(pybind11) -namespace detail { +NAMESPACE_BEGIN(detail) template struct npy_format_descriptor { }; template struct is_pod_struct; @@ -141,10 +147,12 @@ private: return api; } }; -} +NAMESPACE_END(detail) -#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) -#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_GET_(ptr, attr) \ + (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) +#define PyArrayDescr_GET_(ptr, attr) \ + (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) #define PyArray_CHKFLAGS_(ptr, flag) \ (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) @@ -250,7 +258,7 @@ public: }; array(const pybind11::dtype& dt, const std::vector& shape, - const std::vector& strides, void *ptr = nullptr) { + const std::vector& strides, const void *ptr = nullptr) { auto& api = detail::npy_api::get(); auto ndim = shape.size(); if (shape.size() != strides.size()) @@ -258,7 +266,7 @@ public: auto descr = dt; object tmp(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), - (Py_intptr_t *) strides.data(), ptr, 0, nullptr), false); + (Py_intptr_t *) strides.data(), const_cast(ptr), 0, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); if (ptr) @@ -266,20 +274,20 @@ public: m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype& dt, const std::vector& shape, void *ptr = nullptr) + array(const pybind11::dtype& dt, const std::vector& shape, const void *ptr = nullptr) : array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { } - array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr) + array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr) : array(dt, std::vector { count }, ptr) { } template array(const std::vector& shape, - const std::vector& strides, T* ptr) + const std::vector& strides, const T* ptr) : array(pybind11::dtype::of(), shape, strides, (void *) ptr) { } - template array(const std::vector& shape, T* ptr) + template array(const std::vector& shape, const T* ptr) : array(shape, default_strides(shape, sizeof(T)), ptr) { } - template array(size_t count, T* ptr) + template array(size_t count, const T* ptr) : array(std::vector { count }, ptr) { } array(const buffer_info &info) @@ -312,27 +320,25 @@ public: /// Dimensions of the array const size_t* shape() const { - static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); } /// Dimension along a given axis size_t shape(size_t dim) const { if (dim >= ndim()) - pybind11_fail("NumPy: attempted to index shape beyond ndim"); + fail_dim_check(dim, "invalid axis"); return shape()[dim]; } /// Strides of the array const size_t* strides() const { - static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); return reinterpret_cast(PyArray_GET_(m_ptr, strides)); } /// Stride along a given axis size_t strides(size_t dim) const { if (dim >= ndim()) - pybind11_fail("NumPy: attempted to index strides beyond ndim"); + fail_dim_check(dim, "invalid axis"); return strides()[dim]; } @@ -346,20 +352,61 @@ public: return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); } - /// Direct pointer to contained buffer - const void* data() const { - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + /// Pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + template const void* data(Ix&&... index) const { + return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); } - /// Direct mutable pointer to contained buffer (checks writeable flag) - void* mutable_data() { - if (!writeable()) - pybind11_fail("NumPy: cannot get mutable data of a read-only array"); - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + /// Mutable pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + /// May throw if the array is not writeable. + template void* mutable_data(Ix&&... index) { + check_writeable(); + return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); + } + + /// Byte offset from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template size_t offset_at(Ix&&... index) const { + if (sizeof...(index) > ndim()) + fail_dim_check(sizeof...(index), "too many indices for an array"); + return get_byte_offset(index...); + } + + size_t offset_at() const { return 0; } + + /// Item count from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template size_t index_at(Ix&&... index) const { + return offset_at(index...) / itemsize(); } protected: - template friend struct detail::npy_format_descriptor; + template friend struct detail::npy_format_descriptor; + + void fail_dim_check(size_t dim, const std::string& msg) const { + throw index_error(msg + ": " + std::to_string(dim) + + " (ndim = " + std::to_string(ndim()) + ")"); + } + + template size_t get_byte_offset(Ix&&... index) const { + const size_t idx[] = { (size_t) index... }; + if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less{})) { + auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less{}); + throw index_error(std::string("index ") + std::to_string(*mismatch.first) + + " is out of bounds for axis " + std::to_string(mismatch.first - idx) + + " with size " + std::to_string(*mismatch.second)); + } + return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0); + } + + size_t get_byte_offset() const { return 0; } + + void check_writeable() const { + if (!writeable()) + throw std::runtime_error("array is not writeable"); + } static std::vector default_strides(const std::vector& shape, size_t itemsize) { auto ndim = shape.size(); @@ -382,23 +429,45 @@ public: array_t(const buffer_info& info) : array(info) { } - array_t(const std::vector& shape, const std::vector& strides, T* ptr = nullptr) + array_t(const std::vector& shape, const std::vector& strides, const T* ptr = nullptr) : array(shape, strides, ptr) { } - array_t(const std::vector& shape, T* ptr = nullptr) + array_t(const std::vector& shape, const T* ptr = nullptr) : array(shape, ptr) { } - array_t(size_t count, T* ptr = nullptr) + array_t(size_t count, const T* ptr = nullptr) : array(count, ptr) { } - const T* data() const { - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + constexpr size_t itemsize() const { + return sizeof(T); } - T* mutable_data() { - if (!writeable()) - pybind11_fail("NumPy: cannot get mutable data of a read-only array"); - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + template size_t index_at(Ix&... index) const { + return offset_at(index...) / itemsize(); + } + + template const T* data(Ix&&... index) const { + return static_cast(array::data(index...)); + } + + template T* mutable_data(Ix&&... index) { + return static_cast(array::mutable_data(index...)); + } + + // Reference to element at a given index + template const T& at(Ix&&... index) const { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + // not using offset_at() / index_at() here so as to avoid another dimension check + return *(static_cast(array::data()) + get_byte_offset(index...) / itemsize()); + } + + // Mutable reference to element at a given index + template T& mutable_at(Ix&&... index) { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + // not using offset_at() / index_at() here so as to avoid another dimension check + return *(static_cast(array::mutable_data()) + get_byte_offset(index...) / itemsize()); } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } From aca6bcaea5e223752b210f991076d42202be61bf Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Thu, 8 Sep 2016 23:03:35 +0100 Subject: [PATCH 3/3] Add tests for array data access /index methods --- tests/test_numpy_array.cpp | 109 +++++++++++++++++------- tests/test_numpy_array.py | 167 ++++++++++++++++++++++++++++++------- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index ed118a0ae..0614f5717 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -8,38 +8,87 @@ */ #include "pybind11_tests.h" + #include #include +#include +#include + +using arr = py::array; +using arr_t = py::array_t; + +template arr data(const arr& a, Ix&&... index) { + return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...)); +} + +template arr data_t(const arr_t& a, Ix&&... index) { + return arr(a.size() - a.index_at(index...), a.data(index...)); +} + +arr& mutate_data(arr& a) { + auto ptr = (uint8_t *) a.mutable_data(); + for (size_t i = 0; i < a.nbytes(); i++) + ptr[i] = (uint8_t) (ptr[i] * 2); + return a; +} + +arr_t& mutate_data_t(arr_t& a) { + auto ptr = a.mutable_data(); + for (size_t i = 0; i < a.size(); i++) + ptr[i]++; + return a; +} + +template arr& mutate_data(arr& a, Ix&&... index) { + auto ptr = (uint8_t *) a.mutable_data(index...); + for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) + ptr[i] = (uint8_t) (ptr[i] * 2); + return a; +} + +template arr_t& mutate_data_t(arr_t& a, Ix&&... index) { + auto ptr = a.mutable_data(index...); + for (size_t i = 0; i < a.size() - a.index_at(index...); i++) + ptr[i]++; + return a; +} + +template size_t index_at(const arr& a, Ix&&... idx) { return a.index_at(idx...); } +template size_t index_at_t(const arr_t& a, Ix&&... idx) { return a.index_at(idx...); } +template size_t offset_at(const arr& a, Ix&&... idx) { return a.offset_at(idx...); } +template size_t offset_at_t(const arr_t& a, Ix&&... idx) { return a.offset_at(idx...); } +template size_t at_t(const arr_t& a, Ix&&... idx) { return a.at(idx...); } +template arr_t& mutate_at_t(arr_t& a, Ix&&... idx) { a.mutable_at(idx...)++; return a; } + +#define def_index_fn(name, type) \ + sm.def(#name, [](type a) { return name(a); }); \ + sm.def(#name, [](type a, int i) { return name(a, i); }); \ + sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \ + sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); }); + test_initializer numpy_array([](py::module &m) { - m.def("get_arr_ndim", [](const py::array& arr) { - return arr.ndim(); - }); - m.def("get_arr_shape", [](const py::array& arr) { - return std::vector(arr.shape(), arr.shape() + arr.ndim()); - }); - m.def("get_arr_shape", [](const py::array& arr, size_t dim) { - return arr.shape(dim); - }); - m.def("get_arr_strides", [](const py::array& arr) { - return std::vector(arr.strides(), arr.strides() + arr.ndim()); - }); - m.def("get_arr_strides", [](const py::array& arr, size_t dim) { - return arr.strides(dim); - }); - m.def("get_arr_writeable", [](const py::array& arr) { - return arr.writeable(); - }); - m.def("get_arr_size", [](const py::array& arr) { - return arr.size(); - }); - m.def("get_arr_itemsize", [](const py::array& arr) { - return arr.itemsize(); - }); - m.def("get_arr_nbytes", [](const py::array& arr) { - return arr.nbytes(); - }); - m.def("get_arr_owndata", [](const py::array& arr) { - return arr.owndata(); - }); + auto sm = m.def_submodule("array"); + + sm.def("ndim", [](const arr& a) { return a.ndim(); }); + sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); }); + sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); }); + sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); }); + sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); }); + sm.def("writeable", [](const arr& a) { return a.writeable(); }); + sm.def("size", [](const arr& a) { return a.size(); }); + sm.def("itemsize", [](const arr& a) { return a.itemsize(); }); + sm.def("nbytes", [](const arr& a) { return a.nbytes(); }); + sm.def("owndata", [](const arr& a) { return a.owndata(); }); + + def_index_fn(data, const arr&); + def_index_fn(data_t, const arr_t&); + def_index_fn(index_at, const arr&); + def_index_fn(index_at_t, const arr_t&); + def_index_fn(offset_at, const arr&); + def_index_fn(offset_at_t, const arr_t&); + def_index_fn(mutate_data, arr&); + def_index_fn(mutate_data_t, arr_t&); + def_index_fn(at_t, const arr_t&); + def_index_fn(mutate_at_t, arr_t&); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 929cdc4d9..4a6af5ef6 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -4,40 +4,147 @@ with pytest.suppress(ImportError): import numpy as np +@pytest.fixture(scope='function') +def arr(): + return np.array([[1, 2, 3], [4, 5, 6]], '