From 4d785be984b5e60f76b23ae61436b233c3991435 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 15:19:04 +0200 Subject: [PATCH 1/8] array_t: overload operator () for indexing --- include/pybind11/numpy.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 36077ec04..7108e06c4 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1110,6 +1110,20 @@ public: + byte_offset(ssize_t(index)...) / itemsize()); } + // const-reference to element at a given index without bounds checking + template + const T &operator()(Ix... index) const { + return *(static_cast(array::data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + + // mutable reference to element at a given index without bounds checking + template + T &operator()(Ix... index) { + return *(static_cast(array::mutable_data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + /** * Returns a proxy object that provides access to the array's data without bounds or * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with From cba2b32eadecdd33da2e9a8aafae5986bd46c739 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 15:25:55 +0200 Subject: [PATCH 2/8] add tests --- tests/test_numpy_array.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 8c122a865..6c20dc42a 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -131,6 +131,15 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) { a.mutable_at(idx...)++; return a; } +template +arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) { + a(idx...)++; + return a; +} +template +py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) { + return a(idx...); +} #define def_index_fn(name, type) \ sm.def(#name, [](type a) { return name(a); }); \ @@ -210,6 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) { def_index_fn(mutate_data_t, arr_t &); def_index_fn(at_t, const arr_t &); def_index_fn(mutate_at_t, arr_t &); + def_index_fn(call_operator_subscript_t, arr_t &); + def_index_fn(const_call_operator_subscript_t, const arr_t &); // test_make_c_f_array sm.def("make_f_array", [] { return py::array_t({2, 2}, {4, 8}); }); From 87fd2c5121c04fcf160d91c6e99d07e65f4ef464 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 17:12:01 +0200 Subject: [PATCH 3/8] array_t: tests, refine impl for indexing via operator() --- include/pybind11/numpy.h | 46 ++++++++++++++++++++++++++------------ tests/test_numpy_array.cpp | 8 +++---- tests/test_numpy_array.py | 7 ++++++ 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 7108e06c4..cf64fafeb 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,35 +1093,33 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); + check_access_precondition(index...); + return const_reference(index...); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); + check_access_precondition(index...); + return mutable_reference(index...); } // const-reference to element at a given index without bounds checking template const T &operator()(Ix... index) const { - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); +#if defined(NDEBUG) + check_access_precondition(index...); +#endif + return const_reference(index...); } // mutable reference to element at a given index without bounds checking template T &operator()(Ix... index) { - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); +#if defined(NDEBUG) + check_access_precondition(index...); +#endif + return mutable_reference(index...); } /** @@ -1180,6 +1178,26 @@ protected: | ExtraFlags, nullptr); } + +private: + template + const T &const_reference(Ix... index) const { + return *(static_cast(array::data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + + template + T &mutable_reference(Ix... index) { + return *(static_cast(array::mutable_data()) + + byte_offset(ssize_t(index)...) / itemsize()); + } + + template + void check_access_precondition(Ix... index) const { + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } + } }; template diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 6c20dc42a..8942672e9 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -132,12 +132,12 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) { return a; } template -arr_t &call_operator_subscript_t(arr_t &a, Ix... idx) { +arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) { a(idx...)++; return a; } template -py::ssize_t const_call_operator_subscript_t(const arr_t &a, Ix... idx) { +py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) { return a(idx...); } @@ -219,8 +219,8 @@ TEST_SUBMODULE(numpy_array, sm) { def_index_fn(mutate_data_t, arr_t &); def_index_fn(at_t, const arr_t &); def_index_fn(mutate_at_t, arr_t &); - def_index_fn(call_operator_subscript_t, arr_t &); - def_index_fn(const_call_operator_subscript_t, const arr_t &); + def_index_fn(subscript_via_call_operator_t, arr_t &); + def_index_fn(const_subscript_via_call_operator_t, const arr_t &); // test_make_c_f_array sm.def("make_f_array", [] { return py::array_t({2, 2}, {4, 8}); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 12e7d17d1..cfad6f09e 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -125,6 +125,13 @@ def test_at(arr): assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) +def test_subscript_via_call_operator(arr): + assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3 + assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4 + assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) + assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) + + def test_mutate_readonly(arr): arr.flags.writeable = False for func, args in ( From f52a31a004929409f70fc4853daeed679bb995d6 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 18:04:40 +0200 Subject: [PATCH 4/8] fix preproc condition --- include/pybind11/numpy.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index cf64fafeb..89e88dc49 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1107,7 +1107,7 @@ public: // const-reference to element at a given index without bounds checking template const T &operator()(Ix... index) const { -#if defined(NDEBUG) +#if !defined(NDEBUG) check_access_precondition(index...); #endif return const_reference(index...); @@ -1116,7 +1116,7 @@ public: // mutable reference to element at a given index without bounds checking template T &operator()(Ix... index) { -#if defined(NDEBUG) +#if !defined(NDEBUG) check_access_precondition(index...); #endif return mutable_reference(index...); From 17912ba667ba2d4ba1f2354950120fc21e690208 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 18:20:52 +0200 Subject: [PATCH 5/8] move check --- include/pybind11/numpy.h | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 89e88dc49..057db0243 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,14 +1093,18 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - check_access_precondition(index...); + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } return const_reference(index...); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - check_access_precondition(index...); + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } return mutable_reference(index...); } @@ -1108,7 +1112,9 @@ public: template const T &operator()(Ix... index) const { #if !defined(NDEBUG) - check_access_precondition(index...); + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } #endif return const_reference(index...); } @@ -1117,7 +1123,9 @@ public: template T &operator()(Ix... index) { #if !defined(NDEBUG) - check_access_precondition(index...); + if ((ssize_t) sizeof...(index) != ndim()) { + fail_dim_check(sizeof...(index), "index dimension mismatch"); + } #endif return mutable_reference(index...); } @@ -1191,13 +1199,6 @@ private: return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); } - - template - void check_access_precondition(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } - } }; template From 136c664b5a6240778479806ed806310c2f8db786 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 18:45:13 +0200 Subject: [PATCH 6/8] reduce redundancy --- include/pybind11/numpy.h | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 057db0243..f01c9d154 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,18 +1093,14 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_dim_precondition(sizeof...(index)); return const_reference(index...); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_dim_precondition(sizeof...(index)); return mutable_reference(index...); } @@ -1112,9 +1108,7 @@ public: template const T &operator()(Ix... index) const { #if !defined(NDEBUG) - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_dim_precondition(sizeof...(index)); #endif return const_reference(index...); } @@ -1123,9 +1117,7 @@ public: template T &operator()(Ix... index) { #if !defined(NDEBUG) - if ((ssize_t) sizeof...(index) != ndim()) { - fail_dim_check(sizeof...(index), "index dimension mismatch"); - } + check_dim_precondition(sizeof...(index)); #endif return mutable_reference(index...); } @@ -1199,6 +1191,12 @@ private: return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); } + + void check_dim_precondition(ssize_t dim) const { + if (dim != ndim()) { + fail_dim_check(dim, "index dimension mismatch"); + } + } }; template From d2ea386ef7b6fe974fca5faa08f7c90396d85a7f Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Wed, 28 Jun 2023 20:51:17 +0200 Subject: [PATCH 7/8] make changes as per review comments --- tests/test_numpy_array.cpp | 7 ++++++ tests/test_numpy_array.py | 45 ++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 8942672e9..51bf577cc 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -206,6 +206,13 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("nbytes", [](const arr &a) { return a.nbytes(); }); sm.def("owndata", [](const arr &a) { return a.owndata(); }); + sm.attr("defined_NDEBUG") = +#ifdef NDEBUG + true; +#else + false; +#endif + // test_index_offset def_index_fn(index_at, const arr &); def_index_fn(index_at_t, const arr_t &); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index cfad6f09e..b2434a07e 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -109,27 +109,40 @@ def test_data(arr, args, ret): assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0) +@pytest.mark.parametrize( + "func", + [ + m.at_t, + m.mutate_at_t, + m.const_subscript_via_call_operator_t, + m.subscript_via_call_operator_t, + ][: 2 if m.defined_NDEBUG else 99], +) @pytest.mark.parametrize("dim", [0, 1, 3]) -def test_at_fail(arr, dim): - for func in m.at_t, m.mutate_at_t: - with pytest.raises(IndexError) as excinfo: - func(arr, *([0] * dim)) - assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" +def test_elem_reference(arr, func, dim): + with pytest.raises(IndexError) as excinfo: + func(arr, *([0] * dim)) + assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" -def test_at(arr): - assert m.at_t(arr, 0, 2) == 3 - assert m.at_t(arr, 1, 0) == 4 - - assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) - assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) +# @pytest.mark.parametrize("dim", [0, 1, 3]) +# def test_at_fail(arr, dim): +# for func in m.at_t, m.mutate_at_t: +# with pytest.raises(IndexError) as excinfo: +# func(arr, *([0] * dim)) +# assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" -def test_subscript_via_call_operator(arr): - assert m.const_subscript_via_call_operator_t(arr, 0, 2) == 3 - assert m.const_subscript_via_call_operator_t(arr, 1, 0) == 4 - assert all(m.subscript_via_call_operator_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) - assert all(m.subscript_via_call_operator_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) +@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t]) +def test_const_elem_reference(arr, func): + assert func(arr, 0, 2) == 3 + assert func(arr, 1, 0) == 4 + + +@pytest.mark.parametrize("func", [m.mutate_at_t, m.subscript_via_call_operator_t]) +def test_mutable_elem_reference(arr, func): + assert all(func(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) + assert all(func(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) def test_mutate_readonly(arr): From 8917a1e9b396c99276d32c1883fbc4a7392c1bd8 Mon Sep 17 00:00:00 2001 From: Francesco Rizzi Date: Thu, 29 Jun 2023 09:12:17 +0200 Subject: [PATCH 8/8] add bounds check and test --- include/pybind11/numpy.h | 36 +++++++++++++++--------------------- tests/test_numpy_array.py | 31 ++++++++++++++----------------- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index f01c9d154..fe226e4df 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1093,33 +1093,39 @@ public: // Reference to element at a given index template const T &at(Ix... index) const { - check_dim_precondition(sizeof...(index)); - return const_reference(index...); + check_rank_precondition(sizeof...(index)); + return *(static_cast(array::data()) + + byte_offset(ssize_t(index)...) / itemsize()); } // Mutable reference to element at a given index template T &mutable_at(Ix... index) { - check_dim_precondition(sizeof...(index)); - return mutable_reference(index...); + check_rank_precondition(sizeof...(index)); + return *(static_cast(array::mutable_data()) + + byte_offset(ssize_t(index)...) / itemsize()); } // const-reference to element at a given index without bounds checking template const T &operator()(Ix... index) const { #if !defined(NDEBUG) - check_dim_precondition(sizeof...(index)); + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); #endif - return const_reference(index...); + return *(static_cast(array::data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); } // mutable reference to element at a given index without bounds checking template T &operator()(Ix... index) { #if !defined(NDEBUG) - check_dim_precondition(sizeof...(index)); + check_rank_precondition(sizeof...(index)); + check_dimensions(index...); #endif - return mutable_reference(index...); + return *(static_cast(array::mutable_data()) + + detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize()); } /** @@ -1180,19 +1186,7 @@ protected: } private: - template - const T &const_reference(Ix... index) const { - return *(static_cast(array::data()) - + byte_offset(ssize_t(index)...) / itemsize()); - } - - template - T &mutable_reference(Ix... index) { - return *(static_cast(array::mutable_data()) - + byte_offset(ssize_t(index)...) / itemsize()); - } - - void check_dim_precondition(ssize_t dim) const { + void check_rank_precondition(ssize_t dim) const { if (dim != ndim()) { fail_dim_check(dim, "index dimension mismatch"); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index b2434a07e..222a23125 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -125,14 +125,6 @@ def test_elem_reference(arr, func, dim): assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" -# @pytest.mark.parametrize("dim", [0, 1, 3]) -# def test_at_fail(arr, dim): -# for func in m.at_t, m.mutate_at_t: -# with pytest.raises(IndexError) as excinfo: -# func(arr, *([0] * dim)) -# assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)" - - @pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t]) def test_const_elem_reference(arr, func): assert func(arr, 0, 2) == 3 @@ -171,8 +163,9 @@ def test_mutate_data(arr): assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) -def test_bounds_check(arr): - for func in ( +@pytest.mark.parametrize( + "func", + [ m.index_at, m.index_at_t, m.data, @@ -181,13 +174,17 @@ def test_bounds_check(arr): m.mutate_data_t, m.at_t, m.mutate_at_t, - ): - with pytest.raises(IndexError) as excinfo: - func(arr, 2, 0) - assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" - with pytest.raises(IndexError) as excinfo: - func(arr, 0, 4) - assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" + m.const_subscript_via_call_operator_t, + m.subscript_via_call_operator_t, + ][: 8 if m.defined_NDEBUG else 99], +) +def test_bounds_check(arr, func): + with pytest.raises(IndexError) as excinfo: + func(arr, 2, 0) + assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" + with pytest.raises(IndexError) as excinfo: + func(arr, 0, 4) + assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" def test_make_c_f_array():