Add a test for NumPy scalar conversion

This commit is contained in:
Ivan Smirnov 2016-10-20 16:47:29 +01:00
parent 85e16262d6
commit cbbb7830f2
2 changed files with 33 additions and 0 deletions

View File

@ -298,6 +298,9 @@ test_initializer numpy_dtypes([](py::module &m) {
return;
}
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
py::class_<SimpleStruct>(m, "SimpleStruct");
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
@ -306,6 +309,11 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
// ... or after...
py::class_<PackedStruct>(m, "PackedStruct");
// ... or not at all
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
m.def("create_rec_nested", &create_nested);
@ -324,6 +332,9 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
});
#undef PYBIND11_PACKED

View File

@ -174,3 +174,25 @@ def test_signature(doc):
from pybind11_tests import create_rec_nested
assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
@pytest.requires_numpy
def test_scalar_conversion():
from pybind11_tests import (create_rec_simple, f_simple,
create_rec_packed, f_packed,
create_rec_nested, f_nested,
create_enum_array)
n = 3
arrays = [create_rec_simple(n), create_rec_packed(n),
create_rec_nested(n), create_enum_array(n)]
funcs = [f_simple, f_packed, f_nested]
for i, func in enumerate(funcs):
for j, arr in enumerate(arrays):
if i == j:
assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
else:
with pytest.raises(TypeError) as excinfo:
func(arr[0])
assert 'incompatible function arguments' in str(excinfo.value)