This commit is contained in:
Francesco Rizzi 2024-11-21 20:04:37 +08:00 committed by GitHub
commit 73fd5a97e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 25 deletions

View File

@ -1321,9 +1321,7 @@ public:
// Reference to element at a given index
template <typename... Ix>
const T &at(Ix... index) const {
if ((ssize_t) sizeof...(index) != ndim()) {
fail_dim_check(sizeof...(index), "index dimension mismatch");
}
check_rank_precondition(sizeof...(index));
return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
@ -1331,13 +1329,33 @@ public:
// Mutable reference to element at a given index
template <typename... Ix>
T &mutable_at(Ix... index) {
if ((ssize_t) sizeof...(index) != ndim()) {
fail_dim_check(sizeof...(index), "index dimension mismatch");
}
check_rank_precondition(sizeof...(index));
return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
// const-reference to element at a given index without bounds checking
template <typename... Ix>
const T &operator()(Ix... index) const {
#if !defined(NDEBUG)
check_rank_precondition(sizeof...(index));
check_dimensions(index...);
#endif
return *(static_cast<const T *>(array::data())
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
}
// mutable reference to element at a given index without bounds checking
template <typename... Ix>
T &operator()(Ix... index) {
#if !defined(NDEBUG)
check_rank_precondition(sizeof...(index));
check_dimensions(index...);
#endif
return *(static_cast<T *>(array::mutable_data())
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
}
/**
* Returns a proxy object that provides access to the array's data without bounds or
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
@ -1394,6 +1412,13 @@ protected:
| ExtraFlags,
nullptr);
}
private:
void check_rank_precondition(ssize_t dim) const {
if (dim != ndim()) {
fail_dim_check(dim, "index dimension mismatch");
}
}
};
template <typename T>

View File

@ -131,6 +131,15 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
a.mutable_at(idx...)++;
return a;
}
template <typename... Ix>
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
a(idx...)++;
return a;
}
template <typename... Ix>
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
return a(idx...);
}
#define def_index_fn(name, type) \
sm.def(#name, [](type a) { return name(a); }); \
@ -246,6 +255,13 @@ TEST_SUBMODULE(numpy_array, sm) {
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
sm.def("owndata", [](const arr &a) { return a.owndata(); });
sm.attr("defined_NDEBUG") =
#ifdef NDEBUG
true;
#else
false;
#endif
// test_index_offset
def_index_fn(index_at, const arr &);
def_index_fn(index_at_t, const arr_t &);
@ -259,6 +275,8 @@ TEST_SUBMODULE(numpy_array, sm) {
def_index_fn(mutate_data_t, arr_t &);
def_index_fn(at_t, const arr_t &);
def_index_fn(mutate_at_t, arr_t &);
def_index_fn(subscript_via_call_operator_t, arr_t &);
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
// test_make_c_f_array
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });

View File

@ -111,20 +111,32 @@ def test_data(arr, args, ret):
assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0)
@pytest.mark.parametrize(
"func",
[
m.at_t,
m.mutate_at_t,
m.const_subscript_via_call_operator_t,
m.subscript_via_call_operator_t,
][: 2 if m.defined_NDEBUG else 99],
)
@pytest.mark.parametrize("dim", [0, 1, 3])
def test_at_fail(arr, dim):
for func in m.at_t, m.mutate_at_t:
def test_elem_reference(arr, func, dim):
with pytest.raises(IndexError) as excinfo:
func(arr, *([0] * dim))
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
def test_at(arr):
assert m.at_t(arr, 0, 2) == 3
assert m.at_t(arr, 1, 0) == 4
@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t])
def test_const_elem_reference(arr, func):
assert func(arr, 0, 2) == 3
assert func(arr, 1, 0) == 4
assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
@pytest.mark.parametrize("func", [m.mutate_at_t, m.subscript_via_call_operator_t])
def test_mutable_elem_reference(arr, func):
assert all(func(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
assert all(func(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
def test_mutate_readonly(arr):
@ -153,8 +165,9 @@ def test_mutate_data(arr):
assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
def test_bounds_check(arr):
for func in (
@pytest.mark.parametrize(
"func",
[
m.index_at,
m.index_at_t,
m.data,
@ -163,7 +176,11 @@ def test_bounds_check(arr):
m.mutate_data_t,
m.at_t,
m.mutate_at_t,
):
m.const_subscript_via_call_operator_t,
m.subscript_via_call_operator_t,
][: 8 if m.defined_NDEBUG else 99],
)
def test_bounds_check(arr, func):
with pytest.raises(IndexError) as excinfo:
func(arr, 2, 0)
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"