mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-29 08:32:02 +00:00
Compare commits
9 Commits
73fd5a97e0
...
ff0275aaf1
Author | SHA1 | Date | |
---|---|---|---|
|
ff0275aaf1 | ||
|
8917a1e9b3 | ||
|
d2ea386ef7 | ||
|
136c664b5a | ||
|
17912ba667 | ||
|
f52a31a004 | ||
|
87fd2c5121 | ||
|
cba2b32ead | ||
|
4d785be984 |
@ -1232,9 +1232,7 @@ public:
|
|||||||
// Reference to element at a given index
|
// Reference to element at a given index
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
const T &at(Ix... index) const {
|
const T &at(Ix... index) const {
|
||||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
check_rank_precondition(sizeof...(index));
|
||||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
|
||||||
}
|
|
||||||
return *(static_cast<const T *>(array::data())
|
return *(static_cast<const T *>(array::data())
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||||
}
|
}
|
||||||
@ -1242,13 +1240,33 @@ public:
|
|||||||
// Mutable reference to element at a given index
|
// Mutable reference to element at a given index
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
T &mutable_at(Ix... index) {
|
T &mutable_at(Ix... index) {
|
||||||
if ((ssize_t) sizeof...(index) != ndim()) {
|
check_rank_precondition(sizeof...(index));
|
||||||
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
|
||||||
}
|
|
||||||
return *(static_cast<T *>(array::mutable_data())
|
return *(static_cast<T *>(array::mutable_data())
|
||||||
+ byte_offset(ssize_t(index)...) / itemsize());
|
+ byte_offset(ssize_t(index)...) / itemsize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// const-reference to element at a given index without bounds checking
|
||||||
|
template <typename... Ix>
|
||||||
|
const T &operator()(Ix... index) const {
|
||||||
|
#if !defined(NDEBUG)
|
||||||
|
check_rank_precondition(sizeof...(index));
|
||||||
|
check_dimensions(index...);
|
||||||
|
#endif
|
||||||
|
return *(static_cast<const T *>(array::data())
|
||||||
|
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
// mutable reference to element at a given index without bounds checking
|
||||||
|
template <typename... Ix>
|
||||||
|
T &operator()(Ix... index) {
|
||||||
|
#if !defined(NDEBUG)
|
||||||
|
check_rank_precondition(sizeof...(index));
|
||||||
|
check_dimensions(index...);
|
||||||
|
#endif
|
||||||
|
return *(static_cast<T *>(array::mutable_data())
|
||||||
|
+ detail::byte_offset_unsafe(strides(), ssize_t(index)...) / itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a proxy object that provides access to the array's data without bounds or
|
* Returns a proxy object that provides access to the array's data without bounds or
|
||||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||||
@ -1305,6 +1323,13 @@ protected:
|
|||||||
| ExtraFlags,
|
| ExtraFlags,
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void check_rank_precondition(ssize_t dim) const {
|
||||||
|
if (dim != ndim()) {
|
||||||
|
fail_dim_check(dim, "index dimension mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -131,6 +131,15 @@ arr_t &mutate_at_t(arr_t &a, Ix... idx) {
|
|||||||
a.mutable_at(idx...)++;
|
a.mutable_at(idx...)++;
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
template <typename... Ix>
|
||||||
|
arr_t &subscript_via_call_operator_t(arr_t &a, Ix... idx) {
|
||||||
|
a(idx...)++;
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
template <typename... Ix>
|
||||||
|
py::ssize_t const_subscript_via_call_operator_t(const arr_t &a, Ix... idx) {
|
||||||
|
return a(idx...);
|
||||||
|
}
|
||||||
|
|
||||||
#define def_index_fn(name, type) \
|
#define def_index_fn(name, type) \
|
||||||
sm.def(#name, [](type a) { return name(a); }); \
|
sm.def(#name, [](type a) { return name(a); }); \
|
||||||
@ -197,6 +206,13 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
||||||
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
||||||
|
|
||||||
|
sm.attr("defined_NDEBUG") =
|
||||||
|
#ifdef NDEBUG
|
||||||
|
true;
|
||||||
|
#else
|
||||||
|
false;
|
||||||
|
#endif
|
||||||
|
|
||||||
// test_index_offset
|
// test_index_offset
|
||||||
def_index_fn(index_at, const arr &);
|
def_index_fn(index_at, const arr &);
|
||||||
def_index_fn(index_at_t, const arr_t &);
|
def_index_fn(index_at_t, const arr_t &);
|
||||||
@ -210,6 +226,8 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
def_index_fn(mutate_data_t, arr_t &);
|
def_index_fn(mutate_data_t, arr_t &);
|
||||||
def_index_fn(at_t, const arr_t &);
|
def_index_fn(at_t, const arr_t &);
|
||||||
def_index_fn(mutate_at_t, arr_t &);
|
def_index_fn(mutate_at_t, arr_t &);
|
||||||
|
def_index_fn(subscript_via_call_operator_t, arr_t &);
|
||||||
|
def_index_fn(const_subscript_via_call_operator_t, const arr_t &);
|
||||||
|
|
||||||
// test_make_c_f_array
|
// test_make_c_f_array
|
||||||
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
|
||||||
|
@ -111,20 +111,32 @@ def test_data(arr, args, ret):
|
|||||||
assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0)
|
assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"func",
|
||||||
|
[
|
||||||
|
m.at_t,
|
||||||
|
m.mutate_at_t,
|
||||||
|
m.const_subscript_via_call_operator_t,
|
||||||
|
m.subscript_via_call_operator_t,
|
||||||
|
][: 2 if m.defined_NDEBUG else 99],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dim", [0, 1, 3])
|
@pytest.mark.parametrize("dim", [0, 1, 3])
|
||||||
def test_at_fail(arr, dim):
|
def test_elem_reference(arr, func, dim):
|
||||||
for func in m.at_t, m.mutate_at_t:
|
|
||||||
with pytest.raises(IndexError) as excinfo:
|
with pytest.raises(IndexError) as excinfo:
|
||||||
func(arr, *([0] * dim))
|
func(arr, *([0] * dim))
|
||||||
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
|
assert str(excinfo.value) == f"index dimension mismatch: {dim} (ndim = 2)"
|
||||||
|
|
||||||
|
|
||||||
def test_at(arr):
|
@pytest.mark.parametrize("func", [m.at_t, m.const_subscript_via_call_operator_t])
|
||||||
assert m.at_t(arr, 0, 2) == 3
|
def test_const_elem_reference(arr, func):
|
||||||
assert m.at_t(arr, 1, 0) == 4
|
assert func(arr, 0, 2) == 3
|
||||||
|
assert func(arr, 1, 0) == 4
|
||||||
|
|
||||||
assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
|
||||||
assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
@pytest.mark.parametrize("func", [m.mutate_at_t, m.subscript_via_call_operator_t])
|
||||||
|
def test_mutable_elem_reference(arr, func):
|
||||||
|
assert all(func(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6])
|
||||||
|
assert all(func(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6])
|
||||||
|
|
||||||
|
|
||||||
def test_mutate_readonly(arr):
|
def test_mutate_readonly(arr):
|
||||||
@ -153,8 +165,9 @@ def test_mutate_data(arr):
|
|||||||
assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
|
assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197])
|
||||||
|
|
||||||
|
|
||||||
def test_bounds_check(arr):
|
@pytest.mark.parametrize(
|
||||||
for func in (
|
"func",
|
||||||
|
[
|
||||||
m.index_at,
|
m.index_at,
|
||||||
m.index_at_t,
|
m.index_at_t,
|
||||||
m.data,
|
m.data,
|
||||||
@ -163,7 +176,11 @@ def test_bounds_check(arr):
|
|||||||
m.mutate_data_t,
|
m.mutate_data_t,
|
||||||
m.at_t,
|
m.at_t,
|
||||||
m.mutate_at_t,
|
m.mutate_at_t,
|
||||||
):
|
m.const_subscript_via_call_operator_t,
|
||||||
|
m.subscript_via_call_operator_t,
|
||||||
|
][: 8 if m.defined_NDEBUG else 99],
|
||||||
|
)
|
||||||
|
def test_bounds_check(arr, func):
|
||||||
with pytest.raises(IndexError) as excinfo:
|
with pytest.raises(IndexError) as excinfo:
|
||||||
func(arr, 2, 0)
|
func(arr, 2, 0)
|
||||||
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"
|
assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2"
|
||||||
|
Loading…
Reference in New Issue
Block a user