diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0d0cbdfa8..fa128efdd 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -199,6 +199,7 @@ struct npy_api { int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int); + PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*); private: enum functions { @@ -216,6 +217,7 @@ private: API_PyArray_DescrNewFromType = 96, API_PyArray_Newshape = 135, API_PyArray_Squeeze = 136, + API_PyArray_View = 137, API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, API_PyArray_GetArrayParamsFromObject = 278, @@ -248,6 +250,7 @@ private: DECL_NPY_API(PyArray_DescrNewFromType); DECL_NPY_API(PyArray_Newshape); DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_View); DECL_NPY_API(PyArray_DescrConverter); DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_GetArrayParamsFromObject); @@ -802,6 +805,21 @@ public: return new_array; } + /// Create a view of an array in a different data type. + /// This function may fundamentally reinterpret the data in the array. + /// It is the responsibility of the caller to ensure that this is safe. + /// Only supports the `dtype` argument, the `type` argument is omitted, + /// to be added as needed. + array view(const std::string &dtype) { + auto &api = detail::npy_api::get(); + auto new_view = reinterpret_steal(api.PyArray_View_( + m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr)); + if (!new_view) { + throw error_already_set(); + } + return new_view; + } + /// Ensure that the argument is a NumPy array /// In case of an error, nullptr is returned and the Python error is cleared. static array ensure(handle h, int ExtraFlags = 0) { diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 4ccfd279b..30a71acc9 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -405,6 +405,9 @@ TEST_SUBMODULE(numpy_array, sm) { return a; }); + sm.def("array_view", + [](py::array_t a, const std::string &dtype) { return a.view(dtype); }); + sm.def("reshape_initializer_list", [](py::array_t a, size_t N, size_t M, size_t O) { return a.reshape({N, M, O}); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index e96454be4..e4138f023 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -476,6 +476,21 @@ def test_array_create_and_resize(): assert np.all(a == 42.0) +def test_array_view(): + a = np.ones(100 * 4).astype("uint8") + a_float_view = m.array_view(a, "float32") + assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32 + + a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32 + assert a_int16_view.shape == (100 * 2,) + + +def test_array_view_invalid(): + a = np.ones(100 * 4).astype("uint8") + with pytest.raises(TypeError): + m.array_view(a, "deadly_dtype") + + def test_reshape_initializer_list(): a = np.arange(2 * 7 * 3) + 1 x = m.reshape_initializer_list(a, 2, 7, 3)