diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 63a48088a..7c6869930 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -83,12 +83,11 @@ struct type_caster::value && static constexpr bool isVector = Type::IsVectorAtCompileTime; bool load(handle src, bool) { - array_t buffer(src, true); - if (!buffer.check()) - return false; + array_t buf(src, true); + if (!buf.check()) + return false; - auto info = buffer.request(); - if (info.ndim == 1) { + if (buf.ndim() == 1) { typedef Eigen::InnerStride<> Strides; if (!isVector && !(Type::RowsAtCompileTime == Eigen::Dynamic && @@ -96,31 +95,32 @@ struct type_caster::value && return false; if (Type::SizeAtCompileTime != Eigen::Dynamic && - info.shape[0] != (size_t) Type::SizeAtCompileTime) + buf.shape(0) != (size_t) Type::SizeAtCompileTime) return false; - auto strides = Strides(info.strides[0] / sizeof(Scalar)); - - Strides::Index n_elts = (Strides::Index) info.shape[0]; + Strides::Index n_elts = (Strides::Index) buf.shape(0); Strides::Index unity = 1; value = Eigen::Map( - (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); - } else if (info.ndim == 2) { + buf.mutable_data(), + rowMajor ? unity : n_elts, + rowMajor ? n_elts : unity, + Strides(buf.strides(0) / sizeof(Scalar)) + ); + } else if (buf.ndim() == 2) { typedef Eigen::Stride Strides; - if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) || - (Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime)) + if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || + (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) return false; - auto strides = Strides( - info.strides[rowMajor ? 0 : 1] / sizeof(Scalar), - info.strides[rowMajor ? 1 : 0] / sizeof(Scalar)); - value = Eigen::Map( - (Scalar *) info.ptr, - typename Strides::Index(info.shape[0]), - typename Strides::Index(info.shape[1]), strides); + buf.mutable_data(), + typename Strides::Index(buf.shape(0)), + typename Strides::Index(buf.shape(1)), + Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar), + buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar)) + ); } else { return false; } @@ -222,28 +222,18 @@ struct type_caster::value>:: } } - auto valuesArray = array_t((object) obj.attr("data")); - auto innerIndicesArray = array_t((object) obj.attr("indices")); - auto outerIndicesArray = array_t((object) obj.attr("indptr")); + auto values = array_t((object) obj.attr("data")); + auto innerIndices = array_t((object) obj.attr("indices")); + auto outerIndices = array_t((object) obj.attr("indptr")); auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); auto nnz = obj.attr("nnz").cast(); - if (!valuesArray.check() || !innerIndicesArray.check() || - !outerIndicesArray.check()) + if (!values.check() || !innerIndices.check() || !outerIndices.check()) return false; - auto outerIndices = outerIndicesArray.request(); - auto innerIndices = innerIndicesArray.request(); - auto values = valuesArray.request(); - value = Eigen::MappedSparseMatrix( - shape[0].cast(), - shape[1].cast(), - nnz, - static_cast(outerIndices.ptr), - static_cast(innerIndices.ptr), - static_cast(values.ptr) - ); + shape[0].cast(), shape[1].cast(), nnz, + outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); return true; } diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 51c68ad96..6e0785e84 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(push) @@ -30,12 +31,41 @@ namespace detail { template struct npy_format_descriptor { }; template struct is_pod_struct; +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject *typeobj; + char kind; + char type; + char byteorder; + char flags; + int type_num; + int elsize; + int alignment; + char *subarray; + PyObject *fields; + PyObject *names; +}; + +struct PyArray_Proxy { + PyObject_HEAD + char *data; + int nd; + ssize_t *dimensions; + ssize_t *strides; + PyObject *base; + PyObject *descr; + int flags; +}; + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, NPY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_OWNDATA_ = 0x0004, NPY_ARRAY_FORCECAST_ = 0x0010, NPY_ENSURE_ARRAY_ = 0x0040, + NPY_ARRAY_ALIGNED_ = 0x0100, + NPY_ARRAY_WRITEABLE_ = 0x0400, NPY_BOOL_ = 0, NPY_BYTE_, NPY_UBYTE_, NPY_SHORT_, NPY_USHORT_, @@ -113,6 +143,11 @@ private: }; } +#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) +#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_CHKFLAGS_(ptr, flag) \ + (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) + class dtype : public object { public: PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); @@ -150,15 +185,15 @@ public: } size_t itemsize() const { - return attr("itemsize").cast(); + return (size_t) PyArrayDescr_GET_(m_ptr, elsize); } bool has_fields() const { - return attr("fields").cast().ptr() != Py_None; + return PyArrayDescr_GET_(m_ptr, names) != nullptr; } - std::string kind() const { - return (std::string) attr("kind").cast(); + char kind() const { + return PyArrayDescr_GET_(m_ptr, kind); } private: @@ -171,20 +206,20 @@ private: dtype strip_padding() { // Recursively strip all void fields with empty names that are generated for // padding fields (as of NumPy v1.11). - auto fields = attr("fields").cast(); - if (fields.ptr() == Py_None) + if (!has_fields()) return *this; struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; std::vector field_descriptors; + auto fields = attr("fields").cast(); auto items = fields.attr("items").cast(); for (auto field : items()) { auto spec = object(field, true).cast(); auto name = spec[0].cast(); auto format = spec[1].cast()[0].cast(); auto offset = spec[1].cast()[1].cast(); - if (!len(name) && format.kind() == "V") + if (!len(name) && format.kind() == 'V') continue; field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset}); } @@ -244,14 +279,83 @@ public: template array(const std::vector& shape, T* ptr) : array(shape, default_strides(shape, sizeof(T)), ptr) { } - template array(size_t size, T* ptr) - : array(std::vector { size }, ptr) { } + template array(size_t count, T* ptr) + : array(std::vector { count }, ptr) { } array(const buffer_info &info) : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } - pybind11::dtype dtype() { - return attr("dtype").cast(); + /// Array descriptor (dtype) + pybind11::dtype dtype() const { + return object(PyArray_GET_(m_ptr, descr), true); + } + + /// Total number of elements + size_t size() const { + return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies()); + } + + /// Byte size of a single element + size_t itemsize() const { + return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize); + } + + /// Total number of bytes + size_t nbytes() const { + return size() * itemsize(); + } + + /// Number of dimensions + size_t ndim() const { + return (size_t) PyArray_GET_(m_ptr, nd); + } + + /// Dimensions of the array + const size_t* shape() const { + static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); + } + + /// Dimension along a given axis + size_t shape(size_t dim) const { + if (dim >= ndim()) + pybind11_fail("NumPy: attempted to index shape beyond ndim"); + return shape()[dim]; + } + + /// Strides of the array + const size_t* strides() const { + static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + return reinterpret_cast(PyArray_GET_(m_ptr, strides)); + } + + /// Stride along a given axis + size_t strides(size_t dim) const { + if (dim >= ndim()) + pybind11_fail("NumPy: attempted to index strides beyond ndim"); + return strides()[dim]; + } + + /// If set, the array is writeable (otherwise the buffer is read-only) + bool writeable() const { + return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); + } + + /// If set, the array owns the data (will be freed when the array is deleted) + bool owndata() const { + return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); + } + + /// Direct pointer to contained buffer + const void* data() const { + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } + + /// Direct mutable pointer to contained buffer (checks writeable flag) + void* mutable_data() { + if (!writeable()) + pybind11_fail("NumPy: cannot get mutable data of a read-only array"); + return reinterpret_cast(PyArray_GET_(m_ptr, data)); } protected: @@ -284,8 +388,18 @@ public: array_t(const std::vector& shape, T* ptr = nullptr) : array(shape, ptr) { } - array_t(size_t size, T* ptr = nullptr) - : array(size, ptr) { } + array_t(size_t count, T* ptr = nullptr) + : array(count, ptr) { } + + const T* data() const { + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } + + T* mutable_data() { + if (!writeable()) + pybind11_fail("NumPy: cannot get mutable data of a read-only array"); + return reinterpret_cast(PyArray_GET_(m_ptr, data)); + } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } @@ -678,16 +792,13 @@ struct vectorize_helper { if (size == 1) return cast(f(*((Args *) buffers[Index].ptr)...)); - array result(buffer_info(nullptr, sizeof(Return), - format_descriptor::format(), - ndim, shape, strides)); - - buffer_info buf = result.request(); - Return *output = (Return *) buf.ptr; + array_t result(shape, strides); + auto buf = result.request(); + auto output = (Return *) buf.ptr; if (trivial_broadcast) { /* Call the function */ - for (size_t i=0; i + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include "pybind11_tests.h" +#include +#include + +test_initializer numpy_array([](py::module &m) { + m.def("get_arr_ndim", [](const py::array& arr) { + return arr.ndim(); + }); + m.def("get_arr_shape", [](const py::array& arr) { + return std::vector(arr.shape(), arr.shape() + arr.ndim()); + }); + m.def("get_arr_shape", [](const py::array& arr, size_t dim) { + return arr.shape(dim); + }); + m.def("get_arr_strides", [](const py::array& arr) { + return std::vector(arr.strides(), arr.strides() + arr.ndim()); + }); + m.def("get_arr_strides", [](const py::array& arr, size_t dim) { + return arr.strides(dim); + }); + m.def("get_arr_writeable", [](const py::array& arr) { + return arr.writeable(); + }); + m.def("get_arr_size", [](const py::array& arr) { + return arr.size(); + }); + m.def("get_arr_itemsize", [](const py::array& arr) { + return arr.itemsize(); + }); + m.def("get_arr_nbytes", [](const py::array& arr) { + return arr.nbytes(); + }); + m.def("get_arr_owndata", [](const py::array& arr) { + return arr.owndata(); + }); +}); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py new file mode 100644 index 000000000..929cdc4d9 --- /dev/null +++ b/tests/test_numpy_array.py @@ -0,0 +1,43 @@ +import pytest + +with pytest.suppress(ImportError): + import numpy as np + + +@pytest.requires_numpy +def test_array_attributes(): + from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable, + get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata) + + a = np.array(0, 'f8') + assert get_arr_ndim(a) == 0 + assert get_arr_shape(a) == [] + assert get_arr_strides(a) == [] + with pytest.raises(RuntimeError): + get_arr_shape(a, 1) + with pytest.raises(RuntimeError): + get_arr_strides(a, 0) + assert get_arr_writeable(a) + assert get_arr_size(a) == 1 + assert get_arr_itemsize(a) == 8 + assert get_arr_nbytes(a) == 8 + assert get_arr_owndata(a) + + a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() + a.flags.writeable = False + assert get_arr_ndim(a) == 2 + assert get_arr_shape(a) == [2, 3] + assert get_arr_shape(a, 0) == 2 + assert get_arr_shape(a, 1) == 3 + assert get_arr_strides(a) == [6, 2] + assert get_arr_strides(a, 0) == 6 + assert get_arr_strides(a, 1) == 2 + with pytest.raises(RuntimeError): + get_arr_shape(a, 2) + with pytest.raises(RuntimeError): + get_arr_strides(a, 2) + assert not get_arr_writeable(a) + assert get_arr_size(a) == 6 + assert get_arr_itemsize(a) == 2 + assert get_arr_nbytes(a) == 12 + assert not get_arr_owndata(a)