mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Expose some dtype/array attributes via NumPy C API
This commit is contained in:
parent
720136bfa7
commit
91b3d681ad
@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||
|
||||
bool load(handle src, bool) {
|
||||
array_t<Scalar> buffer(src, true);
|
||||
if (!buffer.check())
|
||||
return false;
|
||||
array_t<Scalar> buf(src, true);
|
||||
if (!buf.check())
|
||||
return false;
|
||||
|
||||
auto info = buffer.request();
|
||||
if (info.ndim == 1) {
|
||||
if (buf.ndim() == 1) {
|
||||
typedef Eigen::InnerStride<> Strides;
|
||||
if (!isVector &&
|
||||
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
|
||||
@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
|
||||
return false;
|
||||
|
||||
if (Type::SizeAtCompileTime != Eigen::Dynamic &&
|
||||
info.shape[0] != (size_t) Type::SizeAtCompileTime)
|
||||
buf.shape(0) != (size_t) Type::SizeAtCompileTime)
|
||||
return false;
|
||||
|
||||
auto strides = Strides(info.strides[0] / sizeof(Scalar));
|
||||
|
||||
Strides::Index n_elts = (Strides::Index) info.shape[0];
|
||||
Strides::Index n_elts = (Strides::Index) buf.shape(0);
|
||||
Strides::Index unity = 1;
|
||||
|
||||
value = Eigen::Map<Type, 0, Strides>(
|
||||
(Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides);
|
||||
} else if (info.ndim == 2) {
|
||||
buf.mutable_data(),
|
||||
rowMajor ? unity : n_elts,
|
||||
rowMajor ? n_elts : unity,
|
||||
Strides(buf.strides(0) / sizeof(Scalar))
|
||||
);
|
||||
} else if (buf.ndim() == 2) {
|
||||
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
|
||||
|
||||
if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
|
||||
(Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
|
||||
if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
|
||||
(Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
|
||||
return false;
|
||||
|
||||
auto strides = Strides(
|
||||
info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
|
||||
info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
|
||||
|
||||
value = Eigen::Map<Type, 0, Strides>(
|
||||
(Scalar *) info.ptr,
|
||||
typename Strides::Index(info.shape[0]),
|
||||
typename Strides::Index(info.shape[1]), strides);
|
||||
buf.mutable_data(),
|
||||
typename Strides::Index(buf.shape(0)),
|
||||
typename Strides::Index(buf.shape(1)),
|
||||
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
|
||||
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
|
||||
);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
|
||||
}
|
||||
}
|
||||
|
||||
auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
|
||||
auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||
auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||
auto values = array_t<Scalar>((object) obj.attr("data"));
|
||||
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
||||
auto nnz = obj.attr("nnz").cast<Index>();
|
||||
|
||||
if (!valuesArray.check() || !innerIndicesArray.check() ||
|
||||
!outerIndicesArray.check())
|
||||
if (!values.check() || !innerIndices.check() || !outerIndices.check())
|
||||
return false;
|
||||
|
||||
auto outerIndices = outerIndicesArray.request();
|
||||
auto innerIndices = innerIndicesArray.request();
|
||||
auto values = valuesArray.request();
|
||||
|
||||
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
||||
shape[0].cast<Index>(),
|
||||
shape[1].cast<Index>(),
|
||||
nnz,
|
||||
static_cast<StorageIndex *>(outerIndices.ptr),
|
||||
static_cast<StorageIndex *>(innerIndices.ptr),
|
||||
static_cast<Scalar *>(values.ptr)
|
||||
);
|
||||
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
|
||||
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <initializer_list>
|
||||
#include <functional>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
@ -30,12 +31,41 @@ namespace detail {
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
template <typename type> struct is_pod_struct;
|
||||
|
||||
struct PyArrayDescr_Proxy {
|
||||
PyObject_HEAD
|
||||
PyObject *typeobj;
|
||||
char kind;
|
||||
char type;
|
||||
char byteorder;
|
||||
char flags;
|
||||
int type_num;
|
||||
int elsize;
|
||||
int alignment;
|
||||
char *subarray;
|
||||
PyObject *fields;
|
||||
PyObject *names;
|
||||
};
|
||||
|
||||
struct PyArray_Proxy {
|
||||
PyObject_HEAD
|
||||
char *data;
|
||||
int nd;
|
||||
ssize_t *dimensions;
|
||||
ssize_t *strides;
|
||||
PyObject *base;
|
||||
PyObject *descr;
|
||||
int flags;
|
||||
};
|
||||
|
||||
struct npy_api {
|
||||
enum constants {
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
NPY_ARRAY_OWNDATA_ = 0x0004,
|
||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||
NPY_ARRAY_ALIGNED_ = 0x0100,
|
||||
NPY_ARRAY_WRITEABLE_ = 0x0400,
|
||||
NPY_BOOL_ = 0,
|
||||
NPY_BYTE_, NPY_UBYTE_,
|
||||
NPY_SHORT_, NPY_USHORT_,
|
||||
@ -113,6 +143,11 @@ private:
|
||||
};
|
||||
}
|
||||
|
||||
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
||||
|
||||
class dtype : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
||||
@ -150,15 +185,15 @@ public:
|
||||
}
|
||||
|
||||
size_t itemsize() const {
|
||||
return attr("itemsize").cast<size_t>();
|
||||
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
|
||||
}
|
||||
|
||||
bool has_fields() const {
|
||||
return attr("fields").cast<object>().ptr() != Py_None;
|
||||
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
|
||||
}
|
||||
|
||||
std::string kind() const {
|
||||
return (std::string) attr("kind").cast<pybind11::str>();
|
||||
char kind() const {
|
||||
return PyArrayDescr_GET_(m_ptr, kind);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -171,20 +206,20 @@ private:
|
||||
dtype strip_padding() {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
auto fields = attr("fields").cast<object>();
|
||||
if (fields.ptr() == Py_None)
|
||||
if (!has_fields())
|
||||
return *this;
|
||||
|
||||
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
|
||||
std::vector<field_descr> field_descriptors;
|
||||
|
||||
auto fields = attr("fields").cast<object>();
|
||||
auto items = fields.attr("items").cast<object>();
|
||||
for (auto field : items()) {
|
||||
auto spec = object(field, true).cast<tuple>();
|
||||
auto name = spec[0].cast<pybind11::str>();
|
||||
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
||||
if (!len(name) && format.kind() == "V")
|
||||
if (!len(name) && format.kind() == 'V')
|
||||
continue;
|
||||
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
|
||||
}
|
||||
@ -244,14 +279,83 @@ public:
|
||||
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
||||
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
||||
|
||||
template<typename T> array(size_t size, T* ptr)
|
||||
: array(std::vector<size_t> { size }, ptr) { }
|
||||
template<typename T> array(size_t count, T* ptr)
|
||||
: array(std::vector<size_t> { count }, ptr) { }
|
||||
|
||||
array(const buffer_info &info)
|
||||
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
||||
|
||||
pybind11::dtype dtype() {
|
||||
return attr("dtype").cast<pybind11::dtype>();
|
||||
/// Array descriptor (dtype)
|
||||
pybind11::dtype dtype() const {
|
||||
return object(PyArray_GET_(m_ptr, descr), true);
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
size_t size() const {
|
||||
return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
/// Byte size of a single element
|
||||
size_t itemsize() const {
|
||||
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
|
||||
}
|
||||
|
||||
/// Total number of bytes
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
|
||||
/// Number of dimensions
|
||||
size_t ndim() const {
|
||||
return (size_t) PyArray_GET_(m_ptr, nd);
|
||||
}
|
||||
|
||||
/// Dimensions of the array
|
||||
const size_t* shape() const {
|
||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||
}
|
||||
|
||||
/// Dimension along a given axis
|
||||
size_t shape(size_t dim) const {
|
||||
if (dim >= ndim())
|
||||
pybind11_fail("NumPy: attempted to index shape beyond ndim");
|
||||
return shape()[dim];
|
||||
}
|
||||
|
||||
/// Strides of the array
|
||||
const size_t* strides() const {
|
||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
||||
}
|
||||
|
||||
/// Stride along a given axis
|
||||
size_t strides(size_t dim) const {
|
||||
if (dim >= ndim())
|
||||
pybind11_fail("NumPy: attempted to index strides beyond ndim");
|
||||
return strides()[dim];
|
||||
}
|
||||
|
||||
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||
bool writeable() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||
}
|
||||
|
||||
/// If set, the array owns the data (will be freed when the array is deleted)
|
||||
bool owndata() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||
}
|
||||
|
||||
/// Direct pointer to contained buffer
|
||||
const void* data() const {
|
||||
return reinterpret_cast<const void *>(PyArray_GET_(m_ptr, data));
|
||||
}
|
||||
|
||||
/// Direct mutable pointer to contained buffer (checks writeable flag)
|
||||
void* mutable_data() {
|
||||
if (!writeable())
|
||||
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
||||
return reinterpret_cast<void *>(PyArray_GET_(m_ptr, data));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -284,8 +388,18 @@ public:
|
||||
array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
|
||||
: array(shape, ptr) { }
|
||||
|
||||
array_t(size_t size, T* ptr = nullptr)
|
||||
: array(size, ptr) { }
|
||||
array_t(size_t count, T* ptr = nullptr)
|
||||
: array(count, ptr) { }
|
||||
|
||||
const T* data() const {
|
||||
return reinterpret_cast<const T *>(PyArray_GET_(m_ptr, data));
|
||||
}
|
||||
|
||||
T* mutable_data() {
|
||||
if (!writeable())
|
||||
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
||||
return reinterpret_cast<T *>(PyArray_GET_(m_ptr, data));
|
||||
}
|
||||
|
||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||
|
||||
@ -678,16 +792,13 @@ struct vectorize_helper {
|
||||
if (size == 1)
|
||||
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||
|
||||
array result(buffer_info(nullptr, sizeof(Return),
|
||||
format_descriptor<Return>::format(),
|
||||
ndim, shape, strides));
|
||||
|
||||
buffer_info buf = result.request();
|
||||
Return *output = (Return *) buf.ptr;
|
||||
array_t<Return> result(shape, strides);
|
||||
auto buf = result.request();
|
||||
auto output = (Return *) buf.ptr;
|
||||
|
||||
if (trivial_broadcast) {
|
||||
/* Call the function */
|
||||
for (size_t i=0; i<size; ++i) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
output[i] = f((buffers[Index].size == 1
|
||||
? *((Args *) buffers[Index].ptr)
|
||||
: ((Args *) buffers[Index].ptr)[i])...);
|
||||
|
@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
|
||||
test_kwargs_and_defaults.cpp
|
||||
test_methods_and_attributes.cpp
|
||||
test_modules.cpp
|
||||
test_numpy_array.cpp
|
||||
test_numpy_dtypes.cpp
|
||||
test_numpy_vectorize.cpp
|
||||
test_opaque_types.cpp
|
||||
|
45
tests/test_numpy_array.cpp
Normal file
45
tests/test_numpy_array.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
tests/test_numpy_array.cpp -- test core array functionality
|
||||
|
||||
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include "pybind11_tests.h"
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
test_initializer numpy_array([](py::module &m) {
|
||||
m.def("get_arr_ndim", [](const py::array& arr) {
|
||||
return arr.ndim();
|
||||
});
|
||||
m.def("get_arr_shape", [](const py::array& arr) {
|
||||
return std::vector<size_t>(arr.shape(), arr.shape() + arr.ndim());
|
||||
});
|
||||
m.def("get_arr_shape", [](const py::array& arr, size_t dim) {
|
||||
return arr.shape(dim);
|
||||
});
|
||||
m.def("get_arr_strides", [](const py::array& arr) {
|
||||
return std::vector<size_t>(arr.strides(), arr.strides() + arr.ndim());
|
||||
});
|
||||
m.def("get_arr_strides", [](const py::array& arr, size_t dim) {
|
||||
return arr.strides(dim);
|
||||
});
|
||||
m.def("get_arr_writeable", [](const py::array& arr) {
|
||||
return arr.writeable();
|
||||
});
|
||||
m.def("get_arr_size", [](const py::array& arr) {
|
||||
return arr.size();
|
||||
});
|
||||
m.def("get_arr_itemsize", [](const py::array& arr) {
|
||||
return arr.itemsize();
|
||||
});
|
||||
m.def("get_arr_nbytes", [](const py::array& arr) {
|
||||
return arr.nbytes();
|
||||
});
|
||||
m.def("get_arr_owndata", [](const py::array& arr) {
|
||||
return arr.owndata();
|
||||
});
|
||||
});
|
43
tests/test_numpy_array.py
Normal file
43
tests/test_numpy_array.py
Normal file
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
with pytest.suppress(ImportError):
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_array_attributes():
|
||||
from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable,
|
||||
get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata)
|
||||
|
||||
a = np.array(0, 'f8')
|
||||
assert get_arr_ndim(a) == 0
|
||||
assert get_arr_shape(a) == []
|
||||
assert get_arr_strides(a) == []
|
||||
with pytest.raises(RuntimeError):
|
||||
get_arr_shape(a, 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
get_arr_strides(a, 0)
|
||||
assert get_arr_writeable(a)
|
||||
assert get_arr_size(a) == 1
|
||||
assert get_arr_itemsize(a) == 8
|
||||
assert get_arr_nbytes(a) == 8
|
||||
assert get_arr_owndata(a)
|
||||
|
||||
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
|
||||
a.flags.writeable = False
|
||||
assert get_arr_ndim(a) == 2
|
||||
assert get_arr_shape(a) == [2, 3]
|
||||
assert get_arr_shape(a, 0) == 2
|
||||
assert get_arr_shape(a, 1) == 3
|
||||
assert get_arr_strides(a) == [6, 2]
|
||||
assert get_arr_strides(a, 0) == 6
|
||||
assert get_arr_strides(a, 1) == 2
|
||||
with pytest.raises(RuntimeError):
|
||||
get_arr_shape(a, 2)
|
||||
with pytest.raises(RuntimeError):
|
||||
get_arr_strides(a, 2)
|
||||
assert not get_arr_writeable(a)
|
||||
assert get_arr_size(a) == 6
|
||||
assert get_arr_itemsize(a) == 2
|
||||
assert get_arr_nbytes(a) == 12
|
||||
assert not get_arr_owndata(a)
|
Loading…
Reference in New Issue
Block a user