mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Add tests for numpy enum descriptors
This commit is contained in:
parent
fb74df50c9
commit
2f3f3687dc
@ -67,6 +67,14 @@ struct StringStruct {
|
||||
std::array<char, 3> b;
|
||||
};
|
||||
|
||||
enum class E1 : int64_t { A = -1, B = 1 };
|
||||
enum E2 : uint8_t { X = 1, Y = 2 };
|
||||
|
||||
PYBIND11_PACKED(struct EnumStruct {
|
||||
E1 e1;
|
||||
E2 e2;
|
||||
});
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
|
||||
os << "a='";
|
||||
for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
|
||||
@ -75,6 +83,10 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
|
||||
return os << "'";
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const EnumStruct& v) {
|
||||
return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::array mkarray_via_buffer(size_t n) {
|
||||
return py::array(py::buffer_info(nullptr, sizeof(T),
|
||||
@ -137,6 +149,16 @@ py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
|
||||
auto arr = mkarray_via_buffer<EnumStruct>(n);
|
||||
auto ptr = (EnumStruct *) arr.mutable_data();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
ptr[i].e1 = static_cast<E1>(-1 + ((int) i % 2) * 2);
|
||||
ptr[i].e2 = static_cast<E2>(1 + (i % 2));
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
py::list print_recarray(py::array_t<S, 0> arr) {
|
||||
const auto req = arr.request();
|
||||
@ -157,7 +179,8 @@ py::list print_format_descriptors() {
|
||||
py::format_descriptor<NestedStruct>::format(),
|
||||
py::format_descriptor<PartialStruct>::format(),
|
||||
py::format_descriptor<PartialNestedStruct>::format(),
|
||||
py::format_descriptor<StringStruct>::format()
|
||||
py::format_descriptor<StringStruct>::format(),
|
||||
py::format_descriptor<EnumStruct>::format()
|
||||
};
|
||||
auto l = py::list();
|
||||
for (const auto &fmt : fmts) {
|
||||
@ -173,7 +196,8 @@ py::list print_dtypes() {
|
||||
py::dtype::of<NestedStruct>().str(),
|
||||
py::dtype::of<PartialStruct>().str(),
|
||||
py::dtype::of<PartialNestedStruct>().str(),
|
||||
py::dtype::of<StringStruct>().str()
|
||||
py::dtype::of<StringStruct>().str(),
|
||||
py::dtype::of<EnumStruct>().str()
|
||||
};
|
||||
auto l = py::list();
|
||||
for (const auto &s : dtypes) {
|
||||
@ -280,6 +304,7 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
@ -294,6 +319,8 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("get_format_unbound", &get_format_unbound);
|
||||
m.def("create_string_array", &create_string_array);
|
||||
m.def("print_string_array", &print_recarray<StringStruct>);
|
||||
m.def("create_enum_array", &create_enum_array);
|
||||
m.def("print_enum_array", &print_recarray<EnumStruct>);
|
||||
m.def("test_array_ctors", &test_array_ctors);
|
||||
m.def("test_dtype_ctors", &test_dtype_ctors);
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
|
@ -26,7 +26,8 @@ def test_format_descriptors():
|
||||
"T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}",
|
||||
"T{=?:x:3x=I:y:=f:z:12x}",
|
||||
"T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}",
|
||||
"T{=3s:a:=3s:b:}"
|
||||
"T{=3s:a:=3s:b:}",
|
||||
'T{=q:e1:=B:e2:}'
|
||||
]
|
||||
|
||||
|
||||
@ -40,7 +41,8 @@ def test_dtype():
|
||||
"[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]",
|
||||
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}",
|
||||
"{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}",
|
||||
"[('a', 'S3'), ('b', 'S3')]"
|
||||
"[('a', 'S3'), ('b', 'S3')]",
|
||||
"[('e1', '<i8'), ('e2', 'u1')]"
|
||||
]
|
||||
|
||||
d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
|
||||
@ -150,6 +152,23 @@ def test_string_array():
|
||||
assert dtype == arr.dtype
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_enum_array():
|
||||
from pybind11_tests import create_enum_array, print_enum_array
|
||||
|
||||
arr = create_enum_array(3)
|
||||
dtype = arr.dtype
|
||||
assert dtype == np.dtype([('e1', '<i8'), ('e2', 'u1')])
|
||||
assert print_enum_array(arr) == [
|
||||
"e1=A,e2=X",
|
||||
"e1=B,e2=Y",
|
||||
"e1=A,e2=X"
|
||||
]
|
||||
assert arr['e1'].tolist() == [-1, 1, -1]
|
||||
assert arr['e2'].tolist() == [1, 2, 1]
|
||||
assert create_enum_array(0).dtype == dtype
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_signature(doc):
|
||||
from pybind11_tests import create_rec_nested
|
||||
|
Loading…
Reference in New Issue
Block a user