Add a test for buffer format of unbound struct

This commit is contained in:
Ivan Smirnov 2016-06-26 16:35:28 +01:00
parent a0e37f250e
commit d0bafd90e0
2 changed files with 12 additions and 1 deletions

View File

@ -44,6 +44,8 @@ std::ostream& operator<<(std::ostream& os, const NestedStruct& v) {
return os << "n:a=" << v.a << ";b=" << v.b; return os << "n:a=" << v.a << ";b=" << v.b;
} }
struct UnboundStruct { };
template <typename T> template <typename T>
py::array mkarray_via_buffer(size_t n) { py::array mkarray_via_buffer(size_t n) {
return py::array(py::buffer_info(nullptr, sizeof(T), return py::array(py::buffer_info(nullptr, sizeof(T),
@ -61,6 +63,10 @@ py::array_t<S> create_recarray(size_t n) {
return arr; return arr;
} }
std::string get_format_unbound() {
return py::format_descriptor<UnboundStruct>::format();
}
py::array_t<NestedStruct> create_nested(size_t n) { py::array_t<NestedStruct> create_nested(size_t n) {
auto arr = mkarray_via_buffer<NestedStruct>(n); auto arr = mkarray_via_buffer<NestedStruct>(n);
auto ptr = static_cast<NestedStruct*>(arr.request().ptr); auto ptr = static_cast<NestedStruct*>(arr.request().ptr);
@ -107,4 +113,5 @@ void init_ex20(py::module &m) {
m.def("print_rec_packed", &print_recarray<PackedStruct>); m.def("print_rec_packed", &print_recarray<PackedStruct>);
m.def("print_rec_nested", &print_recarray<NestedStruct>); m.def("print_rec_nested", &print_recarray<NestedStruct>);
m.def("print_dtypes", &print_dtypes); m.def("print_dtypes", &print_dtypes);
m.def("get_format_unbound", &get_format_unbound);
} }

View File

@ -1,16 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function from __future__ import print_function
import unittest
import numpy as np import numpy as np
from example import ( from example import (
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors, create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound
) )
def check_eq(arr, data, dtype): def check_eq(arr, data, dtype):
np.testing.assert_equal(arr, np.array(data, dtype=dtype)) np.testing.assert_equal(arr, np.array(data, dtype=dtype))
unittest.TestCase().assertRaisesRegex(
RuntimeError, 'unsupported buffer format', get_format_unbound)
print_format_descriptors() print_format_descriptors()
print_dtypes() print_dtypes()