array_t: tests, refine impl for indexing via operator()

This commit is contained in:
Francesco Rizzi 2023-06-28 17:12:01 +02:00
parent cba2b32ead
commit 87fd2c5121
3 changed files with 43 additions and 18 deletions

View File

@ -1093,35 +1093,33 @@ public:
// Reference to element at a given index
template <typename... Ix>
const T &at(Ix... index) const {
if ((ssize_t) sizeof...(index) != ndim()) {
fail_dim_check(sizeof...(index), "index dimension mismatch");
}
return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
check_access_precondition(index...);
return const_reference(index...);
}
// Mutable reference to element at a given index
template <typename... Ix>
T &mutable_at(Ix... index) {
if ((ssize_t) sizeof...(index) != ndim()) {
fail_dim_check(sizeof...(index), "index dimension mismatch");
}
return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
check_access_precondition(index...);
return mutable_reference(index...);
}
// const-reference to element at a given index without bounds checking
template <typename... Ix>
const T &operator()(Ix... index) const {
return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
#if defined(NDEBUG)
check_access_precondition(index...);
#endif
return const_reference(index...);
}
// mutable reference to element at a given index without bounds checking
template <typename... Ix>
T &operator()(Ix... index) {
return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
#if defined(NDEBUG)
check_access_precondition(index...);
#endif
return mutable_reference(index...);
}
/**
@ -1180,6 +1178,26 @@ protected:
| ExtraFlags,
nullptr);
}
private:
template <typename... Ix>
const T &const_reference(Ix... index) const {
return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
template <typename... Ix>
T &mutable_reference(Ix... index) {
return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
template <typename... Ix>
void check_access_precondition(Ix... index) const {
if ((ssize_t) sizeof...(index) != ndim()) {
fail_dim_check(sizeof...(index), "index dimension mismatch");
}
}
};
template <typename T>

View File

@ -132,12 +132,12 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
return a;
}
template <typename... Ix>
arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) {
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
a(idx...)++;
return a;
}
template <typename... Ix>
py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) {
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
return a(idx...);
}
@ -219,8 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) {
def_index_fn(mutate_data_t, arr_t &);
def_index_fn(at_t, const arr_t &);
def_index_fn(mutate_at_t, arr_t &);
def_index_fn(call_operator_subscript_t, arr_t &);
def_index_fn(const_call_operator_subscript_t, const arr_t &);
def_index_fn(subscript_via_call_operator_t, arr_t &);
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
// test_make_c_f_array
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });

View File

@ -125,6 +125,13 @@ def test_at(arr):
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
def test_subscript_via_call_operator(arr):
assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3
assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4
assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
def test_mutate_readonly(arr):
arr.flags.writeable = False
for func, args in (