diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 09894cf74..726a20add 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1232,9 +1232,7 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_rank_precondition(sizeof...(index)); return *(static_cast(array::data()) + byte_offset(ssize_t(index)...) / itemsize()); } @@ -1242,13 +1240,33 @@ public: // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_rank_precondition(sizeof...(index)); return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); } + // const-reference to element at a given index without bounds checking + template + const T &operator()(Ix... index) const { +#if !defined(NDEBUG) + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); +#endif + return *(static_cast(array::data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); + } + + // mutable reference to element at a given index without bounds checking + template + T &operator()(Ix... index) { +#if !defined(NDEBUG) + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); +#endif + return *(static_cast(array::mutable_data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); + } + /** * Returns a proxy object that provides access to the array's data without bounds or * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with @@ -1305,6 +1323,13 @@ protected: | ExtraFlags, nullptr); } + +private: + void check_rank_precondition(ssize_t dim) const { + if (dim != ndim()) { + fail_dim_check(dim, "index dimension mismatch"); + } + } }; template diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index c2f754208..51e965d8f 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -131,6 +131,15 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) { a.mutable_at(idx...)++; return a; } +template +arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) { + a(idx...)++; + return a; +} +template +py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) { + return a(idx...); +} #define def_index_fn(name, type) \ sm.def(#name, [](type a) { return name(a); }); \ @@ -197,6 +206,13 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("nbytes", [](const arr &a) { return a.nbytes(); }); sm.def("owndata", [](const arr &a) { return a.owndata(); }); + sm.attr("defined_NDEBUG") = +#ifdef NDEBUG + true; +#else + false; +#endif + // test_index_offset def_index_fn(index_at, const arr &); def_index_fn(index_at_t, const arr_t &); @@ -210,6 +226,8 @@ TEST_SUBMODULE(numpy_array, sm) { def_index_fn(mutate_data_t, arr_t &); def_index_fn(at_t, const arr_t &); def_index_fn(mutate_at_t, arr_t &); + def_index_fn(subscript_via_call_operator_t, arr_t &); + def_index_fn(const_subscript_via_call_operator_t, const arr_t &); // test_make_c_f_array sm.def("make_f_array", [] { return py::array_t({2, 2}, {4, 8}); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index bc7b3d555..53b899de2 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -111,20 +111,32 @@ def test_data(arr, args, ret): assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0) +@pytest.mark.parametrize( + "func", + [ + m.at_t, + m.mutate_at_t, + m.const_subscript_via_call_operator_t, + m.subscript_via_call_operator_t, + ][: 2 if m.defined_NDEBUG else 99], +) @pytest.mark.parametrize("dim", [0, 1, 3]) -def test_at_fail(arr, dim): - for func in m.at_t, m.mutate_at_t: - with pytest.raises(IndexError) as excinfo: - func(arr, *([0] * dim)) - assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" +def test_elem_reference(arr, func, dim): + with pytest.raises(IndexError) as excinfo: + func(arr, *([0] * dim)) + assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" -def test_at(arr): - assert m.at_t(arr, 0, 2) == 3 - assert m.at_t(arr, 1, 0) == 4 +@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t]) +def test_const_elem_reference(arr, func): + assert func(arr, 0, 2) == 3 + assert func(arr, 1, 0) == 4 - assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) - assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) + +@pytest.mark.parametrize("func", [m.mutate_at_t, m.subscript_via_call_operator_t]) +def test_mutable_elem_reference(arr, func): + assert all(func(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) + assert all(func(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) def test_mutate_readonly(arr): @@ -153,8 +165,9 @@ def test_mutate_data(arr): assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) -def test_bounds_check(arr): - for func in ( +@pytest.mark.parametrize( + "func", + [ m.index_at, m.index_at_t, m.data, @@ -163,13 +176,17 @@ def test_bounds_check(arr): m.mutate_data_t, m.at_t, m.mutate_at_t, - ): - with pytest.raises(IndexError) as excinfo: - func(arr, 2, 0) - assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" - with pytest.raises(IndexError) as excinfo: - func(arr, 0, 4) - assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" + m.const_subscript_via_call_operator_t, + m.subscript_via_call_operator_t, + ][: 8 if m.defined_NDEBUG else 99], +) +def test_bounds_check(arr, func): + with pytest.raises(IndexError) as excinfo: + func(arr, 2, 0) + assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" + with pytest.raises(IndexError) as excinfo: + func(arr, 0, 4) + assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" def test_make_c_f_array():