feat: vectorize functions with void return type (#1969)

* Allow function/functor passed to py::vectorize to return void

* Stealing @sizmailov's test and fixing unused argument warning

* Add missing std::move()

RVO doesn't work here because function return type is different from
actual returned type

* remove extra EOL

* docs: add a few details

* chore: pre-commit autoupdate

* Remove array_iterator, array_begin, and array_end (in detail namespace)

Co-authored-by: Sergei Izmailov <sergei.a.izmailov@gmail.com>
Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
This commit is contained in:
Yannick Jadoul 2020-10-02 21:30:34 +02:00 committed by GitHub
parent 56784c4f42
commit 9796fe98fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 35 deletions

View File

@ -98,6 +98,9 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version.
``get_type_overload`` is deprecated. ``get_type_overload`` is deprecated.
`#2325 <https://github.com/pybind/pybind11/pull/2325>`_ `#2325 <https://github.com/pybind/pybind11/pull/2325>`_
* Error now thrown when ``__init__`` is forgotten on subclasses.
`#2152 <https://github.com/pybind/pybind11/pull/2152>`_
* `py::class_<union_type>` is now supported. Note that writing to one data * `py::class_<union_type>` is now supported. Note that writing to one data
member of the union and reading another (type punning) is UB in C++. Thus member of the union and reading another (type punning) is UB in C++. Thus
pybind11-bound enums should never be used for such conversion. pybind11-bound enums should never be used for such conversion.
@ -109,9 +112,6 @@ Smaller or developer focused features:
.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc .. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc
* Error now thrown when ``__init__`` is forgotten on subclasses.
`#2152 <https://github.com/pybind/pybind11/pull/2152>`_
* If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to * If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to
``None``. ``None``.
`#2291 <https://github.com/pybind/pybind11/pull/2291>`_ `#2291 <https://github.com/pybind/pybind11/pull/2291>`_
@ -122,9 +122,6 @@ Smaller or developer focused features:
* Throw if conversion to ``str`` fails. * Throw if conversion to ``str`` fails.
`#2477 <https://github.com/pybind/pybind11/pull/2477>`_ `#2477 <https://github.com/pybind/pybind11/pull/2477>`_
* Added missing signature for ``py::array``.
`#2363 <https://github.com/pybind/pybind11/pull/2363>`_
* Pointer to ``std::tuple`` & ``std::pair`` supported in cast. * Pointer to ``std::tuple`` & ``std::pair`` supported in cast.
`#2334 <https://github.com/pybind/pybind11/pull/2334>`_ `#2334 <https://github.com/pybind/pybind11/pull/2334>`_
@ -132,7 +129,13 @@ Smaller or developer focused features:
argument type. argument type.
`#2293 <https://github.com/pybind/pybind11/pull/2293>`_ `#2293 <https://github.com/pybind/pybind11/pull/2293>`_
* Bugfixes related to more extensive testing * Added missing signature for ``py::array``.
`#2363 <https://github.com/pybind/pybind11/pull/2363>`_
* ``py::vectorize`` is now supported on functions that return void.
`#1969 <https://github.com/pybind/pybind11/pull/1969>`_
* Bugfixes related to more extensive testing.
`#2321 <https://github.com/pybind/pybind11/pull/2321>`_ `#2321 <https://github.com/pybind/pybind11/pull/2321>`_
* Bug in timezone issue in Eastern hemisphere midnight fixed. * Bug in timezone issue in Eastern hemisphere midnight fixed.

View File

@ -1274,19 +1274,6 @@ private:
#endif // __CLION_IDE__ #endif // __CLION_IDE__
template <class T>
using array_iterator = typename std::add_pointer<T>::type;
template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}
template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}
class common_iterator { class common_iterator {
public: public:
using container_type = std::vector<ssize_t>; using container_type = std::vector<ssize_t>;
@ -1486,6 +1473,56 @@ struct vectorize_arg {
using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>; using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
}; };
// py::vectorize when a return type is present
template <typename Func, typename Return, typename... Args>
struct vectorize_returned_array {
using Type = array_t<Return>;
static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
if (trivial == broadcast_trivial::f_trivial)
return array_t<Return, array::f_style>(shape);
else
return array_t<Return>(shape);
}
static Return *mutable_data(Type &array) {
return array.mutable_data();
}
static Return call(Func &f, Args &... args) {
return f(args...);
}
static void call(Return *out, size_t i, Func &f, Args &... args) {
out[i] = f(args...);
}
};
// py::vectorize when a return type is not present
template <typename Func, typename... Args>
struct vectorize_returned_array<Func, void, Args...> {
using Type = none;
static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
return none();
}
static void *mutable_data(Type &) {
return nullptr;
}
static detail::void_type call(Func &f, Args &... args) {
f(args...);
return {};
}
static void call(void *, size_t, Func &f, Args &... args) {
f(args...);
}
};
template <typename Func, typename Return, typename... Args> template <typename Func, typename Return, typename... Args>
struct vectorize_helper { struct vectorize_helper {
@ -1520,6 +1557,8 @@ private:
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>; using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type; template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
using returned_array = vectorize_returned_array<Func, Return, Args...>;
// Runs a vectorized function given arguments tuple and three index sequences: // Runs a vectorized function given arguments tuple and three index sequences:
// - Index is the full set of 0 ... (N-1) argument indices; // - Index is the full set of 0 ... (N-1) argument indices;
// - VIndex is the subset of argument indices with vectorized parameters, letting us access // - VIndex is the subset of argument indices with vectorized parameters, letting us access
@ -1551,20 +1590,19 @@ private:
// not wrapped in an array). // not wrapped in an array).
if (size == 1 && ndim == 0) { if (size == 1 && ndim == 0) {
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
return cast(f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...)); return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
} }
array_t<Return> result; auto result = returned_array::create(trivial, shape);
if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
else result = array_t<Return>(shape);
if (size == 0) return std::move(result); if (size == 0) return std::move(result);
/* Call the function */ /* Call the function */
auto mutable_data = returned_array::mutable_data(result);
if (trivial == broadcast_trivial::non_trivial) if (trivial == broadcast_trivial::non_trivial)
apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq); apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
else else
apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq); apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
return std::move(result); return std::move(result);
} }
@ -1587,7 +1625,7 @@ private:
}}; }};
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
out[i] = f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...); returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
for (auto &x : vecparams) x.first += x.second; for (auto &x : vecparams) x.first += x.second;
} }
} }
@ -1595,19 +1633,18 @@ private:
template <size_t... Index, size_t... VIndex, size_t... BIndex> template <size_t... Index, size_t... VIndex, size_t... BIndex>
void apply_broadcast(std::array<buffer_info, NVectorized> &buffers, void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
std::array<void *, N> &params, std::array<void *, N> &params,
array_t<Return> &output_array, Return *out,
size_t size,
const std::vector<ssize_t> &output_shape,
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) { index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
buffer_info output = output_array.request(); multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
multi_array_iterator<NVectorized> input_iter(buffers, output.shape);
for (array_iterator<Return> iter = array_begin<Return>(output), end = array_end<Return>(output); for (size_t i = 0; i < size; ++i, ++input_iter) {
iter != end;
++iter, ++input_iter) {
PYBIND11_EXPAND_SIDE_EFFECTS(( PYBIND11_EXPAND_SIDE_EFFECTS((
params[VIndex] = input_iter.template data<BIndex>() params[VIndex] = input_iter.template data<BIndex>()
)); ));
*iter = f(*reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...); returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
} }
} }
}; };

