mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 13:47:53 +00:00
Merge pull request #402 from aldanor/feature/numpy-c-api
Add array methods via C API
This commit is contained in:
commit
f217c04195
@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||
|
||||
bool load(handle src, bool) {
|
||||
array_t<Scalar> buffer(src, true);
|
||||
if (!buffer.check())
|
||||
return false;
|
||||
array_t<Scalar> buf(src, true);
|
||||
if (!buf.check())
|
||||
return false;
|
||||
|
||||
auto info = buffer.request();
|
||||
if (info.ndim == 1) {
|
||||
if (buf.ndim() == 1) {
|
||||
typedef Eigen::InnerStride<> Strides;
|
||||
if (!isVector &&
|
||||
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
|
||||
@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
return false;
|
||||
|
||||
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
|
||||
info.shape[0] != (size_t) Type::SizeAtCompileTime)
|
||||
buf.shape(0) != (size_t) Type::SizeAtCompileTime)
|
||||
return false;
|
||||
|
||||
auto strides = Strides(info.strides[0] / sizeof(Scalar));
|
||||
|
||||
Strides::Index n_elts = (Strides::Index) info.shape[0];
|
||||
Strides::Index n_elts = (Strides::Index) buf.shape(0);
|
||||
Strides::Index unity = 1;
|
||||
|
||||
value = Eigen::Map<Type, 0, Strides>(
|
||||
(Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides);
|
||||
} else if (info.ndim == 2) {
|
||||
buf.mutable_data(),
|
||||
rowMajor ? unity : n_elts,
|
||||
rowMajor ? n_elts : unity,
|
||||
Strides(buf.strides(0) / sizeof(Scalar))
|
||||
);
|
||||
} else if (buf.ndim() == 2) {
|
||||
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
|
||||
|
||||
if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
|
||||
(Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
|
||||
if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
|
||||
(Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
|
||||
return false;
|
||||
|
||||
auto strides = Strides(
|
||||
info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
|
||||
info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
|
||||
|
||||
value = Eigen::Map<Type, 0, Strides>(
|
||||
(Scalar *) info.ptr,
|
||||
typename Strides::Index(info.shape[0]),
|
||||
typename Strides::Index(info.shape[1]), strides);
|
||||
buf.mutable_data(),
|
||||
typename Strides::Index(buf.shape(0)),
|
||||
typename Strides::Index(buf.shape(1)),
|
||||
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
|
||||
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
|
||||
);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
|
||||
}
|
||||
}
|
||||
|
||||
auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
|
||||
auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||
auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||
auto values = array_t<Scalar>((object) obj.attr("data"));
|
||||
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
||||
auto nnz = obj.attr("nnz").cast<Index>();
|
||||
|
||||
if (!valuesArray.check() || !innerIndicesArray.check() ||
|
||||
!outerIndicesArray.check())
|
||||
if (!values.check() || !innerIndices.check() || !outerIndices.check())
|
||||
return false;
|
||||
|
||||
auto outerIndices = outerIndicesArray.request();
|
||||
auto innerIndices = innerIndicesArray.request();
|
||||
auto values = valuesArray.request();
|
||||
|
||||
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
||||
shape[0].cast<Index>(),
|
||||
shape[1].cast<Index>(),
|
||||
nnz,
|
||||
static_cast<StorageIndex *>(outerIndices.ptr),
|
||||
static_cast<StorageIndex *>(innerIndices.ptr),
|
||||
static_cast<Scalar *>(values.ptr)
|
||||
);
|
||||
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
|
||||
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -19,23 +19,59 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <initializer_list>
|
||||
#include <functional>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
/* This will be true on all flat address space platforms and allows us to reduce the
|
||||
whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
|
||||
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
|
||||
upon the library user. */
|
||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
namespace detail {
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
template <typename type> struct is_pod_struct;
|
||||
|
||||
struct PyArrayDescr_Proxy {
|
||||
PyObject_HEAD
|
||||
PyObject *typeobj;
|
||||
char kind;
|
||||
char type;
|
||||
char byteorder;
|
||||
char flags;
|
||||
int type_num;
|
||||
int elsize;
|
||||
int alignment;
|
||||
char *subarray;
|
||||
PyObject *fields;
|
||||
PyObject *names;
|
||||
};
|
||||
|
||||
struct PyArray_Proxy {
|
||||
PyObject_HEAD
|
||||
char *data;
|
||||
int nd;
|
||||
ssize_t *dimensions;
|
||||
ssize_t *strides;
|
||||
PyObject *base;
|
||||
PyObject *descr;
|
||||
int flags;
|
||||
};
|
||||
|
||||
struct npy_api {
|
||||
enum constants {
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
NPY_ARRAY_OWNDATA_ = 0x0004,
|
||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||
NPY_ARRAY_ALIGNED_ = 0x0100,
|
||||
NPY_ARRAY_WRITEABLE_ = 0x0400,
|
||||
NPY_BOOL_ = 0,
|
||||
NPY_BYTE_, NPY_UBYTE_,
|
||||
NPY_SHORT_, NPY_USHORT_,
|
||||
@ -111,7 +147,14 @@ private:
|
||||
return api;
|
||||
}
|
||||
};
|
||||
}
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
#define PyArray_GET_(ptr, attr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||
#define PyArrayDescr_GET_(ptr, attr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
||||
|
||||
class dtype : public object {
|
||||
public:
|
||||
@ -150,15 +193,15 @@ public:
|
||||
}
|
||||
|
||||
size_t itemsize() const {
|
||||
return attr("itemsize").cast<size_t>();
|
||||
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
|
||||
}
|
||||
|
||||
bool has_fields() const {
|
||||
return attr("fields").cast<object>().ptr() != Py_None;
|
||||
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
|
||||
}
|
||||
|
||||
std::string kind() const {
|
||||
return (std::string) attr("kind").cast<pybind11::str>();
|
||||
char kind() const {
|
||||
return PyArrayDescr_GET_(m_ptr, kind);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -171,20 +214,20 @@ private:
|
||||
dtype strip_padding() {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
auto fields = attr("fields").cast<object>();
|
||||
if (fields.ptr() == Py_None)
|
||||
if (!has_fields())
|
||||
return *this;
|
||||
|
||||
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
|
||||
std::vector<field_descr> field_descriptors;
|
||||
|
||||
auto fields = attr("fields").cast<object>();
|
||||
auto items = fields.attr("items").cast<object>();
|
||||
for (auto field : items()) {
|
||||
auto spec = object(field, true).cast<tuple>();
|
||||
auto name = spec[0].cast<pybind11::str>();
|
||||
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
||||
if (!len(name) && format.kind() == "V")
|
||||
if (!len(name) && format.kind() == 'V')
|
||||
continue;
|
||||
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
|
||||
}
|
||||
@ -215,7 +258,7 @@ public:
|
||||
};
|
||||
|
||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
|
||||
const std::vector<size_t>& strides, void *ptr = nullptr) {
|
||||
const std::vector<size_t>& strides, const void *ptr = nullptr) {
|
||||
auto& api = detail::npy_api::get();
|
||||
auto ndim = shape.size();
|
||||
if (shape.size() != strides.size())
|
||||
@ -223,7 +266,7 @@ public:
|
||||
auto descr = dt;
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
||||
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
|
||||
(Py_intptr_t *) strides.data(), const_cast<void *>(ptr), 0, nullptr), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (ptr)
|
||||
@ -231,31 +274,139 @@ public:
|
||||
m_ptr = tmp.release().ptr();
|
||||
}
|
||||
|
||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr = nullptr)
|
||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, const void *ptr = nullptr)
|
||||
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
|
||||
|
||||
array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr)
|
||||
array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr)
|
||||
: array(dt, std::vector<size_t> { count }, ptr) { }
|
||||
|
||||
template<typename T> array(const std::vector<size_t>& shape,
|
||||
const std::vector<size_t>& strides, T* ptr)
|
||||
const std::vector<size_t>& strides, const T* ptr)
|
||||
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
|
||||
|
||||
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
||||
template<typename T> array(const std::vector<size_t>& shape, const T* ptr)
|
||||
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
||||
|
||||
template<typename T> array(size_t size, T* ptr)
|
||||
: array(std::vector<size_t> { size }, ptr) { }
|
||||
template<typename T> array(size_t count, const T* ptr)
|
||||
: array(std::vector<size_t> { count }, ptr) { }
|
||||
|
||||
array(const buffer_info &info)
|
||||
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
||||
|
||||
pybind11::dtype dtype() {
|
||||
return attr("dtype").cast<pybind11::dtype>();
|
||||
/// Array descriptor (dtype)
|
||||
pybind11::dtype dtype() const {
|
||||
return object(PyArray_GET_(m_ptr, descr), true);
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
size_t size() const {
|
||||
return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
/// Byte size of a single element
|
||||
size_t itemsize() const {
|
||||
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
|
||||
}
|
||||
|
||||
/// Total number of bytes
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
|
||||
/// Number of dimensions
|
||||
size_t ndim() const {
|
||||
return (size_t) PyArray_GET_(m_ptr, nd);
|
||||
}
|
||||
|
||||
/// Dimensions of the array
|
||||
const size_t* shape() const {
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||
}
|
||||
|
||||
/// Dimension along a given axis
|
||||
size_t shape(size_t dim) const {
|
||||
if (dim >= ndim())
|
||||
fail_dim_check(dim, "invalid axis");
|
||||
return shape()[dim];
|
||||
}
|
||||
|
||||
/// Strides of the array
|
||||
const size_t* strides() const {
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
||||
}
|
||||
|
||||
/// Stride along a given axis
|
||||
size_t strides(size_t dim) const {
|
||||
if (dim >= ndim())
|
||||
fail_dim_check(dim, "invalid axis");
|
||||
return strides()[dim];
|
||||
}
|
||||
|
||||
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||
bool writeable() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||
}
|
||||
|
||||
/// If set, the array owns the data (will be freed when the array is deleted)
|
||||
bool owndata() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||
}
|
||||
|
||||
/// Pointer to the contained data. If index is not provided, points to the
|
||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> const void* data(Ix&&... index) const {
|
||||
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Mutable pointer to the contained data. If index is not provided, points to the
|
||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||
/// May throw if the array is not writeable.
|
||||
template<typename... Ix> void* mutable_data(Ix&&... index) {
|
||||
check_writeable();
|
||||
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Byte offset from beginning of the array to a given index (full or partial).
|
||||
/// May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> size_t offset_at(Ix&&... index) const {
|
||||
if (sizeof...(index) > ndim())
|
||||
fail_dim_check(sizeof...(index), "too many indices for an array");
|
||||
return get_byte_offset(index...);
|
||||
}
|
||||
|
||||
size_t offset_at() const { return 0; }
|
||||
|
||||
/// Item count from beginning of the array to a given index (full or partial).
|
||||
/// May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> size_t index_at(Ix&&... index) const {
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||
template<typename, typename> friend struct detail::npy_format_descriptor;
|
||||
|
||||
void fail_dim_check(size_t dim, const std::string& msg) const {
|
||||
throw index_error(msg + ": " + std::to_string(dim) +
|
||||
" (ndim = " + std::to_string(ndim()) + ")");
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t get_byte_offset(Ix&&... index) const {
|
||||
const size_t idx[] = { (size_t) index... };
|
||||
if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{})) {
|
||||
auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{});
|
||||
throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
|
||||
" is out of bounds for axis " + std::to_string(mismatch.first - idx) +
|
||||
" with size " + std::to_string(*mismatch.second));
|
||||
}
|
||||
return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0);
|
||||
}
|
||||
|
||||
size_t get_byte_offset() const { return 0; }
|
||||
|
||||
void check_writeable() const {
|
||||
if (!writeable())
|
||||
throw std::runtime_error("array is not writeable");
|
||||
}
|
||||
|
||||
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
||||
auto ndim = shape.size();
|
||||
@ -278,14 +429,46 @@ public:
|
||||
|
||||
array_t(const buffer_info& info) : array(info) { }
|
||||
|
||||
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, T* ptr = nullptr)
|
||||
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr = nullptr)
|
||||
: array(shape, strides, ptr) { }
|
||||
|
||||
array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
|
||||
array_t(const std::vector<size_t>& shape, const T* ptr = nullptr)
|
||||
: array(shape, ptr) { }
|
||||
|
||||
array_t(size_t size, T* ptr = nullptr)
|
||||
: array(size, ptr) { }
|
||||
array_t(size_t count, const T* ptr = nullptr)
|
||||
: array(count, ptr) { }
|
||||
|
||||
constexpr size_t itemsize() const {
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t index_at(Ix&... index) const {
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
template<typename... Ix> const T* data(Ix&&... index) const {
|
||||
return static_cast<const T*>(array::data(index...));
|
||||
}
|
||||
|
||||
template<typename... Ix> T* mutable_data(Ix&&... index) {
|
||||
return static_cast<T*>(array::mutable_data(index...));
|
||||
}
|
||||
|
||||
// Reference to element at a given index
|
||||
template<typename... Ix> const T& at(Ix&&... index) const {
|
||||
if (sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||
return *(static_cast<const T*>(array::data()) + get_byte_offset(index...) / itemsize());
|
||||
}
|
||||
|
||||
// Mutable reference to element at a given index
|
||||
template<typename... Ix> T& mutable_at(Ix&&... index) {
|
||||
if (sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||
return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
|
||||
}
|
||||
|
||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||
|
||||
@ -678,16 +861,13 @@ struct vectorize_helper {
|
||||
if (size == 1)
|
||||
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||
|
||||
array result(buffer_info(nullptr, sizeof(Return),
|
||||
format_descriptor<Return>::format(),
|
||||
ndim, shape, strides));
|
||||
|
||||
buffer_info buf = result.request();
|
||||
Return *output = (Return *) buf.ptr;
|
||||
array_t<Return> result(shape, strides);
|
||||
auto buf = result.request();
|
||||
auto output = (Return *) buf.ptr;
|
||||
|
||||
if (trivial_broadcast) {
|
||||
/* Call the function */
|
||||
for (size_t i=0; i<size; ++i) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
output[i] = f((buffers[Index].size == 1
|
||||
? *((Args *) buffers[Index].ptr)
|
||||
: ((Args *) buffers[Index].ptr)[i])...);
|
||||
|
@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
|
||||
test_kwargs_and_defaults.cpp
|
||||
test_methods_and_attributes.cpp
|
||||
test_modules.cpp
|
||||
test_numpy_array.cpp
|
||||
test_numpy_dtypes.cpp
|
||||
test_numpy_vectorize.cpp
|
||||
test_opaque_types.cpp
|
||||
|
94
tests/test_numpy_array.cpp
Normal file
94
tests/test_numpy_array.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
/*
|
||||
tests/test_numpy_array.cpp -- test core array functionality
|
||||
|
||||
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include "pybind11_tests.h"
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
using arr = py::array;
|
||||
using arr_t = py::array_t<uint16_t, 0>;
|
||||
|
||||
template<typename... Ix> arr data(const arr& a, Ix&&... index) {
|
||||
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
|
||||
}
|
||||
|
||||
template<typename... Ix> arr data_t(const arr_t& a, Ix&&... index) {
|
||||
return arr(a.size() - a.index_at(index...), a.data(index...));
|
||||
}
|
||||
|
||||
arr& mutate_data(arr& a) {
|
||||
auto ptr = (uint8_t *) a.mutable_data();
|
||||
for (size_t i = 0; i < a.nbytes(); i++)
|
||||
ptr[i] = (uint8_t) (ptr[i] * 2);
|
||||
return a;
|
||||
}
|
||||
|
||||
arr_t& mutate_data_t(arr_t& a) {
|
||||
auto ptr = a.mutable_data();
|
||||
for (size_t i = 0; i < a.size(); i++)
|
||||
ptr[i]++;
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> arr& mutate_data(arr& a, Ix&&... index) {
|
||||
auto ptr = (uint8_t *) a.mutable_data(index...);
|
||||
for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++)
|
||||
ptr[i] = (uint8_t) (ptr[i] * 2);
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix&&... index) {
|
||||
auto ptr = a.mutable_data(index...);
|
||||
for (size_t i = 0; i < a.size() - a.index_at(index...); i++)
|
||||
ptr[i]++;
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t index_at(const arr& a, Ix&&... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t index_at_t(const arr_t& a, Ix&&... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at(const arr& a, Ix&&... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at_t(const arr_t& a, Ix&&... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t at_t(const arr_t& a, Ix&&... idx) { return a.at(idx...); }
|
||||
template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix&&... idx) { a.mutable_at(idx...)++; return a; }
|
||||
|
||||
#define def_index_fn(name, type) \
|
||||
sm.def(#name, [](type a) { return name(a); }); \
|
||||
sm.def(#name, [](type a, int i) { return name(a, i); }); \
|
||||
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
|
||||
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
|
||||
|
||||
test_initializer numpy_array([](py::module &m) {
|
||||
auto sm = m.def_submodule("array");
|
||||
|
||||
sm.def("ndim", [](const arr& a) { return a.ndim(); });
|
||||
sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
|
||||
sm.def("shape", [](const arr& a, size_t dim) { return a.shape(dim); });
|
||||
sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); });
|
||||
sm.def("strides", [](const arr& a, size_t dim) { return a.strides(dim); });
|
||||
sm.def("writeable", [](const arr& a) { return a.writeable(); });
|
||||
sm.def("size", [](const arr& a) { return a.size(); });
|
||||
sm.def("itemsize", [](const arr& a) { return a.itemsize(); });
|
||||
sm.def("nbytes", [](const arr& a) { return a.nbytes(); });
|
||||
sm.def("owndata", [](const arr& a) { return a.owndata(); });
|
||||
|
||||
def_index_fn(data, const arr&);
|
||||
def_index_fn(data_t, const arr_t&);
|
||||
def_index_fn(index_at, const arr&);
|
||||
def_index_fn(index_at_t, const arr_t&);
|
||||
def_index_fn(offset_at, const arr&);
|
||||
def_index_fn(offset_at_t, const arr_t&);
|
||||
def_index_fn(mutate_data, arr&);
|
||||
def_index_fn(mutate_data_t, arr_t&);
|
||||
def_index_fn(at_t, const arr_t&);
|
||||
def_index_fn(mutate_at_t, arr_t&);
|
||||
});
|
150
tests/test_numpy_array.py
Normal file
150
tests/test_numpy_array.py
Normal file
@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
|
||||
with pytest.suppress(ImportError):
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def arr():
|
||||
return np.array([[1, 2, 3], [4, 5, 6]], '<u2')
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_array_attributes():
|
||||
from pybind11_tests.array import (
|
||||
ndim, shape, strides, writeable, size, itemsize, nbytes, owndata
|
||||
)
|
||||
|
||||
a = np.array(0, 'f8')
|
||||
assert ndim(a) == 0
|
||||
assert all(shape(a) == [])
|
||||
assert all(strides(a) == [])
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
shape(a, 0)
|
||||
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
strides(a, 0)
|
||||
assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)'
|
||||
assert writeable(a)
|
||||
assert size(a) == 1
|
||||
assert itemsize(a) == 8
|
||||
assert nbytes(a) == 8
|
||||
assert owndata(a)
|
||||
|
||||
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
|
||||
a.flags.writeable = False
|
||||
assert ndim(a) == 2
|
||||
assert all(shape(a) == [2, 3])
|
||||
assert shape(a, 0) == 2
|
||||
assert shape(a, 1) == 3
|
||||
assert all(strides(a) == [6, 2])
|
||||
assert strides(a, 0) == 6
|
||||
assert strides(a, 1) == 2
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
shape(a, 2)
|
||||
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
strides(a, 2)
|
||||
assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)'
|
||||
assert not writeable(a)
|
||||
assert size(a) == 6
|
||||
assert itemsize(a) == 2
|
||||
assert nbytes(a) == 12
|
||||
assert not owndata(a)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)])
|
||||
def test_index_offset(arr, args, ret):
|
||||
from pybind11_tests.array import index_at, index_at_t, offset_at, offset_at_t
|
||||
assert index_at(arr, *args) == ret
|
||||
assert index_at_t(arr, *args) == ret
|
||||
assert offset_at(arr, *args) == ret * arr.dtype.itemsize
|
||||
assert offset_at_t(arr, *args) == ret * arr.dtype.itemsize
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_dim_check_fail(arr):
|
||||
from pybind11_tests.array import (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
|
||||
mutate_data, mutate_data_t)
|
||||
for func in (index_at, index_at_t, offset_at, offset_at_t, data, data_t,
|
||||
mutate_data, mutate_data_t):
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, 1, 2, 3)
|
||||
assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)'
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
@pytest.mark.parametrize('args, ret',
|
||||
[([], [1, 2, 3, 4, 5, 6]),
|
||||
([1], [4, 5, 6]),
|
||||
([0, 1], [2, 3, 4, 5, 6]),
|
||||
([1, 2], [6])])
|
||||
def test_data(arr, args, ret):
|
||||
from pybind11_tests.array import data, data_t
|
||||
assert all(data_t(arr, *args) == ret)
|
||||
assert all(data(arr, *args)[::2] == ret)
|
||||
assert all(data(arr, *args)[1::2] == 0)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_mutate_readonly(arr):
|
||||
from pybind11_tests.array import mutate_data, mutate_data_t, mutate_at_t
|
||||
arr.flags.writeable = False
|
||||
for func, args in (mutate_data, ()), (mutate_data_t, ()), (mutate_at_t, (0, 0)):
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
func(arr, *args)
|
||||
assert str(excinfo.value) == 'array is not writeable'
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
@pytest.mark.parametrize('dim', [0, 1, 3])
|
||||
def test_at_fail(arr, dim):
|
||||
from pybind11_tests.array import at_t, mutate_at_t
|
||||
for func in at_t, mutate_at_t:
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, *([0] * dim))
|
||||
assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_at(arr):
|
||||
from pybind11_tests.array import at_t, mutate_at_t
|
||||
|
||||
assert at_t(arr, 0, 2) == 3
|
||||
assert at_t(arr, 1, 0) == 4
|
||||
|
||||
assert all(mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||
assert all(mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_mutate_data(arr):
|
||||
from pybind11_tests.array import mutate_data, mutate_data_t
|
||||
|
||||
assert all(mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12])
|
||||
assert all(mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24])
|
||||
assert all(mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48])
|
||||
assert all(mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96])
|
||||
assert all(mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192])
|
||||
|
||||
assert all(mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193])
|
||||
assert all(mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194])
|
||||
assert all(mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195])
|
||||
assert all(mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196])
|
||||
assert all(mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_bounds_check(arr):
|
||||
from pybind11_tests.array import (index_at, index_at_t, data, data_t,
|
||||
mutate_data, mutate_data_t, at_t, mutate_at_t)
|
||||
funcs = (index_at, index_at_t, data, data_t,
|
||||
mutate_data, mutate_data_t, at_t, mutate_at_t)
|
||||
for func in funcs:
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
index_at(arr, 2, 0)
|
||||
assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2'
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
index_at(arr, 0, 4)
|
||||
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
|
Loading…
Reference in New Issue
Block a user