NumPy-style broadcasting support in pybind11::vectorize

This commit is contained in:
Johan Mabille 2016-02-11 10:47:11 +01:00 committed by Wenzel Jakob
parent f8584b630b
commit 1dc960c37f
4 changed files with 232 additions and 21 deletions

View File

@ -22,3 +22,8 @@ for f in [vectorized_func, vectorized_func2]:
print(f(np.array([1, 3]), np.array([2, 4]), 3)) print(f(np.array([1, 3]), np.array([2, 4]), 3))
print(f(np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3)) print(f(np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3))
print(np.array([[1, 3, 5], [7, 9, 11]])* np.array([[2, 4, 6], [8, 10, 12]])*3) print(np.array([[1, 3, 5], [7, 9, 11]])* np.array([[2, 4, 6], [8, 10, 12]])*3)
print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2))
print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([2, 3, 4])* 2)
print(f(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2))
print(np.array([[1, 2, 3], [4, 5, 6]])* np.array([[2], [3]])* 2)

View File

@ -16,6 +16,26 @@ my_func(x:int=11, y:float=12, z:float=3)
[ 168. 270. 396.]] [ 168. 270. 396.]]
[[ 6 36 90] [[ 6 36 90]
[168 270 396]] [168 270 396]]
my_func(x:int=1, y:float=2, z:float=2)
my_func(x:int=2, y:float=3, z:float=2)
my_func(x:int=3, y:float=4, z:float=2)
my_func(x:int=4, y:float=2, z:float=2)
my_func(x:int=5, y:float=3, z:float=2)
my_func(x:int=6, y:float=4, z:float=2)
[[ 4. 12. 24.]
[ 16. 30. 48.]]
[[ 4 12 24]
[16 30 48]]
my_func(x:int=1, y:float=2, z:float=2)
my_func(x:int=2, y:float=2, z:float=2)
my_func(x:int=3, y:float=2, z:float=2)
my_func(x:int=4, y:float=3, z:float=2)
my_func(x:int=5, y:float=3, z:float=2)
my_func(x:int=6, y:float=3, z:float=2)
[[ 4. 8. 12.]
[ 24. 30. 36.]]
[[ 4 8 12]
[24 30 36]]
my_func(x:int=1, y:float=2, z:float=3) my_func(x:int=1, y:float=2, z:float=3)
6.0 6.0
my_func(x:int=1, y:float=2, z:float=3) my_func(x:int=1, y:float=2, z:float=3)
@ -33,3 +53,23 @@ my_func(x:int=11, y:float=12, z:float=3)
[ 168. 270. 396.]] [ 168. 270. 396.]]
[[ 6 36 90] [[ 6 36 90]
[168 270 396]] [168 270 396]]
my_func(x:int=1, y:float=2, z:float=2)
my_func(x:int=2, y:float=3, z:float=2)
my_func(x:int=3, y:float=4, z:float=2)
my_func(x:int=4, y:float=2, z:float=2)
my_func(x:int=5, y:float=3, z:float=2)
my_func(x:int=6, y:float=4, z:float=2)
[[ 4. 12. 24.]
[ 16. 30. 48.]]
[[ 4 12 24]
[16 30 48]]
my_func(x:int=1, y:float=2, z:float=2)
my_func(x:int=2, y:float=2, z:float=2)
my_func(x:int=3, y:float=2, z:float=2)
my_func(x:int=4, y:float=3, z:float=2)
my_func(x:int=5, y:float=3, z:float=2)
my_func(x:int=6, y:float=3, z:float=2)
[[ 4. 8. 12.]
[ 24. 30. 36.]]
[[ 4 8 12]
[24 30 36]]

View File

