mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
array: add unchecked access via proxy object
This adds bounds-unchecked access to arrays through a `a.unchecked<Type, Dimensions>()` method. (For `array_t<T>`, the `Type` template parameter is omitted). The mutable version (which requires the array have the `writeable` flag) is available as `a.mutable_unchecked<...>()`. Specifying the Dimensions as a template parameter allows storage of an std::array; having the strides and sizes stored that way (as opposed to storing a copy of the array's strides/shape pointers) allows the compiler to make significant optimizations of the shape() method that it can't make with a pointer; testing with nested loops of the form: for (size_t i0 = 0; i0 < r.shape(0); i0++) for (size_t i1 = 0; i1 < r.shape(1); i1++) ... r(i0, i1, ...) += 1; over a 10 million element array gives around a 25% speedup (versus using a pointer) for the 1D case, 33% for 2D, and runs more than twice as fast with a 5D array.
This commit is contained in:
parent
0d765f4a7c
commit
423a49b8be
@ -305,3 +305,48 @@ simply using ``vectorize``).
|
||||
|
||||
The file :file:`tests/test_numpy_vectorize.cpp` contains a complete
|
||||
example that demonstrates using :func:`vectorize` in more detail.
|
||||
|
||||
Direct access
|
||||
=============
|
||||
|
||||
For performance reasons, particularly when dealing with very large arrays, it
|
||||
is often desirable to directly access array elements without internal checking
|
||||
of dimensions and bounds on every access when indices are known to be already
|
||||
valid. To avoid such checks, the ``array`` class and ``array_t<T>`` template
|
||||
class offer an unchecked proxy object that can be used for this unchecked
|
||||
access through the ``unchecked<N>`` and ``mutable_unchecked<N>`` methods,
|
||||
where ``N`` gives the required dimensionality of the array:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
m.def("sum_3d", [](py::array_t<double> x) {
|
||||
auto r = x.unchecked<3>(); // x must have ndim = 3; can be non-writeable
|
||||
double sum = 0;
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
for (size_t k = 0; k < r.shape(2); k++)
|
||||
sum += r(i, j, k);
|
||||
return sum;
|
||||
});
|
||||
m.def("increment_3d", [](py::array_t<double> x) {
|
||||
auto r = x.mutable_unchecked<3>(); // Will throw if ndim != 3 or flags.writeable is false
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
for (size_t k = 0; k < r.shape(2); k++)
|
||||
r(i, j, k) += 1.0;
|
||||
}, py::arg().noconvert());
|
||||
|
||||
To obtain the proxy from an ``array`` object, you must specify both the data
|
||||
type and number of dimensions as template arguments, such as ``auto r =
|
||||
myarray.mutable_unchecked<float, 2>()``.
|
||||
|
||||
Note that the returned proxy object directly references the array's data, and
|
||||
only reads its shape, strides, and writeable flag when constructed. You must
|
||||
take care to ensure that the referenced array is not destroyed or reshaped for
|
||||
the duration of the returned object, typically by limiting the scope of the
|
||||
returned instance.
|
||||
|
||||
.. seealso::
|
||||
|
||||
The file :file:`tests/test_numpy_array.cpp` contains additional examples
|
||||
demonstrating the use of this feature.
|
||||
|
@ -35,6 +35,9 @@
|
||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
class array; // Forward declaration
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
||||
|
||||
@ -232,6 +235,78 @@ template <typename T> using is_pod_struct = all_of<
|
||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||
>;
|
||||
|
||||
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
|
||||
template <size_t Dim = 0, typename Strides, typename... Ix>
|
||||
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
|
||||
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
|
||||
}
|
||||
|
||||
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
|
||||
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
|
||||
*/
|
||||
template <typename T, size_t Dims>
|
||||
class unchecked_reference {
|
||||
protected:
|
||||
const unsigned char *data_;
|
||||
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
|
||||
// make large performance gains on big, nested loops.
|
||||
std::array<size_t, Dims> shape_, strides_;
|
||||
|
||||
friend class pybind11::array;
|
||||
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
|
||||
: data_{reinterpret_cast<const unsigned char *>(data)} {
|
||||
for (size_t i = 0; i < Dims; i++) {
|
||||
shape_[i] = shape[i];
|
||||
strides_[i] = strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
|
||||
* is equivalent to specifying them as 0.
|
||||
*/
|
||||
template <typename... Ix> const T& operator()(Ix... index) const {
|
||||
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
|
||||
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
|
||||
}
|
||||
/** Unchecked const reference access to data; this operator only participates if the reference
|
||||
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
const T &operator[](size_t index) const { return operator()(index); }
|
||||
|
||||
/// Returns the shape (i.e. size) of dimension `dim`
|
||||
size_t shape(size_t dim) const { return shape_[dim]; }
|
||||
|
||||
/// Returns the number of dimensions of the array
|
||||
constexpr static size_t ndim() { return Dims; }
|
||||
};
|
||||
|
||||
template <typename T, size_t Dims>
|
||||
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
|
||||
friend class pybind11::array;
|
||||
using ConstBase = unchecked_reference<T, Dims>;
|
||||
using ConstBase::ConstBase;
|
||||
public:
|
||||
/// Mutable, unchecked access to data at the given indices.
|
||||
template <typename... Ix> T& operator()(Ix... index) {
|
||||
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
|
||||
return const_cast<T &>(ConstBase::operator()(index...));
|
||||
}
|
||||
/** Mutable, unchecked access data at the given index; this operator only participates if the
|
||||
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
T &operator[](size_t index) { return operator()(index); }
|
||||
};
|
||||
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_reference<T, Dim>> {
|
||||
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
|
||||
};
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
class dtype : public object {
|
||||
@ -500,6 +575,31 @@ public:
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
if (ndim() != Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
|
||||
* underlying array have the `writable` flag. Use with care: the array must not be destroyed or
|
||||
* reshaped for the duration of the returned object, and the caller must take care not to access
|
||||
* invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
if (ndim() != Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
|
||||
}
|
||||
|
||||
/// Return a new view with all of the dimensions of length 1 removed
|
||||
array squeeze() {
|
||||
auto& api = detail::npy_api::get();
|
||||
@ -525,15 +625,9 @@ protected:
|
||||
|
||||
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
||||
check_dimensions(index...);
|
||||
return byte_offset_unsafe(index...);
|
||||
return detail::byte_offset_unsafe(strides(), size_t{index}...);
|
||||
}
|
||||
|
||||
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
|
||||
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
|
||||
}
|
||||
|
||||
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
|
||||
|
||||
void check_writeable() const {
|
||||
if (!writeable())
|
||||
throw std::domain_error("array is not writeable");
|
||||
@ -637,6 +731,25 @@ public:
|
||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
return array::mutable_unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
|
||||
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
|
||||
* for the duration of the returned object, and the caller must take care not to access invalid
|
||||
* dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
return array::unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
||||
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
||||
static array_t ensure(handle h) {
|
||||
|
@ -184,4 +184,36 @@ test_initializer numpy_array([](py::module &m) {
|
||||
sm.def("issue685", [](std::string) { return "string"; });
|
||||
sm.def("issue685", [](py::array) { return "array"; });
|
||||
sm.def("issue685", [](py::object) { return "other"; });
|
||||
|
||||
sm.def("proxy_add2", [](py::array_t<double> a, double v) {
|
||||
auto r = a.mutable_unchecked<2>();
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
r(i, j) += v;
|
||||
}, py::arg().noconvert(), py::arg());
|
||||
sm.def("proxy_init3", [](double start) {
|
||||
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
|
||||
auto r = a.mutable_unchecked<3>();
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
for (size_t k = 0; k < r.shape(2); k++)
|
||||
r(i, j, k) = start++;
|
||||
return a;
|
||||
});
|
||||
sm.def("proxy_init3F", [](double start) {
|
||||
py::array_t<double, py::array::f_style> a({ 3, 3, 3 });
|
||||
auto r = a.mutable_unchecked<3>();
|
||||
for (size_t k = 0; k < r.shape(2); k++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
r(i, j, k) = start++;
|
||||
return a;
|
||||
});
|
||||
sm.def("proxy_squared_L2_norm", [](py::array_t<double> a) {
|
||||
auto r = a.unchecked<1>();
|
||||
double sumsq = 0;
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
sumsq += r[i] * r(i); // Either notation works for a 1D array
|
||||
return sumsq;
|
||||
});
|
||||
});
|
||||
|
@ -339,3 +339,23 @@ def test_greedy_string_overload(): # issue 685
|
||||
assert issue685("abc") == "string"
|
||||
assert issue685(np.array([97, 98, 99], dtype='b')) == "array"
|
||||
assert issue685(123) == "other"
|
||||
|
||||
|
||||
def test_array_unchecked(msg):
|
||||
from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm
|
||||
|
||||
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
|
||||
proxy_add2(z1, 10)
|
||||
assert np.all(z1 == [[11, 12], [13, 14]])
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
proxy_add2(np.array([1., 2, 3]), 5.0)
|
||||
assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2"
|
||||
|
||||
expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
|
||||
assert np.all(proxy_init3(3.0) == expect_c)
|
||||
expect_f = np.transpose(expect_c)
|
||||
assert np.all(proxy_init3F(3.0) == expect_f)
|
||||
|
||||
assert proxy_squared_L2_norm(np.array(range(6))) == 55
|
||||
assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
|
||||
|
Loading…
Reference in New Issue
Block a user