mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
Support arrays inside PYBIND11_NUMPY_DTYPE (#832)
Resolves #800. Both C++ arrays and std::array are supported, including mixtures like std::array<int, 2>[4]. In a multi-dimensional array of char, the last dimension is used to construct a numpy string type.
This commit is contained in:
parent
78f1dcf98f
commit
8e0d832c7d
@ -198,6 +198,10 @@ expects the type followed by field names:
|
|||||||
/* now both A and B can be used as template arguments to py::array_t */
|
/* now both A and B can be used as template arguments to py::array_t */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
The structure should consist of fundamental arithmetic types, previously
|
||||||
|
registered substructures, and arrays of any of the above. Both C++ arrays and
|
||||||
|
``std::array`` are supported.
|
||||||
|
|
||||||
Vectorizing functions
|
Vectorizing functions
|
||||||
=====================
|
=====================
|
||||||
|
|
||||||
|
@ -246,6 +246,46 @@ template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::tru
|
|||||||
template <typename T> struct is_complex : std::false_type { };
|
template <typename T> struct is_complex : std::false_type { };
|
||||||
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
||||||
|
|
||||||
|
template <typename T> struct array_info_scalar {
|
||||||
|
typedef T type;
|
||||||
|
static constexpr bool is_array = false;
|
||||||
|
static constexpr bool is_empty = false;
|
||||||
|
static PYBIND11_DESCR extents() { return _(""); }
|
||||||
|
static void append_extents(list& /* shape */) { }
|
||||||
|
};
|
||||||
|
// Computes underlying type and a comma-separated list of extents for array
|
||||||
|
// types (any mix of std::array and built-in arrays). An array of char is
|
||||||
|
// treated as scalar because it gets special handling.
|
||||||
|
template <typename T> struct array_info : array_info_scalar<T> { };
|
||||||
|
template <typename T, size_t N> struct array_info<std::array<T, N>> {
|
||||||
|
using type = typename array_info<T>::type;
|
||||||
|
static constexpr bool is_array = true;
|
||||||
|
static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
|
||||||
|
static constexpr size_t extent = N;
|
||||||
|
|
||||||
|
// appends the extents to shape
|
||||||
|
static void append_extents(list& shape) {
|
||||||
|
shape.append(N);
|
||||||
|
array_info<T>::append_extents(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
|
||||||
|
static PYBIND11_DESCR extents() {
|
||||||
|
return _<N>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
|
||||||
|
static PYBIND11_DESCR extents() {
|
||||||
|
return concat(_<N>(), array_info<T>::extents());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// For numpy we have special handling for arrays of characters, so we don't include
|
||||||
|
// the size in the array extents.
|
||||||
|
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
|
||||||
|
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
|
||||||
|
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
|
||||||
|
template <typename T> using remove_all_extents_t = typename array_info<T>::type;
|
||||||
|
|
||||||
template <typename T> using is_pod_struct = all_of<
|
template <typename T> using is_pod_struct = all_of<
|
||||||
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
|
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
|
||||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||||
@ -745,6 +785,8 @@ protected:
|
|||||||
|
|
||||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||||
public:
|
public:
|
||||||
|
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
|
||||||
|
|
||||||
using value_type = T;
|
using value_type = T;
|
||||||
|
|
||||||
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
||||||
@ -871,6 +913,15 @@ struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
|
||||||
|
static std::string format() {
|
||||||
|
using detail::_;
|
||||||
|
PYBIND11_DESCR extents = _("(") + detail::array_info<T>::extents() + _(")");
|
||||||
|
return extents.text() + format_descriptor<detail::remove_all_extents_t<T>>::format();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
NAMESPACE_BEGIN(detail)
|
NAMESPACE_BEGIN(detail)
|
||||||
template <typename T, int ExtraFlags>
|
template <typename T, int ExtraFlags>
|
||||||
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||||
@ -939,6 +990,20 @@ template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_F
|
|||||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
||||||
#undef PYBIND11_DECL_CHAR_FMT
|
#undef PYBIND11_DECL_CHAR_FMT
|
||||||
|
|
||||||
|
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
|
||||||
|
private:
|
||||||
|
using base_descr = npy_format_descriptor<typename array_info<T>::type>;
|
||||||
|
public:
|
||||||
|
static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
|
||||||
|
|
||||||
|
static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
|
||||||
|
static pybind11::dtype dtype() {
|
||||||
|
list shape;
|
||||||
|
array_info<T>::append_extents(shape);
|
||||||
|
return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
||||||
private:
|
private:
|
||||||
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
||||||
|
@ -70,6 +70,13 @@ struct StringStruct {
|
|||||||
std::array<char, 3> b;
|
std::array<char, 3> b;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ArrayStruct {
|
||||||
|
char a[3][4];
|
||||||
|
int32_t b[2];
|
||||||
|
std::array<uint8_t, 3> c;
|
||||||
|
std::array<float, 2> d[4];
|
||||||
|
};
|
||||||
|
|
||||||
PYBIND11_PACKED(struct StructWithUglyNames {
|
PYBIND11_PACKED(struct StructWithUglyNames {
|
||||||
int8_t __x__;
|
int8_t __x__;
|
||||||
uint64_t __y__;
|
uint64_t __y__;
|
||||||
@ -91,6 +98,27 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
|
|||||||
return os << "'";
|
return os << "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const ArrayStruct& v) {
|
||||||
|
os << "a={";
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
if (i > 0)
|
||||||
|
os << ',';
|
||||||
|
os << '{';
|
||||||
|
for (int j = 0; j < 3; j++)
|
||||||
|
os << v.a[i][j] << ',';
|
||||||
|
os << v.a[i][3] << '}';
|
||||||
|
}
|
||||||
|
os << "},b={" << v.b[0] << ',' << v.b[1];
|
||||||
|
os << "},c={" << int(v.c[0]) << ',' << int(v.c[1]) << ',' << int(v.c[2]);
|
||||||
|
os << "},d={";
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
if (i > 0)
|
||||||
|
os << ',';
|
||||||
|
os << '{' << v.d[i][0] << ',' << v.d[i][1] << '}';
|
||||||
|
}
|
||||||
|
return os << '}';
|
||||||
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const EnumStruct& v) {
|
std::ostream& operator<<(std::ostream& os, const EnumStruct& v) {
|
||||||
return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y");
|
return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y");
|
||||||
}
|
}
|
||||||
@ -163,6 +191,24 @@ py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
|
|||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::array_t<ArrayStruct, 0> create_array_array(size_t n) {
|
||||||
|
auto arr = mkarray_via_buffer<ArrayStruct>(n);
|
||||||
|
auto ptr = (ArrayStruct *) arr.mutable_data();
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
for (size_t j = 0; j < 3; j++)
|
||||||
|
for (size_t k = 0; k < 4; k++)
|
||||||
|
ptr[i].a[j][k] = char('A' + (i * 100 + j * 10 + k) % 26);
|
||||||
|
for (size_t j = 0; j < 2; j++)
|
||||||
|
ptr[i].b[j] = int32_t(i * 1000 + j);
|
||||||
|
for (size_t j = 0; j < 3; j++)
|
||||||
|
ptr[i].c[j] = uint8_t(i * 10 + j);
|
||||||
|
for (size_t j = 0; j < 4; j++)
|
||||||
|
for (size_t k = 0; k < 2; k++)
|
||||||
|
ptr[i].d[j][k] = float(i) * 100.0f + float(j) * 10.0f + float(k);
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
|
||||||
py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
|
py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
|
||||||
auto arr = mkarray_via_buffer<EnumStruct>(n);
|
auto arr = mkarray_via_buffer<EnumStruct>(n);
|
||||||
auto ptr = (EnumStruct *) arr.mutable_data();
|
auto ptr = (EnumStruct *) arr.mutable_data();
|
||||||
@ -194,6 +240,7 @@ py::list print_format_descriptors() {
|
|||||||
py::format_descriptor<PartialStruct>::format(),
|
py::format_descriptor<PartialStruct>::format(),
|
||||||
py::format_descriptor<PartialNestedStruct>::format(),
|
py::format_descriptor<PartialNestedStruct>::format(),
|
||||||
py::format_descriptor<StringStruct>::format(),
|
py::format_descriptor<StringStruct>::format(),
|
||||||
|
py::format_descriptor<ArrayStruct>::format(),
|
||||||
py::format_descriptor<EnumStruct>::format()
|
py::format_descriptor<EnumStruct>::format()
|
||||||
};
|
};
|
||||||
auto l = py::list();
|
auto l = py::list();
|
||||||
@ -211,6 +258,7 @@ py::list print_dtypes() {
|
|||||||
py::str(py::dtype::of<PartialStruct>()),
|
py::str(py::dtype::of<PartialStruct>()),
|
||||||
py::str(py::dtype::of<PartialNestedStruct>()),
|
py::str(py::dtype::of<PartialNestedStruct>()),
|
||||||
py::str(py::dtype::of<StringStruct>()),
|
py::str(py::dtype::of<StringStruct>()),
|
||||||
|
py::str(py::dtype::of<ArrayStruct>()),
|
||||||
py::str(py::dtype::of<EnumStruct>()),
|
py::str(py::dtype::of<EnumStruct>()),
|
||||||
py::str(py::dtype::of<StructWithUglyNames>())
|
py::str(py::dtype::of<StructWithUglyNames>())
|
||||||
};
|
};
|
||||||
@ -351,6 +399,7 @@ test_initializer numpy_dtypes([](py::module &m) {
|
|||||||
PYBIND11_NUMPY_DTYPE(PartialStruct, bool_, uint_, float_, ldbl_);
|
PYBIND11_NUMPY_DTYPE(PartialStruct, bool_, uint_, float_, ldbl_);
|
||||||
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
||||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||||
|
PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d);
|
||||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||||
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
|
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
|
||||||
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
|
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
|
||||||
@ -378,6 +427,8 @@ test_initializer numpy_dtypes([](py::module &m) {
|
|||||||
m.def("get_format_unbound", &get_format_unbound);
|
m.def("get_format_unbound", &get_format_unbound);
|
||||||
m.def("create_string_array", &create_string_array);
|
m.def("create_string_array", &create_string_array);
|
||||||
m.def("print_string_array", &print_recarray<StringStruct>);
|
m.def("print_string_array", &print_recarray<StringStruct>);
|
||||||
|
m.def("create_array_array", &create_array_array);
|
||||||
|
m.def("print_array_array", &print_recarray<ArrayStruct>);
|
||||||
m.def("create_enum_array", &create_enum_array);
|
m.def("create_enum_array", &create_enum_array);
|
||||||
m.def("print_enum_array", &print_recarray<EnumStruct>);
|
m.def("print_enum_array", &print_recarray<EnumStruct>);
|
||||||
m.def("test_array_ctors", &test_array_ctors);
|
m.def("test_array_ctors", &test_array_ctors);
|
||||||
|
@ -86,6 +86,7 @@ def test_format_descriptors():
|
|||||||
partial_fmt,
|
partial_fmt,
|
||||||
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
|
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
|
||||||
"T{3s:a:3s:b:}",
|
"T{3s:a:3s:b:}",
|
||||||
|
"T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
|
||||||
'T{q:e1:B:e2:}'
|
'T{q:e1:B:e2:}'
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -103,6 +104,9 @@ def test_dtype(simple_dtype):
|
|||||||
partial_dtype_fmt(),
|
partial_dtype_fmt(),
|
||||||
partial_nested_fmt(),
|
partial_nested_fmt(),
|
||||||
"[('a', 'S3'), ('b', 'S3')]",
|
"[('a', 'S3'), ('b', 'S3')]",
|
||||||
|
("{{'names':['a','b','c','d'], " +
|
||||||
|
"'formats':[('S4', (3,)),('<i4', (2,)),('u1', (3,)),('<f4', (4, 2))], " +
|
||||||
|
"'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e),
|
||||||
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
|
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
|
||||||
"[('x', 'i1'), ('y', '" + e + "u8')]"
|
"[('x', 'i1'), ('y', '" + e + "u8')]"
|
||||||
]
|
]
|
||||||
@ -213,6 +217,31 @@ def test_string_array():
|
|||||||
assert dtype == arr.dtype
|
assert dtype == arr.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_array():
|
||||||
|
from pybind11_tests import create_array_array, print_array_array
|
||||||
|
from sys import byteorder
|
||||||
|
e = '<' if byteorder == 'little' else '>'
|
||||||
|
|
||||||
|
arr = create_array_array(3)
|
||||||
|
assert str(arr.dtype) == (
|
||||||
|
"{{'names':['a','b','c','d'], " +
|
||||||
|
"'formats':[('S4', (3,)),('<i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " +
|
||||||
|
"'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e)
|
||||||
|
assert print_array_array(arr) == [
|
||||||
|
"a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," +
|
||||||
|
"c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}",
|
||||||
|
"a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," +
|
||||||
|
"c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}",
|
||||||
|
"a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," +
|
||||||
|
"c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}",
|
||||||
|
]
|
||||||
|
assert arr['a'].tolist() == [[b'ABCD', b'KLMN', b'UVWX'],
|
||||||
|
[b'WXYZ', b'GHIJ', b'QRST'],
|
||||||
|
[b'STUV', b'CDEF', b'MNOP']]
|
||||||
|
assert arr['b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]]
|
||||||
|
assert create_array_array(0).dtype == arr.dtype
|
||||||
|
|
||||||
|
|
||||||
def test_enum_array():
|
def test_enum_array():
|
||||||
from pybind11_tests import create_enum_array, print_enum_array
|
from pybind11_tests import create_enum_array, print_enum_array
|
||||||
from sys import byteorder
|
from sys import byteorder
|
||||||
|
Loading…
Reference in New Issue
Block a user