@ -11,6 +11,8 @@
#include "pybind11.h" #include "pybind11.h"
#include "complex.h" #include "complex.h"
#include <numeric>
#include <algorithm>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(push) #pragma warning(push)
@ -146,10 +148,158 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_);
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
template <class T>
using array_iterator = typename std::add_pointer<T>::type;
template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}
template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}
class common_iterator {
public:
using container_type = std::vector<size_t>;
using value_type = container_type::value_type;
using size_type = container_type::size_type;
common_iterator() : p_ptr(0), m_strides() {}
common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
: p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
m_strides.back() = static_cast<value_type>(strides.back());
for (size_type i = m_strides.size() - 1; i != 0; --i) {
size_type j = i - 1;
value_type s = static_cast<value_type>(shape[i]);
m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
}
}
void increment(size_type dim) {
p_ptr += m_strides[dim];
}
void* data() const {
return p_ptr;
}
private:
char* p_ptr;
container_type m_strides;
};
template <size_t N>
class multi_array_iterator {
public:
using container_type = std::vector<size_t>;
multi_array_iterator(const std::array<buffer_info, N>& buffers,
const std::vector<size_t>& shape)
: m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() {
// Manual copy to avoid conversion warning if using std::copy
for (size_t i = 0; i < shape.size(); ++i) {
m_shape[i] = static_cast<container_type::value_type>(shape[i]);
}
container_type strides(shape.size());
for (size_t i = 0; i < N; ++i) {
init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
}
}
multi_array_iterator& operator++() {
for (size_t j = m_index.size(); j != 0; --j) {
size_t i = j - 1;
if (++m_index[i] != m_shape[i]) {
increment_common_iterator(i);
break;
}
else {
m_index[i] = 0;
}
}
return *this;
}
template <size_t K, class T>
const T& data() const {
return *reinterpret_cast<T*>(m_common_iterator[K].data());
}
private:
using common_iter = common_iterator;
void init_common_iterator(const buffer_info& buffer, const std::vector<size_t>& shape, common_iter& iterator, container_type& strides) {
auto buffer_shape_iter = buffer.shape.rbegin();
auto buffer_strides_iter = buffer.strides.rbegin();
auto shape_iter = shape.rbegin();
auto strides_iter = strides.rbegin();
while (buffer_shape_iter != buffer.shape.rend()) {
if (*shape_iter == *buffer_shape_iter)
*strides_iter = static_cast<int>(*buffer_strides_iter);
else
*strides_iter = 0;
++buffer_shape_iter;
++buffer_strides_iter;
++shape_iter;
++strides_iter;
}
std::fill(strides_iter, strides.rend(), 0);
iterator = common_iter(buffer.ptr, strides, shape);
}
void increment_common_iterator(size_t dim) {
std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](common_iter& iter) {
iter.increment(dim);
});
}
container_type m_shape;
container_type m_index;
std::array<common_iter, N> m_common_iterator;
};
template <typename T> struct handle_type_name<array_t<T>> { template <typename T> struct handle_type_name<array_t<T>> {
static PYBIND11_DESCR name() { return _("array[") + type_caster<T>::name() + _("]"); } static PYBIND11_DESCR name() { return _("array[") + type_caster<T>::name() + _("]"); }
}; };
template <size_t N>
bool broadcast(const std::array<buffer_info, N>& buffers, int& ndim, std::vector<size_t>& shape) {
ndim = std::accumulate(buffers.begin(), buffers.end(), 0, [](int res, const buffer_info& buf) {
return std::max(res, buf.ndim);
});
shape = std::vector<size_t>(static_cast<size_t>(ndim), 1);
bool trivial_broadcast = true;
for (size_t i = 0; i < N; ++i) {
auto res_iter = shape.rbegin();
bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {
if (*res_iter == 1) {
*res_iter = *shape_iter;
}
else if ((*shape_iter != 1) && (*res_iter != *shape_iter)) {
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
}
i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
}
trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
}
return trivial_broadcast;
}
template <typename Func, typename Return, typename... Args> template <typename Func, typename Return, typename... Args>
struct vectorize_helper { struct vectorize_helper {
typename std::remove_reference<Func>::type f; typename std::remove_reference<Func>::type f;
@ -161,32 +311,27 @@ struct vectorize_helper {
return run(args..., typename make_index_sequence<sizeof...(Args)>::type()); return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
} }
template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...>) { template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...> index) {
/* Request buffers from all parameters */ /* Request buffers from all parameters */
const size_t N = sizeof...(Args); const size_t N = sizeof...(Args);
std::array<buffer_info, N> buffers {{ args.request()... }}; std::array<buffer_info, N> buffers {{ args.request()... }};
/* Determine dimensions parameters of output array */ /* Determine dimensions parameters of output array */
int ndim = 0; size_t size = 0; int ndim = 0;
std::vector<size_t> shape; std::vector<size_t> shape(0);
for (size_t i=0; i<N; ++i) { bool trivial_broadcast = broadcast(buffers, ndim, shape);
if (buffers[i].size > size) {
ndim = buffers[i].ndim; size_t size = 1;
shape = buffers[i].shape;
size = buffers[i].size;
}
}
std::vector<size_t> strides(ndim); std::vector<size_t> strides(ndim);
if (ndim > 0) { if (ndim > 0) {
strides[ndim-1] = sizeof(Return); strides[ndim-1] = sizeof(Return);
for (int i=ndim-1; i>0; --i) for (int i = ndim - 1; i > 0; --i) {
strides[i-1] = strides[i] * shape[i]; strides[i - 1] = strides[i] * shape[i];
size *= shape[i];
}
size *= shape[0];
} }
/* Check if the parameters are actually compatible */
for (size_t i=0; i<N; ++i)
if (buffers[i].size != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
if (size == 1) if (size == 1)
return cast(f(*((Args *) buffers[Index].ptr)...)); return cast(f(*((Args *) buffers[Index].ptr)...));
@ -198,14 +343,33 @@ struct vectorize_helper {
buffer_info buf = result.request(); buffer_info buf = result.request();
Return *output = (Return *) buf.ptr; Return *output = (Return *) buf.ptr;
if(trivial_broadcast) {
/* Call the function */ /* Call the function */
for (size_t i=0; i<size; ++i) for (size_t i=0; i<size; ++i) {
output[i] = f((buffers[Index].size == 1 output[i] = f((buffers[Index].size == 1
? *((Args *) buffers[Index].ptr) ? *((Args *) buffers[Index].ptr)
: ((Args *) buffers[Index].ptr)[i])...); : ((Args *) buffers[Index].ptr)[i])...);
}
}
else {
apply_broadcast<N, Index...>(buffers, buf, index);
}
return result; return result;
} }
template <size_t N, size_t... Index>
void apply_broadcast(const std::array<buffer_info, N>& buffers, buffer_info& output, index_sequence<Index...>) {
using input_iterator = multi_array_iterator<N>;
using output_iterator = array_iterator<Return>;
input_iterator input_iter(buffers, output.shape);
output_iterator output_end = array_end<Return>(output);
for (output_iterator iter = array_begin<Return>(output); iter != output_end; ++iter, ++input_iter) {
*iter = f((input_iter.template data<Index, Args>())...);
}
}
}; };
NAMESPACE_END(detail) NAMESPACE_END(detail)

View File

@ -27,7 +27,9 @@ setup(
'include/pybind11/functional.h', 'include/pybind11/functional.h',
'include/pybind11/operators.h', 'include/pybind11/operators.h',
'include/pybind11/pytypes.h', 'include/pybind11/pytypes.h',
'include/pybind11/typeid.h' 'include/pybind11/typeid.h',
'include/pybind11/short_vector.h',
'include/pybind11/array_iterator.h'
], ],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',