mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-21 20:55:11 +00:00
test_numpy_dtypes: Add test for py::vectorize() (#2260)
This commit is contained in:
parent
e3774b76ed
commit
4e7c08daee
@ -259,7 +259,26 @@ TEST_SUBMODULE(numpy_dtypes, m) {
|
||||
catch (...) { return; }
|
||||
|
||||
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
|
||||
py::class_<SimpleStruct>(m, "SimpleStruct");
|
||||
py::class_<SimpleStruct>(m, "SimpleStruct")
|
||||
// Explicit construct to ensure zero-valued initialization.
|
||||
.def(py::init([]() { return SimpleStruct(); }))
|
||||
.def_readwrite("bool_", &SimpleStruct::bool_)
|
||||
.def_readwrite("uint_", &SimpleStruct::uint_)
|
||||
.def_readwrite("float_", &SimpleStruct::float_)
|
||||
.def_readwrite("ldbl_", &SimpleStruct::ldbl_)
|
||||
.def("astuple", [](const SimpleStruct& self) {
|
||||
return py::make_tuple(self.bool_, self.uint_, self.float_, self.ldbl_);
|
||||
})
|
||||
.def_static("fromtuple", [](const py::tuple tup) {
|
||||
if (py::len(tup) != 4) {
|
||||
throw py::cast_error("Invalid size");
|
||||
}
|
||||
return SimpleStruct{
|
||||
tup[0].cast<bool>(),
|
||||
tup[1].cast<uint32_t>(),
|
||||
tup[2].cast<float>(),
|
||||
tup[3].cast<long double>()};
|
||||
});
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_);
|
||||
PYBIND11_NUMPY_DTYPE(SimpleStructReordered, bool_, uint_, float_, ldbl_);
|
||||
@ -462,10 +481,16 @@ TEST_SUBMODULE(numpy_dtypes, m) {
|
||||
m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); });
|
||||
|
||||
// test_scalar_conversion
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
|
||||
auto f_simple = [](SimpleStruct s) { return s.uint_ * 10; };
|
||||
m.def("f_simple", f_simple);
|
||||
m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; });
|
||||
|
||||
// test_vectorize
|
||||
m.def("f_simple_vectorized", py::vectorize(f_simple));
|
||||
auto f_simple_pass_thru = [](SimpleStruct s) { return s; };
|
||||
m.def("f_simple_pass_thru_vectorized", py::vectorize(f_simple_pass_thru));
|
||||
|
||||
// test_register_dtype
|
||||
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); });
|
||||
|
||||
|
@ -138,6 +138,10 @@ def test_recarray(simple_dtype, packed_dtype):
|
||||
assert_equal(arr, elements, simple_dtype)
|
||||
assert_equal(arr, elements, packed_dtype)
|
||||
|
||||
# Show what recarray's look like in NumPy.
|
||||
assert type(arr[0]) == np.void
|
||||
assert type(arr[0].item()) == tuple
|
||||
|
||||
if dtype == simple_dtype:
|
||||
assert m.print_rec_simple(arr) == [
|
||||
"s:0,0,0,-0",
|
||||
@ -289,6 +293,56 @@ def test_scalar_conversion():
|
||||
assert 'incompatible function arguments' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_vectorize():
|
||||
n = 3
|
||||
array = m.create_rec_simple(n)
|
||||
values = m.f_simple_vectorized(array)
|
||||
np.testing.assert_array_equal(values, [0, 10, 20])
|
||||
array_2 = m.f_simple_pass_thru_vectorized(array)
|
||||
np.testing.assert_array_equal(array, array_2)
|
||||
|
||||
|
||||
def test_cls_and_dtype_conversion(simple_dtype):
|
||||
s = m.SimpleStruct()
|
||||
assert s.astuple() == (False, 0, 0., 0.)
|
||||
assert m.SimpleStruct.fromtuple(s.astuple()).astuple() == s.astuple()
|
||||
|
||||
s.uint_ = 2
|
||||
assert m.f_simple(s) == 20
|
||||
|
||||
# Try as recarray of shape==(1,).
|
||||
s_recarray = np.array([(False, 2, 0., 0.)], dtype=simple_dtype)
|
||||
# Show that this will work for vectorized case.
|
||||
np.testing.assert_array_equal(m.f_simple_vectorized(s_recarray), [20])
|
||||
|
||||
# Show as a scalar that inherits from np.generic.
|
||||
s_scalar = s_recarray[0]
|
||||
assert isinstance(s_scalar, np.void)
|
||||
assert m.f_simple(s_scalar) == 20
|
||||
|
||||
# Show that an *array* scalar (np.ndarray.shape == ()) does not convert.
|
||||
# More specifically, conversion to SimpleStruct is not implicit.
|
||||
s_recarray_scalar = s_recarray.reshape(())
|
||||
assert isinstance(s_recarray_scalar, np.ndarray)
|
||||
assert s_recarray_scalar.dtype == simple_dtype
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
m.f_simple(s_recarray_scalar)
|
||||
assert 'incompatible function arguments' in str(excinfo.value)
|
||||
# Explicitly convert to m.SimpleStruct.
|
||||
assert m.f_simple(
|
||||
m.SimpleStruct.fromtuple(s_recarray_scalar.item())) == 20
|
||||
|
||||
# Show that an array of dtype=object does *not* convert.
|
||||
s_array_object = np.array([s])
|
||||
assert s_array_object.dtype == object
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
m.f_simple_vectorized(s_array_object)
|
||||
assert 'incompatible function arguments' in str(excinfo.value)
|
||||
# Explicitly convert to `np.array(..., dtype=simple_dtype)`
|
||||
s_array = np.array([s.astuple()], dtype=simple_dtype)
|
||||
np.testing.assert_array_equal(m.f_simple_vectorized(s_array), [20])
|
||||
|
||||
|
||||
def test_register_dtype():
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
m.register_dtype()
|
||||
|
Loading…
Reference in New Issue
Block a user