diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index 71123631a..9157e5031 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -198,6 +198,10 @@ expects the type followed by field names: /* now both A and B can be used as template arguments to py::array_t */ } +The structure should consist of fundamental arithmetic types, previously +registered substructures, and arrays of any of the above. Both C++ arrays and +``std::array`` are supported. + Vectorizing functions ===================== diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 72bb35001..62901b85e 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -246,6 +246,46 @@ template struct is_std_array> : std::tru template struct is_complex : std::false_type { }; template struct is_complex> : std::true_type { }; +template struct array_info_scalar { + typedef T type; + static constexpr bool is_array = false; + static constexpr bool is_empty = false; + static PYBIND11_DESCR extents() { return _(""); } + static void append_extents(list& /* shape */) { } +}; +// Computes underlying type and a comma-separated list of extents for array +// types (any mix of std::array and built-in arrays). An array of char is +// treated as scalar because it gets special handling. +template struct array_info : array_info_scalar { }; +template struct array_info> { + using type = typename array_info::type; + static constexpr bool is_array = true; + static constexpr bool is_empty = (N == 0) || array_info::is_empty; + static constexpr size_t extent = N; + + // appends the extents to shape + static void append_extents(list& shape) { + shape.append(N); + array_info::append_extents(shape); + } + + template::is_array, int> = 0> + static PYBIND11_DESCR extents() { + return _(); + } + + template::is_array, int> = 0> + static PYBIND11_DESCR extents() { + return concat(_(), array_info::extents()); + } +}; +// For numpy we have special handling for arrays of characters, so we don't include +// the size in the array extents. +template struct array_info : array_info_scalar { }; +template struct array_info> : array_info_scalar> { }; +template struct array_info : array_info> { }; +template using remove_all_extents_t = typename array_info::type; + template using is_pod_struct = all_of< std::is_pod, // since we're accessing directly in memory we need a POD type satisfies_none_of @@ -745,6 +785,8 @@ protected: template class array_t : public array { public: + static_assert(!detail::array_info::is_array, "Array types cannot be used with array_t"); + using value_type = T; array_t() : array(0, static_cast(nullptr)) {} @@ -871,6 +913,15 @@ struct format_descriptor::value>> { } }; +template +struct format_descriptor::is_array>> { + static std::string format() { + using detail::_; + PYBIND11_DESCR extents = _("(") + detail::array_info::extents() + _(")"); + return extents.text() + format_descriptor>::format(); + } +}; + NAMESPACE_BEGIN(detail) template struct pyobject_caster> { @@ -939,6 +990,20 @@ template struct npy_format_descriptor { PYBIND11_DECL_CHAR_F template struct npy_format_descriptor> { PYBIND11_DECL_CHAR_FMT }; #undef PYBIND11_DECL_CHAR_FMT +template struct npy_format_descriptor::is_array>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static_assert(!array_info::is_empty, "Zero-sized arrays are not supported"); + + static PYBIND11_DESCR name() { return _("(") + array_info::extents() + _(")") + base_descr::name(); } + static pybind11::dtype dtype() { + list shape; + array_info::append_extents(shape); + return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape)); + } +}; + template struct npy_format_descriptor::value>> { private: using base_descr = npy_format_descriptor::type>; diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index a8ba3d87f..8c0a4bed3 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -70,6 +70,13 @@ struct StringStruct { std::array b; }; +struct ArrayStruct { + char a[3][4]; + int32_t b[2]; + std::array c; + std::array d[4]; +}; + PYBIND11_PACKED(struct StructWithUglyNames { int8_t __x__; uint64_t __y__; @@ -91,6 +98,27 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) { return os << "'"; } +std::ostream& operator<<(std::ostream& os, const ArrayStruct& v) { + os << "a={"; + for (int i = 0; i < 3; i++) { + if (i > 0) + os << ','; + os << '{'; + for (int j = 0; j < 3; j++) + os << v.a[i][j] << ','; + os << v.a[i][3] << '}'; + } + os << "},b={" << v.b[0] << ',' << v.b[1]; + os << "},c={" << int(v.c[0]) << ',' << int(v.c[1]) << ',' << int(v.c[2]); + os << "},d={"; + for (int i = 0; i < 4; i++) { + if (i > 0) + os << ','; + os << '{' << v.d[i][0] << ',' << v.d[i][1] << '}'; + } + return os << '}'; +} + std::ostream& operator<<(std::ostream& os, const EnumStruct& v) { return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y"); } @@ -163,6 +191,24 @@ py::array_t create_string_array(bool non_empty) { return arr; } +py::array_t create_array_array(size_t n) { + auto arr = mkarray_via_buffer(n); + auto ptr = (ArrayStruct *) arr.mutable_data(); + for (size_t i = 0; i < n; i++) { + for (size_t j = 0; j < 3; j++) + for (size_t k = 0; k < 4; k++) + ptr[i].a[j][k] = char('A' + (i * 100 + j * 10 + k) % 26); + for (size_t j = 0; j < 2; j++) + ptr[i].b[j] = int32_t(i * 1000 + j); + for (size_t j = 0; j < 3; j++) + ptr[i].c[j] = uint8_t(i * 10 + j); + for (size_t j = 0; j < 4; j++) + for (size_t k = 0; k < 2; k++) + ptr[i].d[j][k] = float(i) * 100.0f + float(j) * 10.0f + float(k); + } + return arr; +} + py::array_t create_enum_array(size_t n) { auto arr = mkarray_via_buffer(n); auto ptr = (EnumStruct *) arr.mutable_data(); @@ -194,6 +240,7 @@ py::list print_format_descriptors() { py::format_descriptor::format(), py::format_descriptor::format(), py::format_descriptor::format(), + py::format_descriptor::format(), py::format_descriptor::format() }; auto l = py::list(); @@ -211,6 +258,7 @@ py::list print_dtypes() { py::str(py::dtype::of()), py::str(py::dtype::of()), py::str(py::dtype::of()), + py::str(py::dtype::of()), py::str(py::dtype::of()), py::str(py::dtype::of()) }; @@ -351,6 +399,7 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(PartialStruct, bool_, uint_, float_, ldbl_); PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a); PYBIND11_NUMPY_DTYPE(StringStruct, a, b); + PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b); PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z); @@ -378,6 +427,8 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("get_format_unbound", &get_format_unbound); m.def("create_string_array", &create_string_array); m.def("print_string_array", &print_recarray); + m.def("create_array_array", &create_array_array); + m.def("print_array_array", &print_recarray); m.def("create_enum_array", &create_enum_array); m.def("print_enum_array", &print_recarray); m.def("test_array_ctors", &test_array_ctors); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index f63814f9d..5fe165b6b 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -86,6 +86,7 @@ def test_format_descriptors(): partial_fmt, "T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", "T{3s:a:3s:b:}", + "T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", 'T{q:e1:B:e2:}' ] @@ -103,6 +104,9 @@ def test_dtype(simple_dtype): partial_dtype_fmt(), partial_nested_fmt(), "[('a', 'S3'), ('b', 'S3')]", + ("{{'names':['a','b','c','d'], " + + "'formats':[('S4', (3,)),('' + + arr = create_array_array(3) + assert str(arr.dtype) == ( + "{{'names':['a','b','c','d'], " + + "'formats':[('S4', (3,)),('