Add negative indexing support to stl_bind. (#1882)

This commit is contained in:
ali-beep 2019-08-15 13:41:12 -04:00 committed by Wenzel Jakob
parent b2fdfd1228
commit 5ef13eb680
2 changed files with 52 additions and 23 deletions

View File

@ -115,6 +115,14 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
using SizeType = typename Vector::size_type; using SizeType = typename Vector::size_type;
using DiffType = typename Vector::difference_type; using DiffType = typename Vector::difference_type;
auto wrap_i = [](DiffType i, SizeType n) {
if (i < 0)
i += n;
if (i < 0 || (SizeType)i >= n)
throw index_error();
return i;
};
cl.def("append", cl.def("append",
[](Vector &v, const T &value) { v.push_back(value); }, [](Vector &v, const T &value) { v.push_back(value); },
arg("x"), arg("x"),
@ -159,10 +167,13 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
); );
cl.def("insert", cl.def("insert",
[](Vector &v, SizeType i, const T &x) { [](Vector &v, DiffType i, const T &x) {
if (i > v.size()) // Can't use wrap_i; i == v.size() is OK
if (i < 0)
i += v.size();
if (i < 0 || (SizeType)i > v.size())
throw index_error(); throw index_error();
v.insert(v.begin() + (DiffType) i, x); v.insert(v.begin() + i, x);
}, },
arg("i") , arg("x"), arg("i") , arg("x"),
"Insert an item at a given position." "Insert an item at a given position."
@ -180,11 +191,10 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
); );
cl.def("pop", cl.def("pop",
[](Vector &v, SizeType i) { [wrap_i](Vector &v, DiffType i) {
if (i >= v.size()) i = wrap_i(i, v.size());
throw index_error(); T t = v[(SizeType) i];
T t = v[i]; v.erase(v.begin() + i);
v.erase(v.begin() + (DiffType) i);
return t; return t;
}, },
arg("i"), arg("i"),
@ -192,10 +202,9 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
); );
cl.def("__setitem__", cl.def("__setitem__",
[](Vector &v, SizeType i, const T &t) { [wrap_i](Vector &v, DiffType i, const T &t) {
if (i >= v.size()) i = wrap_i(i, v.size());
throw index_error(); v[(SizeType)i] = t;
v[i] = t;
} }
); );
@ -238,10 +247,9 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t
); );
cl.def("__delitem__", cl.def("__delitem__",
[](Vector &v, SizeType i) { [wrap_i](Vector &v, DiffType i) {
if (i >= v.size()) i = wrap_i(i, v.size());
throw index_error(); v.erase(v.begin() + i);
v.erase(v.begin() + DiffType(i));
}, },
"Delete the list elements at index ``i``" "Delete the list elements at index ``i``"
); );
@ -277,13 +285,21 @@ template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) { void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type; using T = typename Vector::value_type;
using SizeType = typename Vector::size_type; using SizeType = typename Vector::size_type;
using DiffType = typename Vector::difference_type;
using ItType = typename Vector::iterator; using ItType = typename Vector::iterator;
auto wrap_i = [](DiffType i, SizeType n) {
if (i < 0)
i += n;
if (i < 0 || (SizeType)i >= n)
throw index_error();
return i;
};
cl.def("__getitem__", cl.def("__getitem__",
[](Vector &v, SizeType i) -> T & { [wrap_i](Vector &v, DiffType i) -> T & {
if (i >= v.size()) i = wrap_i(i, v.size());
throw index_error(); return v[(SizeType)i];
return v[i];
}, },
return_value_policy::reference_internal // ref + keepalive return_value_policy::reference_internal // ref + keepalive
); );
@ -303,12 +319,15 @@ template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) { void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type; using T = typename Vector::value_type;
using SizeType = typename Vector::size_type; using SizeType = typename Vector::size_type;
using DiffType = typename Vector::difference_type;
using ItType = typename Vector::iterator; using ItType = typename Vector::iterator;
cl.def("__getitem__", cl.def("__getitem__",
[](const Vector &v, SizeType i) -> T { [](const Vector &v, DiffType i) -> T {
if (i >= v.size()) if (i < 0 && (i += v.size()) < 0)
throw index_error(); throw index_error();
return v[i]; if ((SizeType)i >= v.size())
throw index_error();
return v[(SizeType)i];
} }
); );

View File

@ -53,6 +53,16 @@ def test_vector_int():
v_int2.extend(x for x in range(5)) v_int2.extend(x for x in range(5))
assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]) assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4])
# test negative indexing
assert v_int2[-1] == 4
# insert with negative index
v_int2.insert(-1, 88)
assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4])
# delete negative index
del v_int2[-1]
assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88])
# related to the PyPy's buffer protocol. # related to the PyPy's buffer protocol.
@pytest.unsupported_on_pypy @pytest.unsupported_on_pypy