diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 7108e06c4..cf64fafeb 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,35 +1093,33 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); + check_access_precondition(index...); + return const_reference(index...); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); + check_access_precondition(index...); + return mutable_reference(index...); } // const-reference to element at a given index without bounds checking template const T &operator()(Ix... index) const { - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); +#if defined(NDEBUG) + check_access_precondition(index...); +#endif + return const_reference(index...); } // mutable reference to element at a given index without bounds checking template T &operator()(Ix... index) { - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); +#if defined(NDEBUG) + check_access_precondition(index...); +#endif + return mutable_reference(index...); } /** @@ -1180,6 +1178,26 @@ protected: | ExtraFlags, nullptr); } + +private: + template + const T &const_reference(Ix... index) const { + return *(static_cast(array::data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + + template + T &mutable_reference(Ix... index) { + return *(static_cast(array::mutable_data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + + template + void check_access_precondition(Ix... index) const { + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } + } }; template diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 6c20dc42a..8942672e9 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -132,12 +132,12 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) { return a; } template -arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) { +arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) { a(idx...)++; return a; } template -py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) { +py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) { return a(idx...); } @@ -219,8 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) { def_index_fn(mutate_data_t, arr_t &); def_index_fn(at_t, const arr_t &); def_index_fn(mutate_at_t, arr_t &); - def_index_fn(call_operator_subscript_t, arr_t &); - def_index_fn(const_call_operator_subscript_t, const arr_t &); + def_index_fn(subscript_via_call_operator_t, arr_t &); + def_index_fn(const_subscript_via_call_operator_t, const arr_t &); // test_make_c_f_array sm.def("make_f_array", [] { return py::array_t({2, 2}, {4, 8}); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 12e7d17d1..cfad6f09e 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -125,6 +125,13 @@ def test_at(arr): assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) +def test_subscript_via_call_operator(arr): + assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3 + assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4 + assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) + assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) + + def test_mutate_readonly(arr): arr.flags.writeable = False for func, args in (