diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index 467e0253f..f235831b4 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -259,7 +259,26 @@ TEST_SUBMODULE(numpy_dtypes, m) { catch (...) { return; } // typeinfo may be registered before the dtype descriptor for scalar casts to work... - py::class_(m, "SimpleStruct"); + py::class_(m, "SimpleStruct") + // Explicit construct to ensure zero-valued initialization. + .def(py::init([]() { return SimpleStruct(); })) + .def_readwrite("bool_", &SimpleStruct::bool_) + .def_readwrite("uint_", &SimpleStruct::uint_) + .def_readwrite("float_", &SimpleStruct::float_) + .def_readwrite("ldbl_", &SimpleStruct::ldbl_) + .def("astuple", [](const SimpleStruct& self) { + return py::make_tuple(self.bool_, self.uint_, self.float_, self.ldbl_); + }) + .def_static("fromtuple", [](const py::tuple tup) { + if (py::len(tup) != 4) { + throw py::cast_error("Invalid size"); + } + return SimpleStruct{ + tup[0].cast(), + tup[1].cast(), + tup[2].cast(), + tup[3].cast()}; + }); PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); PYBIND11_NUMPY_DTYPE(SimpleStructReordered, bool_, uint_, float_, ldbl_); @@ -462,10 +481,16 @@ TEST_SUBMODULE(numpy_dtypes, m) { m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); }); // test_scalar_conversion - m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; }); + auto f_simple = [](SimpleStruct s) { return s.uint_ * 10; }; + m.def("f_simple", f_simple); m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; }); m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; }); + // test_vectorize + m.def("f_simple_vectorized", py::vectorize(f_simple)); + auto f_simple_pass_thru = [](SimpleStruct s) { return s; }; + m.def("f_simple_pass_thru_vectorized", py::vectorize(f_simple_pass_thru)); + // test_register_dtype m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); }); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 417d6f1cf..b9f3bd26e 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -138,6 +138,10 @@ def test_recarray(simple_dtype, packed_dtype): assert_equal(arr, elements, simple_dtype) assert_equal(arr, elements, packed_dtype) + # Show what recarray's look like in NumPy. + assert type(arr[0]) == np.void + assert type(arr[0].item()) == tuple + if dtype == simple_dtype: assert m.print_rec_simple(arr) == [ "s:0,0,0,-0", @@ -289,6 +293,56 @@ def test_scalar_conversion(): assert 'incompatible function arguments' in str(excinfo.value) +def test_vectorize(): + n = 3 + array = m.create_rec_simple(n) + values = m.f_simple_vectorized(array) + np.testing.assert_array_equal(values, [0, 10, 20]) + array_2 = m.f_simple_pass_thru_vectorized(array) + np.testing.assert_array_equal(array, array_2) + + +def test_cls_and_dtype_conversion(simple_dtype): + s = m.SimpleStruct() + assert s.astuple() == (False, 0, 0., 0.) + assert m.SimpleStruct.fromtuple(s.astuple()).astuple() == s.astuple() + + s.uint_ = 2 + assert m.f_simple(s) == 20 + + # Try as recarray of shape==(1,). + s_recarray = np.array([(False, 2, 0., 0.)], dtype=simple_dtype) + # Show that this will work for vectorized case. + np.testing.assert_array_equal(m.f_simple_vectorized(s_recarray), [20]) + + # Show as a scalar that inherits from np.generic. + s_scalar = s_recarray[0] + assert isinstance(s_scalar, np.void) + assert m.f_simple(s_scalar) == 20 + + # Show that an *array* scalar (np.ndarray.shape == ()) does not convert. + # More specifically, conversion to SimpleStruct is not implicit. + s_recarray_scalar = s_recarray.reshape(()) + assert isinstance(s_recarray_scalar, np.ndarray) + assert s_recarray_scalar.dtype == simple_dtype + with pytest.raises(TypeError) as excinfo: + m.f_simple(s_recarray_scalar) + assert 'incompatible function arguments' in str(excinfo.value) + # Explicitly convert to m.SimpleStruct. + assert m.f_simple( + m.SimpleStruct.fromtuple(s_recarray_scalar.item())) == 20 + + # Show that an array of dtype=object does *not* convert. + s_array_object = np.array([s]) + assert s_array_object.dtype == object + with pytest.raises(TypeError) as excinfo: + m.f_simple_vectorized(s_array_object) + assert 'incompatible function arguments' in str(excinfo.value) + # Explicitly convert to `np.array(..., dtype=simple_dtype)` + s_array = np.array([s.astuple()], dtype=simple_dtype) + np.testing.assert_array_equal(m.f_simple_vectorized(s_array), [20]) + + def test_register_dtype(): with pytest.raises(RuntimeError) as excinfo: m.register_dtype()