From b82c0f0a2daa5a672afd130c56989731b8d9dd29 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Wed, 10 May 2017 11:36:24 +0200 Subject: [PATCH] Allow std::complex field with PYBIND11_NUMPY_DTYPE (#831) This exposed a few underlying issues: 1. is_pod_struct was too strict to allow this. I've relaxed it to require only trivially copyable and standard layout, rather than POD (which additionally requires a trivial constructor, which std::complex violates). 2. format_descriptor>::format() returned numpy format strings instead of PEP3118 format strings, but register_dtype feeds format codes of its fields to _dtype_from_pep3118. I've changed it to return PEP3118 format codes. format_descriptor is a public type, so this may be considered an incompatible change. 3. register_structured_dtype tried to be smart about whether to mark fields as unaligned (with ^). However, it's examining the C++ alignment, rather than what numpy (or possibly PEP3118) thinks the alignment should be. For complex values those are different. I've made it mark all fields as ^ unconditionally, which should always be safe even if they are aligned, because we explicitly mark the padding. --- docs/advanced/pycpp/numpy.rst | 9 ++++++--- include/pybind11/common.h | 6 +++--- include/pybind11/complex.h | 13 ++++++++++-- include/pybind11/numpy.h | 21 ++++++++++++------- tests/test_numpy_dtypes.cpp | 30 +++++++++++++++++++++++++-- tests/test_numpy_dtypes.py | 38 ++++++++++++++++++++++++++--------- 6 files changed, 91 insertions(+), 26 deletions(-) diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index 9157e5031..57a52f6a8 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -198,9 +198,12 @@ expects the type followed by field names: /* now both A and B can be used as template arguments to py::array_t */ } -The structure should consist of fundamental arithmetic types, previously -registered substructures, and arrays of any of the above. Both C++ arrays and -``std::array`` are supported. +The structure should consist of fundamental arithmetic types, ``std::complex``, +previously registered substructures, and arrays of any of the above. Both C++ +arrays and ``std::array`` are supported. While there is a static assertion to +prevent many types of unsupported structures, it is still the user's +responsibility to use only "plain" structures that can be safely manipulated as +raw memory without violating invariants. Vectorizing functions ===================== diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 7c53638ca..ddf966ea7 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -608,14 +608,14 @@ template struct is_fmt_numeric }; NAMESPACE_END(detail) -template struct format_descriptor::value>> { - static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric::index]; +template struct format_descriptor::value>> { + static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric::index]; static constexpr const char value[2] = { c, '\0' }; static std::string format() { return std::string(1, c); } }; template constexpr const char format_descriptor< - T, detail::enable_if_t::value>>::value[2]; + T, detail::enable_if_t::value>>::value[2]; /// RAII wrapper that temporarily clears any Python error state struct error_scope { diff --git a/include/pybind11/complex.h b/include/pybind11/complex.h index 945ca0710..7d422e209 100644 --- a/include/pybind11/complex.h +++ b/include/pybind11/complex.h @@ -18,10 +18,19 @@ #endif NAMESPACE_BEGIN(pybind11) + +template struct format_descriptor, detail::enable_if_t::value>> { + static constexpr const char c = format_descriptor::c; + static constexpr const char value[3] = { 'Z', c, '\0' }; + static std::string format() { return std::string(value); } +}; + +template constexpr const char format_descriptor< + std::complex, detail::enable_if_t::value>>::value[3]; + NAMESPACE_BEGIN(detail) -// The format codes are already in the string in common.h, we just need to provide a specialization -template struct is_fmt_numeric> { +template struct is_fmt_numeric, detail::enable_if_t::value>> { static constexpr bool value = true; static constexpr int index = is_fmt_numeric::index + 3; }; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 62901b85e..72840c4ea 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -287,7 +287,14 @@ template struct array_info : array_info using remove_all_extents_t = typename array_info::type; template using is_pod_struct = all_of< - std::is_pod, // since we're accessing directly in memory we need a POD type + std::is_standard_layout, // since we're accessing directly in memory we need a standard layout type +#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5 + std::is_trivially_copyable, +#else + // GCC 4 doesn't implement is_trivially_copyable, so approximate it + std::is_trivially_destructible, + satisfies_any_of, +#endif satisfies_none_of >; @@ -1016,7 +1023,6 @@ struct field_descriptor { const char *name; ssize_t offset; ssize_t size; - ssize_t alignment; std::string format; dtype descr; }; @@ -1053,13 +1059,15 @@ inline PYBIND11_NOINLINE void register_structured_dtype( [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); ssize_t offset = 0; std::ostringstream oss; - oss << "T{"; + // mark the structure as unaligned with '^', because numpy and C++ don't + // always agree about alignment (particularly for complex), and we're + // explicitly listing all our padding. This depends on none of the fields + // overriding the endianness. Putting the ^ in front of individual fields + // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049 + oss << "^T{"; for (auto& field : ordered_fields) { if (field.offset > offset) oss << (field.offset - offset) << 'x'; - // mark unaligned fields with '^' (unaligned native type) - if (field.offset % field.alignment) - oss << '^'; oss << field.format << ':' << field.name << ':'; offset = field.offset + field.size; } @@ -1121,7 +1129,6 @@ private: #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ ::pybind11::detail::field_descriptor { \ Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ - alignof(decltype(std::declval().Field)), \ ::pybind11::format_descriptor().Field)>::format(), \ ::pybind11::detail::npy_format_descriptor().Field)>::dtype() \ } diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index 8c0a4bed3..5f987a847 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -70,6 +70,15 @@ struct StringStruct { std::array b; }; +struct ComplexStruct { + std::complex cflt; + std::complex cdbl; +}; + +std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) { + return os << "c:" << v.cflt << "," << v.cdbl; +} + struct ArrayStruct { char a[3][4]; int32_t b[2]; @@ -219,6 +228,18 @@ py::array_t create_enum_array(size_t n) { return arr; } +py::array_t create_complex_array(size_t n) { + auto arr = mkarray_via_buffer(n); + auto ptr = (ComplexStruct *) arr.mutable_data(); + for (size_t i = 0; i < n; i++) { + ptr[i].cflt.real(float(i)); + ptr[i].cflt.imag(float(i) + 0.25f); + ptr[i].cdbl.real(double(i) + 0.5); + ptr[i].cdbl.imag(double(i) + 0.75); + } + return arr; +} + template py::list print_recarray(py::array_t arr) { const auto req = arr.request(); @@ -241,7 +262,8 @@ py::list print_format_descriptors() { py::format_descriptor::format(), py::format_descriptor::format(), py::format_descriptor::format(), - py::format_descriptor::format() + py::format_descriptor::format(), + py::format_descriptor::format() }; auto l = py::list(); for (const auto &fmt : fmts) { @@ -260,7 +282,8 @@ py::list print_dtypes() { py::str(py::dtype::of()), py::str(py::dtype::of()), py::str(py::dtype::of()), - py::str(py::dtype::of()) + py::str(py::dtype::of()), + py::str(py::dtype::of()) }; auto l = py::list(); for (const auto &s : dtypes) { @@ -401,6 +424,7 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(StringStruct, a, b); PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); + PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl); PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b); PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z); @@ -431,6 +455,8 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("print_array_array", &print_recarray); m.def("create_enum_array", &create_enum_array); m.def("print_enum_array", &print_recarray); + m.def("create_complex_array", &create_complex_array); + m.def("print_complex_array", &print_recarray); m.def("test_array_ctors", &test_array_ctors); m.def("test_dtype_ctors", &test_dtype_ctors); m.def("test_dtype_methods", &test_dtype_methods); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 5fe165b6b..24803a97f 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -73,21 +73,22 @@ def test_format_descriptors(): ld = np.dtype('longdouble') ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char - ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}" + ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}" dbl = np.dtype('double') - partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" + + partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" + str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) + "xg:ldbl_:}") nested_extra = str(max(8, ld.alignment)) assert print_format_descriptors() == [ ss_fmt, - "T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}", - "T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}", + "^T{?:bool_:I:uint_:f:float_:g:ldbl_:}", + "^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}", partial_fmt, - "T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", - "T{3s:a:3s:b:}", - "T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", - 'T{q:e1:B:e2:}' + "^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", + "^T{3s:a:3s:b:}", + "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", + '^T{q:e1:B:e2:}', + '^T{Zf:cflt:Zd:cdbl:}' ] @@ -108,7 +109,8 @@ def test_dtype(simple_dtype): "'formats':[('S4', (3,)),('' + + arr = create_complex_array(3) + dtype = arr.dtype + assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')]) + assert print_complex_array(arr) == [ + "c:(0,0.25),(0.5,0.75)", + "c:(1,1.25),(1.5,1.75)", + "c:(2,2.25),(2.5,2.75)" + ] + assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j] + assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j] + assert create_complex_array(0).dtype == dtype + + def test_signature(doc): from pybind11_tests import create_rec_nested