From f41e1125c6fc878a5974e9092a8a6d8a6c55d6cf Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Thu, 11 Feb 2016 18:33:52 +0100 Subject: [PATCH] Broadcasting like numpy --- include/pybind11/array_iterator.h | 108 ++++++++++++++++-------------- include/pybind11/numpy.h | 101 ++++++++++++++++++++++------ include/pybind11/short_vector.h | 4 +- 3 files changed, 138 insertions(+), 75 deletions(-) diff --git a/include/pybind11/array_iterator.h b/include/pybind11/array_iterator.h index bbf36910c..00a74ef60 100644 --- a/include/pybind11/array_iterator.h +++ b/include/pybind11/array_iterator.h @@ -16,29 +16,49 @@ #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif +/* + WARNING: These iterators are not a binding to numpy.nditer, there convenient classes for broadcasting in vectorize +*/ + NAMESPACE_BEGIN(pybind11) +template +using array_iterator = typename std::add_pointer::type; + +template +array_iterator array_begin(const buffer_info& buffer) +{ + return array_iterator(reinterpret_cast(buffer.ptr)); +} + +template +array_iterator array_end(const buffer_info& buffer) +{ + return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); +} + NAMESPACE_BEGIN(detail) -template +template class common_iterator { public: - common_iterator() : p_ptr(0), m_strides() {} - common_iterator(void* ptr, const S& strides, const std::vector& shape) - : p_ptr(reinterpret_cast(ptr), m_strides(strides.size()) - { - using value_type = typename S::value_type; - using size_type = typename S::size_type; + using container_type = C; + using value_type = typename container_type::value_type; + using size_type = typename container_type::size_type; + common_iterator() : p_ptr(0), m_strides() {} + common_iterator(void* ptr, const container_type& strides, const std::vector& shape) + : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) + { m_strides.back() = static_cast(strides.back()); - for (size_type i = m_strides.size() - 1; --i; i != 0) + for (size_type i = m_strides.size() - 1; i != 0; --i) { size_type j = i - 1; value_type s = static_cast(shape[i]); - m_strides[j] = strides[j] - ((s - 1) * strides[i] + m_strides[i]); + m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; } } @@ -55,59 +75,44 @@ public: private: char* p_ptr; - S m_strides; + container_type m_strides; }; -template -struct rebind_container_impl; - -template -struct rebind_container_impl, U> -{ - typedef std::vector type; -}; - -template -struct rebind_container_impl, U> -{ - typedef short_vector type; -}; - -template -using rebind_container = typename rebind_container_impl; - NAMESPACE_END(detail) -template +template class multi_array_iterator { - using int_container = rebind_container; - public: + using container_type = C; + multi_array_iterator(const std::array& buffers, - const std::vector& strides, const std::vector& shape) : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() { - std::copy(shape.begin(), shape.end(), m_shape.begin()); + // Maual copy to avoid conversion warning if using std::copy + for (size_t i = 0; i < shape.size(); ++i) + { + m_shape[i] = static_cast(shape[i]); + } - int_container new_strides(strides.size()); + container_type strides(shape.size()); for (size_t i = 0; i < N; ++i) { - init_common_iterator(buffers[i], strides, shape, m_common_iter[i], new_strides); + init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); } } multi_array_iterator& operator++() { - for (size_t j = m_index.size(); --j; j != 0) + for (size_t j = m_index.size(); j != 0; --j) { size_t i = j - 1; if (++m_index[i] != m_shape[i]) { - increment_common_iterator(m_index[i]); + increment_common_iterator(i); break; } else @@ -126,42 +131,43 @@ public: private: - void init_common_iterator(const buffer_info& buffer, const std::vector& strides, const std::vector& shape, common_iter& iterator, int_container& new_strides) + using common_iter = detail::common_iterator; + + void init_common_iterator(const buffer_info& buffer, const std::vector& shape, common_iter& iterator, container_type& strides) { auto buffer_shape_iter = buffer.shape.rbegin(); - auto strides_iter = strides.rbegin(); + auto buffer_strides_iter = buffer.strides.rbegin(); auto shape_iter = shape.rbegin(); - auto new_strides_iter = new_strides.rbegin(); + auto strides_iter = strides.rbegin(); while (buffer_shape_iter != buffer.shape.rend()) { if (*shape_iter == *buffer_shape_iter) - *new_stride_iter = static_cast(*strides_iter); + *strides_iter = static_cast(*buffer_strides_iter); else - *new_strides_iter = 0; + *strides_iter = 0; ++buffer_shape_iter; - ++strides_iter; + ++buffer_strides_iter; ++shape_iter; - ++new_strides_iter; + ++strides_iter; } - std::fill(new_strides_iter, strides.rend(), 0); + std::fill(strides_iter, strides.rend(), 0); - iterator = common_iter(buffer.ptr, new_strides, shape); + iterator = common_iter(buffer.ptr, strides, shape); } - void increment_common_iterator(int dim) + void increment_common_iterator(size_t dim) { - std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](const common_iter& iter) + std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](common_iter& iter) { iter.increment(dim); }); } - S m_shape; - S m_index; - using common_iter = detail::common_iterator; + container_type m_shape; + container_type m_index; std::array m_common_iterator; }; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0aaac96ef..6ed365413 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -11,6 +11,8 @@ #include "pybind11.h" #include "complex.h" +#include "array_iterator.h" +#include #if defined(_MSC_VER) #pragma warning(push) @@ -150,6 +152,37 @@ template struct handle_type_name> { static PYBIND11_DESCR name() { return _("array[") + type_caster::name() + _("]"); } }; +template +bool broadcast(const std::array& buffers, int& ndim, std::vector& shape) +{ + ndim = std::accumulate(buffers.begin(), buffers.end(), 0, [](int res, const buffer_info& buf) + { + return std::max(res, buf.ndim); + }); + + shape = std::vector(static_cast(ndim), 1); + bool trivial_broadcast = true; + for (size_t i = 0; i < N; ++i) + { + auto res_iter = shape.rbegin(); + bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) + { + if (*res_iter == 1) + { + *res_iter = *shape_iter; + } + else if ((*shape_iter != 1) && (*res_iter != *shape_iter)) + { + pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); + } + i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter); + } + trivial_broadcast = trivial_broadcast && i_trivial_broadcast; + } + return trivial_broadcast; +} + template struct vectorize_helper { typename std::remove_reference::type f; @@ -161,32 +194,29 @@ struct vectorize_helper { return run(args..., typename make_index_sequence::type()); } - template object run(array_t&... args, index_sequence) { + template object run(array_t&... args, index_sequence index) { /* Request buffers from all parameters */ const size_t N = sizeof...(Args); + constexpr size_t SMALL_DIM = 4; + std::array buffers {{ args.request()... }}; /* Determine dimensions parameters of output array */ - int ndim = 0; size_t size = 0; - std::vector shape; - for (size_t i=0; i size) { - ndim = buffers[i].ndim; - shape = buffers[i].shape; - size = buffers[i].size; - } - } + int ndim = 0; + std::vector shape(0); + bool trivial_broadcast = broadcast(buffers, ndim, shape); + + size_t size = 1; std::vector strides(ndim); if (ndim > 0) { strides[ndim-1] = sizeof(Return); - for (int i=ndim-1; i>0; --i) - strides[i-1] = strides[i] * shape[i]; + for (int i = ndim - 1; i > 0; --i) + { + strides[i - 1] = strides[i] * shape[i]; + size *= shape[i]; + } } - - /* Check if the parameters are actually compatible */ - for (size_t i=0; i, N, Index...>(buffers, buf, index); + } + else + { + apply_broadcast, N, Index...>(buffers, buf, index); + } return result; } + + template + void apply_broadcast(const std::array& buffers, buffer_info& output, index_sequence) + { + using input_iterator = multi_array_iterator; + using output_iterator = array_iterator; + + input_iterator input_iter(buffers, output.shape); + output_iterator output_end = array_end(output); + + for (output_iterator iter = array_begin(output); iter != output_end; ++iter, ++input_iter) + { + *iter = f(input_iter.data()...); + } + + } }; NAMESPACE_END(detail) diff --git a/include/pybind11/short_vector.h b/include/pybind11/short_vector.h index 62d97e4af..2ac60d5c1 100644 --- a/include/pybind11/short_vector.h +++ b/include/pybind11/short_vector.h @@ -26,7 +26,7 @@ class short_vector public: - using std::array data_type; + using data_type = std::array; using value_type = typename data_type::value_type; using size_type = typename data_type::size_type; using difference_type = typename data_type::difference_type; @@ -47,7 +47,7 @@ public: } size_type size() const noexcept { return m_size; } - constexpr size_type max_size() noexcept { return m_data.max_size(); } + constexpr size_type max_size() const noexcept { return m_data.max_size(); } bool empty() const noexcept { return m_size == 0; } void resize(size_type size) { m_size = size; }