Expose some dtype/array attributes via NumPy C API

This commit is contained in:
Ivan Smirnov 2016-08-29 02:41:05 +01:00
parent 720136bfa7
commit 91b3d681ad
5 changed files with 246 additions and 56 deletions

View File

@ -83,12 +83,11 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
static constexpr bool isVector = Type::IsVectorAtCompileTime; static constexpr bool isVector = Type::IsVectorAtCompileTime;
bool load(handle src, bool) { bool load(handle src, bool) {
array_t<Scalar> buffer(src, true); array_t<Scalar> buf(src, true);
if (!buffer.check()) if (!buf.check())
return false; return false;
auto info = buffer.request(); if (buf.ndim() == 1) {
if (info.ndim == 1) {
typedef Eigen::InnerStride<> Strides; typedef Eigen::InnerStride<> Strides;
if (!isVector && if (!isVector &&
!(Type::RowsAtCompileTime == Eigen::Dynamic && !(Type::RowsAtCompileTime == Eigen::Dynamic &&
@ -96,31 +95,32 @@ struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value &&
return false; return false;
if (Type::SizeAtCompileTime != Eigen::Dynamic && if (Type::SizeAtCompileTime != Eigen::Dynamic &&
info.shape[0] != (size_t) Type::SizeAtCompileTime) buf.shape(0) != (size_t) Type::SizeAtCompileTime)
return false; return false;
auto strides = Strides(info.strides[0] / sizeof(Scalar)); Strides::Index n_elts = (Strides::Index) buf.shape(0);
Strides::Index n_elts = (Strides::Index) info.shape[0];
Strides::Index unity = 1; Strides::Index unity = 1;
value = Eigen::Map<Type, 0, Strides>( value = Eigen::Map<Type, 0, Strides>(
(Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); buf.mutable_data(),
} else if (info.ndim == 2) { rowMajor ? unity : n_elts,
rowMajor ? n_elts : unity,
Strides(buf.strides(0) / sizeof(Scalar))
);
} else if (buf.ndim() == 2) {
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides; typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) || if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
(Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime)) (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
return false; return false;
auto strides = Strides(
info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
value = Eigen::Map<Type, 0, Strides>( value = Eigen::Map<Type, 0, Strides>(
(Scalar *) info.ptr, buf.mutable_data(),
typename Strides::Index(info.shape[0]), typename Strides::Index(buf.shape(0)),
typename Strides::Index(info.shape[1]), strides); typename Strides::Index(buf.shape(1)),
Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
);
} else { } else {
return false; return false;
} }
@ -222,28 +222,18 @@ struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::
} }
} }
auto valuesArray = array_t<Scalar>((object) obj.attr("data")); auto values = array_t<Scalar>((object) obj.attr("data"));
auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices")); auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr")); auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
auto nnz = obj.attr("nnz").cast<Index>(); auto nnz = obj.attr("nnz").cast<Index>();
if (!valuesArray.check() || !innerIndicesArray.check() || if (!values.check() || !innerIndices.check() || !outerIndices.check())
!outerIndicesArray.check())
return false; return false;
auto outerIndices = outerIndicesArray.request();
auto innerIndices = innerIndicesArray.request();
auto values = valuesArray.request();
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>( value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
shape[0].cast<Index>(), shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
shape[1].cast<Index>(), outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
nnz,
static_cast<StorageIndex *>(outerIndices.ptr),
static_cast<StorageIndex *>(innerIndices.ptr),
static_cast<Scalar *>(values.ptr)
);
return true; return true;
} }

View File

