mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
array-unchecked: add runtime dimension support and array-compatible methods
The extends the previous unchecked support with the ability to determine the dimensions at runtime. This incurs a small performance hit when used (versus the compile-time fixed alternative), but is still considerably faster than the full checks on every call that happen with `.at()`/`.mutable_at()`.
This commit is contained in:
parent
423a49b8be
commit
773339f131
@ -340,12 +340,39 @@ To obtain the proxy from an ``array`` object, you must specify both the data
|
||||
type and number of dimensions as template arguments, such as ``auto r =
|
||||
myarray.mutable_unchecked<float, 2>()``.
|
||||
|
||||
If the number of dimensions is not known at compile time, you can omit the
|
||||
dimensions template parameter (i.e. calling ``arr_t.unchecked()`` or
|
||||
``arr.unchecked<T>()``. This will give you a proxy object that works in the
|
||||
same way, but results in less optimizable code and thus a small efficiency
|
||||
loss in tight loops.
|
||||
|
||||
Note that the returned proxy object directly references the array's data, and
|
||||
only reads its shape, strides, and writeable flag when constructed. You must
|
||||
take care to ensure that the referenced array is not destroyed or reshaped for
|
||||
the duration of the returned object, typically by limiting the scope of the
|
||||
returned instance.
|
||||
|
||||
The returned proxy object supports some of the same methods as ``py::array`` so
|
||||
that it can be used as a drop-in replacement for some existing, index-checked
|
||||
uses of ``py::array``:
|
||||
|
||||
- ``r.ndim()`` returns the number of dimensions
|
||||
|
||||
- ``r.data(1, 2, ...)`` and ``r.mutable_data(1, 2, ...)``` returns a pointer to
|
||||
the ``const T`` or ``T`` data, respectively, at the given indices. The
|
||||
latter is only available to proxies obtained via ``a.mutable_unchecked()``.
|
||||
|
||||
- ``itemsize()`` returns the size of an item in bytes, i.e. ``sizeof(T)``.
|
||||
|
||||
- ``ndim()`` returns the number of dimensions.
|
||||
|
||||
- ``shape(n)`` returns the size of dimension ``n``
|
||||
|
||||
- ``size()`` returns the total number of elements (i.e. the product of the shapes).
|
||||
|
||||
- ``nbytes()`` returns the number of bytes used by the referenced elements
|
||||
(i.e. ``itemsize()`` times ``size()``).
|
||||
|
||||
.. seealso::
|
||||
|
||||
The file :file:`tests/test_numpy_array.cpp` contains additional examples
|
||||
|
@ -242,67 +242,107 @@ size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
|
||||
}
|
||||
|
||||
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
|
||||
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
|
||||
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
|
||||
* will be -1 for dimensions determined at runtime.
|
||||
*/
|
||||
template <typename T, size_t Dims>
|
||||
template <typename T, ssize_t Dims>
|
||||
class unchecked_reference {
|
||||
protected:
|
||||
static constexpr bool Dynamic = Dims < 0;
|
||||
const unsigned char *data_;
|
||||
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
|
||||
// make large performance gains on big, nested loops.
|
||||
std::array<size_t, Dims> shape_, strides_;
|
||||
// make large performance gains on big, nested loops, but requires compile-time dimensions
|
||||
conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>>
|
||||
shape_, strides_;
|
||||
const size_t dims_;
|
||||
|
||||
friend class pybind11::array;
|
||||
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
|
||||
: data_{reinterpret_cast<const unsigned char *>(data)} {
|
||||
for (size_t i = 0; i < Dims; i++) {
|
||||
// Constructor for compile-time dimensions:
|
||||
template <bool Dyn = Dynamic>
|
||||
unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>)
|
||||
: data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
|
||||
for (size_t i = 0; i < dims_; i++) {
|
||||
shape_[i] = shape[i];
|
||||
strides_[i] = strides[i];
|
||||
}
|
||||
}
|
||||
// Constructor for runtime dimensions:
|
||||
template <bool Dyn = Dynamic>
|
||||
unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims)
|
||||
: data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
|
||||
|
||||
public:
|
||||
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
|
||||
* is equivalent to specifying them as 0.
|
||||
/** Unchecked const reference access to data at the given indices. For a compile-time known
|
||||
* number of dimensions, this requires the correct number of arguments; for run-time
|
||||
* dimensionality, this is not checked (and so is up to the caller to use safely).
|
||||
*/
|
||||
template <typename... Ix> const T& operator()(Ix... index) const {
|
||||
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
|
||||
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
|
||||
template <typename... Ix> const T &operator()(Ix... index) const {
|
||||
static_assert(sizeof...(Ix) == Dims || Dynamic,
|
||||
"Invalid number of indices for unchecked array reference");
|
||||
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...));
|
||||
}
|
||||
/** Unchecked const reference access to data; this operator only participates if the reference
|
||||
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
|
||||
const T &operator[](size_t index) const { return operator()(index); }
|
||||
|
||||
/// Pointer access to the data at the given indices.
|
||||
template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); }
|
||||
|
||||
/// Returns the item size, i.e. sizeof(T)
|
||||
constexpr static size_t itemsize() { return sizeof(T); }
|
||||
|
||||
/// Returns the shape (i.e. size) of dimension `dim`
|
||||
size_t shape(size_t dim) const { return shape_[dim]; }
|
||||
|
||||
/// Returns the number of dimensions of the array
|
||||
constexpr static size_t ndim() { return Dims; }
|
||||
size_t ndim() const { return dims_; }
|
||||
|
||||
/// Returns the total number of elements in the referenced array, i.e. the product of the shapes
|
||||
template <bool Dyn = Dynamic>
|
||||
enable_if_t<!Dyn, size_t> size() const {
|
||||
return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>());
|
||||
}
|
||||
template <bool Dyn = Dynamic>
|
||||
enable_if_t<Dyn, size_t> size() const {
|
||||
return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
/// Returns the total number of bytes used by the referenced data. Note that the actual span in
|
||||
/// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, size_t Dims>
|
||||
template <typename T, ssize_t Dims>
|
||||
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
|
||||
friend class pybind11::array;
|
||||
using ConstBase = unchecked_reference<T, Dims>;
|
||||
using ConstBase::ConstBase;
|
||||
using ConstBase::Dynamic;
|
||||
public:
|
||||
/// Mutable, unchecked access to data at the given indices.
|
||||
template <typename... Ix> T& operator()(Ix... index) {
|
||||
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
|
||||
static_assert(sizeof...(Ix) == Dims || Dynamic,
|
||||
"Invalid number of indices for unchecked array reference");
|
||||
return const_cast<T &>(ConstBase::operator()(index...));
|
||||
}
|
||||
/** Mutable, unchecked access data at the given index; this operator only participates if the
|
||||
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
* reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
|
||||
* exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
|
||||
T &operator[](size_t index) { return operator()(index); }
|
||||
|
||||
/// Mutable pointer access to the data at the given indices.
|
||||
template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); }
|
||||
};
|
||||
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_reference<T, Dim>> {
|
||||
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
|
||||
static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
|
||||
};
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
|
||||
@ -580,11 +620,11 @@ public:
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
if (ndim() != Dims)
|
||||
template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
if (Dims >= 0 && ndim() != (size_t) Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
|
||||
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||
@ -593,11 +633,11 @@ public:
|
||||
* reshaped for the duration of the returned object, and the caller must take care not to access
|
||||
* invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
if (ndim() != Dims)
|
||||
template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
if (Dims >= 0 && ndim() != (size_t) Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
|
||||
return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
|
||||
}
|
||||
|
||||
/// Return a new view with all of the dimensions of length 1 removed
|
||||
@ -625,7 +665,7 @@ protected:
|
||||
|
||||
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
||||
check_dimensions(index...);
|
||||
return detail::byte_offset_unsafe(strides(), size_t{index}...);
|
||||
return detail::byte_offset_unsafe(strides(), size_t(index)...);
|
||||
}
|
||||
|
||||
void check_writeable() const {
|
||||
@ -736,7 +776,7 @@ public:
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
return array::mutable_unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
@ -746,7 +786,7 @@ public:
|
||||
* for the duration of the returned object, and the caller must take care not to access invalid
|
||||
* dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
return array::unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
|
@ -68,6 +68,21 @@ template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(
|
||||
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
|
||||
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
|
||||
|
||||
template <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) {
|
||||
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
|
||||
py::list l;
|
||||
l.append(*r.data(0, 0));
|
||||
l.append(*r2.mutable_data(0, 0));
|
||||
l.append(r.data(0, 1) == r2.mutable_data(0, 1));
|
||||
l.append(r.ndim());
|
||||
l.append(r.itemsize());
|
||||
l.append(r.shape(0));
|
||||
l.append(r.shape(1));
|
||||
l.append(r.size());
|
||||
l.append(r.nbytes());
|
||||
return l.release();
|
||||
}
|
||||
|
||||
test_initializer numpy_array([](py::module &m) {
|
||||
auto sm = m.def_submodule("array");
|
||||
|
||||
@ -191,6 +206,7 @@ test_initializer numpy_array([](py::module &m) {
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
r(i, j) += v;
|
||||
}, py::arg().noconvert(), py::arg());
|
||||
|
||||
sm.def("proxy_init3", [](double start) {
|
||||
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
|
||||
auto r = a.mutable_unchecked<3>();
|
||||
@ -216,4 +232,36 @@ test_initializer numpy_array([](py::module &m) {
|
||||
sumsq += r[i] * r(i); // Either notation works for a 1D array
|
||||
return sumsq;
|
||||
});
|
||||
|
||||
sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
|
||||
auto r = a.unchecked<2>();
|
||||
auto r2 = a.mutable_unchecked<2>();
|
||||
return auxiliaries(r, r2);
|
||||
});
|
||||
|
||||
// Same as the above, but without a compile-time dimensions specification:
|
||||
sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) {
|
||||
auto r = a.mutable_unchecked();
|
||||
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
r(i, j) += v;
|
||||
}, py::arg().noconvert(), py::arg());
|
||||
sm.def("proxy_init3_dyn", [](double start) {
|
||||
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
|
||||
auto r = a.mutable_unchecked();
|
||||
if (r.ndim() != 3) throw std::domain_error("error: ndim != 3");
|
||||
for (size_t i = 0; i < r.shape(0); i++)
|
||||
for (size_t j = 0; j < r.shape(1); j++)
|
||||
for (size_t k = 0; k < r.shape(2); k++)
|
||||
r(i, j, k) = start++;
|
||||
return a;
|
||||
});
|
||||
sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
|
||||
return auxiliaries(a.unchecked(), a.mutable_unchecked());
|
||||
});
|
||||
|
||||
sm.def("array_auxiliaries2", [](py::array_t<double> a) {
|
||||
return auxiliaries(a, a);
|
||||
});
|
||||
});
|
||||
|
@ -341,8 +341,9 @@ def test_greedy_string_overload(): # issue 685
|
||||
assert issue685(123) == "other"
|
||||
|
||||
|
||||
def test_array_unchecked(msg):
|
||||
from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm
|
||||
def test_array_unchecked_fixed_dims(msg):
|
||||
from pybind11_tests.array import (proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm,
|
||||
proxy_auxiliaries2, array_auxiliaries2)
|
||||
|
||||
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
|
||||
proxy_add2(z1, 10)
|
||||
@ -359,3 +360,20 @@ def test_array_unchecked(msg):
|
||||
|
||||
assert proxy_squared_L2_norm(np.array(range(6))) == 55
|
||||
assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
|
||||
|
||||
assert proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
|
||||
assert proxy_auxiliaries2(z1) == array_auxiliaries2(z1)
|
||||
|
||||
|
||||
def test_array_unchecked_dyn_dims(msg):
|
||||
from pybind11_tests.array import (proxy_add2_dyn, proxy_init3_dyn, proxy_auxiliaries2_dyn,
|
||||
array_auxiliaries2)
|
||||
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
|
||||
proxy_add2_dyn(z1, 10)
|
||||
assert np.all(z1 == [[11, 12], [13, 14]])
|
||||
|
||||
expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
|
||||
assert np.all(proxy_init3_dyn(3.0) == expect_c)
|
||||
|
||||
assert proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
|
||||
assert proxy_auxiliaries2_dyn(z1) == array_auxiliaries2(z1)
|
||||
|
Loading…
Reference in New Issue
Block a user