From 7f913aecabed087bc368708cc7e2b23d0edb58aa Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Sun, 19 Jun 2016 16:41:15 +0100 Subject: [PATCH] Add tests for nested recarrays --- example/example20.cpp | 32 +++++++++++++++++++++++--------- example/example20.py | 23 ++++++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/example/example20.cpp b/example/example20.cpp index c22eb213e..2535b258b 100644 --- a/example/example20.cpp +++ b/example/example20.cpp @@ -30,23 +30,36 @@ struct PackedStruct { struct NestedStruct { Struct a; PackedStruct b; -}; +} __attribute__((packed)); + +template +py::array mkarray_via_buffer(size_t n) { + return py::array(py::buffer_info(nullptr, sizeof(T), + py::format_descriptor::value(), + 1, { n }, { sizeof(T) })); +} template py::array_t create_recarray(size_t n) { - auto arr = py::array(py::buffer_info(nullptr, sizeof(S), - py::format_descriptor::value(), - 1, { n }, { sizeof(S) })); - auto buf = arr.request(); - auto ptr = static_cast(buf.ptr); + auto arr = mkarray_via_buffer(n); + auto ptr = static_cast(arr.request().ptr); for (size_t i = 0; i < n; i++) { - ptr[i].x = i % 2; - ptr[i].y = i; - ptr[i].z = i * 1.5; + ptr[i].x = i % 2; ptr[i].y = (uint32_t) i; ptr[i].z = (float) i * 1.5f; } return arr; } +py::array_t create_nested(size_t n) { + auto arr = mkarray_via_buffer(n); + auto ptr = static_cast(arr.request().ptr); + for (size_t i = 0; i < n; i++) { + ptr[i].a.x = i % 2; ptr[i].a.y = (uint32_t) i; ptr[i].a.z = (float) i * 1.5f; + ptr[i].b.x = (i + 1) % 2; ptr[i].b.y = (uint32_t) (i + 1); ptr[i].b.z = (float) (i + 1) * 1.5f; + } + return arr; + +} + void init_ex20(py::module &m) { PYBIND11_DTYPE(Struct, x, y, z); PYBIND11_DTYPE(PackedStruct, x, y, z); @@ -54,4 +67,5 @@ void init_ex20(py::module &m) { m.def("create_rec_simple", &create_recarray); m.def("create_rec_packed", &create_recarray); + m.def("create_rec_nested", &create_nested); } diff --git a/example/example20.py b/example/example20.py index 0e42c68d6..83725b528 100644 --- a/example/example20.py +++ b/example/example20.py @@ -2,7 +2,7 @@ from __future__ import print_function import numpy as np -from example import create_rec_simple, create_rec_packed +from example import create_rec_simple, create_rec_packed, create_rec_nested def check_eq(arr, data, dtype): @@ -14,12 +14,25 @@ simple_dtype = np.dtype({'names': ['x', 'y', 'z'], packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]: + arr = func(0) + assert arr.dtype == dtype + check_eq(arr, [], simple_dtype) + check_eq(arr, [], packed_dtype) + arr = func(3) assert arr.dtype == dtype check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], simple_dtype) check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], packed_dtype) - arr = func(0) - assert arr.dtype == dtype - check_eq(arr, [], simple_dtype) - check_eq(arr, [], packed_dtype) + +nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)]) + +arr = create_rec_nested(0) +assert arr.dtype == nested_dtype +check_eq(arr, [], nested_dtype) + +arr = create_rec_nested(3) +assert arr.dtype == nested_dtype +check_eq(arr, [((False, 0, 0.0), (True, 1, 1.5)), + ((True, 1, 1.5), (False, 2, 3.0)), + ((False, 2, 3.0), (True, 3, 4.5))], nested_dtype)