diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index d89e4beae..f9d6acbf6 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -305,3 +305,48 @@ simply using ``vectorize``). The file :file:`tests/test_numpy_vectorize.cpp` contains a complete example that demonstrates using :func:`vectorize` in more detail. + +Direct access +============= + +For performance reasons, particularly when dealing with very large arrays, it +is often desirable to directly access array elements without internal checking +of dimensions and bounds on every access when indices are known to be already +valid. To avoid such checks, the ``array`` class and ``array_t`` template +class offer an unchecked proxy object that can be used for this unchecked +access through the ``unchecked`` and ``mutable_unchecked`` methods, +where ``N`` gives the required dimensionality of the array: + +.. code-block:: cpp + + m.def("sum_3d", [](py::array_t x) { + auto r = x.unchecked<3>(); // x must have ndim = 3; can be non-writeable + double sum = 0; + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + for (size_t k = 0; k < r.shape(2); k++) + sum += r(i, j, k); + return sum; + }); + m.def("increment_3d", [](py::array_t x) { + auto r = x.mutable_unchecked<3>(); // Will throw if ndim != 3 or flags.writeable is false + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + for (size_t k = 0; k < r.shape(2); k++) + r(i, j, k) += 1.0; + }, py::arg().noconvert()); + +To obtain the proxy from an ``array`` object, you must specify both the data +type and number of dimensions as template arguments, such as ``auto r = +myarray.mutable_unchecked()``. + +Note that the returned proxy object directly references the array's data, and +only reads its shape, strides, and writeable flag when constructed. You must +take care to ensure that the referenced array is not destroyed or reshaped for +the duration of the returned object, typically by limiting the scope of the +returned instance. + +.. seealso:: + + The file :file:`tests/test_numpy_array.cpp` contains additional examples + demonstrating the use of this feature. diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0c703b8cc..a5f68cce6 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -35,6 +35,9 @@ static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); NAMESPACE_BEGIN(pybind11) + +class array; // Forward declaration + NAMESPACE_BEGIN(detail) template struct npy_format_descriptor; @@ -232,6 +235,78 @@ template using is_pod_struct = all_of< satisfies_none_of >; +template size_t byte_offset_unsafe(const Strides &) { return 0; } +template +size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) { + return i * strides[Dim] + byte_offset_unsafe(strides, index...); +} + +/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through + * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. + */ +template +class unchecked_reference { +protected: + const unsigned char *data_; + // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to + // make large performance gains on big, nested loops. + std::array shape_, strides_; + + friend class pybind11::array; + unchecked_reference(const void *data, const size_t *shape, const size_t *strides) + : data_{reinterpret_cast(data)} { + for (size_t i = 0; i < Dims; i++) { + shape_[i] = shape[i]; + strides_[i] = strides[i]; + } + } + +public: + /** Unchecked const reference access to data at the given indices. Omiting trailing indices + * is equivalent to specifying them as 0. + */ + template const T& operator()(Ix... index) const { + static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference"); + return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, size_t{index}...)); + } + /** Unchecked const reference access to data; this operator only participates if the reference + * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + */ + template > + const T &operator[](size_t index) const { return operator()(index); } + + /// Returns the shape (i.e. size) of dimension `dim` + size_t shape(size_t dim) const { return shape_[dim]; } + + /// Returns the number of dimensions of the array + constexpr static size_t ndim() { return Dims; } +}; + +template +class unchecked_mutable_reference : public unchecked_reference { + friend class pybind11::array; + using ConstBase = unchecked_reference; + using ConstBase::ConstBase; +public: + /// Mutable, unchecked access to data at the given indices. + template T& operator()(Ix... index) { + static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference"); + return const_cast(ConstBase::operator()(index...)); + } + /** Mutable, unchecked access data at the given index; this operator only participates if the + * reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + */ + template > + T &operator[](size_t index) { return operator()(index); } +}; + +template +struct type_caster> { + static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable"); +}; +template +struct type_caster> : type_caster> {}; + NAMESPACE_END(detail) class dtype : public object { @@ -500,6 +575,31 @@ public: return offset_at(index...) / itemsize(); } + /** Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() { + if (ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_mutable_reference(mutable_data(), shape(), strides()); + } + + /** Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the + * underlying array have the `writable` flag. Use with care: the array must not be destroyed or + * reshaped for the duration of the returned object, and the caller must take care not to access + * invalid dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const { + if (ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_reference(data(), shape(), strides()); + } + /// Return a new view with all of the dimensions of length 1 removed array squeeze() { auto& api = detail::npy_api::get(); @@ -525,15 +625,9 @@ protected: template size_t byte_offset(Ix... index) const { check_dimensions(index...); - return byte_offset_unsafe(index...); + return detail::byte_offset_unsafe(strides(), size_t{index}...); } - template size_t byte_offset_unsafe(size_t i, Ix... index) const { - return i * strides()[dim] + byte_offset_unsafe(index...); - } - - template size_t byte_offset_unsafe() const { return 0; } - void check_writeable() const { if (!writeable()) throw std::domain_error("array is not writeable"); @@ -637,6 +731,25 @@ public: return *(static_cast(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); } + /** Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() { + return array::mutable_unchecked(); + } + + /** Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying + * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped + * for the duration of the returned object, and the caller must take care not to access invalid + * dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const { + return array::unchecked(); + } + /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert /// it). In case of an error, nullptr is returned and the Python error is cleared. static array_t ensure(handle h) { diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 88996443d..461c9c004 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -184,4 +184,36 @@ test_initializer numpy_array([](py::module &m) { sm.def("issue685", [](std::string) { return "string"; }); sm.def("issue685", [](py::array) { return "array"; }); sm.def("issue685", [](py::object) { return "other"; }); + + sm.def("proxy_add2", [](py::array_t a, double v) { + auto r = a.mutable_unchecked<2>(); + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + r(i, j) += v; + }, py::arg().noconvert(), py::arg()); + sm.def("proxy_init3", [](double start) { + py::array_t a({ 3, 3, 3 }); + auto r = a.mutable_unchecked<3>(); + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + for (size_t k = 0; k < r.shape(2); k++) + r(i, j, k) = start++; + return a; + }); + sm.def("proxy_init3F", [](double start) { + py::array_t a({ 3, 3, 3 }); + auto r = a.mutable_unchecked<3>(); + for (size_t k = 0; k < r.shape(2); k++) + for (size_t j = 0; j < r.shape(1); j++) + for (size_t i = 0; i < r.shape(0); i++) + r(i, j, k) = start++; + return a; + }); + sm.def("proxy_squared_L2_norm", [](py::array_t a) { + auto r = a.unchecked<1>(); + double sumsq = 0; + for (size_t i = 0; i < r.shape(0); i++) + sumsq += r[i] * r(i); // Either notation works for a 1D array + return sumsq; + }); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 7109ff386..9081f8cd0 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -339,3 +339,23 @@ def test_greedy_string_overload(): # issue 685 assert issue685("abc") == "string" assert issue685(np.array([97, 98, 99], dtype='b')) == "array" assert issue685(123) == "other" + + +def test_array_unchecked(msg): + from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm + + z1 = np.array([[1, 2], [3, 4]], dtype='float64') + proxy_add2(z1, 10) + assert np.all(z1 == [[11, 12], [13, 14]]) + + with pytest.raises(ValueError) as excinfo: + proxy_add2(np.array([1., 2, 3]), 5.0) + assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2" + + expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') + assert np.all(proxy_init3(3.0) == expect_c) + expect_f = np.transpose(expect_c) + assert np.all(proxy_init3F(3.0) == expect_f) + + assert proxy_squared_L2_norm(np.array(range(6))) == 55 + assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55