View File

@ -50,7 +50,9 @@ TEST_SUBMODULE(numpy_vectorize, m) {
NonPODClass(int v) : value{v} {} NonPODClass(int v) : value{v} {}
int value; int value;
}; };
py::class_<NonPODClass>(m, "NonPODClass").def(py::init<int>()); py::class_<NonPODClass>(m, "NonPODClass")
.def(py::init<int>())
.def_readwrite("value", &NonPODClass::value);
m.def("vec_passthrough", py::vectorize( m.def("vec_passthrough", py::vectorize(
[](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) { [](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
return *a + b + c.at(0) + d + e + f.value + g; return *a + b + c.at(0) + d + e + f.value + g;
@ -86,4 +88,6 @@ TEST_SUBMODULE(numpy_vectorize, m) {
std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }}; std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
return py::detail::broadcast(buffers, ndim, shape); return py::detail::broadcast(buffers, ndim, shape);
}); });
m.def("add_to", py::vectorize([](NonPODClass& x, int a) { x.value += a; }));
} }

View File

@ -192,3 +192,14 @@ def test_array_collapse():
z = m.vectorized_func(1, [[[2]]], 3) z = m.vectorized_func(1, [[[2]]], 3)
assert isinstance(z, np.ndarray) assert isinstance(z, np.ndarray)
assert z.shape == (1, 1, 1) assert z.shape == (1, 1, 1)
def test_vectorized_noreturn():
x = m.NonPODClass(0)
assert x.value == 0
m.add_to(x, [1, 2, 3, 4])
assert x.value == 10
m.add_to(x, 1)
assert x.value == 11
m.add_to(x, [[1, 1], [2, 3]])
assert x.value == 18