mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-21 12:45:11 +00:00
Add support for array_t<handle> and array_t<object> (#5427)
* Add support for array_t<handle> and array_t<object> * style: pre-commit fixes * Remove loops that aren't strictly needed * Fix compiler warning * Disable GC-dependent checks when running on pypy or graalpy * style: pre-commit fixes * Remove PyValueHolder counter again * Move tests to templates to avoid code duplication * Rerun pre-commit * Restore import that was erroneously removed by pre-commit * Reduce code duplication with more template magic * Bring back `.attr("value")` in `return_array_cpp_loop()` This was meant to further stress-test correctness of refcount handling. All modified test functions were manually leak-checked (`while True`, top command, Python 3.12.3, Ubuntu 24.01, gcc 13.2.0). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
This commit is contained in:
parent
08095d9c70
commit
b9fb3168ab
@ -1428,7 +1428,11 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
|
struct npy_format_descriptor<
|
||||||
|
T,
|
||||||
|
enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value
|
||||||
|
|| ((std::is_same<T, handle>::value || std::is_same<T, object>::value)
|
||||||
|
&& sizeof(T) == sizeof(PyObject *))>> {
|
||||||
static constexpr auto name = const_name("object");
|
static constexpr auto name = const_name("object");
|
||||||
|
|
||||||
static constexpr int value = npy_api::NPY_OBJECT_;
|
static constexpr int value = npy_api::NPY_OBJECT_;
|
||||||
|
@ -156,6 +156,55 @@ py::handle auxiliaries(T &&r, T2 &&r2) {
|
|||||||
return l.release();
|
return l.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename PyObjectType>
|
||||||
|
PyObjectType convert_to_pyobjecttype(py::object obj);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
PyObject *convert_to_pyobjecttype<PyObject *>(py::object obj) {
|
||||||
|
return obj.release().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
py::handle convert_to_pyobjecttype<py::handle>(py::object obj) {
|
||||||
|
return obj.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
py::object convert_to_pyobjecttype<py::object>(py::object obj) {
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PyObjectType>
|
||||||
|
std::string pass_array_return_sum_str_values(const py::array_t<PyObjectType> &objs) {
|
||||||
|
std::string sum_str_values;
|
||||||
|
for (const auto &obj : objs) {
|
||||||
|
sum_str_values += py::str(obj.attr("value"));
|
||||||
|
}
|
||||||
|
return sum_str_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PyObjectType>
|
||||||
|
py::list pass_array_return_as_list(const py::array_t<PyObjectType> &objs) {
|
||||||
|
return objs;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PyObjectType>
|
||||||
|
py::array_t<PyObjectType> return_array_cpp_loop(const py::list &objs) {
|
||||||
|
py::size_t arr_size = py::len(objs);
|
||||||
|
py::array_t<PyObjectType> arr_from_list(static_cast<py::ssize_t>(arr_size));
|
||||||
|
PyObjectType *data = arr_from_list.mutable_data();
|
||||||
|
for (py::size_t i = 0; i < arr_size; i++) {
|
||||||
|
assert(!data[i]);
|
||||||
|
data[i] = convert_to_pyobjecttype<PyObjectType>(objs[i].attr("value"));
|
||||||
|
}
|
||||||
|
return arr_from_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PyObjectType>
|
||||||
|
py::array_t<PyObjectType> return_array_from_list(const py::list &objs) {
|
||||||
|
return objs;
|
||||||
|
}
|
||||||
|
|
||||||
// note: declaration at local scope would create a dangling reference!
|
// note: declaration at local scope would create a dangling reference!
|
||||||
static int data_i = 42;
|
static int data_i = 42;
|
||||||
|
|
||||||
@ -520,28 +569,21 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
sm.def("round_trip_float", [](double d) { return d; });
|
sm.def("round_trip_float", [](double d) { return d; });
|
||||||
|
|
||||||
sm.def("pass_array_pyobject_ptr_return_sum_str_values",
|
sm.def("pass_array_pyobject_ptr_return_sum_str_values",
|
||||||
[](const py::array_t<PyObject *> &objs) {
|
pass_array_return_sum_str_values<PyObject *>);
|
||||||
std::string sum_str_values;
|
sm.def("pass_array_handle_return_sum_str_values",
|
||||||
for (const auto &obj : objs) {
|
pass_array_return_sum_str_values<py::handle>);
|
||||||
sum_str_values += py::str(obj.attr("value"));
|
sm.def("pass_array_object_return_sum_str_values",
|
||||||
}
|
pass_array_return_sum_str_values<py::object>);
|
||||||
return sum_str_values;
|
|
||||||
});
|
|
||||||
|
|
||||||
sm.def("pass_array_pyobject_ptr_return_as_list",
|
sm.def("pass_array_pyobject_ptr_return_as_list", pass_array_return_as_list<PyObject *>);
|
||||||
[](const py::array_t<PyObject *> &objs) -> py::list { return objs; });
|
sm.def("pass_array_handle_return_as_list", pass_array_return_as_list<py::handle>);
|
||||||
|
sm.def("pass_array_object_return_as_list", pass_array_return_as_list<py::object>);
|
||||||
|
|
||||||
sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) {
|
sm.def("return_array_pyobject_ptr_cpp_loop", return_array_cpp_loop<PyObject *>);
|
||||||
py::size_t arr_size = py::len(objs);
|
sm.def("return_array_handle_cpp_loop", return_array_cpp_loop<py::handle>);
|
||||||
py::array_t<PyObject *> arr_from_list(static_cast<py::ssize_t>(arr_size));
|
sm.def("return_array_object_cpp_loop", return_array_cpp_loop<py::object>);
|
||||||
PyObject **data = arr_from_list.mutable_data();
|
|
||||||
for (py::size_t i = 0; i < arr_size; i++) {
|
|
||||||
assert(data[i] == nullptr);
|
|
||||||
data[i] = py::cast<PyObject *>(objs[i].attr("value"));
|
|
||||||
}
|
|
||||||
return arr_from_list;
|
|
||||||
});
|
|
||||||
|
|
||||||
sm.def("return_array_pyobject_ptr_from_list",
|
sm.def("return_array_pyobject_ptr_from_list", return_array_from_list<PyObject *>);
|
||||||
[](const py::list &objs) -> py::array_t<PyObject *> { return objs; });
|
sm.def("return_array_handle_from_list", return_array_from_list<py::handle>);
|
||||||
|
sm.def("return_array_object_from_list", return_array_from_list<py::object>);
|
||||||
}
|
}
|
||||||
|
@ -629,45 +629,61 @@ def UnwrapPyValueHolder(vhs):
|
|||||||
return [vh.value for vh in vhs]
|
return [vh.value for vh in vhs]
|
||||||
|
|
||||||
|
|
||||||
def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray():
|
PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS = [
|
||||||
|
m.pass_array_pyobject_ptr_return_sum_str_values,
|
||||||
|
m.pass_array_handle_return_sum_str_values,
|
||||||
|
m.pass_array_object_return_sum_str_values,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS
|
||||||
|
)
|
||||||
|
def test_pass_array_object_return_sum_str_values_ndarray(pass_array):
|
||||||
# Intentionally all temporaries, do not change.
|
# Intentionally all temporaries, do not change.
|
||||||
assert (
|
assert (
|
||||||
m.pass_array_pyobject_ptr_return_sum_str_values(
|
pass_array(np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object))
|
||||||
np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
|
|
||||||
)
|
|
||||||
== "-3four5.0"
|
== "-3four5.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_pass_array_pyobject_ptr_return_sum_str_values_list():
|
@pytest.mark.parametrize(
|
||||||
|
"pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS
|
||||||
|
)
|
||||||
|
def test_pass_array_object_return_sum_str_values_list(pass_array):
|
||||||
# Intentionally all temporaries, do not change.
|
# Intentionally all temporaries, do not change.
|
||||||
assert (
|
assert pass_array(WrapWithPyValueHolder(2, "three", -4.0)) == "2three-4.0"
|
||||||
m.pass_array_pyobject_ptr_return_sum_str_values(
|
|
||||||
WrapWithPyValueHolder(2, "three", -4.0)
|
|
||||||
)
|
|
||||||
== "2three-4.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pass_array_pyobject_ptr_return_as_list():
|
@pytest.mark.parametrize(
|
||||||
|
"pass_array",
|
||||||
|
[
|
||||||
|
m.pass_array_pyobject_ptr_return_as_list,
|
||||||
|
m.pass_array_handle_return_as_list,
|
||||||
|
m.pass_array_object_return_as_list,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_pass_array_object_return_as_list(pass_array):
|
||||||
# Intentionally all temporaries, do not change.
|
# Intentionally all temporaries, do not change.
|
||||||
assert UnwrapPyValueHolder(
|
assert UnwrapPyValueHolder(
|
||||||
m.pass_array_pyobject_ptr_return_as_list(
|
pass_array(np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object))
|
||||||
np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
|
|
||||||
)
|
|
||||||
) == [-1, "two", 3.0]
|
) == [-1, "two", 3.0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("return_array_pyobject_ptr", "unwrap"),
|
("return_array", "unwrap"),
|
||||||
[
|
[
|
||||||
(m.return_array_pyobject_ptr_cpp_loop, list),
|
(m.return_array_pyobject_ptr_cpp_loop, list),
|
||||||
|
(m.return_array_handle_cpp_loop, list),
|
||||||
|
(m.return_array_object_cpp_loop, list),
|
||||||
(m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
|
(m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
|
||||||
|
(m.return_array_handle_from_list, UnwrapPyValueHolder),
|
||||||
|
(m.return_array_object_from_list, UnwrapPyValueHolder),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap):
|
def test_return_array_object_cpp_loop(return_array, unwrap):
|
||||||
# Intentionally all temporaries, do not change.
|
# Intentionally all temporaries, do not change.
|
||||||
arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0))
|
arr_from_list = return_array(WrapWithPyValueHolder(6, "seven", -8.0))
|
||||||
assert isinstance(arr_from_list, np.ndarray)
|
assert isinstance(arr_from_list, np.ndarray)
|
||||||
assert arr_from_list.dtype == np.dtype("O")
|
assert arr_from_list.dtype == np.dtype("O")
|
||||||
assert unwrap(arr_from_list) == [6, "seven", -8.0]
|
assert unwrap(arr_from_list) == [6, "seven", -8.0]
|
||||||
|
Loading…
Reference in New Issue
Block a user