Add all possible ctors for py::array

This commit is contained in:
Ivan Smirnov 2016-07-24 18:35:14 +01:00
parent d77bc8c343
commit 6bb0ee1186

View File

@ -211,12 +211,16 @@ public:
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
}; };
template <typename Type> array(size_t size, const Type *ptr) { array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
void *ptr, const std::vector<size_t>& strides) {
auto& api = detail::npy_api::get(); auto& api = detail::npy_api::get();
auto descr = pybind11::dtype::of<Type>().release().ptr(); auto ndim = shape.size();
Py_intptr_t shape = (Py_intptr_t) size; if (shape.size() != strides.size())
object tmp = object(api.PyArray_NewFromDescr_( pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false); auto descr = dt;
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
if (!tmp) if (!tmp)
pybind11_fail("NumPy: unable to create array!"); pybind11_fail("NumPy: unable to create array!");
if (ptr) if (ptr)
@ -224,18 +228,30 @@ public:
m_ptr = tmp.release().ptr(); m_ptr = tmp.release().ptr();
} }
array(const buffer_info &info) { array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr)
auto& api = detail::npy_api::get(); : array(dt, shape, ptr, default_strides(shape, dt.itemsize()))
auto descr = pybind11::dtype(info).release().ptr(); { }
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0], array(const pybind11::dtype& dt, size_t size, void *ptr)
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false); : array(dt, std::vector<size_t> { size }, ptr)
if (!tmp) { }
pybind11_fail("NumPy: unable to create array!");
if (info.ptr) template<typename T> array(const std::vector<size_t>& shape,
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); T* ptr, const std::vector<size_t>& strides)
m_ptr = tmp.release().ptr(); : array(pybind11::dtype::of<T>(), shape, (void *) ptr, strides)
} { }
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
: array(shape, ptr, default_strides(shape, sizeof(T)))
{ }
template<typename T> array(size_t size, T* ptr)
: array(std::vector<size_t> { size }, ptr)
{ }
array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.ptr, info.strides)
{ }
pybind11::dtype dtype() { pybind11::dtype dtype() {
return attr("dtype").cast<pybind11::dtype>(); return attr("dtype").cast<pybind11::dtype>();
@ -243,6 +259,18 @@ public:
protected: protected:
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor; template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
auto ndim = shape.size();
std::vector<size_t> strides(ndim);
if (ndim) {
std::fill(strides.begin(), strides.end(), itemsize);
for (size_t i = 0; i < ndim - 1; i++)
for (size_t j = 0; j < ndim - 1 - i; j++)
strides[j] *= shape[ndim - 1 - i];
}
return strides;
}
}; };
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array { template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {