mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 13:47:53 +00:00
Add npy_format_descriptor<PyObject *>
to enable py::array_t<PyObject *>
to/from-python conversions.
This commit is contained in:
parent
d72ffb448c
commit
5168c135ae
@ -586,6 +586,16 @@ public:
|
||||
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
|
||||
}
|
||||
|
||||
/// Return dtype for the given typenum (one of the NPY_TYPES).
|
||||
/// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType
|
||||
static dtype from_typenum(int typenum) {
|
||||
auto *ptr = detail::npy_api::get().PyArray_DescrFromType_(typenum);
|
||||
if (!ptr) {
|
||||
throw error_already_set();
|
||||
}
|
||||
return reinterpret_steal<dtype>(ptr);
|
||||
}
|
||||
|
||||
/// Size of the data type in bytes.
|
||||
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
|
||||
|
||||
@ -1283,12 +1293,16 @@ private:
|
||||
public:
|
||||
static constexpr int value = values[detail::is_fmt_numeric<T>::index];
|
||||
|
||||
static pybind11::dtype dtype() {
|
||||
if (auto *ptr = npy_api::get().PyArray_DescrFromType_(value)) {
|
||||
return reinterpret_steal<pybind11::dtype>(ptr);
|
||||
}
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
}
|
||||
static pybind11::dtype dtype() { return pybind11::dtype::from_typenum(value); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
|
||||
static constexpr auto name = const_name("object");
|
||||
|
||||
static constexpr int value = npy_api::NPY_OBJECT_;
|
||||
|
||||
static pybind11::dtype dtype() { return pybind11::dtype::from_typenum(value); }
|
||||
};
|
||||
|
||||
#define PYBIND11_DECL_CHAR_FMT \
|
||||
|
@ -523,4 +523,30 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
sm.def("test_fmt_desc_const_double", [](const py::array_t<const double> &) {});
|
||||
|
||||
sm.def("round_trip_float", [](double d) { return d; });
|
||||
|
||||
sm.def("pass_array_pyobject_ptr_return_sum_str_values",
|
||||
[](const py::array_t<PyObject *> &objs) {
|
||||
std::string sum_str_values;
|
||||
for (auto &obj : objs) {
|
||||
sum_str_values += py::str(obj.attr("value"));
|
||||
}
|
||||
return sum_str_values;
|
||||
});
|
||||
|
||||
sm.def("pass_array_pyobject_ptr_return_as_list",
|
||||
[](const py::array_t<PyObject *> &objs) -> py::list { return objs; });
|
||||
|
||||
sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) {
|
||||
py::size_t arr_size = py::len(objs);
|
||||
py::array_t<PyObject *> arr_from_list(static_cast<py::ssize_t>(arr_size));
|
||||
PyObject **data = arr_from_list.mutable_data();
|
||||
for (py::size_t i = 0; i < arr_size; i++) {
|
||||
assert(data[i] == nullptr);
|
||||
data[i] = py::cast<PyObject *>(objs[i].attr("value"));
|
||||
}
|
||||
return arr_from_list;
|
||||
});
|
||||
|
||||
sm.def("return_array_pyobject_ptr_from_list",
|
||||
[](const py::list &objs) -> py::array_t<PyObject *> { return objs; });
|
||||
}
|
||||
|
@ -595,3 +595,61 @@ def test_round_trip_float():
|
||||
arr = np.zeros((), np.float64)
|
||||
arr[()] = 37.2
|
||||
assert m.round_trip_float(arr) == 37.2
|
||||
|
||||
|
||||
# For use as a temporary user-defined object, to maximize sensitivity of the tests below.
|
||||
class PyValueHolder:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
def WrapWithPyValueHolder(*values):
|
||||
return [PyValueHolder(v) for v in values]
|
||||
|
||||
|
||||
def UnwrapPyValueHolder(vhs):
|
||||
return [vh.value for vh in vhs]
|
||||
|
||||
|
||||
def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray():
|
||||
# Intentionally all temporaries, do not change.
|
||||
assert (
|
||||
m.pass_array_pyobject_ptr_return_sum_str_values(
|
||||
np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
|
||||
)
|
||||
== "-3four5.0"
|
||||
)
|
||||
|
||||
|
||||
def test_pass_array_pyobject_ptr_return_sum_str_values_list():
|
||||
# Intentionally all temporaries, do not change.
|
||||
assert (
|
||||
m.pass_array_pyobject_ptr_return_sum_str_values(
|
||||
WrapWithPyValueHolder(2, "three", -4.0)
|
||||
)
|
||||
== "2three-4.0"
|
||||
)
|
||||
|
||||
|
||||
def test_pass_array_pyobject_ptr_return_as_list():
|
||||
# Intentionally all temporaries, do not change.
|
||||
assert UnwrapPyValueHolder(
|
||||
m.pass_array_pyobject_ptr_return_as_list(
|
||||
np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
|
||||
)
|
||||
) == [-1, "two", 3.0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("return_array_pyobject_ptr", "unwrap"),
|
||||
[
|
||||
(m.return_array_pyobject_ptr_cpp_loop, list),
|
||||
(m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
|
||||
],
|
||||
)
|
||||
def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap):
|
||||
# Intentionally all temporaries, do not change.
|
||||
arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0))
|
||||
assert isinstance(arr_from_list, np.ndarray)
|
||||
assert arr_from_list.dtype == np.dtype("O")
|
||||
assert unwrap(arr_from_list) == [6, "seven", -8.0]
|
||||
|
Loading…
Reference in New Issue
Block a user