Add all possible ctors for py::array

This commit is contained in:
Ivan Smirnov 2016-07-24 18:35:14 +01:00
parent d77bc8c343
commit 6bb0ee1186

View File

@ -211,12 +211,16 @@ public:
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};
template <typename Type> array(size_t size, const Type *ptr) {
array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
void *ptr, const std::vector<size_t>& strides) {
auto& api = detail::npy_api::get();
auto descr = pybind11::dtype::of<Type>().release().ptr();
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
auto ndim = shape.size();
if (shape.size() != strides.size())
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
auto descr = dt;
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (ptr)
@ -224,18 +228,30 @@ public:
m_ptr = tmp.release().ptr();
}
array(const buffer_info &info) {
auto& api = detail::npy_api::get();
auto descr = pybind11::dtype(info).release().ptr();
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (info.ptr)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
}
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr)
: array(dt, shape, ptr, default_strides(shape, dt.itemsize()))
{ }
array(const pybind11::dtype& dt, size_t size, void *ptr)
: array(dt, std::vector<size_t> { size }, ptr)
{ }
template<typename T> array(const std::vector<size_t>& shape,
T* ptr, const std::vector<size_t>& strides)
: array(pybind11::dtype::of<T>(), shape, (void *) ptr, strides)
{ }
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
: array(shape, ptr, default_strides(shape, sizeof(T)))
{ }
template<typename T> array(size_t size, T* ptr)
: array(std::vector<size_t> { size }, ptr)
{ }
array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.ptr, info.strides)
{ }
pybind11::dtype dtype() {
return attr("dtype").cast<pybind11::dtype>();
@ -243,6 +259,18 @@ public:
protected:
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
auto ndim = shape.size();
std::vector<size_t> strides(ndim);
if (ndim) {
std::fill(strides.begin(), strides.end(), itemsize);
for (size_t i = 0; i < ndim - 1; i++)
for (size_t j = 0; j < ndim - 1 - i; j++)
strides[j] *= shape[ndim - 1 - i];
}
return strides;
}
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {