mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-31 15:20:34 +00:00
Add all possible ctors for py::array
This commit is contained in:
parent
d77bc8c343
commit
6bb0ee1186
@ -211,12 +211,16 @@ public:
|
||||
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
||||
};
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
|
||||
void *ptr, const std::vector<size_t>& strides) {
|
||||
auto& api = detail::npy_api::get();
|
||||
auto descr = pybind11::dtype::of<Type>().release().ptr();
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
object tmp = object(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||
auto ndim = shape.size();
|
||||
if (shape.size() != strides.size())
|
||||
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
||||
auto descr = dt;
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
||||
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (ptr)
|
||||
@ -224,18 +228,30 @@ public:
|
||||
m_ptr = tmp.release().ptr();
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
auto& api = detail::npy_api::get();
|
||||
auto descr = pybind11::dtype(info).release().ptr();
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (info.ptr)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
m_ptr = tmp.release().ptr();
|
||||
}
|
||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr)
|
||||
: array(dt, shape, ptr, default_strides(shape, dt.itemsize()))
|
||||
{ }
|
||||
|
||||
array(const pybind11::dtype& dt, size_t size, void *ptr)
|
||||
: array(dt, std::vector<size_t> { size }, ptr)
|
||||
{ }
|
||||
|
||||
template<typename T> array(const std::vector<size_t>& shape,
|
||||
T* ptr, const std::vector<size_t>& strides)
|
||||
: array(pybind11::dtype::of<T>(), shape, (void *) ptr, strides)
|
||||
{ }
|
||||
|
||||
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
||||
: array(shape, ptr, default_strides(shape, sizeof(T)))
|
||||
{ }
|
||||
|
||||
template<typename T> array(size_t size, T* ptr)
|
||||
: array(std::vector<size_t> { size }, ptr)
|
||||
{ }
|
||||
|
||||
array(const buffer_info &info)
|
||||
: array(pybind11::dtype(info), info.shape, info.ptr, info.strides)
|
||||
{ }
|
||||
|
||||
pybind11::dtype dtype() {
|
||||
return attr("dtype").cast<pybind11::dtype>();
|
||||
@ -243,6 +259,18 @@ public:
|
||||
|
||||
protected:
|
||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||
|
||||
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
||||
auto ndim = shape.size();
|
||||
std::vector<size_t> strides(ndim);
|
||||
if (ndim) {
|
||||
std::fill(strides.begin(), strides.end(), itemsize);
|
||||
for (size_t i = 0; i < ndim - 1; i++)
|
||||
for (size_t j = 0; j < ndim - 1 - i; j++)
|
||||
strides[j] *= shape[ndim - 1 - i];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
|
Loading…
Reference in New Issue
Block a user