mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-29 16:37:13 +00:00
Expose some dtype/array attributes via NumPy C API
This commit is contained in:
parent
720136bfa7
commit
91b3d681ad
@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
|||||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||||
|
|
||||||
bool load(handle src, bool) {
|
bool load(handle src, bool) {
|
||||||
array_t<Scalar> buffer(src, true);
|
array_t<Scalar> buf(src, true);
|
||||||
if (!buffer.check())
|
if (!buf.check())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto info = buffer.request();
|
if (buf.ndim() == 1) {
|
||||||
if (info.ndim == 1) {
|
|
||||||
typedef Eigen::InnerStride<> Strides;
|
typedef Eigen::InnerStride<> Strides;
|
||||||
if (!isVector &&
|
if (!isVector &&
|
||||||
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
|
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
|
||||||
@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
|
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
|
||||||
info.shape[0] != (size_t) Type::SizeAtCompileTime)
|
buf.shape(0) != (size_t) Type::SizeAtCompileTime)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto strides = Strides(info.strides[0] / sizeof(Scalar));
|
Strides::Index n_elts = (Strides::Index) buf.shape(0);
|
||||||
|
|
||||||
Strides::Index n_elts = (Strides::Index) info.shape[0];
|
|
||||||
Strides::Index unity = 1;
|
Strides::Index unity = 1;
|
||||||
|
|
||||||
value = Eigen::Map<Type, 0, Strides>(
|
value = Eigen::Map<Type, 0, Strides>(
|
||||||
(Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides);
|
buf.mutable_data(),
|
||||||
} else if (info.ndim == 2) {
|
rowMajor ? unity : n_elts,
|
||||||
|
rowMajor ? n_elts : unity,
|
||||||
|
Strides(buf.strides(0) / sizeof(Scalar))
|
||||||
|
);
|
||||||
|
} else if (buf.ndim() == 2) {
|
||||||
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
|
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
|
||||||
|
|
||||||
if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
|
if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
|
||||||
(Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
|
(Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto strides = Strides(
|
|
||||||
info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
|
|
||||||
info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
|
|
||||||
|
|
||||||
value = Eigen::Map<Type, 0, Strides>(
|
value = Eigen::Map<Type, 0, Strides>(
|
||||||
(Scalar *) info.ptr,
|
buf.mutable_data(),
|
||||||
typename Strides::Index(info.shape[0]),
|
typename Strides::Index(buf.shape(0)),
|
||||||
typename Strides::Index(info.shape[1]), strides);
|
typename Strides::Index(buf.shape(1)),
|
||||||
|
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
|
||||||
|
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
|
auto values = array_t<Scalar>((object) obj.attr("data"));
|
||||||
auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
|
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||||
auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
|
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||||
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
||||||
auto nnz = obj.attr("nnz").cast<Index>();
|
auto nnz = obj.attr("nnz").cast<Index>();
|
||||||
|
|
||||||
if (!valuesArray.check() || !innerIndicesArray.check() ||
|
if (!values.check() || !innerIndices.check() || !outerIndices.check())
|
||||||
!outerIndicesArray.check())
|
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto outerIndices = outerIndicesArray.request();
|
|
||||||
auto innerIndices = innerIndicesArray.request();
|
|
||||||
auto values = valuesArray.request();
|
|
||||||
|
|
||||||
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
||||||
shape[0].cast<Index>(),
|
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
|
||||||
shape[1].cast<Index>(),
|
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
|
||||||
nnz,
|
|
||||||
static_cast<StorageIndex *>(outerIndices.ptr),
|
|
||||||
static_cast<StorageIndex *>(innerIndices.ptr),
|
|
||||||
static_cast<Scalar *>(values.ptr)
|
|
||||||
);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(push)
|
#pragma warning(push)
|
||||||
@ -30,12 +31,41 @@ namespace detail {
|
|||||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||||
template <typename type> struct is_pod_struct;
|
template <typename type> struct is_pod_struct;
|
||||||
|
|
||||||
|
struct PyArrayDescr_Proxy {
|
||||||
|
PyObject_HEAD
|
||||||
|
PyObject *typeobj;
|
||||||
|
char kind;
|
||||||
|
char type;
|
||||||
|
char byteorder;
|
||||||
|
char flags;
|
||||||
|
int type_num;
|
||||||
|
int elsize;
|
||||||
|
int alignment;
|
||||||
|
char *subarray;
|
||||||
|
PyObject *fields;
|
||||||
|
PyObject *names;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PyArray_Proxy {
|
||||||
|
PyObject_HEAD
|
||||||
|
char *data;
|
||||||
|
int nd;
|
||||||
|
ssize_t *dimensions;
|
||||||
|
ssize_t *strides;
|
||||||
|
PyObject *base;
|
||||||
|
PyObject *descr;
|
||||||
|
int flags;
|
||||||
|
};
|
||||||
|
|
||||||
struct npy_api {
|
struct npy_api {
|
||||||
enum constants {
|
enum constants {
|
||||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||||
|
NPY_ARRAY_OWNDATA_ = 0x0004,
|
||||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||||
|
NPY_ARRAY_ALIGNED_ = 0x0100,
|
||||||
|
NPY_ARRAY_WRITEABLE_ = 0x0400,
|
||||||
NPY_BOOL_ = 0,
|
NPY_BOOL_ = 0,
|
||||||
NPY_BYTE_, NPY_UBYTE_,
|
NPY_BYTE_, NPY_UBYTE_,
|
||||||
NPY_SHORT_, NPY_USHORT_,
|
NPY_SHORT_, NPY_USHORT_,
|
||||||
@ -113,6 +143,11 @@ private:
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||||
|
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||||
|
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||||
|
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
||||||
|
|
||||||
class dtype : public object {
|
class dtype : public object {
|
||||||
public:
|
public:
|
||||||
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
||||||
@ -150,15 +185,15 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t itemsize() const {
|
size_t itemsize() const {
|
||||||
return attr("itemsize").cast<size_t>();
|
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_fields() const {
|
bool has_fields() const {
|
||||||
return attr("fields").cast<object>().ptr() != Py_None;
|
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string kind() const {
|
char kind() const {
|
||||||
return (std::string) attr("kind").cast<pybind11::str>();
|
return PyArrayDescr_GET_(m_ptr, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -171,20 +206,20 @@ private:
|
|||||||
dtype strip_padding() {
|
dtype strip_padding() {
|
||||||
// Recursively strip all void fields with empty names that are generated for
|
// Recursively strip all void fields with empty names that are generated for
|
||||||
// padding fields (as of NumPy v1.11).
|
// padding fields (as of NumPy v1.11).
|
||||||
auto fields = attr("fields").cast<object>();
|
if (!has_fields())
|
||||||
if (fields.ptr() == Py_None)
|
|
||||||
return *this;
|
return *this;
|
||||||
|
|
||||||
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
|
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
|
||||||
std::vector<field_descr> field_descriptors;
|
std::vector<field_descr> field_descriptors;
|
||||||
|
|
||||||
|
auto fields = attr("fields").cast<object>();
|
||||||
auto items = fields.attr("items").cast<object>();
|
auto items = fields.attr("items").cast<object>();
|
||||||
for (auto field : items()) {
|
for (auto field : items()) {
|
||||||
auto spec = object(field, true).cast<tuple>();
|
auto spec = object(field, true).cast<tuple>();
|
||||||
auto name = spec[0].cast<pybind11::str>();
|
auto name = spec[0].cast<pybind11::str>();
|
||||||
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
||||||
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
||||||
if (!len(name) && format.kind() == "V")
|
if (!len(name) && format.kind() == 'V')
|
||||||
continue;
|
continue;
|
||||||
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
|
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
|
||||||
}
|
}
|
||||||
@ -244,14 +279,83 @@ public:
|
|||||||
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
||||||
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
||||||
|
|
||||||
template<typename T> array(size_t size, T* ptr)
|
template<typename T> array(size_t count, T* ptr)
|
||||||
: array(std::vector<size_t> { size }, ptr) { }
|
: array(std::vector<size_t> { count }, ptr) { }
|
||||||
|
|
||||||
array(const buffer_info &info)
|
array(const buffer_info &info)
|
||||||
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
||||||
|
|
||||||
pybind11::dtype dtype() {
|
/// Array descriptor (dtype)
|
||||||
return attr("dtype").cast<pybind11::dtype>();
|
pybind11::dtype dtype() const {
|
||||||
|
return object(PyArray_GET_(m_ptr, descr), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Total number of elements
|
||||||
|
size_t size() const {
|
||||||
|
return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Byte size of a single element
|
||||||
|
size_t itemsize() const {
|
||||||
|
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Total number of bytes
|
||||||
|
size_t nbytes() const {
|
||||||
|
return size() * itemsize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of dimensions
|
||||||
|
size_t ndim() const {
|
||||||
|
return (size_t) PyArray_GET_(m_ptr, nd);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dimensions of the array
|
||||||
|
const size_t* shape() const {
|
||||||
|
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||||
|
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dimension along a given axis
|
||||||
|
size_t shape(size_t dim) const {
|
||||||
|
if (dim >= ndim())
|
||||||
|
pybind11_fail("NumPy: attempted to index shape beyond ndim");
|
||||||
|
return shape()[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Strides of the array
|
||||||
|
const size_t* strides() const {
|
||||||
|
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||||
|
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stride along a given axis
|
||||||
|
size_t strides(size_t dim) const {
|
||||||
|
if (dim >= ndim())
|
||||||
|
pybind11_fail("NumPy: attempted to index strides beyond ndim");
|
||||||
|
return strides()[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||||
|
bool writeable() const {
|
||||||
|
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If set, the array owns the data (will be freed when the array is deleted)
|
||||||
|
bool owndata() const {
|
||||||
|
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Direct pointer to contained buffer
|
||||||
|
const void* data() const {
|
||||||
|
return reinterpret_cast<const void *>(PyArray_GET_(m_ptr, data));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Direct mutable pointer to contained buffer (checks writeable flag)
|
||||||
|
void* mutable_data() {
|
||||||
|
if (!writeable())
|
||||||
|
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
||||||
|
return reinterpret_cast<void *>(PyArray_GET_(m_ptr, data));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -284,8 +388,18 @@ public:
|
|||||||
array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
|
array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
|
||||||
: array(shape, ptr) { }
|
: array(shape, ptr) { }
|
||||||
|
|
||||||
array_t(size_t size, T* ptr = nullptr)
|
array_t(size_t count, T* ptr = nullptr)
|
||||||
: array(size, ptr) { }
|
: array(count, ptr) { }
|
||||||
|
|
||||||
|
const T* data() const {
|
||||||
|
return reinterpret_cast<const T *>(PyArray_GET_(m_ptr, data));
|
||||||
|
}
|
||||||
|
|
||||||
|
T* mutable_data() {
|
||||||
|
if (!writeable())
|
||||||
|
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
||||||
|
return reinterpret_cast<T *>(PyArray_GET_(m_ptr, data));
|
||||||
|
}
|
||||||
|
|
||||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||||
|
|
||||||
@ -678,16 +792,13 @@ struct vectorize_helper {
|
|||||||
if (size == 1)
|
if (size == 1)
|
||||||
return cast(f(*((Args *) buffers[Index].ptr)...));
|
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||||
|
|
||||||
array result(buffer_info(nullptr, sizeof(Return),
|
array_t<Return> result(shape, strides);
|
||||||
format_descriptor<Return>::format(),
|
auto buf = result.request();
|
||||||
ndim, shape, strides));
|
auto output = (Return *) buf.ptr;
|
||||||
|
|
||||||
buffer_info buf = result.request();
|
|
||||||
Return *output = (Return *) buf.ptr;
|
|
||||||
|
|
||||||
if (trivial_broadcast) {
|
if (trivial_broadcast) {
|
||||||
/* Call the function */
|
/* Call the function */
|
||||||
for (size_t i=0; i<size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
output[i] = f((buffers[Index].size == 1
|
output[i] = f((buffers[Index].size == 1
|
||||||
? *((Args *) buffers[Index].ptr)
|
? *((Args *) buffers[Index].ptr)
|
||||||
: ((Args *) buffers[Index].ptr)[i])...);
|
: ((Args *) buffers[Index].ptr)[i])...);
|
||||||
|
@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
|
|||||||
test_kwargs_and_defaults.cpp
|
test_kwargs_and_defaults.cpp
|
||||||
test_methods_and_attributes.cpp
|
test_methods_and_attributes.cpp
|
||||||
test_modules.cpp
|
test_modules.cpp
|
||||||
|
test_numpy_array.cpp
|
||||||
test_numpy_dtypes.cpp
|
test_numpy_dtypes.cpp
|
||||||
test_numpy_vectorize.cpp
|
test_numpy_vectorize.cpp
|
||||||
test_opaque_types.cpp
|
test_opaque_types.cpp
|
||||||
|
45
tests/test_numpy_array.cpp
Normal file
45
tests/test_numpy_array.cpp
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/*
|
||||||
|
tests/test_numpy_array.cpp -- test core array functionality
|
||||||
|
|
||||||
|
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
|
||||||
|
|
||||||
|
All rights reserved. Use of this source code is governed by a
|
||||||
|
BSD-style license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "pybind11_tests.h"
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
test_initializer numpy_array([](py::module &m) {
|
||||||
|
m.def("get_arr_ndim", [](const py::array& arr) {
|
||||||
|
return arr.ndim();
|
||||||
|
});
|
||||||
|
m.def("get_arr_shape", [](const py::array& arr) {
|
||||||
|
return std::vector<size_t>(arr.shape(), arr.shape() + arr.ndim());
|
||||||
|
});
|
||||||
|
m.def("get_arr_shape", [](const py::array& arr, size_t dim) {
|
||||||
|
return arr.shape(dim);
|
||||||
|
});
|
||||||
|
m.def("get_arr_strides", [](const py::array& arr) {
|
||||||
|
return std::vector<size_t>(arr.strides(), arr.strides() + arr.ndim());
|
||||||
|
});
|
||||||
|
m.def("get_arr_strides", [](const py::array& arr, size_t dim) {
|
||||||
|
return arr.strides(dim);
|
||||||
|
});
|
||||||
|
m.def("get_arr_writeable", [](const py::array& arr) {
|
||||||
|
return arr.writeable();
|
||||||
|
});
|
||||||
|
m.def("get_arr_size", [](const py::array& arr) {
|
||||||
|
return arr.size();
|
||||||
|
});
|
||||||
|
m.def("get_arr_itemsize", [](const py::array& arr) {
|
||||||
|
return arr.itemsize();
|
||||||
|
});
|
||||||
|
m.def("get_arr_nbytes", [](const py::array& arr) {
|
||||||
|
return arr.nbytes();
|
||||||
|
});
|
||||||
|
m.def("get_arr_owndata", [](const py::array& arr) {
|
||||||
|
return arr.owndata();
|
||||||
|
});
|
||||||
|
});
|
43
tests/test_numpy_array.py
Normal file
43
tests/test_numpy_array.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
with pytest.suppress(ImportError):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_array_attributes():
|
||||||
|
from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable,
|
||||||
|
get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata)
|
||||||
|
|
||||||
|
a = np.array(0, 'f8')
|
||||||
|
assert get_arr_ndim(a) == 0
|
||||||
|
assert get_arr_shape(a) == []
|
||||||
|
assert get_arr_strides(a) == []
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
get_arr_shape(a, 1)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
get_arr_strides(a, 0)
|
||||||
|
assert get_arr_writeable(a)
|
||||||
|
assert get_arr_size(a) == 1
|
||||||
|
assert get_arr_itemsize(a) == 8
|
||||||
|
assert get_arr_nbytes(a) == 8
|
||||||
|
assert get_arr_owndata(a)
|
||||||
|
|
||||||
|
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
|
||||||
|
a.flags.writeable = False
|
||||||
|
assert get_arr_ndim(a) == 2
|
||||||
|
assert get_arr_shape(a) == [2, 3]
|
||||||
|
assert get_arr_shape(a, 0) == 2
|
||||||
|
assert get_arr_shape(a, 1) == 3
|
||||||
|
assert get_arr_strides(a) == [6, 2]
|
||||||
|
assert get_arr_strides(a, 0) == 6
|
||||||
|
assert get_arr_strides(a, 1) == 2
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
get_arr_shape(a, 2)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
get_arr_strides(a, 2)
|
||||||
|
assert not get_arr_writeable(a)
|
||||||
|
assert get_arr_size(a) == 6
|
||||||
|
assert get_arr_itemsize(a) == 2
|
||||||
|
assert get_arr_nbytes(a) == 12
|
||||||
|
assert not get_arr_owndata(a)
|
Loading…
Reference in New Issue
Block a user