@ -19,6 +19,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <initializer_list> #include <initializer_list>
#include <functional>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(push) #pragma warning(push)
@ -30,12 +31,41 @@ namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { }; template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
template <typename type> struct is_pod_struct; template <typename type> struct is_pod_struct;
struct PyArrayDescr_Proxy {
PyObject_HEAD
PyObject *typeobj;
char kind;
char type;
char byteorder;
char flags;
int type_num;
int elsize;
int alignment;
char *subarray;
PyObject *fields;
PyObject *names;
};
struct PyArray_Proxy {
PyObject_HEAD
char *data;
int nd;
ssize_t *dimensions;
ssize_t *strides;
PyObject *base;
PyObject *descr;
int flags;
};
struct npy_api { struct npy_api {
enum constants { enum constants {
NPY_C_CONTIGUOUS_ = 0x0001, NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002, NPY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_OWNDATA_ = 0x0004,
NPY_ARRAY_FORCECAST_ = 0x0010, NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ENSURE_ARRAY_ = 0x0040, NPY_ENSURE_ARRAY_ = 0x0040,
NPY_ARRAY_ALIGNED_ = 0x0100,
NPY_ARRAY_WRITEABLE_ = 0x0400,
NPY_BOOL_ = 0, NPY_BOOL_ = 0,
NPY_BYTE_, NPY_UBYTE_, NPY_BYTE_, NPY_UBYTE_,
NPY_SHORT_, NPY_USHORT_, NPY_SHORT_, NPY_USHORT_,
@ -113,6 +143,11 @@ private:
}; };
} }
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
class dtype : public object { class dtype : public object {
public: public:
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
@ -150,15 +185,15 @@ public:
} }
size_t itemsize() const { size_t itemsize() const {
return attr("itemsize").cast<size_t>(); return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
} }
bool has_fields() const { bool has_fields() const {
return attr("fields").cast<object>().ptr() != Py_None; return PyArrayDescr_GET_(m_ptr, names) != nullptr;
} }
std::string kind() const { char kind() const {
return (std::string) attr("kind").cast<pybind11::str>(); return PyArrayDescr_GET_(m_ptr, kind);
} }
private: private:
@ -171,20 +206,20 @@ private:
dtype strip_padding() { dtype strip_padding() {
// Recursively strip all void fields with empty names that are generated for // Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11). // padding fields (as of NumPy v1.11).
auto fields = attr("fields").cast<object>(); if (!has_fields())
if (fields.ptr() == Py_None)
return *this; return *this;
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
std::vector<field_descr> field_descriptors; std::vector<field_descr> field_descriptors;
auto fields = attr("fields").cast<object>();
auto items = fields.attr("items").cast<object>(); auto items = fields.attr("items").cast<object>();
for (auto field : items()) { for (auto field : items()) {
auto spec = object(field, true).cast<tuple>(); auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>(); auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<dtype>(); auto format = spec[1].cast<tuple>()[0].cast<dtype>();
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>(); auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
if (!len(name) && format.kind() == "V") if (!len(name) && format.kind() == 'V')
continue; continue;
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset}); field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
} }
@ -244,14 +279,83 @@ public:
template<typename T> array(const std::vector<size_t>& shape, T* ptr) template<typename T> array(const std::vector<size_t>& shape, T* ptr)
: array(shape, default_strides(shape, sizeof(T)), ptr) { } : array(shape, default_strides(shape, sizeof(T)), ptr) { }
template<typename T> array(size_t size, T* ptr) template<typename T> array(size_t count, T* ptr)
: array(std::vector<size_t> { size }, ptr) { } : array(std::vector<size_t> { count }, ptr) { }
array(const buffer_info &info) array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
pybind11::dtype dtype() { /// Array descriptor (dtype)
return attr("dtype").cast<pybind11::dtype>(); pybind11::dtype dtype() const {
return object(PyArray_GET_(m_ptr, descr), true);
}
/// Total number of elements
size_t size() const {
return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
}
/// Byte size of a single element
size_t itemsize() const {
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
}
/// Total number of bytes
size_t nbytes() const {
return size() * itemsize();
}
/// Number of dimensions
size_t ndim() const {
return (size_t) PyArray_GET_(m_ptr, nd);
}
/// Dimensions of the array
const size_t* shape() const {
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
}
/// Dimension along a given axis
size_t shape(size_t dim) const {
if (dim >= ndim())
pybind11_fail("NumPy: attempted to index shape beyond ndim");
return shape()[dim];
}
/// Strides of the array
const size_t* strides() const {
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
}
/// Stride along a given axis
size_t strides(size_t dim) const {
if (dim >= ndim())
pybind11_fail("NumPy: attempted to index strides beyond ndim");
return strides()[dim];
}
/// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
}
/// If set, the array owns the data (will be freed when the array is deleted)
bool owndata() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
}
/// Direct pointer to contained buffer
const void* data() const {
return reinterpret_cast<const void *>(PyArray_GET_(m_ptr, data));
}
/// Direct mutable pointer to contained buffer (checks writeable flag)
void* mutable_data() {
if (!writeable())
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
return reinterpret_cast<void *>(PyArray_GET_(m_ptr, data));
} }
protected: protected:
@ -284,8 +388,18 @@ public:
array_t(const std::vector<size_t>& shape, T* ptr = nullptr) array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
: array(shape, ptr) { } : array(shape, ptr) { }
array_t(size_t size, T* ptr = nullptr) array_t(size_t count, T* ptr = nullptr)
: array(size, ptr) { } : array(count, ptr) { }
const T* data() const {
return reinterpret_cast<const T *>(PyArray_GET_(m_ptr, data));
}
T* mutable_data() {
if (!writeable())
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
return reinterpret_cast<T *>(PyArray_GET_(m_ptr, data));
}
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
@ -678,12 +792,9 @@ struct vectorize_helper {
if (size == 1) if (size == 1)
return cast(f(*((Args *) buffers[Index].ptr)...)); return cast(f(*((Args *) buffers[Index].ptr)...));
array result(buffer_info(nullptr, sizeof(Return), array_t<Return> result(shape, strides);
format_descriptor<Return>::format(), auto buf = result.request();
ndim, shape, strides)); auto output = (Return *) buf.ptr;
buffer_info buf = result.request();
Return *output = (Return *) buf.ptr;
if (trivial_broadcast) { if (trivial_broadcast) {
/* Call the function */ /* Call the function */

View File

@ -19,6 +19,7 @@ set(PYBIND11_TEST_FILES
test_kwargs_and_defaults.cpp test_kwargs_and_defaults.cpp
test_methods_and_attributes.cpp test_methods_and_attributes.cpp
test_modules.cpp test_modules.cpp
test_numpy_array.cpp
test_numpy_dtypes.cpp test_numpy_dtypes.cpp
test_numpy_vectorize.cpp test_numpy_vectorize.cpp
test_opaque_types.cpp test_opaque_types.cpp

View File

@ -0,0 +1,45 @@
/*
tests/test_numpy_array.cpp -- test core array functionality
Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include "pybind11_tests.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
test_initializer numpy_array([](py::module &m) {
m.def("get_arr_ndim", [](const py::array& arr) {
return arr.ndim();
});
m.def("get_arr_shape", [](const py::array& arr) {
return std::vector<size_t>(arr.shape(), arr.shape() + arr.ndim());
});
m.def("get_arr_shape", [](const py::array& arr, size_t dim) {
return arr.shape(dim);
});
m.def("get_arr_strides", [](const py::array& arr) {
return std::vector<size_t>(arr.strides(), arr.strides() + arr.ndim());
});
m.def("get_arr_strides", [](const py::array& arr, size_t dim) {
return arr.strides(dim);
});
m.def("get_arr_writeable", [](const py::array& arr) {
return arr.writeable();
});
m.def("get_arr_size", [](const py::array& arr) {
return arr.size();
});
m.def("get_arr_itemsize", [](const py::array& arr) {
return arr.itemsize();
});
m.def("get_arr_nbytes", [](const py::array& arr) {
return arr.nbytes();
});
m.def("get_arr_owndata", [](const py::array& arr) {
return arr.owndata();
});
});

43
tests/test_numpy_array.py Normal file
View File

@ -0,0 +1,43 @@
import pytest
with pytest.suppress(ImportError):
import numpy as np
@pytest.requires_numpy
def test_array_attributes():
from pybind11_tests import (get_arr_ndim, get_arr_shape, get_arr_strides, get_arr_writeable,
get_arr_size, get_arr_itemsize, get_arr_nbytes, get_arr_owndata)
a = np.array(0, 'f8')
assert get_arr_ndim(a) == 0
assert get_arr_shape(a) == []
assert get_arr_strides(a) == []
with pytest.raises(RuntimeError):
get_arr_shape(a, 1)
with pytest.raises(RuntimeError):
get_arr_strides(a, 0)
assert get_arr_writeable(a)
assert get_arr_size(a) == 1
assert get_arr_itemsize(a) == 8
assert get_arr_nbytes(a) == 8
assert get_arr_owndata(a)
a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view()
a.flags.writeable = False
assert get_arr_ndim(a) == 2
assert get_arr_shape(a) == [2, 3]
assert get_arr_shape(a, 0) == 2
assert get_arr_shape(a, 1) == 3
assert get_arr_strides(a) == [6, 2]
assert get_arr_strides(a, 0) == 6
assert get_arr_strides(a, 1) == 2
with pytest.raises(RuntimeError):
get_arr_shape(a, 2)
with pytest.raises(RuntimeError):
get_arr_strides(a, 2)
assert not get_arr_writeable(a)
assert get_arr_size(a) == 6
assert get_arr_itemsize(a) == 2
assert get_arr_nbytes(a) == 12
assert not get_arr_owndata(a)