add bounds check and test

This commit is contained in:
Francesco Rizzi 2023-06-29 09:12:17 +02:00
parent d2ea386ef7
commit 8917a1e9b3
2 changed files with 29 additions and 38 deletions

View File

@ -1093,33 +1093,39 @@ public:
// Reference to element at a given index // Reference to element at a given index
template <typename... Ix> template <typename... Ix>
const T &at(Ix... index) const { const T &at(Ix... index) const {
check_dim_precondition(sizeof...(index)); check_rank_precondition(sizeof...(index));
return const_reference(index...); return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
} }
// Mutable reference to element at a given index // Mutable reference to element at a given index
template <typename... Ix> template <typename... Ix>
T &mutable_at(Ix... index) { T &mutable_at(Ix... index) {
check_dim_precondition(sizeof...(index)); check_rank_precondition(sizeof...(index));
return mutable_reference(index...); return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
} }
// const-reference to element at a given index without bounds checking // const-reference to element at a given index without bounds checking
template <typename... Ix> template <typename... Ix>
const T &operator()(Ix... index) const { const T &operator()(Ix... index) const {
#if !defined(NDEBUG) #if !defined(NDEBUG)
check_dim_precondition(sizeof...(index)); check_rank_precondition(sizeof...(index));
check_dimensions(index...);
#endif #endif
return const_reference(index...); return *(static_cast<const T *>(array::data())
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
} }
// mutable reference to element at a given index without bounds checking // mutable reference to element at a given index without bounds checking
template <typename... Ix> template <typename... Ix>
T &operator()(Ix... index) { T &operator()(Ix... index) {
#if !defined(NDEBUG) #if !defined(NDEBUG)
check_dim_precondition(sizeof...(index)); check_rank_precondition(sizeof...(index));
check_dimensions(index...);
#endif #endif
return mutable_reference(index...); return *(static_cast<T *>(array::mutable_data())
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
} }
/** /**
@ -1180,19 +1186,7 @@ protected:
} }
private: private:
template <typename... Ix> void check_rank_precondition(ssize_t dim) const {
const T &const_reference(Ix... index) const {
return *(static_cast<const T *>(array::data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
template <typename... Ix>
T &mutable_reference(Ix... index) {
return *(static_cast<T *>(array::mutable_data())
+ byte_offset(ssize_t(index)...) / itemsize());
}
void check_dim_precondition(ssize_t dim) const {
if (dim != ndim()) { if (dim != ndim()) {
fail_dim_check(dim, "index dimension mismatch"); fail_dim_check(dim, "index dimension mismatch");
} }

View File

@ -125,14 +125,6 @@ def test_elem_reference(arr, func, dim):
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
# @pytest.mark.parametrize("dim", [0, 1, 3])
# def test_at_fail(arr, dim):
# for func in m.at_t, m.mutate_at_t:
# with pytest.raises(IndexError) as excinfo:
# func(arr, *([0] * dim))
# assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t]) @pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t])
def test_const_elem_reference(arr, func): def test_const_elem_reference(arr, func):
assert func(arr, 0, 2) == 3 assert func(arr, 0, 2) == 3
@ -171,8 +163,9 @@ def test_mutate_data(arr):
assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
def test_bounds_check(arr): @pytest.mark.parametrize(
for func in ( "func",
[
m.index_at, m.index_at,
m.index_at_t, m.index_at_t,
m.data, m.data,
@ -181,7 +174,11 @@ def test_bounds_check(arr):
m.mutate_data_t, m.mutate_data_t,
m.at_t, m.at_t,
m.mutate_at_t, m.mutate_at_t,
): m.const_subscript_via_call_operator_t,
m.subscript_via_call_operator_t,
][: 8 if m.defined_NDEBUG else 99],
)
def test_bounds_check(arr, func):
with pytest.raises(IndexError) as excinfo: with pytest.raises(IndexError) as excinfo:
func(arr, 2, 0) func(arr, 2, 0)
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"