Use memoryview for constructing array from buffer

This commit is contained in:
Ivan Smirnov 2016-06-19 14:50:06 +01:00
parent ea2755ccdc
commit a67c2b52e4

View File

@ -13,6 +13,7 @@
#include "complex.h"
#include <numeric>
#include <algorithm>
#include <cstdlib>
#if defined(_MSC_VER)
#pragma warning(push)
@ -31,6 +32,7 @@ public:
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002,
@ -61,6 +63,7 @@ public:
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
}
@ -74,6 +77,8 @@ public:
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
@ -100,24 +105,22 @@ public:
}
array(const buffer_info &info) {
API& api = lookup_api();
if ((info.format.size() < 1) || (info.format.size() > 2))
pybind11_fail("Unsupported buffer format!");
int fmt = (int) info.format[0];
if (info.format == "Zd") fmt = API::NPY_CDOUBLE_;
else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;
PyObject *arr = nullptr, *descr = nullptr;
int ndim = 0;
Py_ssize_t dims[32];
PyObject *descr = api.PyArray_DescrFromType_(fmt);
if (descr == nullptr)
pybind11_fail("NumPy: unsupported buffer format '" + info.format + "'!");
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (info.ptr && tmp)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
m_ptr = tmp.release().ptr();
// allocate zeroed memory if it hasn't been provided
auto buf_info = info;
if (!buf_info.ptr)
buf_info.ptr = std::calloc(info.size, info.itemsize);
auto view = py::memoryview(buf_info);
API& api = lookup_api();
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
&ndim, dims, &arr, nullptr);
if (res < 0 || !arr || descr)
pybind11_fail("NumPy: unable to convert buffer to an array");
m_ptr = arr;
}
protected: