diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index d04887591..0d4aeaf63 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -211,12 +211,16 @@ public: forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; - template array(size_t size, const Type *ptr) { + array(const pybind11::dtype& dt, const std::vector& shape, + void *ptr, const std::vector& strides) { auto& api = detail::npy_api::get(); - auto descr = pybind11::dtype::of().release().ptr(); - Py_intptr_t shape = (Py_intptr_t) size; - object tmp = object(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false); + auto ndim = shape.size(); + if (shape.size() != strides.size()) + pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); + auto descr = dt; + object tmp(api.PyArray_NewFromDescr_( + api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), + (Py_intptr_t *) strides.data(), ptr, 0, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); if (ptr) @@ -224,18 +228,30 @@ public: m_ptr = tmp.release().ptr(); } - array(const buffer_info &info) { - auto& api = detail::npy_api::get(); - auto descr = pybind11::dtype(info).release().ptr(); - object tmp(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0], - (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false); - if (!tmp) - pybind11_fail("NumPy: unable to create array!"); - if (info.ptr) - tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); - m_ptr = tmp.release().ptr(); - } + array(const pybind11::dtype& dt, const std::vector& shape, void *ptr) + : array(dt, shape, ptr, default_strides(shape, dt.itemsize())) + { } + + array(const pybind11::dtype& dt, size_t size, void *ptr) + : array(dt, std::vector { size }, ptr) + { } + + template array(const std::vector& shape, + T* ptr, const std::vector& strides) + : array(pybind11::dtype::of(), shape, (void *) ptr, strides) + { } + + template array(const std::vector& shape, T* ptr) + : array(shape, ptr, default_strides(shape, sizeof(T))) + { } + + template array(size_t size, T* ptr) + : array(std::vector { size }, ptr) + { } + + array(const buffer_info &info) + : array(pybind11::dtype(info), info.shape, info.ptr, info.strides) + { } pybind11::dtype dtype() { return attr("dtype").cast(); @@ -243,6 +259,18 @@ public: protected: template friend struct detail::npy_format_descriptor; + + static std::vector default_strides(const std::vector& shape, size_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim); + if (ndim) { + std::fill(strides.begin(), strides.end(), itemsize); + for (size_t i = 0; i < ndim - 1; i++) + for (size_t j = 0; j < ndim - 1 - i; j++) + strides[j] *= shape[ndim - 1 - i]; + } + return strides; + } }; template class array_t : public array {