mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Switch NumPy variadic indexing to per-value arguments (#500)
* Also added unsafe version without checks
This commit is contained in:
parent
2fb5f1d0c3
commit
5027c4f95b
@ -444,31 +444,31 @@ public:
|
||||
|
||||
/// Pointer to the contained data. If index is not provided, points to the
|
||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> const void* data(Ix&&... index) const {
|
||||
template<typename... Ix> const void* data(Ix... index) const {
|
||||
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Mutable pointer to the contained data. If index is not provided, points to the
|
||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||
/// May throw if the array is not writeable.
|
||||
template<typename... Ix> void* mutable_data(Ix&&... index) {
|
||||
template<typename... Ix> void* mutable_data(Ix... index) {
|
||||
check_writeable();
|
||||
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Byte offset from beginning of the array to a given index (full or partial).
|
||||
/// May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> size_t offset_at(Ix&&... index) const {
|
||||
template<typename... Ix> size_t offset_at(Ix... index) const {
|
||||
if (sizeof...(index) > ndim())
|
||||
fail_dim_check(sizeof...(index), "too many indices for an array");
|
||||
return get_byte_offset(index...);
|
||||
return byte_offset(size_t(index)...);
|
||||
}
|
||||
|
||||
size_t offset_at() const { return 0; }
|
||||
|
||||
/// Item count from beginning of the array to a given index (full or partial).
|
||||
/// May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> size_t index_at(Ix&&... index) const {
|
||||
template<typename... Ix> size_t index_at(Ix... index) const {
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
@ -493,18 +493,16 @@ protected:
|
||||
" (ndim = " + std::to_string(ndim()) + ")");
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t get_byte_offset(Ix&&... index) const {
|
||||
const size_t idx[] = { (size_t) index... };
|
||||
if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{})) {
|
||||
auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{});
|
||||
throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
|
||||
" is out of bounds for axis " + std::to_string(mismatch.first - idx) +
|
||||
" with size " + std::to_string(*mismatch.second));
|
||||
}
|
||||
return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0);
|
||||
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
||||
check_dimensions(index...);
|
||||
return byte_offset_unsafe(index...);
|
||||
}
|
||||
|
||||
size_t get_byte_offset() const { return 0; }
|
||||
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
|
||||
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
|
||||
}
|
||||
|
||||
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
|
||||
|
||||
void check_writeable() const {
|
||||
if (!writeable())
|
||||
@ -522,6 +520,23 @@ protected:
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
template<typename... Ix> void check_dimensions(Ix... index) const {
|
||||
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
|
||||
}
|
||||
|
||||
void check_dimensions_impl(size_t, const size_t*) const { }
|
||||
|
||||
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
|
||||
if (i >= *shape) {
|
||||
throw index_error(std::string("index ") + std::to_string(i) +
|
||||
" is out of bounds for axis " + std::to_string(axis) +
|
||||
" with size " + std::to_string(*shape));
|
||||
}
|
||||
check_dimensions_impl(axis + 1, shape + 1, index...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
@ -548,32 +563,30 @@ public:
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t index_at(Ix&... index) const {
|
||||
template<typename... Ix> size_t index_at(Ix... index) const {
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
template<typename... Ix> const T* data(Ix&&... index) const {
|
||||
template<typename... Ix> const T* data(Ix... index) const {
|
||||
return static_cast<const T*>(array::data(index...));
|
||||
}
|
||||
|
||||
template<typename... Ix> T* mutable_data(Ix&&... index) {
|
||||
template<typename... Ix> T* mutable_data(Ix... index) {
|
||||
return static_cast<T*>(array::mutable_data(index...));
|
||||
}
|
||||
|
||||
// Reference to element at a given index
|
||||
template<typename... Ix> const T& at(Ix&&... index) const {
|
||||
template<typename... Ix> const T& at(Ix... index) const {
|
||||
if (sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||
return *(static_cast<const T*>(array::data()) + get_byte_offset(index...) / itemsize());
|
||||
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
// Mutable reference to element at a given index
|
||||
template<typename... Ix> T& mutable_at(Ix&&... index) {
|
||||
template<typename... Ix> T& mutable_at(Ix... index) {
|
||||
if (sizeof...(index) != ndim())
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||
return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
|
||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||
|
@ -18,11 +18,11 @@
|
||||
using arr = py::array;
|
||||
using arr_t = py::array_t<uint16_t, 0>;
|
||||
|
||||
template<typename... Ix> arr data(const arr& a, Ix&&... index) {
|
||||
template<typename... Ix> arr data(const arr& a, Ix... index) {
|
||||
return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
|
||||
}
|
||||
|
||||
template<typename... Ix> arr data_t(const arr_t& a, Ix&&... index) {
|
||||
template<typename... Ix> arr data_t(const arr_t& a, Ix... index) {
|
||||
return arr(a.size() - a.index_at(index...), a.data(index...));
|
||||
}
|
||||
|
||||
@ -40,26 +40,26 @@ arr_t& mutate_data_t(arr_t& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> arr& mutate_data(arr& a, Ix&&... index) {
|
||||
template<typename... Ix> arr& mutate_data(arr& a, Ix... index) {
|
||||
auto ptr = (uint8_t *) a.mutable_data(index...);
|
||||
for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++)
|
||||
ptr[i] = (uint8_t) (ptr[i] * 2);
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix&&... index) {
|
||||
template<typename... Ix> arr_t& mutate_data_t(arr_t& a, Ix... index) {
|
||||
auto ptr = a.mutable_data(index...);
|
||||
for (size_t i = 0; i < a.size() - a.index_at(index...); i++)
|
||||
ptr[i]++;
|
||||
return a;
|
||||
}
|
||||
|
||||
template<typename... Ix> size_t index_at(const arr& a, Ix&&... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t index_at_t(const arr_t& a, Ix&&... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at(const arr& a, Ix&&... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at_t(const arr_t& a, Ix&&... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t at_t(const arr_t& a, Ix&&... idx) { return a.at(idx...); }
|
||||
template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix&&... idx) { a.mutable_at(idx...)++; return a; }
|
||||
template<typename... Ix> size_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); }
|
||||
template<typename... Ix> size_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); }
|
||||
template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; }
|
||||
|
||||
#define def_index_fn(name, type) \
|
||||
sm.def(#name, [](type a) { return name(a); }); \
|
||||
|
Loading…
Reference in New Issue
Block a user