diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index f01c9d154..fe226e4df 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,33 +1093,39 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - check_dim_precondition(sizeof...(index)); - return const_reference(index...); + check_rank_precondition(sizeof...(index)); + return *(static_cast(array::data()) + + byte_offset(ssize_t(index)...) / itemsize()); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - check_dim_precondition(sizeof...(index)); - return mutable_reference(index...); + check_rank_precondition(sizeof...(index)); + return *(static_cast(array::mutable_data()) + + byte_offset(ssize_t(index)...) / itemsize()); } // const-reference to element at a given index without bounds checking template const T &operator()(Ix... index) const { #if !defined(NDEBUG) - check_dim_precondition(sizeof...(index)); + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); #endif - return const_reference(index...); + return *(static_cast(array::data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); } // mutable reference to element at a given index without bounds checking template T &operator()(Ix... index) { #if !defined(NDEBUG) - check_dim_precondition(sizeof...(index)); + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); #endif - return mutable_reference(index...); + return *(static_cast(array::mutable_data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); } /** @@ -1180,19 +1186,7 @@ protected: } private: - template - const T &const_reference(Ix... index) const { - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); - } - - template - T &mutable_reference(Ix... index) { - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); - } - - void check_dim_precondition(ssize_t dim) const { + void check_rank_precondition(ssize_t dim) const { if (dim != ndim()) { fail_dim_check(dim, "index dimension mismatch"); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index b2434a07e..222a23125 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -125,14 +125,6 @@ def test_elem_reference(arr, func, dim): assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" -# @pytest.mark.parametrize("dim", [0, 1, 3]) -# def test_at_fail(arr, dim): -# for func in m.at_t, m.mutate_at_t: -# with pytest.raises(IndexError) as excinfo: -# func(arr, *([0] * dim)) -# assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" - - @pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t]) def test_const_elem_reference(arr, func): assert func(arr, 0, 2) == 3 @@ -171,8 +163,9 @@ def test_mutate_data(arr): assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) -def test_bounds_check(arr): - for func in ( +@pytest.mark.parametrize( + "func", + [ m.index_at, m.index_at_t, m.data, @@ -181,13 +174,17 @@ def test_bounds_check(arr): m.mutate_data_t, m.at_t, m.mutate_at_t, - ): - with pytest.raises(IndexError) as excinfo: - func(arr, 2, 0) - assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" - with pytest.raises(IndexError) as excinfo: - func(arr, 0, 4) - assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" + m.const_subscript_via_call_operator_t, + m.subscript_via_call_operator_t, + ][: 8 if m.defined_NDEBUG else 99], +) +def test_bounds_check(arr, func): + with pytest.raises(IndexError) as excinfo: + func(arr, 2, 0) + assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" + with pytest.raises(IndexError) as excinfo: + func(arr, 0, 4) + assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" def test_make_c_f_array():