diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 166cbd06a..0c703b8cc 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1052,11 +1052,14 @@ private: std::array m_common_iterator; }; -// Populates the shape and number of dimensions for the set of buffers. Returns true if the -// broadcast is "trivial"--that is, has each buffer being either a singleton or a full-size, -// C-contiguous storage buffer. +enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; + +// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial +// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a +// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage +// buffer; returns `non_trivial` otherwise. template -bool broadcast(const std::array &buffers, size_t &ndim, std::vector &shape) { +broadcast_trivial broadcast(const std::array &buffers, size_t &ndim, std::vector &shape) { ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) { return std::max(res, buf.ndim); }); @@ -1064,14 +1067,12 @@ bool broadcast(const std::array &buffers, size_t &ndim, std::vec shape.clear(); shape.resize(ndim, 1); - bool trivial_broadcast = true; + // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or + // the full size). for (size_t i = 0; i < N; ++i) { - trivial_broadcast = trivial_broadcast && (buffers[i].size == 1 || buffers[i].ndim == ndim); - size_t expect_stride = buffers[i].itemsize; auto res_iter = shape.rbegin(); - auto stride_iter = buffers[i].strides.rbegin(); - auto shape_iter = buffers[i].shape.rbegin(); - while (shape_iter != buffers[i].shape.rend()) { + auto end = buffers[i].shape.rend(); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { const auto &dim_size_in = *shape_iter; auto &dim_size_out = *res_iter; @@ -1080,21 +1081,54 @@ bool broadcast(const std::array &buffers, size_t &ndim, std::vec dim_size_out = dim_size_in; else if (dim_size_in != 1 && dim_size_in != dim_size_out) pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); - - if (trivial_broadcast && buffers[i].size > 1) { - if (dim_size_in == dim_size_out && expect_stride == *stride_iter) { - expect_stride *= dim_size_in; - ++stride_iter; - } else { - trivial_broadcast = false; - } - } - - ++shape_iter; - ++res_iter; } } - return trivial_broadcast; + + bool trivial_broadcast_c = true; + bool trivial_broadcast_f = true; + for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { + if (buffers[i].size == 1) + continue; + + // Require the same number of dimensions: + if (buffers[i].ndim != ndim) + return broadcast_trivial::non_trivial; + + // Require all dimensions be full-size: + if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) + return broadcast_trivial::non_trivial; + + // Check for C contiguity (but only if previous inputs were also C contiguous) + if (trivial_broadcast_c) { + size_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.crend(); + for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); + trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_c = false; + } + } + + // Check for Fortran contiguity (if previous inputs were also F contiguous) + if (trivial_broadcast_f) { + size_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.cend(); + for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); + trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_f = false; + } + } + } + + return + trivial_broadcast_c ? broadcast_trivial::c_trivial : + trivial_broadcast_f ? broadcast_trivial::f_trivial : + broadcast_trivial::non_trivial; } template @@ -1116,32 +1150,42 @@ struct vectorize_helper { /* Determine dimensions parameters of output array */ size_t ndim = 0; std::vector shape(0); - bool trivial_broadcast = broadcast(buffers, ndim, shape); + auto trivial = broadcast(buffers, ndim, shape); size_t size = 1; std::vector strides(ndim); if (ndim > 0) { - strides[ndim-1] = sizeof(Return); - for (size_t i = ndim - 1; i > 0; --i) { - strides[i - 1] = strides[i] * shape[i]; - size *= shape[i]; + if (trivial == broadcast_trivial::f_trivial) { + strides[0] = sizeof(Return); + for (size_t i = 1; i < ndim; ++i) { + strides[i] = strides[i - 1] * shape[i - 1]; + size *= shape[i - 1]; + } + size *= shape[ndim - 1]; + } + else { + strides[ndim-1] = sizeof(Return); + for (size_t i = ndim - 1; i > 0; --i) { + strides[i - 1] = strides[i] * shape[i]; + size *= shape[i]; + } + size *= shape[0]; } - size *= shape[0]; } if (size == 1) return cast(f(*reinterpret_cast(buffers[Index].ptr)...)); - array_t result(shape, strides); + array_t result(shape, strides); auto buf = result.request(); auto output = (Return *) buf.ptr; - if (trivial_broadcast) { - /* Call the function */ + /* Call the function */ + if (trivial == broadcast_trivial::non_trivial) { + apply_broadcast(buffers, buf, index); + } else { for (size_t i = 0; i < size; ++i) output[i] = f((reinterpret_cast(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...); - } else { - apply_broadcast(buffers, buf, index); } return result; diff --git a/tests/test_numpy_vectorize.cpp b/tests/test_numpy_vectorize.cpp index e5adff800..8e951c6e1 100644 --- a/tests/test_numpy_vectorize.cpp +++ b/tests/test_numpy_vectorize.cpp @@ -41,6 +41,10 @@ test_initializer numpy_vectorize([](py::module &m) { // Internal optimization test for whether the input is trivially broadcastable: + py::enum_(m, "trivial") + .value("f_trivial", py::detail::broadcast_trivial::f_trivial) + .value("c_trivial", py::detail::broadcast_trivial::c_trivial) + .value("non_trivial", py::detail::broadcast_trivial::non_trivial); m.def("vectorized_is_trivial", []( py::array_t arg1, py::array_t arg2, diff --git a/tests/test_numpy_vectorize.py b/tests/test_numpy_vectorize.py index 9a8c6ab94..7ae777227 100644 --- a/tests/test_numpy_vectorize.py +++ b/tests/test_numpy_vectorize.py @@ -24,6 +24,20 @@ def test_vectorize(capture): my_func(x:int=1, y:float=2, z:float=3) my_func(x:int=3, y:float=4, z:float=3) """ + with capture: + a = np.array([[1, 2], [3, 4]], order='F') + b = np.array([[10, 20], [30, 40]], order='F') + c = 3 + result = f(a, b, c) + assert np.allclose(result, a * b * c) + assert result.flags.f_contiguous + # All inputs are F order and full or singletons, so we the result is in col-major order: + assert capture == """ + my_func(x:int=1, y:float=10, z:float=3) + my_func(x:int=3, y:float=30, z:float=3) + my_func(x:int=2, y:float=20, z:float=3) + my_func(x:int=4, y:float=40, z:float=3) + """ with capture: a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3 assert np.allclose(f(a, b, c), a * b * c) @@ -105,29 +119,43 @@ def test_docs(doc): def test_trivial_broadcasting(): - from pybind11_tests import vectorized_is_trivial + from pybind11_tests import vectorized_is_trivial, trivial, vectorized_func - assert vectorized_is_trivial(1, 2, 3) - assert vectorized_is_trivial(np.array(1), np.array(2), 3) - assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) - assert vectorized_is_trivial( + assert vectorized_is_trivial(1, 2, 3) == trivial.c_trivial + assert vectorized_is_trivial(np.array(1), np.array(2), 3) == trivial.c_trivial + assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) == trivial.c_trivial + assert trivial.c_trivial == vectorized_is_trivial( np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3) - assert not vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) - assert not vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) + assert vectorized_is_trivial( + np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) == trivial.non_trivial + assert vectorized_is_trivial( + np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) == trivial.non_trivial z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype='int32') z2 = np.array(z1, dtype='float32') z3 = np.array(z1, dtype='float64') - assert vectorized_is_trivial(z1, z2, z3) - assert not vectorized_is_trivial(z1[::2, ::2], 1, 1) - assert vectorized_is_trivial(1, 1, z1[::2, ::2]) - assert not vectorized_is_trivial(1, 1, z3[::2, ::2]) - assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4]) + assert vectorized_is_trivial(z1, z2, z3) == trivial.c_trivial + assert vectorized_is_trivial(1, z2, z3) == trivial.c_trivial + assert vectorized_is_trivial(z1, 1, z3) == trivial.c_trivial + assert vectorized_is_trivial(z1, z2, 1) == trivial.c_trivial + assert vectorized_is_trivial(z1[::2, ::2], 1, 1) == trivial.non_trivial + assert vectorized_is_trivial(1, 1, z1[::2, ::2]) == trivial.c_trivial + assert vectorized_is_trivial(1, 1, z3[::2, ::2]) == trivial.non_trivial + assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4]) == trivial.c_trivial y1 = np.array(z1, order='F') y2 = np.array(y1) y3 = np.array(y1) - assert not vectorized_is_trivial(y1, y2, y3) - assert not vectorized_is_trivial(y1, z2, z3) - assert not vectorized_is_trivial(y1, 1, 1) + assert vectorized_is_trivial(y1, y2, y3) == trivial.f_trivial + assert vectorized_is_trivial(y1, 1, 1) == trivial.f_trivial + assert vectorized_is_trivial(1, y2, 1) == trivial.f_trivial + assert vectorized_is_trivial(1, 1, y3) == trivial.f_trivial + assert vectorized_is_trivial(y1, z2, 1) == trivial.non_trivial + assert vectorized_is_trivial(z1[1::4, 1::4], y2, 1) == trivial.f_trivial + assert vectorized_is_trivial(y1[1::4, 1::4], z2, 1) == trivial.c_trivial + + assert vectorized_func(z1, z2, z3).flags.c_contiguous + assert vectorized_func(y1, y2, y3).flags.f_contiguous + assert vectorized_func(z1, 1, 1).flags.c_contiguous + assert vectorized_func(1, y2, 1).flags.f_contiguous + assert vectorized_func(z1[1::4, 1::4], y2, 1).flags.f_contiguous + assert vectorized_func(y1[1::4, 1::4], z2, 1).flags.c_contiguous