mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 22:52:01 +00:00
Compare commits
9 Commits
73fd5a97e0
...
ff0275aaf1
Author | SHA1 | Date | |
---|---|---|---|
|
ff0275aaf1 | ||
|
8917a1e9b3 | ||
|
d2ea386ef7 | ||
|
136c664b5a | ||
|
17912ba667 | ||
|
f52a31a004 | ||
|
87fd2c5121 | ||
|
cba2b32ead | ||
|
4d785be984 |
@ -1232,9 +1232,7 @@ public:
|
||||
// Reference to element at a given index
|
||||
template <typename... Ix>
|
||||
const T &at(Ix... index) const {
|
||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
}
|
||||
check_rank_precondition(sizeof...(index));
|
||||
return *(static_cast<const T *>(array::data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
@ -1242,13 +1240,33 @@ public:
|
||||
// Mutable reference to element at a given index
|
||||
template <typename... Ix>
|
||||
T &mutable_at(Ix... index) {
|
||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||
}
|
||||
check_rank_precondition(sizeof...(index));
|
||||
return *(static_cast<T *>(array::mutable_data())
|
||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
// const-reference to element at a given index without bounds checking
|
||||
template <typename... Ix>
|
||||
const T &operator()(Ix... index) const {
|
||||
#if !defined(NDEBUG)
|
||||
check_rank_precondition(sizeof...(index));
|
||||
check_dimensions(index...);
|
||||
#endif
|
||||
return *(static_cast<const T *>(array::data())
|
||||
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
// mutable reference to element at a given index without bounds checking
|
||||
template <typename... Ix>
|
||||
T &operator()(Ix... index) {
|
||||
#if !defined(NDEBUG)
|
||||
check_rank_precondition(sizeof...(index));
|
||||
check_dimensions(index...);
|
||||
#endif
|
||||
return *(static_cast<T *>(array::mutable_data())
|
||||
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a proxy object that provides access to the array's data without bounds or
|
||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||
@ -1305,6 +1323,13 @@ protected:
|
||||
| ExtraFlags,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
private:
|
||||
void check_rank_precondition(ssize_t dim) const {
|
||||
if (dim != ndim()) {
|
||||
fail_dim_check(dim, "index dimension mismatch");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -131,6 +131,15 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
|
||||
a.mutable_at(idx...)++;
|
||||
return a;
|
||||
}
|
||||
template <typename... Ix>
|
||||
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
|
||||
a(idx...)++;
|
||||
return a;
|
||||
}
|
||||
template <typename... Ix>
|
||||
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
|
||||
return a(idx...);
|
||||
}
|
||||
|
||||
#define def_index_fn(name, type) \
|
||||
sm.def(#name, [](type a) { return name(a); }); \
|
||||
@ -197,6 +206,13 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
||||
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
||||
|
||||
sm.attr("defined_NDEBUG") =
|
||||
#ifdef NDEBUG
|
||||
true;
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
|
||||
// test_index_offset
|
||||
def_index_fn(index_at, const arr &);
|
||||
def_index_fn(index_at_t, const arr_t &);
|
||||
@ -210,6 +226,8 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
def_index_fn(mutate_data_t, arr_t &);
|
||||
def_index_fn(at_t, const arr_t &);
|
||||
def_index_fn(mutate_at_t, arr_t &);
|
||||
def_index_fn(subscript_via_call_operator_t, arr_t &);
|
||||
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
|
||||
|
||||
// test_make_c_f_array
|
||||
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
||||
|
@ -111,20 +111,32 @@ def test_data(arr, args, ret):
|
||||
assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
m.at_t,
|
||||
m.mutate_at_t,
|
||||
m.const_subscript_via_call_operator_t,
|
||||
m.subscript_via_call_operator_t,
|
||||
][: 2 if m.defined_NDEBUG else 99],
|
||||
)
|
||||
@pytest.mark.parametrize("dim", [0, 1, 3])
|
||||
def test_at_fail(arr, dim):
|
||||
for func in m.at_t, m.mutate_at_t:
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, *([0] * dim))
|
||||
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
|
||||
def test_elem_reference(arr, func, dim):
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, *([0] * dim))
|
||||
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
|
||||
|
||||
|
||||
def test_at(arr):
|
||||
assert m.at_t(arr, 0, 2) == 3
|
||||
assert m.at_t(arr, 1, 0) == 4
|
||||
@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t])
|
||||
def test_const_elem_reference(arr, func):
|
||||
assert func(arr, 0, 2) == 3
|
||||
assert func(arr, 1, 0) == 4
|
||||
|
||||
assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||
|
||||
@pytest.mark.parametrize("func", [m.mutate_at_t, m.subscript_via_call_operator_t])
|
||||
def test_mutable_elem_reference(arr, func):
|
||||
assert all(func(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||
assert all(func(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||
|
||||
|
||||
def test_mutate_readonly(arr):
|
||||
@ -153,8 +165,9 @@ def test_mutate_data(arr):
|
||||
assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
|
||||
|
||||
|
||||
def test_bounds_check(arr):
|
||||
for func in (
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
m.index_at,
|
||||
m.index_at_t,
|
||||
m.data,
|
||||
@ -163,13 +176,17 @@ def test_bounds_check(arr):
|
||||
m.mutate_data_t,
|
||||
m.at_t,
|
||||
m.mutate_at_t,
|
||||
):
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, 2, 0)
|
||||
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, 0, 4)
|
||||
assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3"
|
||||
m.const_subscript_via_call_operator_t,
|
||||
m.subscript_via_call_operator_t,
|
||||
][: 8 if m.defined_NDEBUG else 99],
|
||||
)
|
||||
def test_bounds_check(arr, func):
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, 2, 0)
|
||||
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"
|
||||
with pytest.raises(IndexError) as excinfo:
|
||||
func(arr, 0, 4)
|
||||
assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3"
|
||||
|
||||
|
||||
def test_make_c_f_array():
|
||||
|
Loading…
Reference in New Issue
Block a user