mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-24 01:19:23 +00:00
Broadcasting like numpy
This commit is contained in:
parent
7fc85f1d7c
commit
f41e1125c6
@ -16,29 +16,49 @@
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
/*
|
||||
WARNING: These iterators are not a binding to numpy.nditer, there convenient classes for broadcasting in vectorize
|
||||
*/
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
template <class T>
|
||||
using array_iterator = typename std::add_pointer<T>::type;
|
||||
|
||||
template <class T>
|
||||
array_iterator<T> array_begin(const buffer_info& buffer)
|
||||
{
|
||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
array_iterator<T> array_end(const buffer_info& buffer)
|
||||
{
|
||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
|
||||
}
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <class S>
|
||||
template <class C>
|
||||
class common_iterator
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
common_iterator() : p_ptr(0), m_strides() {}
|
||||
common_iterator(void* ptr, const S& strides, const std::vector<size_t>& shape)
|
||||
: p_ptr(reinterpret_cast<char*>(ptr), m_strides(strides.size())
|
||||
{
|
||||
using value_type = typename S::value_type;
|
||||
using size_type = typename S::size_type;
|
||||
using container_type = C;
|
||||
using value_type = typename container_type::value_type;
|
||||
using size_type = typename container_type::size_type;
|
||||
|
||||
common_iterator() : p_ptr(0), m_strides() {}
|
||||
common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
|
||||
: p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size())
|
||||
{
|
||||
m_strides.back() = static_cast<value_type>(strides.back());
|
||||
for (size_type i = m_strides.size() - 1; --i; i != 0)
|
||||
for (size_type i = m_strides.size() - 1; i != 0; --i)
|
||||
{
|
||||
size_type j = i - 1;
|
||||
value_type s = static_cast<value_type>(shape[i]);
|
||||
m_strides[j] = strides[j] - ((s - 1) * strides[i] + m_strides[i]);
|
||||
m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,59 +75,44 @@ public:
|
||||
private:
|
||||
|
||||
char* p_ptr;
|
||||
S m_strides;
|
||||
container_type m_strides;
|
||||
};
|
||||
|
||||
template <class C, class U>
|
||||
struct rebind_container_impl;
|
||||
|
||||
template <class T, class A, class U>
|
||||
struct rebind_container_impl<std::vector<T, A>, U>
|
||||
{
|
||||
typedef std::vector<U, A> type;
|
||||
};
|
||||
|
||||
template <class T, size_t N, class U>
|
||||
struct rebind_container_impl<short_vector<T, N>, U>
|
||||
{
|
||||
typedef short_vector<U, N> type;
|
||||
};
|
||||
|
||||
template <class C, class U>
|
||||
using rebind_container = typename rebind_container_impl<C, U>;
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
template <class S, size_t N>
|
||||
template <class C, size_t N>
|
||||
class multi_array_iterator
|
||||
{
|
||||
|
||||
using int_container = rebind_container<S, int>;
|
||||
|
||||
public:
|
||||
|
||||
using container_type = C;
|
||||
|
||||
multi_array_iterator(const std::array<buffer_info, N>& buffers,
|
||||
const std::vector<size_t>& strides,
|
||||
const std::vector<size_t>& shape)
|
||||
: m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator()
|
||||
{
|
||||
std::copy(shape.begin(), shape.end(), m_shape.begin());
|
||||
// Maual copy to avoid conversion warning if using std::copy
|
||||
for (size_t i = 0; i < shape.size(); ++i)
|
||||
{
|
||||
m_shape[i] = static_cast<typename container_type::value_type>(shape[i]);
|
||||
}
|
||||
|
||||
int_container new_strides(strides.size());
|
||||
container_type strides(shape.size());
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
{
|
||||
init_common_iterator(buffers[i], strides, shape, m_common_iter[i], new_strides);
|
||||
init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
|
||||
}
|
||||
}
|
||||
|
||||
multi_array_iterator& operator++()
|
||||
{
|
||||
for (size_t j = m_index.size(); --j; j != 0)
|
||||
for (size_t j = m_index.size(); j != 0; --j)
|
||||
{
|
||||
size_t i = j - 1;
|
||||
if (++m_index[i] != m_shape[i])
|
||||
{
|
||||
increment_common_iterator(m_index[i]);
|
||||
increment_common_iterator(i);
|
||||
break;
|
||||
}
|
||||
else
|
||||
@ -126,42 +131,43 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
void init_common_iterator(const buffer_info& buffer, const std::vector<size_t>& strides, const std::vector<size_t>& shape, common_iter& iterator, int_container& new_strides)
|
||||
using common_iter = detail::common_iterator<container_type>;
|
||||
|
||||
void init_common_iterator(const buffer_info& buffer, const std::vector<size_t>& shape, common_iter& iterator, container_type& strides)
|
||||
{
|
||||
auto buffer_shape_iter = buffer.shape.rbegin();
|
||||
auto strides_iter = strides.rbegin();
|
||||
auto buffer_strides_iter = buffer.strides.rbegin();
|
||||
auto shape_iter = shape.rbegin();
|
||||
auto new_strides_iter = new_strides.rbegin();
|
||||
auto strides_iter = strides.rbegin();
|
||||
|
||||
while (buffer_shape_iter != buffer.shape.rend())
|
||||
{
|
||||
if (*shape_iter == *buffer_shape_iter)
|
||||
*new_stride_iter = static_cast<int>(*strides_iter);
|
||||
*strides_iter = static_cast<int>(*buffer_strides_iter);
|
||||
else
|
||||
*new_strides_iter = 0;
|
||||
*strides_iter = 0;
|
||||
|
||||
++buffer_shape_iter;
|
||||
++strides_iter;
|
||||
++buffer_strides_iter;
|
||||
++shape_iter;
|
||||
++new_strides_iter;
|
||||
++strides_iter;
|
||||
}
|
||||
|
||||
std::fill(new_strides_iter, strides.rend(), 0);
|
||||
std::fill(strides_iter, strides.rend(), 0);
|
||||
|
||||
iterator = common_iter(buffer.ptr, new_strides, shape);
|
||||
iterator = common_iter(buffer.ptr, strides, shape);
|
||||
}
|
||||
|
||||
void increment_common_iterator(int dim)
|
||||
void increment_common_iterator(size_t dim)
|
||||
{
|
||||
std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](const common_iter& iter)
|
||||
std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](common_iter& iter)
|
||||
{
|
||||
iter.increment(dim);
|
||||
});
|
||||
}
|
||||
|
||||
S m_shape;
|
||||
S m_index;
|
||||
using common_iter = detail::common_iterator<int_container>;
|
||||
container_type m_shape;
|
||||
container_type m_index;
|
||||
std::array<common_iter, N> m_common_iterator;
|
||||
};
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
|
||||
#include "pybind11.h"
|
||||
#include "complex.h"
|
||||
#include "array_iterator.h"
|
||||
#include <numeric>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
@ -150,6 +152,37 @@ template <typename T> struct handle_type_name<array_t<T>> {
|
||||
static PYBIND11_DESCR name() { return _("array[") + type_caster<T>::name() + _("]"); }
|
||||
};
|
||||
|
||||
template <size_t N>
|
||||
bool broadcast(const std::array<buffer_info, N>& buffers, int& ndim, std::vector<size_t>& shape)
|
||||
{
|
||||
ndim = std::accumulate(buffers.begin(), buffers.end(), 0, [](int res, const buffer_info& buf)
|
||||
{
|
||||
return std::max(res, buf.ndim);
|
||||
});
|
||||
|
||||
shape = std::vector<size_t>(static_cast<size_t>(ndim), 1);
|
||||
bool trivial_broadcast = true;
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
{
|
||||
auto res_iter = shape.rbegin();
|
||||
bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
|
||||
for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter)
|
||||
{
|
||||
if (*res_iter == 1)
|
||||
{
|
||||
*res_iter = *shape_iter;
|
||||
}
|
||||
else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
|
||||
{
|
||||
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
|
||||
}
|
||||
i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
|
||||
}
|
||||
trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
|
||||
}
|
||||
return trivial_broadcast;
|
||||
}
|
||||
|
||||
template <typename Func, typename Return, typename... Args>
|
||||
struct vectorize_helper {
|
||||
typename std::remove_reference<Func>::type f;
|
||||
@ -161,32 +194,29 @@ struct vectorize_helper {
|
||||
return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
|
||||
}
|
||||
|
||||
template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...>) {
|
||||
template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...> index) {
|
||||
/* Request buffers from all parameters */
|
||||
const size_t N = sizeof...(Args);
|
||||
constexpr size_t SMALL_DIM = 4;
|
||||
|
||||
std::array<buffer_info, N> buffers {{ args.request()... }};
|
||||
|
||||
/* Determine dimensions parameters of output array */
|
||||
int ndim = 0; size_t size = 0;
|
||||
std::vector<size_t> shape;
|
||||
for (size_t i=0; i<N; ++i) {
|
||||
if (buffers[i].size > size) {
|
||||
ndim = buffers[i].ndim;
|
||||
shape = buffers[i].shape;
|
||||
size = buffers[i].size;
|
||||
}
|
||||
}
|
||||
int ndim = 0;
|
||||
std::vector<size_t> shape(0);
|
||||
bool trivial_broadcast = broadcast(buffers, ndim, shape);
|
||||
|
||||
size_t size = 1;
|
||||
std::vector<size_t> strides(ndim);
|
||||
if (ndim > 0) {
|
||||
strides[ndim-1] = sizeof(Return);
|
||||
for (int i=ndim-1; i>0; --i)
|
||||
strides[i-1] = strides[i] * shape[i];
|
||||
for (int i = ndim - 1; i > 0; --i)
|
||||
{
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
size *= shape[i];
|
||||
}
|
||||
|
||||
/* Check if the parameters are actually compatible */
|
||||
for (size_t i=0; i<N; ++i)
|
||||
if (buffers[i].size != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
|
||||
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
|
||||
}
|
||||
size *= shape[0];
|
||||
|
||||
if (size == 1)
|
||||
return cast(f(*((Args *) buffers[Index].ptr)...));
|
||||
@ -198,14 +228,41 @@ struct vectorize_helper {
|
||||
buffer_info buf = result.request();
|
||||
Return *output = (Return *) buf.ptr;
|
||||
|
||||
if(trivial_broadcast)
|
||||
{
|
||||
/* Call the function */
|
||||
for (size_t i=0; i<size; ++i)
|
||||
output[i] = f((buffers[Index].size == 1
|
||||
? *((Args *) buffers[Index].ptr)
|
||||
: ((Args *) buffers[Index].ptr)[i])...);
|
||||
}
|
||||
else if (shape.size() < SMALL_DIM)
|
||||
{
|
||||
apply_broadcast<short_vector<int>, N, Index...>(buffers, buf, index);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_broadcast<std::vector<int>, N, Index...>(buffers, buf, index);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class C, size_t N, size_t... Index>
|
||||
void apply_broadcast(const std::array<buffer_info, N>& buffers, buffer_info& output, index_sequence<Index...>)
|
||||
{
|
||||
using input_iterator = multi_array_iterator<C, N>;
|
||||
using output_iterator = array_iterator<Return>;
|
||||
|
||||
input_iterator input_iter(buffers, output.shape);
|
||||
output_iterator output_end = array_end<Return>(output);
|
||||
|
||||
for (output_iterator iter = array_begin<Return>(output); iter != output_end; ++iter, ++input_iter)
|
||||
{
|
||||
*iter = f(input_iter.data<Index, Args>()...);
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
@ -26,7 +26,7 @@ class short_vector
|
||||
|
||||
public:
|
||||
|
||||
using std::array<T, N> data_type;
|
||||
using data_type = std::array<T, N>;
|
||||
using value_type = typename data_type::value_type;
|
||||
using size_type = typename data_type::size_type;
|
||||
using difference_type = typename data_type::difference_type;
|
||||
@ -47,7 +47,7 @@ public:
|
||||
}
|
||||
|
||||
size_type size() const noexcept { return m_size; }
|
||||
constexpr size_type max_size() noexcept { return m_data.max_size(); }
|
||||
constexpr size_type max_size() const noexcept { return m_data.max_size(); }
|
||||
bool empty() const noexcept { return m_size == 0; }
|
||||
|
||||
void resize(size_type size) { m_size = size; }
|
||||
|
Loading…
Reference in New Issue
Block a user