test_numpy_dtypes: Add test for py::vectorize() (#2260)

This commit is contained in:
Eric Cousineau 2020-09-17 07:19:33 -04:00 committed by GitHub
parent e3774b76ed
commit 4e7c08daee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 2 deletions

View File

@ -259,7 +259,26 @@ TEST_SUBMODULE(numpy_dtypes, m) {
catch (...) { return; }
// typeinfo may be registered before the dtype descriptor for scalar casts to work...
py::class_<SimpleStruct>(m, "SimpleStruct");
py::class_<SimpleStruct>(m, "SimpleStruct")
// Explicit construct to ensure zero-valued initialization.
.def(py::init([]() { return SimpleStruct(); }))
.def_readwrite("bool_", &SimpleStruct::bool_)
.def_readwrite("uint_", &SimpleStruct::uint_)
.def_readwrite("float_", &SimpleStruct::float_)
.def_readwrite("ldbl_", &SimpleStruct::ldbl_)
.def("astuple", [](const SimpleStruct& self) {
return py::make_tuple(self.bool_, self.uint_, self.float_, self.ldbl_);
})
.def_static("fromtuple", [](const py::tuple tup) {
if (py::len(tup) != 4) {
throw py::cast_error("Invalid size");
}
return SimpleStruct{
tup[0].cast<bool>(),
tup[1].cast<uint32_t>(),
tup[2].cast<float>(),
tup[3].cast<long double>()};
});
PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_);
PYBIND11_NUMPY_DTYPE(SimpleStructReordered, bool_, uint_, float_, ldbl_);
@ -462,10 +481,16 @@ TEST_SUBMODULE(numpy_dtypes, m) {
m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); });
// test_scalar_conversion
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
auto f_simple = [](SimpleStruct s) { return s.uint_ * 10; };
m.def("f_simple", f_simple);
m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; });
m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; });
// test_vectorize
m.def("f_simple_vectorized", py::vectorize(f_simple));
auto f_simple_pass_thru = [](SimpleStruct s) { return s; };
m.def("f_simple_pass_thru_vectorized", py::vectorize(f_simple_pass_thru));
// test_register_dtype
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); });

View File

@ -138,6 +138,10 @@ def test_recarray(simple_dtype, packed_dtype):
assert_equal(arr, elements, simple_dtype)
assert_equal(arr, elements, packed_dtype)
# Show what recarray's look like in NumPy.
assert type(arr[0]) == np.void
assert type(arr[0].item()) == tuple
if dtype == simple_dtype:
assert m.print_rec_simple(arr) == [
"s:0,0,0,-0",
@ -289,6 +293,56 @@ def test_scalar_conversion():
assert 'incompatible function arguments' in str(excinfo.value)
def test_vectorize():
n = 3
array = m.create_rec_simple(n)
values = m.f_simple_vectorized(array)
np.testing.assert_array_equal(values, [0, 10, 20])
array_2 = m.f_simple_pass_thru_vectorized(array)
np.testing.assert_array_equal(array, array_2)
def test_cls_and_dtype_conversion(simple_dtype):
s = m.SimpleStruct()
assert s.astuple() == (False, 0, 0., 0.)
assert m.SimpleStruct.fromtuple(s.astuple()).astuple() == s.astuple()
s.uint_ = 2
assert m.f_simple(s) == 20
# Try as recarray of shape==(1,).
s_recarray = np.array([(False, 2, 0., 0.)], dtype=simple_dtype)
# Show that this will work for vectorized case.
np.testing.assert_array_equal(m.f_simple_vectorized(s_recarray), [20])
# Show as a scalar that inherits from np.generic.
s_scalar = s_recarray[0]
assert isinstance(s_scalar, np.void)
assert m.f_simple(s_scalar) == 20
# Show that an *array* scalar (np.ndarray.shape == ()) does not convert.
# More specifically, conversion to SimpleStruct is not implicit.
s_recarray_scalar = s_recarray.reshape(())
assert isinstance(s_recarray_scalar, np.ndarray)
assert s_recarray_scalar.dtype == simple_dtype
with pytest.raises(TypeError) as excinfo:
m.f_simple(s_recarray_scalar)
assert 'incompatible function arguments' in str(excinfo.value)
# Explicitly convert to m.SimpleStruct.
assert m.f_simple(
m.SimpleStruct.fromtuple(s_recarray_scalar.item())) == 20
# Show that an array of dtype=object does *not* convert.
s_array_object = np.array([s])
assert s_array_object.dtype == object
with pytest.raises(TypeError) as excinfo:
m.f_simple_vectorized(s_array_object)
assert 'incompatible function arguments' in str(excinfo.value)
# Explicitly convert to `np.array(..., dtype=simple_dtype)`
s_array = np.array([s.astuple()], dtype=simple_dtype)
np.testing.assert_array_equal(m.f_simple_vectorized(s_array), [20])
def test_register_dtype():
with pytest.raises(RuntimeError) as excinfo:
m.register_dtype()