mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
feat: vectorize functions with void return type (#1969)
* Allow function/functor passed to py::vectorize to return void * Stealing @sizmailov's test and fixing unused argument warning * Add missing std::move() RVO doesn't work here because function return type is different from actual returned type * remove extra EOL * docs: add a few details * chore: pre-commit autoupdate * Remove array_iterator, array_begin, and array_end (in detail namespace) Co-authored-by: Sergei Izmailov <sergei.a.izmailov@gmail.com> Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
This commit is contained in:
parent
56784c4f42
commit
9796fe98fc
@ -98,6 +98,9 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version.
|
|||||||
``get_type_overload`` is deprecated.
|
``get_type_overload`` is deprecated.
|
||||||
`#2325 <https://github.com/pybind/pybind11/pull/2325>`_
|
`#2325 <https://github.com/pybind/pybind11/pull/2325>`_
|
||||||
|
|
||||||
|
* Error now thrown when ``__init__`` is forgotten on subclasses.
|
||||||
|
`#2152 <https://github.com/pybind/pybind11/pull/2152>`_
|
||||||
|
|
||||||
* `py::class_<union_type>` is now supported. Note that writing to one data
|
* `py::class_<union_type>` is now supported. Note that writing to one data
|
||||||
member of the union and reading another (type punning) is UB in C++. Thus
|
member of the union and reading another (type punning) is UB in C++. Thus
|
||||||
pybind11-bound enums should never be used for such conversion.
|
pybind11-bound enums should never be used for such conversion.
|
||||||
@ -109,9 +112,6 @@ Smaller or developer focused features:
|
|||||||
|
|
||||||
.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc
|
.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc
|
||||||
|
|
||||||
* Error now thrown when ``__init__`` is forgotten on subclasses.
|
|
||||||
`#2152 <https://github.com/pybind/pybind11/pull/2152>`_
|
|
||||||
|
|
||||||
* If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to
|
* If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to
|
||||||
``None``.
|
``None``.
|
||||||
`#2291 <https://github.com/pybind/pybind11/pull/2291>`_
|
`#2291 <https://github.com/pybind/pybind11/pull/2291>`_
|
||||||
@ -122,9 +122,6 @@ Smaller or developer focused features:
|
|||||||
* Throw if conversion to ``str`` fails.
|
* Throw if conversion to ``str`` fails.
|
||||||
`#2477 <https://github.com/pybind/pybind11/pull/2477>`_
|
`#2477 <https://github.com/pybind/pybind11/pull/2477>`_
|
||||||
|
|
||||||
* Added missing signature for ``py::array``.
|
|
||||||
`#2363 <https://github.com/pybind/pybind11/pull/2363>`_
|
|
||||||
|
|
||||||
* Pointer to ``std::tuple`` & ``std::pair`` supported in cast.
|
* Pointer to ``std::tuple`` & ``std::pair`` supported in cast.
|
||||||
`#2334 <https://github.com/pybind/pybind11/pull/2334>`_
|
`#2334 <https://github.com/pybind/pybind11/pull/2334>`_
|
||||||
|
|
||||||
@ -132,7 +129,13 @@ Smaller or developer focused features:
|
|||||||
argument type.
|
argument type.
|
||||||
`#2293 <https://github.com/pybind/pybind11/pull/2293>`_
|
`#2293 <https://github.com/pybind/pybind11/pull/2293>`_
|
||||||
|
|
||||||
* Bugfixes related to more extensive testing
|
* Added missing signature for ``py::array``.
|
||||||
|
`#2363 <https://github.com/pybind/pybind11/pull/2363>`_
|
||||||
|
|
||||||
|
* ``py::vectorize`` is now supported on functions that return void.
|
||||||
|
`#1969 <https://github.com/pybind/pybind11/pull/1969>`_
|
||||||
|
|
||||||
|
* Bugfixes related to more extensive testing.
|
||||||
`#2321 <https://github.com/pybind/pybind11/pull/2321>`_
|
`#2321 <https://github.com/pybind/pybind11/pull/2321>`_
|
||||||
|
|
||||||
* Bug in timezone issue in Eastern hemisphere midnight fixed.
|
* Bug in timezone issue in Eastern hemisphere midnight fixed.
|
||||||
|
@ -1274,19 +1274,6 @@ private:
|
|||||||
|
|
||||||
#endif // __CLION_IDE__
|
#endif // __CLION_IDE__
|
||||||
|
|
||||||
template <class T>
|
|
||||||
using array_iterator = typename std::add_pointer<T>::type;
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
array_iterator<T> array_begin(const buffer_info& buffer) {
|
|
||||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
array_iterator<T> array_end(const buffer_info& buffer) {
|
|
||||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
|
|
||||||
}
|
|
||||||
|
|
||||||
class common_iterator {
|
class common_iterator {
|
||||||
public:
|
public:
|
||||||
using container_type = std::vector<ssize_t>;
|
using container_type = std::vector<ssize_t>;
|
||||||
@ -1486,6 +1473,56 @@ struct vectorize_arg {
|
|||||||
using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
|
using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// py::vectorize when a return type is present
|
||||||
|
template <typename Func, typename Return, typename... Args>
|
||||||
|
struct vectorize_returned_array {
|
||||||
|
using Type = array_t<Return>;
|
||||||
|
|
||||||
|
static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
|
||||||
|
if (trivial == broadcast_trivial::f_trivial)
|
||||||
|
return array_t<Return, array::f_style>(shape);
|
||||||
|
else
|
||||||
|
return array_t<Return>(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Return *mutable_data(Type &array) {
|
||||||
|
return array.mutable_data();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Return call(Func &f, Args &... args) {
|
||||||
|
return f(args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void call(Return *out, size_t i, Func &f, Args &... args) {
|
||||||
|
out[i] = f(args...);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// py::vectorize when a return type is not present
|
||||||
|
template <typename Func, typename... Args>
|
||||||
|
struct vectorize_returned_array<Func, void, Args...> {
|
||||||
|
using Type = none;
|
||||||
|
|
||||||
|
static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
|
||||||
|
return none();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void *mutable_data(Type &) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static detail::void_type call(Func &f, Args &... args) {
|
||||||
|
f(args...);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
static void call(void *, size_t, Func &f, Args &... args) {
|
||||||
|
f(args...);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename Func, typename Return, typename... Args>
|
template <typename Func, typename Return, typename... Args>
|
||||||
struct vectorize_helper {
|
struct vectorize_helper {
|
||||||
|
|
||||||
@ -1520,6 +1557,8 @@ private:
|
|||||||
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
|
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
|
||||||
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
|
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
|
||||||
|
|
||||||
|
using returned_array = vectorize_returned_array<Func, Return, Args...>;
|
||||||
|
|
||||||
// Runs a vectorized function given arguments tuple and three index sequences:
|
// Runs a vectorized function given arguments tuple and three index sequences:
|
||||||
// - Index is the full set of 0 ... (N-1) argument indices;
|
// - Index is the full set of 0 ... (N-1) argument indices;
|
||||||
// - VIndex is the subset of argument indices with vectorized parameters, letting us access
|
// - VIndex is the subset of argument indices with vectorized parameters, letting us access
|
||||||
@ -1551,20 +1590,19 @@ private:
|
|||||||
// not wrapped in an array).
|
// not wrapped in an array).
|
||||||
if (size == 1 && ndim == 0) {
|
if (size == 1 && ndim == 0) {
|
||||||
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
|
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
|
||||||
return cast(f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...));
|
return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
|
||||||
}
|
}
|
||||||
|
|
||||||
array_t<Return> result;
|
auto result = returned_array::create(trivial, shape);
|
||||||
if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
|
|
||||||
else result = array_t<Return>(shape);
|
|
||||||
|
|
||||||
if (size == 0) return std::move(result);
|
if (size == 0) return std::move(result);
|
||||||
|
|
||||||
/* Call the function */
|
/* Call the function */
|
||||||
|
auto mutable_data = returned_array::mutable_data(result);
|
||||||
if (trivial == broadcast_trivial::non_trivial)
|
if (trivial == broadcast_trivial::non_trivial)
|
||||||
apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq);
|
apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
|
||||||
else
|
else
|
||||||
apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
|
apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
|
||||||
|
|
||||||
return std::move(result);
|
return std::move(result);
|
||||||
}
|
}
|
||||||
@ -1587,7 +1625,7 @@ private:
|
|||||||
}};
|
}};
|
||||||
|
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
out[i] = f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...);
|
returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
|
||||||
for (auto &x : vecparams) x.first += x.second;
|
for (auto &x : vecparams) x.first += x.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1595,19 +1633,18 @@ private:
|
|||||||
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
||||||
void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
|
void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
|
||||||
std::array<void *, N> ¶ms,
|
std::array<void *, N> ¶ms,
|
||||||
array_t<Return> &output_array,
|
Return *out,
|
||||||
|
size_t size,
|
||||||
|
const std::vector<ssize_t> &output_shape,
|
||||||
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
|
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
|
||||||
|
|
||||||
buffer_info output = output_array.request();
|
multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
|
||||||
multi_array_iterator<NVectorized> input_iter(buffers, output.shape);
|
|
||||||
|
|
||||||
for (array_iterator<Return> iter = array_begin<Return>(output), end = array_end<Return>(output);
|
for (size_t i = 0; i < size; ++i, ++input_iter) {
|
||||||
iter != end;
|
|
||||||
++iter, ++input_iter) {
|
|
||||||
PYBIND11_EXPAND_SIDE_EFFECTS((
|
PYBIND11_EXPAND_SIDE_EFFECTS((
|
||||||
params[VIndex] = input_iter.template data<BIndex>()
|
params[VIndex] = input_iter.template data<BIndex>()
|
||||||
));
|
));
|
||||||
*iter = f(*reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
|
returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -50,7 +50,9 @@ TEST_SUBMODULE(numpy_vectorize, m) {
|
|||||||
NonPODClass(int v) : value{v} {}
|
NonPODClass(int v) : value{v} {}
|
||||||
int value;
|
int value;
|
||||||
};
|
};
|
||||||
py::class_<NonPODClass>(m, "NonPODClass").def(py::init<int>());
|
py::class_<NonPODClass>(m, "NonPODClass")
|
||||||
|
.def(py::init<int>())
|
||||||
|
.def_readwrite("value", &NonPODClass::value);
|
||||||
m.def("vec_passthrough", py::vectorize(
|
m.def("vec_passthrough", py::vectorize(
|
||||||
[](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
|
[](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
|
||||||
return *a + b + c.at(0) + d + e + f.value + g;
|
return *a + b + c.at(0) + d + e + f.value + g;
|
||||||
@ -86,4 +88,6 @@ TEST_SUBMODULE(numpy_vectorize, m) {
|
|||||||
std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
|
std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
|
||||||
return py::detail::broadcast(buffers, ndim, shape);
|
return py::detail::broadcast(buffers, ndim, shape);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("add_to", py::vectorize([](NonPODClass& x, int a) { x.value += a; }));
|
||||||
}
|
}
|
||||||
|
@ -192,3 +192,14 @@ def test_array_collapse():
|
|||||||
z = m.vectorized_func(1, [[[2]]], 3)
|
z = m.vectorized_func(1, [[[2]]], 3)
|
||||||
assert isinstance(z, np.ndarray)
|
assert isinstance(z, np.ndarray)
|
||||||
assert z.shape == (1, 1, 1)
|
assert z.shape == (1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectorized_noreturn():
|
||||||
|
x = m.NonPODClass(0)
|
||||||
|
assert x.value == 0
|
||||||
|
m.add_to(x, [1, 2, 3, 4])
|
||||||
|
assert x.value == 10
|
||||||
|
m.add_to(x, 1)
|
||||||
|
assert x.value == 11
|
||||||
|
m.add_to(x, [[1, 1], [2, 3]])
|
||||||
|
assert x.value == 18
|
||||||
|
Loading…
Reference in New Issue
Block a user