mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 05:05:11 +00:00
array: add unchecked access via proxy object
This adds bounds-unchecked access to arrays through a `a.unchecked<Type, Dimensions>()` method. (For `array_t<T>`, the `Type` template parameter is omitted). The mutable version (which requires the array have the `writeable` flag) is available as `a.mutable_unchecked<...>()`. Specifying the Dimensions as a template parameter allows storage of an std::array; having the strides and sizes stored that way (as opposed to storing a copy of the array's strides/shape pointers) allows the compiler to make significant optimizations of the shape() method that it can't make with a pointer; testing with nested loops of the form: for (size_t i0 = 0; i0 < r.shape(0); i0++) for (size_t i1 = 0; i1 < r.shape(1); i1++) ... r(i0, i1, ...) += 1; over a 10 million element array gives around a 25% speedup (versus using a pointer) for the 1D case, 33% for 2D, and runs more than twice as fast with a 5D array.
This commit is contained in:
parent
0d765f4a7c
commit
423a49b8be
@ -305,3 +305,48 @@ simply using ``vectorize``).
|
|||||||
|
|
||||||
The file :file:`tests/test_numpy_vectorize.cpp` contains a complete
|
The file :file:`tests/test_numpy_vectorize.cpp` contains a complete
|
||||||
example that demonstrates using :func:`vectorize` in more detail.
|
example that demonstrates using :func:`vectorize` in more detail.
|
||||||
|
|
||||||
|
Direct access
|
||||||
|
=============
|
||||||
|
|
||||||
|
For performance reasons, particularly when dealing with very large arrays, it
|
||||||
|
is often desirable to directly access array elements without internal checking
|
||||||
|
of dimensions and bounds on every access when indices are known to be already
|
||||||
|
valid. To avoid such checks, the ``array`` class and ``array_t<T>`` template
|
||||||
|
class offer an unchecked proxy object that can be used for this unchecked
|
||||||
|
access through the ``unchecked<N>`` and ``mutable_unchecked<N>`` methods,
|
||||||
|
where ``N`` gives the required dimensionality of the array:
|
||||||
|
|
||||||
|
.. code-block:: cpp
|
||||||
|
|
||||||
|
m.def("sum_3d", [](py::array_t<double> x) {
|
||||||
|
auto r = x.unchecked<3>(); // x must have ndim = 3; can be non-writeable
|
||||||
|
double sum = 0;
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
for (size_t j = 0; j < r.shape(1); j++)
|
||||||
|
for (size_t k = 0; k < r.shape(2); k++)
|
||||||
|
sum += r(i, j, k);
|
||||||
|
return sum;
|
||||||
|
});
|
||||||
|
m.def("increment_3d", [](py::array_t<double> x) {
|
||||||
|
auto r = x.mutable_unchecked<3>(); // Will throw if ndim != 3 or flags.writeable is false
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
for (size_t j = 0; j < r.shape(1); j++)
|
||||||
|
for (size_t k = 0; k < r.shape(2); k++)
|
||||||
|
r(i, j, k) += 1.0;
|
||||||
|
}, py::arg().noconvert());
|
||||||
|
|
||||||
|
To obtain the proxy from an ``array`` object, you must specify both the data
|
||||||
|
type and number of dimensions as template arguments, such as ``auto r =
|
||||||
|
myarray.mutable_unchecked<float, 2>()``.
|
||||||
|
|
||||||
|
Note that the returned proxy object directly references the array's data, and
|
||||||
|
only reads its shape, strides, and writeable flag when constructed. You must
|
||||||
|
take care to ensure that the referenced array is not destroyed or reshaped for
|
||||||
|
the duration of the returned object, typically by limiting the scope of the
|
||||||
|
returned instance.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
The file :file:`tests/test_numpy_array.cpp` contains additional examples
|
||||||
|
demonstrating the use of this feature.
|
||||||
|
@ -35,6 +35,9 @@
|
|||||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||||
|
|
||||||
NAMESPACE_BEGIN(pybind11)
|
NAMESPACE_BEGIN(pybind11)
|
||||||
|
|
||||||
|
class array; // Forward declaration
|
||||||
|
|
||||||
NAMESPACE_BEGIN(detail)
|
NAMESPACE_BEGIN(detail)
|
||||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
||||||
|
|
||||||
@ -232,6 +235,78 @@ template <typename T> using is_pod_struct = all_of<
|
|||||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
|
||||||
|
template <size_t Dim = 0, typename Strides, typename... Ix>
|
||||||
|
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
|
||||||
|
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
|
||||||
|
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
|
||||||
|
*/
|
||||||
|
template <typename T, size_t Dims>
|
||||||
|
class unchecked_reference {
|
||||||
|
protected:
|
||||||
|
const unsigned char *data_;
|
||||||
|
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
|
||||||
|
// make large performance gains on big, nested loops.
|
||||||
|
std::array<size_t, Dims> shape_, strides_;
|
||||||
|
|
||||||
|
friend class pybind11::array;
|
||||||
|
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
|
||||||
|
: data_{reinterpret_cast<const unsigned char *>(data)} {
|
||||||
|
for (size_t i = 0; i < Dims; i++) {
|
||||||
|
shape_[i] = shape[i];
|
||||||
|
strides_[i] = strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
|
||||||
|
* is equivalent to specifying them as 0.
|
||||||
|
*/
|
||||||
|
template <typename... Ix> const T& operator()(Ix... index) const {
|
||||||
|
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
|
||||||
|
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
|
||||||
|
}
|
||||||
|
/** Unchecked const reference access to data; this operator only participates if the reference
|
||||||
|
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||||
|
*/
|
||||||
|
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||||
|
const T &operator[](size_t index) const { return operator()(index); }
|
||||||
|
|
||||||
|
/// Returns the shape (i.e. size) of dimension `dim`
|
||||||
|
size_t shape(size_t dim) const { return shape_[dim]; }
|
||||||
|
|
||||||
|
/// Returns the number of dimensions of the array
|
||||||
|
constexpr static size_t ndim() { return Dims; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, size_t Dims>
|
||||||
|
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
|
||||||
|
friend class pybind11::array;
|
||||||
|
using ConstBase = unchecked_reference<T, Dims>;
|
||||||
|
using ConstBase::ConstBase;
|
||||||
|
public:
|
||||||
|
/// Mutable, unchecked access to data at the given indices.
|
||||||
|
template <typename... Ix> T& operator()(Ix... index) {
|
||||||
|
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
|
||||||
|
return const_cast<T &>(ConstBase::operator()(index...));
|
||||||
|
}
|
||||||
|
/** Mutable, unchecked access data at the given index; this operator only participates if the
|
||||||
|
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||||
|
*/
|
||||||
|
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||||
|
T &operator[](size_t index) { return operator()(index); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, size_t Dim>
|
||||||
|
struct type_caster<unchecked_reference<T, Dim>> {
|
||||||
|
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
|
||||||
|
};
|
||||||
|
template <typename T, size_t Dim>
|
||||||
|
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
|
||||||
|
|
||||||
NAMESPACE_END(detail)
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
class dtype : public object {
|
class dtype : public object {
|
||||||
@ -500,6 +575,31 @@ public:
|
|||||||
return offset_at(index...) / itemsize();
|
return offset_at(index...) / itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||||
|
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||||
|
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||||
|
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||||
|
*/
|
||||||
|
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||||
|
if (ndim() != Dims)
|
||||||
|
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||||
|
"; expected " + std::to_string(Dims));
|
||||||
|
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||||
|
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
|
||||||
|
* underlying array have the `writable` flag. Use with care: the array must not be destroyed or
|
||||||
|
* reshaped for the duration of the returned object, and the caller must take care not to access
|
||||||
|
* invalid dimensions or dimension indices.
|
||||||
|
*/
|
||||||
|
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||||
|
if (ndim() != Dims)
|
||||||
|
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||||
|
"; expected " + std::to_string(Dims));
|
||||||
|
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
|
||||||
|
}
|
||||||
|
|
||||||
/// Return a new view with all of the dimensions of length 1 removed
|
/// Return a new view with all of the dimensions of length 1 removed
|
||||||
array squeeze() {
|
array squeeze() {
|
||||||
auto& api = detail::npy_api::get();
|
auto& api = detail::npy_api::get();
|
||||||
@ -525,15 +625,9 @@ protected:
|
|||||||
|
|
||||||
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
||||||
check_dimensions(index...);
|
check_dimensions(index...);
|
||||||
return byte_offset_unsafe(index...);
|
return detail::byte_offset_unsafe(strides(), size_t{index}...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
|
|
||||||
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
|
|
||||||
|
|
||||||
void check_writeable() const {
|
void check_writeable() const {
|
||||||
if (!writeable())
|
if (!writeable())
|
||||||
throw std::domain_error("array is not writeable");
|
throw std::domain_error("array is not writeable");
|
||||||
@ -637,6 +731,25 @@ public:
|
|||||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||||
|
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||||
|
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||||
|
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||||
|
*/
|
||||||
|
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||||
|
return array::mutable_unchecked<T, Dims>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||||
|
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
|
||||||
|
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
|
||||||
|
* for the duration of the returned object, and the caller must take care not to access invalid
|
||||||
|
* dimensions or dimension indices.
|
||||||
|
*/
|
||||||
|
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||||
|
return array::unchecked<T, Dims>();
|
||||||
|
}
|
||||||
|
|
||||||
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
||||||
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
static array_t ensure(handle h) {
|
static array_t ensure(handle h) {
|
||||||
|
@ -184,4 +184,36 @@ test_initializer numpy_array([](py::module &m) {
|
|||||||
sm.def("issue685", [](std::string) { return "string"; });
|
sm.def("issue685", [](std::string) { return "string"; });
|
||||||
sm.def("issue685", [](py::array) { return "array"; });
|
sm.def("issue685", [](py::array) { return "array"; });
|
||||||
sm.def("issue685", [](py::object) { return "other"; });
|
sm.def("issue685", [](py::object) { return "other"; });
|
||||||
|
|
||||||
|
sm.def("proxy_add2", [](py::array_t<double> a, double v) {
|
||||||
|
auto r = a.mutable_unchecked<2>();
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
for (size_t j = 0; j < r.shape(1); j++)
|
||||||
|
r(i, j) += v;
|
||||||
|
}, py::arg().noconvert(), py::arg());
|
||||||
|
sm.def("proxy_init3", [](double start) {
|
||||||
|
py::array_t<double, py::array::c_style> a({ 3, 3, 3 });
|
||||||
|
auto r = a.mutable_unchecked<3>();
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
for (size_t j = 0; j < r.shape(1); j++)
|
||||||
|
for (size_t k = 0; k < r.shape(2); k++)
|
||||||
|
r(i, j, k) = start++;
|
||||||
|
return a;
|
||||||
|
});
|
||||||
|
sm.def("proxy_init3F", [](double start) {
|
||||||
|
py::array_t<double, py::array::f_style> a({ 3, 3, 3 });
|
||||||
|
auto r = a.mutable_unchecked<3>();
|
||||||
|
for (size_t k = 0; k < r.shape(2); k++)
|
||||||
|
for (size_t j = 0; j < r.shape(1); j++)
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
r(i, j, k) = start++;
|
||||||
|
return a;
|
||||||
|
});
|
||||||
|
sm.def("proxy_squared_L2_norm", [](py::array_t<double> a) {
|
||||||
|
auto r = a.unchecked<1>();
|
||||||
|
double sumsq = 0;
|
||||||
|
for (size_t i = 0; i < r.shape(0); i++)
|
||||||
|
sumsq += r[i] * r(i); // Either notation works for a 1D array
|
||||||
|
return sumsq;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
@ -339,3 +339,23 @@ def test_greedy_string_overload(): # issue 685
|
|||||||
assert issue685("abc") == "string"
|
assert issue685("abc") == "string"
|
||||||
assert issue685(np.array([97, 98, 99], dtype='b')) == "array"
|
assert issue685(np.array([97, 98, 99], dtype='b')) == "array"
|
||||||
assert issue685(123) == "other"
|
assert issue685(123) == "other"
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_unchecked(msg):
|
||||||
|
from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm
|
||||||
|
|
||||||
|
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
|
||||||
|
proxy_add2(z1, 10)
|
||||||
|
assert np.all(z1 == [[11, 12], [13, 14]])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
proxy_add2(np.array([1., 2, 3]), 5.0)
|
||||||
|
assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2"
|
||||||
|
|
||||||
|
expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int')
|
||||||
|
assert np.all(proxy_init3(3.0) == expect_c)
|
||||||
|
expect_f = np.transpose(expect_c)
|
||||||
|
assert np.all(proxy_init3F(3.0) == expect_f)
|
||||||
|
|
||||||
|
assert proxy_squared_L2_norm(np.array(range(6))) == 55
|
||||||
|
assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
|
||||||
|
Loading…
Reference in New Issue
Block a user