mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation The current documentation and example reads as though PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the documentation and docs example to make it clear that you need to "call" the macro. * Add satisfies_{all,any,none}_of<T, Preds> `satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`. * Give better error message for non-POD dtype attempts If you try to use a non-POD data type, you get difficult-to-interpret compilation errors (about ::name() not being a member of an internal pybind11 struct, among others), for which isn't at all obvious what the problem is. This adds a static_assert for such cases. It also changes the base case from an empty struct to the is_pod_struct case by no longer using `enable_if<is_pod_struct>` but instead using a static_assert: thus specializations avoid the base class, POD types work, and non-POD types (and unimplemented POD types like std::array) get a more informative static_assert failure. * Prefix macros with PYBIND11_ numpy.h uses unprefixed macros, which seems undesirable. This prefixes them with PYBIND11_ to match all the other macros in numpy.h (and elsewhere). * Add long double support This adds long double and std::complex<long double> support for numpy arrays. This allows some simplification of the code used to generate format descriptors; the new code uses fewer macros, instead putting the code as different templated options; the template conditions end up simpler with this because we are now supporting all basic C++ arithmetic types (and so can use is_arithmetic instead of is_integral + multiple different specializations). In addition to testing that it is indeed working in the test script, it also adds various offset and size calculations there, which fixes the test failures under x86 compilations.
This commit is contained in:
parent
c2d1d95809
commit
f7f5bc8e37
@ -176,9 +176,10 @@ function overload.
|
||||
Structured types
|
||||
================
|
||||
|
||||
In order for ``py::array_t`` to work with structured (record) types, we first need
|
||||
to register the memory layout of the type. This can be done via ``PYBIND11_NUMPY_DTYPE``
|
||||
macro which expects the type followed by field names:
|
||||
In order for ``py::array_t`` to work with structured (record) types, we first
|
||||
need to register the memory layout of the type. This can be done via
|
||||
``PYBIND11_NUMPY_DTYPE`` macro, called in the plugin definition code, which
|
||||
expects the type followed by field names:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
@ -192,10 +193,14 @@ macro which expects the type followed by field names:
|
||||
A a;
|
||||
};
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(A, x, y);
|
||||
PYBIND11_NUMPY_DTYPE(B, z, a);
|
||||
// ...
|
||||
PYBIND11_PLUGIN(test) {
|
||||
// ...
|
||||
|
||||
/* now both A and B can be used as template arguments to py::array_t */
|
||||
PYBIND11_NUMPY_DTYPE(A, x, y);
|
||||
PYBIND11_NUMPY_DTYPE(B, z, a);
|
||||
/* now both A and B can be used as template arguments to py::array_t */
|
||||
}
|
||||
|
||||
Vectorizing functions
|
||||
=====================
|
||||
|
@ -1009,8 +1009,8 @@ class type_caster<T, enable_if_t<is_pyobject<T>::value>> : public pyobject_caste
|
||||
// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it
|
||||
// must have ref_count() == 1)h
|
||||
// If any of the above are not satisfied, we fall back to copying.
|
||||
template <typename T> using move_is_plain_type = none_of<
|
||||
std::is_void<T>, std::is_pointer<T>, std::is_reference<T>, std::is_const<T>
|
||||
template <typename T> using move_is_plain_type = satisfies_none_of<T,
|
||||
std::is_void, std::is_pointer, std::is_reference, std::is_const
|
||||
>;
|
||||
template <typename T, typename SFINAE = void> struct move_always : std::false_type {};
|
||||
template <typename T> struct move_always<T, enable_if_t<all_of<
|
||||
|
@ -419,6 +419,10 @@ template <class... Ts> using any_of = std::disjunction<Ts...>;
|
||||
#endif
|
||||
template <class... Ts> using none_of = negation<any_of<Ts...>>;
|
||||
|
||||
template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
|
||||
template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
|
||||
template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
|
||||
|
||||
/// Strip the class from a method type
|
||||
template <typename T> struct remove_class { };
|
||||
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
|
||||
@ -567,21 +571,31 @@ PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used in
|
||||
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
|
||||
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
|
||||
|
||||
/// Format strings for basic number types
|
||||
#define PYBIND11_DECL_FMT(t, v) template<> struct format_descriptor<t> \
|
||||
{ static constexpr const char* value = v; /* for backwards compatibility */ \
|
||||
static std::string format() { return value; } }
|
||||
|
||||
template <typename T, typename SFINAE = void> struct format_descriptor { };
|
||||
|
||||
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
||||
static constexpr const char c = "bBhHiIqQ"[detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value];
|
||||
NAMESPACE_BEGIN(detail)
|
||||
// Returns the index of the given type in the type char array below, and in the list in numpy.h
|
||||
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
|
||||
// complex float,double,long double. Note that the long double types only participate when long
|
||||
// double is actually longer than double (it isn't under MSVC).
|
||||
// NB: not only the string below but also complex.h and numpy.h rely on this order.
|
||||
template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
|
||||
template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
|
||||
std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
|
||||
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>> {
|
||||
static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric<T>::index];
|
||||
static constexpr const char value[2] = { c, '\0' };
|
||||
static std::string format() { return std::string(1, c); }
|
||||
};
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
T, detail::enable_if_t<std::is_integral<T>::value>>::value[2];
|
||||
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
|
||||
|
||||
/// RAII wrapper that temporarily clears any Python error state
|
||||
struct error_scope {
|
||||
@ -590,10 +604,6 @@ struct error_scope {
|
||||
~error_scope() { PyErr_Restore(type, value, trace); }
|
||||
};
|
||||
|
||||
PYBIND11_DECL_FMT(float, "f");
|
||||
PYBIND11_DECL_FMT(double, "d");
|
||||
PYBIND11_DECL_FMT(bool, "?");
|
||||
|
||||
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
|
||||
struct nodelete { template <typename T> void operator()(T*) { } };
|
||||
|
||||
|
@ -18,11 +18,14 @@
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
PYBIND11_DECL_FMT(std::complex<float>, "Zf");
|
||||
PYBIND11_DECL_FMT(std::complex<double>, "Zd");
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
// The format codes are already in the string in common.h, we just need to provide a specialization
|
||||
template <typename T> struct is_fmt_numeric<std::complex<T>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = is_fmt_numeric<T>::index + 3;
|
||||
};
|
||||
|
||||
template <typename T> class type_caster<std::complex<T>> {
|
||||
public:
|
||||
bool load(handle src, bool) {
|
||||
|
@ -36,8 +36,7 @@ static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
template <typename type> struct is_pod_struct;
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
||||
|
||||
struct PyArrayDescr_Proxy {
|
||||
PyObject_HEAD
|
||||
@ -220,6 +219,16 @@ inline bool check_flags(const void* ptr, int flag) {
|
||||
return (flag == (array_proxy(ptr)->flags & flag));
|
||||
}
|
||||
|
||||
template <typename T> struct is_std_array : std::false_type { };
|
||||
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
||||
template <typename T> struct is_complex : std::false_type { };
|
||||
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
||||
|
||||
template <typename T> using is_pod_struct = all_of<
|
||||
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
|
||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||
>;
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
class dtype : public object {
|
||||
@ -685,65 +694,48 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
|
||||
};
|
||||
|
||||
template <typename T> struct is_std_array : std::false_type { };
|
||||
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
||||
|
||||
template <typename T>
|
||||
struct is_pod_struct {
|
||||
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
|
||||
!std::is_reference<T>::value &&
|
||||
!std::is_array<T>::value &&
|
||||
!is_std_array<T>::value &&
|
||||
!std::is_integral<T>::value &&
|
||||
!std::is_enum<T>::value &&
|
||||
!std::is_same<typename std::remove_cv<T>::type, float>::value &&
|
||||
!std::is_same<typename std::remove_cv<T>::type, double>::value &&
|
||||
!std::is_same<typename std::remove_cv<T>::type, bool>::value &&
|
||||
!std::is_same<typename std::remove_cv<T>::type, std::complex<float>>::value &&
|
||||
!std::is_same<typename std::remove_cv<T>::type, std::complex<double>>::value };
|
||||
};
|
||||
|
||||
template <typename T> struct npy_format_descriptor<T, enable_if_t<std::is_integral<T>::value>> {
|
||||
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
|
||||
private:
|
||||
constexpr static const int values[8] = {
|
||||
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
|
||||
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
|
||||
// NB: the order here must match the one in common.h
|
||||
constexpr static const int values[15] = {
|
||||
npy_api::NPY_BOOL_,
|
||||
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
|
||||
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
|
||||
npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
|
||||
npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
|
||||
};
|
||||
|
||||
public:
|
||||
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
|
||||
static constexpr int value = values[detail::is_fmt_numeric<T>::index];
|
||||
|
||||
static pybind11::dtype dtype() {
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<std::is_signed<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
|
||||
template <typename T2 = T, enable_if_t<!std::is_signed<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
|
||||
template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<T, bool>::value>(_("bool"),
|
||||
_<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
|
||||
_("float") + _<sizeof(T)*8>(), _("longdouble"));
|
||||
}
|
||||
template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
|
||||
static PYBIND11_DESCR name() {
|
||||
return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
|
||||
_("complex") + _<sizeof(T2::value_type)*16>(), _("longcomplex"));
|
||||
}
|
||||
};
|
||||
template <typename T> constexpr const int npy_format_descriptor<
|
||||
T, enable_if_t<std::is_integral<T>::value>>::values[8];
|
||||
|
||||
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
||||
enum { value = npy_api::NumPyName }; \
|
||||
static pybind11::dtype dtype() { \
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr); \
|
||||
pybind11_fail("Unsupported buffer format!"); \
|
||||
} \
|
||||
static PYBIND11_DESCR name() { return _(Name); } }
|
||||
DECL_FMT(float, NPY_FLOAT_, "float32");
|
||||
DECL_FMT(double, NPY_DOUBLE_, "float64");
|
||||
DECL_FMT(bool, NPY_BOOL_, "bool");
|
||||
DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
|
||||
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
#undef DECL_FMT
|
||||
|
||||
#define DECL_CHAR_FMT \
|
||||
#define PYBIND11_DECL_CHAR_FMT \
|
||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||
static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
|
||||
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
|
||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
|
||||
#undef DECL_CHAR_FMT
|
||||
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
|
||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
||||
#undef PYBIND11_DECL_CHAR_FMT
|
||||
|
||||
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
||||
private:
|
||||
@ -798,9 +790,9 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
for (auto& field : ordered_fields) {
|
||||
if (field.offset > offset)
|
||||
oss << (field.offset - offset) << 'x';
|
||||
// mark unaligned fields with '='
|
||||
// mark unaligned fields with '^' (unaligned native type)
|
||||
if (field.offset % field.alignment)
|
||||
oss << '=';
|
||||
oss << '^';
|
||||
oss << field.format << ':' << field.name << ':';
|
||||
offset = field.offset + field.size;
|
||||
}
|
||||
@ -820,8 +812,9 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||
get_internals().direct_conversions[tindex].push_back(direct_converter);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
||||
template <typename T, typename SFINAE> struct npy_format_descriptor {
|
||||
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
||||
|
||||
static PYBIND11_DESCR name() { return _("struct"); }
|
||||
|
||||
static pybind11::dtype dtype() {
|
||||
|
@ -547,10 +547,10 @@ public:
|
||||
template <typename T> using is_keyword = std::is_base_of<arg, T>;
|
||||
template <typename T> using is_s_unpacking = std::is_same<args_proxy, T>; // * unpacking
|
||||
template <typename T> using is_ds_unpacking = std::is_same<kwargs_proxy, T>; // ** unpacking
|
||||
template <typename T> using is_positional = none_of<
|
||||
is_keyword<T>, is_s_unpacking<T>, is_ds_unpacking<T>
|
||||
template <typename T> using is_positional = satisfies_none_of<T,
|
||||
is_keyword, is_s_unpacking, is_ds_unpacking
|
||||
>;
|
||||
template <typename T> using is_keyword_or_ds = any_of<is_keyword<T>, is_ds_unpacking<T>>;
|
||||
template <typename T> using is_keyword_or_ds = satisfies_any_of<T, is_keyword, is_ds_unpacking>;
|
||||
|
||||
// Call argument collector forward declarations
|
||||
template <return_value_policy policy = return_value_policy::automatic_reference>
|
||||
|
@ -19,23 +19,25 @@
|
||||
namespace py = pybind11;
|
||||
|
||||
struct SimpleStruct {
|
||||
bool x;
|
||||
uint32_t y;
|
||||
float z;
|
||||
bool bool_;
|
||||
uint32_t uint_;
|
||||
float float_;
|
||||
long double ldbl_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SimpleStruct& v) {
|
||||
return os << "s:" << v.x << "," << v.y << "," << v.z;
|
||||
return os << "s:" << v.bool_ << "," << v.uint_ << "," << v.float_ << "," << v.ldbl_;
|
||||
}
|
||||
|
||||
PYBIND11_PACKED(struct PackedStruct {
|
||||
bool x;
|
||||
uint32_t y;
|
||||
float z;
|
||||
bool bool_;
|
||||
uint32_t uint_;
|
||||
float float_;
|
||||
long double ldbl_;
|
||||
});
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const PackedStruct& v) {
|
||||
return os << "p:" << v.x << "," << v.y << "," << v.z;
|
||||
return os << "p:" << v.bool_ << "," << v.uint_ << "," << v.float_ << "," << v.ldbl_;
|
||||
}
|
||||
|
||||
PYBIND11_PACKED(struct NestedStruct {
|
||||
@ -48,10 +50,11 @@ std::ostream& operator<<(std::ostream& os, const NestedStruct& v) {
|
||||
}
|
||||
|
||||
struct PartialStruct {
|
||||
bool x;
|
||||
uint32_t y;
|
||||
float z;
|
||||
bool bool_;
|
||||
uint32_t uint_;
|
||||
float float_;
|
||||
uint64_t dummy2;
|
||||
long double ldbl_;
|
||||
};
|
||||
|
||||
struct PartialNestedStruct {
|
||||
@ -99,13 +102,19 @@ py::array mkarray_via_buffer(size_t n) {
|
||||
1, { n }, { sizeof(T) }));
|
||||
}
|
||||
|
||||
#define SET_TEST_VALS(s, i) do { \
|
||||
s.bool_ = (i) % 2 != 0; \
|
||||
s.uint_ = (uint32_t) (i); \
|
||||
s.float_ = (float) (i) * 1.5f; \
|
||||
s.ldbl_ = (long double) (i) * -2.5L; } while (0)
|
||||
|
||||
template <typename S>
|
||||
py::array_t<S, 0> create_recarray(size_t n) {
|
||||
auto arr = mkarray_via_buffer<S>(n);
|
||||
auto req = arr.request();
|
||||
auto ptr = static_cast<S*>(req.ptr);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ptr[i].x = i % 2 != 0; ptr[i].y = (uint32_t) i; ptr[i].z = (float) i * 1.5f;
|
||||
SET_TEST_VALS(ptr[i], i);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
@ -119,8 +128,8 @@ py::array_t<NestedStruct, 0> create_nested(size_t n) {
|
||||
auto req = arr.request();
|
||||
auto ptr = static_cast<NestedStruct*>(req.ptr);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ptr[i].a.x = i % 2 != 0; ptr[i].a.y = (uint32_t) i; ptr[i].a.z = (float) i * 1.5f;
|
||||
ptr[i].b.x = (i + 1) % 2 != 0; ptr[i].b.y = (uint32_t) (i + 1); ptr[i].b.z = (float) (i + 1) * 1.5f;
|
||||
SET_TEST_VALS(ptr[i].a, i);
|
||||
SET_TEST_VALS(ptr[i].b, i + 1);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
@ -130,7 +139,7 @@ py::array_t<PartialNestedStruct, 0> create_partial_nested(size_t n) {
|
||||
auto req = arr.request();
|
||||
auto ptr = static_cast<PartialNestedStruct*>(req.ptr);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ptr[i].a.x = i % 2 != 0; ptr[i].a.y = (uint32_t) i; ptr[i].a.z = (float) i * 1.5f;
|
||||
SET_TEST_VALS(ptr[i].a, i);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
@ -320,10 +329,10 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
|
||||
py::class_<SimpleStruct>(m, "SimpleStruct");
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_);
|
||||
PYBIND11_NUMPY_DTYPE(PackedStruct, bool_, uint_, float_, ldbl_);
|
||||
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PartialStruct, bool_, uint_, float_, ldbl_);
|
||||
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
@ -334,6 +343,11 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
|
||||
PYBIND11_NUMPY_DTYPE_EX(StructWithUglyNames, __x__, "x", __y__, "y");
|
||||
|
||||
// If uncommented, this should produce a static_assert failure telling the user that the struct
|
||||
// is not a POD type
|
||||
// struct NotPOD { std::string v; NotPOD() : v("hi") {}; };
|
||||
// PYBIND11_NUMPY_DTYPE(NotPOD, v);
|
||||
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
m.def("create_rec_nested", &create_nested);
|
||||
@ -354,10 +368,10 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
m.def("trailing_padding_dtype", &trailing_padding_dtype);
|
||||
m.def("buffer_to_dtype", &buffer_to_dtype);
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
|
||||
m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
|
||||
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z); });
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
|
||||
m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; });
|
||||
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); });
|
||||
});
|
||||
|
||||
#undef PYBIND11_PACKED
|
||||
|
@ -7,14 +7,51 @@ with pytest.suppress(ImportError):
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def simple_dtype():
|
||||
return np.dtype({'names': ['x', 'y', 'z'],
|
||||
'formats': ['?', 'u4', 'f4'],
|
||||
'offsets': [0, 4, 8]})
|
||||
ld = np.dtype('longdouble')
|
||||
return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'],
|
||||
'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)],
|
||||
'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]})
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def packed_dtype():
|
||||
return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
|
||||
return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')])
|
||||
|
||||
|
||||
def dt_fmt():
|
||||
return ("{{'names':['bool_','uint_','float_','ldbl_'], 'formats':['?','<u4','<f4','<f{}'],"
|
||||
" 'offsets':[0,4,8,{}], 'itemsize':{}}}")
|
||||
|
||||
|
||||
def simple_dtype_fmt():
|
||||
ld = np.dtype('longdouble')
|
||||
simple_ld_off = 12 + 4 * (ld.alignment > 4)
|
||||
return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize)
|
||||
|
||||
|
||||
def packed_dtype_fmt():
|
||||
return "[('bool_', '?'), ('uint_', '<u4'), ('float_', '<f4'), ('ldbl_', '<f{}')]".format(
|
||||
np.dtype('longdouble').itemsize)
|
||||
|
||||
|
||||
def partial_ld_offset():
|
||||
return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * (
|
||||
np.dtype('longdouble').alignment > 8)
|
||||
|
||||
|
||||
def partial_dtype_fmt():
|
||||
ld = np.dtype('longdouble')
|
||||
partial_ld_off = partial_ld_offset()
|
||||
return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize)
|
||||
|
||||
|
||||
def partial_nested_fmt():
|
||||
ld = np.dtype('longdouble')
|
||||
partial_nested_off = 8 + 8 * (ld.alignment > 8)
|
||||
partial_ld_off = partial_ld_offset()
|
||||
partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize
|
||||
return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format(
|
||||
partial_dtype_fmt(), partial_nested_off, partial_nested_size)
|
||||
|
||||
|
||||
def assert_equal(actual, expected_data, expected_dtype):
|
||||
@ -29,12 +66,20 @@ def test_format_descriptors():
|
||||
get_format_unbound()
|
||||
assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
|
||||
|
||||
ld = np.dtype('longdouble')
|
||||
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
|
||||
ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
|
||||
dbl = np.dtype('double')
|
||||
partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
|
||||
str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
|
||||
"xg:ldbl_:}")
|
||||
nested_extra = str(max(8, ld.alignment))
|
||||
assert print_format_descriptors() == [
|
||||
"T{?:x:3xI:y:f:z:}",
|
||||
"T{?:x:=I:y:=f:z:}",
|
||||
"T{T{?:x:3xI:y:f:z:}:a:T{?:x:=I:y:=f:z:}:b:}",
|
||||
"T{?:x:3xI:y:f:z:12x}",
|
||||
"T{8xT{?:x:3xI:y:f:z:12x}:a:8x}",
|
||||
ss_fmt,
|
||||
"T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
|
||||
"T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
|
||||
partial_fmt,
|
||||
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
|
||||
"T{3s:a:3s:b:}",
|
||||
'T{q:e1:B:e2:}'
|
||||
]
|
||||
@ -46,13 +91,11 @@ def test_dtype(simple_dtype):
|
||||
trailing_padding_dtype, buffer_to_dtype)
|
||||
|
||||
assert print_dtypes() == [
|
||||
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}",
|
||||
"[('x', '?'), ('y', '<u4'), ('z', '<f4')]",
|
||||
"[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8],"
|
||||
" 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]",
|
||||
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}",
|
||||
"{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'],"
|
||||
" 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}",
|
||||
simple_dtype_fmt(),
|
||||
packed_dtype_fmt(),
|
||||
"[('a', {}), ('b', {})]".format(simple_dtype_fmt(), packed_dtype_fmt()),
|
||||
partial_dtype_fmt(),
|
||||
partial_nested_fmt(),
|
||||
"[('a', 'S3'), ('b', 'S3')]",
|
||||
"[('e1', '<i8'), ('e2', 'u1')]",
|
||||
"[('x', 'i1'), ('y', '<u8')]"
|
||||
@ -76,7 +119,7 @@ def test_recarray(simple_dtype, packed_dtype):
|
||||
print_rec_simple, print_rec_packed, print_rec_nested,
|
||||
create_rec_partial, create_rec_partial_nested)
|
||||
|
||||
elements = [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)]
|
||||
elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]
|
||||
|
||||
for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]:
|
||||
arr = func(0)
|
||||
@ -91,15 +134,15 @@ def test_recarray(simple_dtype, packed_dtype):
|
||||
|
||||
if dtype == simple_dtype:
|
||||
assert print_rec_simple(arr) == [
|
||||
"s:0,0,0",
|
||||
"s:1,1,1.5",
|
||||
"s:0,2,3"
|
||||
"s:0,0,0,-0",
|
||||
"s:1,1,1.5,-2.5",
|
||||
"s:0,2,3,-5"
|
||||
]
|
||||
else:
|
||||
assert print_rec_packed(arr) == [
|
||||
"p:0,0,0",
|
||||
"p:1,1,1.5",
|
||||
"p:0,2,3"
|
||||
"p:0,0,0,-0",
|
||||
"p:1,1,1.5,-2.5",
|
||||
"p:0,2,3,-5"
|
||||
]
|
||||
|
||||
nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
|
||||
@ -110,18 +153,17 @@ def test_recarray(simple_dtype, packed_dtype):
|
||||
|
||||
arr = create_rec_nested(3)
|
||||
assert arr.dtype == nested_dtype
|
||||
assert_equal(arr, [((False, 0, 0.0), (True, 1, 1.5)),
|
||||
((True, 1, 1.5), (False, 2, 3.0)),
|
||||
((False, 2, 3.0), (True, 3, 4.5))], nested_dtype)
|
||||
assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)),
|
||||
((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)),
|
||||
((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype)
|
||||
assert print_rec_nested(arr) == [
|
||||
"n:a=s:0,0,0;b=p:1,1,1.5",
|
||||
"n:a=s:1,1,1.5;b=p:0,2,3",
|
||||
"n:a=s:0,2,3;b=p:1,3,4.5"
|
||||
"n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5",
|
||||
"n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5",
|
||||
"n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5"
|
||||
]
|
||||
|
||||
arr = create_rec_partial(3)
|
||||
assert str(arr.dtype) == \
|
||||
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}"
|
||||
assert str(arr.dtype) == partial_dtype_fmt()
|
||||
partial_dtype = arr.dtype
|
||||
assert '' not in arr.dtype.fields
|
||||
assert partial_dtype.itemsize > simple_dtype.itemsize
|
||||
@ -129,9 +171,7 @@ def test_recarray(simple_dtype, packed_dtype):
|
||||
assert_equal(arr, elements, packed_dtype)
|
||||
|
||||
arr = create_rec_partial_nested(3)
|
||||
assert str(arr.dtype) == \
|
||||
"{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4']," \
|
||||
" 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}"
|
||||
assert str(arr.dtype) == partial_nested_fmt()
|
||||
assert '' not in arr.dtype.fields
|
||||
assert '' not in arr.dtype.fields['a'][0].fields
|
||||
assert arr.dtype.itemsize > partial_dtype.itemsize
|
||||
|
Loading…
Reference in New Issue
Block a user