mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 21:57:55 +00:00
Use memoryview for constructing array from buffer
This commit is contained in:
parent
ea2755ccdc
commit
a67c2b52e4
@ -13,6 +13,7 @@
|
||||
#include "complex.h"
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
@ -31,6 +32,7 @@ public:
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
@ -61,6 +63,7 @@ public:
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
}
|
||||
@ -74,6 +77,8 @@ public:
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||
Py_ssize_t *, PyObject **, PyObject *);
|
||||
};
|
||||
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
||||
@ -100,24 +105,22 @@ public:
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
API& api = lookup_api();
|
||||
if ((info.format.size() < 1) || (info.format.size() > 2))
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
int fmt = (int) info.format[0];
|
||||
if (info.format == "Zd") fmt = API::NPY_CDOUBLE_;
|
||||
else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;
|
||||
PyObject *arr = nullptr, *descr = nullptr;
|
||||
int ndim = 0;
|
||||
Py_ssize_t dims[32];
|
||||
|
||||
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
||||
if (descr == nullptr)
|
||||
pybind11_fail("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||
if (info.ptr && tmp)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
m_ptr = tmp.release().ptr();
|
||||
// allocate zeroed memory if it hasn't been provided
|
||||
auto buf_info = info;
|
||||
if (!buf_info.ptr)
|
||||
buf_info.ptr = std::calloc(info.size, info.itemsize);
|
||||
auto view = py::memoryview(buf_info);
|
||||
|
||||
API& api = lookup_api();
|
||||
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
|
||||
&ndim, dims, &arr, nullptr);
|
||||
if (res < 0 || !arr || descr)
|
||||
pybind11_fail("NumPy: unable to convert buffer to an array");
|
||||
m_ptr = arr;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
Loading…
Reference in New Issue
Block a user