mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Add a test for NumPy scalar conversion
This commit is contained in:
parent
85e16262d6
commit
cbbb7830f2
@ -298,6 +298,9 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
return;
|
||||
}
|
||||
|
||||
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
|
||||
py::class_<SimpleStruct>(m, "SimpleStruct");
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
|
||||
@ -306,6 +309,11 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
|
||||
// ... or after...
|
||||
py::class_<PackedStruct>(m, "PackedStruct");
|
||||
|
||||
// ... or not at all
|
||||
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
m.def("create_rec_nested", &create_nested);
|
||||
@ -324,6 +332,9 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("test_array_ctors", &test_array_ctors);
|
||||
m.def("test_dtype_ctors", &test_dtype_ctors);
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
|
||||
m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
|
||||
});
|
||||
|
||||
#undef PYBIND11_PACKED
|
||||
|
@ -174,3 +174,25 @@ def test_signature(doc):
|
||||
from pybind11_tests import create_rec_nested
|
||||
|
||||
assert doc(create_rec_nested) == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_scalar_conversion():
|
||||
from pybind11_tests import (create_rec_simple, f_simple,
|
||||
create_rec_packed, f_packed,
|
||||
create_rec_nested, f_nested,
|
||||
create_enum_array)
|
||||
|
||||
n = 3
|
||||
arrays = [create_rec_simple(n), create_rec_packed(n),
|
||||
create_rec_nested(n), create_enum_array(n)]
|
||||
funcs = [f_simple, f_packed, f_nested]
|
||||
|
||||
for i, func in enumerate(funcs):
|
||||
for j, arr in enumerate(arrays):
|
||||
if i == j:
|
||||
assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)]
|
||||
else:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
func(arr[0])
|
||||
assert 'incompatible function arguments' in str(excinfo.value)
|
||||
|
Loading…
Reference in New Issue
Block a user