mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-16 21:57:55 +00:00
array_t: tests, refine impl for indexing via operator()
This commit is contained in:
parent
cba2b32ead
commit
87fd2c5121
@ -1093,35 +1093,33 @@ public:
|
|||||||
// Reference to element at a given index
|
// Reference to element at a given index
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
const T &at(Ix... index) const {
|
const T &at(Ix... index) const {
|
||||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
check_access_precondition(index...);
|
||||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
return const_reference(index...);
|
||||||
}
|
|
||||||
return *(static_cast<const T *>(array::data())
|
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mutable reference to element at a given index
|
// Mutable reference to element at a given index
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
T &mutable_at(Ix... index) {
|
T &mutable_at(Ix... index) {
|
||||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
check_access_precondition(index...);
|
||||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
return mutable_reference(index...);
|
||||||
}
|
|
||||||
return *(static_cast<T *>(array::mutable_data())
|
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// const-reference to element at a given index without bounds checking
|
// const-reference to element at a given index without bounds checking
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
const T &operator()(Ix... index) const {
|
const T &operator()(Ix... index) const {
|
||||||
return *(static_cast<const T *>(array::data())
|
#if defined(NDEBUG)
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
check_access_precondition(index...);
|
||||||
|
#endif
|
||||||
|
return const_reference(index...);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mutable reference to element at a given index without bounds checking
|
// mutable reference to element at a given index without bounds checking
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
T &operator()(Ix... index) {
|
T &operator()(Ix... index) {
|
||||||
return *(static_cast<T *>(array::mutable_data())
|
#if defined(NDEBUG)
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
check_access_precondition(index...);
|
||||||
|
#endif
|
||||||
|
return mutable_reference(index...);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1180,6 +1178,26 @@ protected:
|
|||||||
| ExtraFlags,
|
| ExtraFlags,
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename... Ix>
|
||||||
|
const T &const_reference(Ix... index) const {
|
||||||
|
return *(static_cast<const T *>(array::data())
|
||||||
|
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ix>
|
||||||
|
T &mutable_reference(Ix... index) {
|
||||||
|
return *(static_cast<T *>(array::mutable_data())
|
||||||
|
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ix>
|
||||||
|
void check_access_precondition(Ix... index) const {
|
||||||
|
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||||
|
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -132,12 +132,12 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) {
|
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
|
||||||
a(idx...)++;
|
a(idx...)++;
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) {
|
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
|
||||||
return a(idx...);
|
return a(idx...);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,8 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
def_index_fn(mutate_data_t, arr_t &);
|
def_index_fn(mutate_data_t, arr_t &);
|
||||||
def_index_fn(at_t, const arr_t &);
|
def_index_fn(at_t, const arr_t &);
|
||||||
def_index_fn(mutate_at_t, arr_t &);
|
def_index_fn(mutate_at_t, arr_t &);
|
||||||
def_index_fn(call_operator_subscript_t, arr_t &);
|
def_index_fn(subscript_via_call_operator_t, arr_t &);
|
||||||
def_index_fn(const_call_operator_subscript_t, const arr_t &);
|
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
|
||||||
|
|
||||||
// test_make_c_f_array
|
// test_make_c_f_array
|
||||||
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
||||||
|
@ -125,6 +125,13 @@ def test_at(arr):
|
|||||||
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||||
|
|
||||||
|
|
||||||
|
def test_subscript_via_call_operator(arr):
|
||||||
|
assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3
|
||||||
|
assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4
|
||||||
|
assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||||
|
assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||||
|
|
||||||
|
|
||||||
def test_mutate_readonly(arr):
|
def test_mutate_readonly(arr):
|
||||||
arr.flags.writeable = False
|
arr.flags.writeable = False
|
||||||
for func, args in (
|
for func, args in (
|
||||||
|
Loading…
Reference in New Issue
Block a user