From 5168c135aec2668701b483789a32791947f9ac45 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 17 May 2023 09:14:15 -0700 Subject: [PATCH] Add `npy_format_descriptor` to enable `py::array_t` to/from-python conversions. --- include/pybind11/numpy.h | 26 +++++++++++++---- tests/test_numpy_array.cpp | 26 +++++++++++++++++ tests/test_numpy_array.py | 58 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 854d6e87f..1a31fc292 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -586,6 +586,16 @@ public: return detail::npy_format_descriptor::type>::dtype(); } + /// Return dtype for the given typenum (one of the NPY_TYPES). + /// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType + static dtype from_typenum(int typenum) { + auto *ptr = detail::npy_api::get().PyArray_DescrFromType_(typenum); + if (!ptr) { + throw error_already_set(); + } + return reinterpret_steal(ptr); + } + /// Size of the data type in bytes. ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; } @@ -1283,12 +1293,16 @@ private: public: static constexpr int value = values[detail::is_fmt_numeric::index]; - static pybind11::dtype dtype() { - if (auto *ptr = npy_api::get().PyArray_DescrFromType_(value)) { - return reinterpret_steal(ptr); - } - pybind11_fail("Unsupported buffer format!"); - } + static pybind11::dtype dtype() { return pybind11::dtype::from_typenum(value); } +}; + +template +struct npy_format_descriptor::value>> { + static constexpr auto name = const_name("object"); + + static constexpr int value = npy_api::NPY_OBJECT_; + + static pybind11::dtype dtype() { return pybind11::dtype::from_typenum(value); } }; #define PYBIND11_DECL_CHAR_FMT \ diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index b118e2c6c..ef907dad6 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -523,4 +523,30 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("test_fmt_desc_const_double", [](const py::array_t &) {}); sm.def("round_trip_float", [](double d) { return d; }); + + sm.def("pass_array_pyobject_ptr_return_sum_str_values", + [](const py::array_t &objs) { + std::string sum_str_values; + for (auto &obj : objs) { + sum_str_values += py::str(obj.attr("value")); + } + return sum_str_values; + }); + + sm.def("pass_array_pyobject_ptr_return_as_list", + [](const py::array_t &objs) -> py::list { return objs; }); + + sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) { + py::size_t arr_size = py::len(objs); + py::array_t arr_from_list(static_cast(arr_size)); + PyObject **data = arr_from_list.mutable_data(); + for (py::size_t i = 0; i < arr_size; i++) { + assert(data[i] == nullptr); + data[i] = py::cast(objs[i].attr("value")); + } + return arr_from_list; + }); + + sm.def("return_array_pyobject_ptr_from_list", + [](const py::list &objs) -> py::array_t { return objs; }); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 070813d3a..f9c4a10fc 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -595,3 +595,61 @@ def test_round_trip_float(): arr = np.zeros((), np.float64) arr[()] = 37.2 assert m.round_trip_float(arr) == 37.2 + + +# For use as a temporary user-defined object, to maximize sensitivity of the tests below. +class PyValueHolder: + def __init__(self, value): + self.value = value + + +def WrapWithPyValueHolder(*values): + return [PyValueHolder(v) for v in values] + + +def UnwrapPyValueHolder(vhs): + return [vh.value for vh in vhs] + + +def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray(): + # Intentionally all temporaries, do not change. + assert ( + m.pass_array_pyobject_ptr_return_sum_str_values( + np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object) + ) + == "-3four5.0" + ) + + +def test_pass_array_pyobject_ptr_return_sum_str_values_list(): + # Intentionally all temporaries, do not change. + assert ( + m.pass_array_pyobject_ptr_return_sum_str_values( + WrapWithPyValueHolder(2, "three", -4.0) + ) + == "2three-4.0" + ) + + +def test_pass_array_pyobject_ptr_return_as_list(): + # Intentionally all temporaries, do not change. + assert UnwrapPyValueHolder( + m.pass_array_pyobject_ptr_return_as_list( + np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object) + ) + ) == [-1, "two", 3.0] + + +@pytest.mark.parametrize( + ("return_array_pyobject_ptr", "unwrap"), + [ + (m.return_array_pyobject_ptr_cpp_loop, list), + (m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder), + ], +) +def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap): + # Intentionally all temporaries, do not change. + arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0)) + assert isinstance(arr_from_list, np.ndarray) + assert arr_from_list.dtype == np.dtype("O") + assert unwrap(arr_from_list) == [6, "seven", -8.0]