mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 22:52:01 +00:00
array_t: tests, refine impl for indexing via operator()
This commit is contained in:
parent
cba2b32ead
commit
87fd2c5121
@ -1093,35 +1093,33 @@ public:
|
||||
// Reference to element at a given index
|
||||
template <typename... Ix>
|
||||
const T &at(Ix... index) const {
|
||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
}
|
||||
return *(static_cast<const T *>(array::data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
check_access_precondition(index...);
|
||||
return const_reference(index...);
|
||||
}
|
||||
|
||||
// Mutable reference to element at a given index
|
||||
template <typename... Ix>
|
||||
T &mutable_at(Ix... index) {
|
||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
}
|
||||
return *(static_cast<T *>(array::mutable_data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
check_access_precondition(index...);
|
||||
return mutable_reference(index...);
|
||||
}
|
||||
|
||||
// const-reference to element at a given index without bounds checking
|
||||
template <typename... Ix>
|
||||
const T &operator()(Ix... index) const {
|
||||
return *(static_cast<const T *>(array::data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
#if defined(NDEBUG)
|
||||
check_access_precondition(index...);
|
||||
#endif
|
||||
return const_reference(index...);
|
||||
}
|
||||
|
||||
// mutable reference to element at a given index without bounds checking
|
||||
template <typename... Ix>
|
||||
T &operator()(Ix... index) {
|
||||
return *(static_cast<T *>(array::mutable_data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
#if defined(NDEBUG)
|
||||
check_access_precondition(index...);
|
||||
#endif
|
||||
return mutable_reference(index...);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1180,6 +1178,26 @@ protected:
|
||||
| ExtraFlags,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename... Ix>
|
||||
const T &const_reference(Ix... index) const {
|
||||
return *(static_cast<const T *>(array::data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
template <typename... Ix>
|
||||
T &mutable_reference(Ix... index) {
|
||||
return *(static_cast<T *>(array::mutable_data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
template <typename... Ix>
|
||||
void check_access_precondition(Ix... index) const {
|
||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -132,12 +132,12 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
|
||||
return a;
|
||||
}
|
||||
template <typename... Ix>
|
||||
arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) {
|
||||
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
|
||||
a(idx...)++;
|
||||
return a;
|
||||
}
|
||||
template <typename... Ix>
|
||||
py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) {
|
||||
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
|
||||
return a(idx...);
|
||||
}
|
||||
|
||||
@ -219,8 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
def_index_fn(mutate_data_t, arr_t &);
|
||||
def_index_fn(at_t, const arr_t &);
|
||||
def_index_fn(mutate_at_t, arr_t &);
|
||||
def_index_fn(call_operator_subscript_t, arr_t &);
|
||||
def_index_fn(const_call_operator_subscript_t, const arr_t &);
|
||||
def_index_fn(subscript_via_call_operator_t, arr_t &);
|
||||
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
|
||||
|
||||
// test_make_c_f_array
|
||||
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
||||
|
@ -125,6 +125,13 @@ def test_at(arr):
|
||||
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||
|
||||
|
||||
def test_subscript_via_call_operator(arr):
|
||||
assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3
|
||||
assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4
|
||||
assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||
assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||
|
||||
|
||||
def test_mutate_readonly(arr):
|
||||
arr.flags.writeable = False
|
||||
for func, args in (
|
||||
|
Loading…
Reference in New Issue
Block a user