diff --git a/example/example-numpy-dtypes.cpp b/example/example-numpy-dtypes.cpp index 2c7cdc07c..7a4c1f6d8 100644 --- a/example/example-numpy-dtypes.cpp +++ b/example/example-numpy-dtypes.cpp @@ -166,6 +166,40 @@ void print_dtypes() { std::cout << (std::string) py::dtype::of().str() << std::endl; } +py::array_t test_array_ctors(int i) { + using arr_t = py::array_t; + + std::vector data { 1, 2, 3, 4, 5, 6 }; + std::vector shape { 3, 2 }; + std::vector strides { 8, 4 }; + + auto ptr = data.data(); + auto vptr = (void *) ptr; + auto dtype = py::dtype("int32"); + + py::buffer_info buf_ndim1(vptr, 4, "i", 6); + py::buffer_info buf_ndim2(vptr, 4, "i", 2, shape, strides); + + switch (i) { + // shape: (3, 2) + case 0: return arr_t(shape, ptr, strides); + case 1: return py::array(shape, ptr, strides); + case 2: return py::array(dtype, shape, vptr, strides); + case 3: return arr_t(shape, ptr); + case 4: return py::array(shape, ptr); + case 5: return py::array(dtype, shape, vptr); + case 6: return arr_t(buf_ndim2); + case 7: return py::array(buf_ndim2); + // shape: (6, ) + case 8: return arr_t(6, ptr); + case 9: return py::array(6, ptr); + case 10: return py::array(dtype, 6, vptr); + case 11: return arr_t(buf_ndim1); + case 12: return py::array(buf_ndim1); + } + return arr_t(); +} + void init_ex_numpy_dtypes(py::module &m) { PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z); PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z); @@ -187,6 +221,7 @@ void init_ex_numpy_dtypes(py::module &m) { m.def("get_format_unbound", &get_format_unbound); m.def("create_string_array", &create_string_array); m.def("print_string_array", &print_recarray); + m.def("test_array_ctors", &test_array_ctors); } #undef PYBIND11_PACKED diff --git a/example/example-numpy-dtypes.py b/example/example-numpy-dtypes.py index 68ea5c2cb..930364afb 100644 --- a/example/example-numpy-dtypes.py +++ b/example/example-numpy-dtypes.py @@ -6,7 +6,8 @@ import numpy as np from example import ( create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors, print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound, - create_rec_partial, create_rec_partial_nested, create_string_array, print_string_array + create_rec_partial, create_rec_partial_nested, create_string_array, print_string_array, + test_array_ctors ) @@ -80,3 +81,8 @@ assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc'] assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc'] arr = create_string_array(False) assert dtype == arr.dtype + +data = np.arange(1, 7, dtype='int32') +for i in range(13): + expected = data if i >= 8 else data.reshape((3, 2)) + np.testing.assert_array_equal(test_array_ctors(i), expected)