From 1dc960c37f7d8c46016700992bf279592adbe6c8 Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Thu, 11 Feb 2016 10:47:11 +0100 Subject: [PATCH] NumPy-style broadcasting support in pybind11::vectorize --- example/example10.py | 5 + example/example10.ref | 40 ++++++++ include/pybind11/numpy.h | 204 +++++++++++++++++++++++++++++++++++---- setup.py | 4 +- 4 files changed, 232 insertions(+), 21 deletions(-) diff --git a/example/example10.py b/example/example10.py index 4b01f81f5..0d49fcaa7 100755 --- a/example/example10.py +++ b/example/example10.py @@ -22,3 +22,8 @@ for f in [vectorized_func, vectorized_func2]: print(f(np.array([1, 3]), np.array([2, 4]), 3)) print(f(np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3)) print(np.array([[1, 3, 5], [7, 9, 11]])* np.array([[2, 4, 6], [8, 10, 12]])*3) + print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2)) + print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([2, 3, 4])* 2) + print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2)) + print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2) + diff --git a/example/example10.ref b/example/example10.ref index 517bb9084..9d48d7cfd 100644 --- a/example/example10.ref +++ b/example/example10.ref @@ -16,6 +16,26 @@ my_func(x:int=11, y:float=12, z:float=3) [ 168. 270. 396.]] [[ 6 36 90] [168 270 396]] +my_func(x:int=1, y:float=2, z:float=2) +my_func(x:int=2, y:float=3, z:float=2) +my_func(x:int=3, y:float=4, z:float=2) +my_func(x:int=4, y:float=2, z:float=2) +my_func(x:int=5, y:float=3, z:float=2) +my_func(x:int=6, y:float=4, z:float=2) +[[ 4. 12. 24.] + [ 16. 30. 48.]] +[[ 4 12 24] + [16 30 48]] +my_func(x:int=1, y:float=2, z:float=2) +my_func(x:int=2, y:float=2, z:float=2) +my_func(x:int=3, y:float=2, z:float=2) +my_func(x:int=4, y:float=3, z:float=2) +my_func(x:int=5, y:float=3, z:float=2) +my_func(x:int=6, y:float=3, z:float=2) +[[ 4. 8. 12.] + [ 24. 30. 36.]] +[[ 4 8 12] + [24 30 36]] my_func(x:int=1, y:float=2, z:float=3) 6.0 my_func(x:int=1, y:float=2, z:float=3) @@ -33,3 +53,23 @@ my_func(x:int=11, y:float=12, z:float=3) [ 168. 270. 396.]] [[ 6 36 90] [168 270 396]] +my_func(x:int=1, y:float=2, z:float=2) +my_func(x:int=2, y:float=3, z:float=2) +my_func(x:int=3, y:float=4, z:float=2) +my_func(x:int=4, y:float=2, z:float=2) +my_func(x:int=5, y:float=3, z:float=2) +my_func(x:int=6, y:float=4, z:float=2) +[[ 4. 12. 24.] + [ 16. 30. 48.]] +[[ 4 12 24] + [16 30 48]] +my_func(x:int=1, y:float=2, z:float=2) +my_func(x:int=2, y:float=2, z:float=2) +my_func(x:int=3, y:float=2, z:float=2) +my_func(x:int=4, y:float=3, z:float=2) +my_func(x:int=5, y:float=3, z:float=2) +my_func(x:int=6, y:float=3, z:float=2) +[[ 4. 8. 12.] + [ 24. 30. 36.]] +[[ 4 8 12] + [24 30 36]] diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0aaac96ef..44a0cb1c4 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -11,6 +11,8 @@ #include "pybind11.h" #include "complex.h" +#include +#include #if defined(_MSC_VER) #pragma warning(push) @@ -146,10 +148,158 @@ DECL_FMT(std::complex, NPY_CDOUBLE_); NAMESPACE_BEGIN(detail) +template +using array_iterator = typename std::add_pointer::type; + +template +array_iterator array_begin(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr)); +} + +template +array_iterator array_end(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); +} + +class common_iterator { + +public: + + using container_type = std::vector; + using value_type = container_type::value_type; + using size_type = container_type::size_type; + + common_iterator() : p_ptr(0), m_strides() {} + common_iterator(void* ptr, const container_type& strides, const std::vector& shape) + : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) { + m_strides.back() = static_cast(strides.back()); + for (size_type i = m_strides.size() - 1; i != 0; --i) { + size_type j = i - 1; + value_type s = static_cast(shape[i]); + m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; + } + } + + void increment(size_type dim) { + p_ptr += m_strides[dim]; + } + + void* data() const { + return p_ptr; + } + +private: + + char* p_ptr; + container_type m_strides; +}; + +template +class multi_array_iterator { + +public: + + using container_type = std::vector; + + multi_array_iterator(const std::array& buffers, + const std::vector& shape) + : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() { + // Manual copy to avoid conversion warning if using std::copy + for (size_t i = 0; i < shape.size(); ++i) { + m_shape[i] = static_cast(shape[i]); + } + + container_type strides(shape.size()); + for (size_t i = 0; i < N; ++i) { + init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); + } + } + + multi_array_iterator& operator++() { + for (size_t j = m_index.size(); j != 0; --j) { + size_t i = j - 1; + if (++m_index[i] != m_shape[i]) { + increment_common_iterator(i); + break; + } + else { + m_index[i] = 0; + } + } + return *this; + } + + template + const T& data() const { + return *reinterpret_cast(m_common_iterator[K].data()); + } + +private: + + using common_iter = common_iterator; + + void init_common_iterator(const buffer_info& buffer, const std::vector& shape, common_iter& iterator, container_type& strides) { + auto buffer_shape_iter = buffer.shape.rbegin(); + auto buffer_strides_iter = buffer.strides.rbegin(); + auto shape_iter = shape.rbegin(); + auto strides_iter = strides.rbegin(); + + while (buffer_shape_iter != buffer.shape.rend()) { + if (*shape_iter == *buffer_shape_iter) + *strides_iter = static_cast(*buffer_strides_iter); + else + *strides_iter = 0; + + ++buffer_shape_iter; + ++buffer_strides_iter; + ++shape_iter; + ++strides_iter; + } + + std::fill(strides_iter, strides.rend(), 0); + iterator = common_iter(buffer.ptr, strides, shape); + } + + void increment_common_iterator(size_t dim) { + std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](common_iter& iter) { + iter.increment(dim); + }); + } + + container_type m_shape; + container_type m_index; + std::array m_common_iterator; +}; + template struct handle_type_name> { static PYBIND11_DESCR name() { return _("array[") + type_caster::name() + _("]"); } }; +template +bool broadcast(const std::array& buffers, int& ndim, std::vector& shape) { + ndim = std::accumulate(buffers.begin(), buffers.end(), 0, [](int res, const buffer_info& buf) { + return std::max(res, buf.ndim); + }); + + shape = std::vector(static_cast(ndim), 1); + bool trivial_broadcast = true; + for (size_t i = 0; i < N; ++i) { + auto res_iter = shape.rbegin(); + bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) { + if (*res_iter == 1) { + *res_iter = *shape_iter; + } + else if ((*shape_iter != 1) && (*res_iter != *shape_iter)) { + pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); + } + i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter); + } + trivial_broadcast = trivial_broadcast && i_trivial_broadcast; + } + return trivial_broadcast; +} + template struct vectorize_helper { typename std::remove_reference::type f; @@ -161,33 +311,28 @@ struct vectorize_helper { return run(args..., typename make_index_sequence::type()); } - template object run(array_t&... args, index_sequence) { + template object run(array_t&... args, index_sequence index) { /* Request buffers from all parameters */ const size_t N = sizeof...(Args); + std::array buffers {{ args.request()... }}; /* Determine dimensions parameters of output array */ - int ndim = 0; size_t size = 0; - std::vector shape; - for (size_t i=0; i size) { - ndim = buffers[i].ndim; - shape = buffers[i].shape; - size = buffers[i].size; - } - } + int ndim = 0; + std::vector shape(0); + bool trivial_broadcast = broadcast(buffers, ndim, shape); + + size_t size = 1; std::vector strides(ndim); if (ndim > 0) { strides[ndim-1] = sizeof(Return); - for (int i=ndim-1; i>0; --i) - strides[i-1] = strides[i] * shape[i]; + for (int i = ndim - 1; i > 0; --i) { + strides[i - 1] = strides[i] * shape[i]; + size *= shape[i]; + } + size *= shape[0]; } - /* Check if the parameters are actually compatible */ - for (size_t i=0; i(buffers, buf, index); + } return result; } + + template + void apply_broadcast(const std::array& buffers, buffer_info& output, index_sequence) { + using input_iterator = multi_array_iterator; + using output_iterator = array_iterator; + + input_iterator input_iter(buffers, output.shape); + output_iterator output_end = array_end(output); + + for (output_iterator iter = array_begin(output); iter != output_end; ++iter, ++input_iter) { + *iter = f((input_iter.template data())...); + } + } }; NAMESPACE_END(detail) diff --git a/setup.py b/setup.py index ae1ca8bf2..8027088a5 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,9 @@ setup( 'include/pybind11/functional.h', 'include/pybind11/operators.h', 'include/pybind11/pytypes.h', - 'include/pybind11/typeid.h' + 'include/pybind11/typeid.h', + 'include/pybind11/short_vector.h', + 'include/pybind11/array_iterator.h' ], classifiers=[ 'Development Status :: 5 - Production/Stable',