add signed overload for py::slice::compute

This commit is contained in:
sizmailov 2018-05-11 21:30:15 +03:00 committed by Wenzel Jakob
parent 22859bb8fc
commit 21c3911bd3
3 changed files with 39 additions and 0 deletions

View File

@ -1133,6 +1133,13 @@ public:
(ssize_t *) stop, (ssize_t *) step, (ssize_t *) stop, (ssize_t *) step,
(ssize_t *) slicelength) == 0; (ssize_t *) slicelength) == 0;
} }
bool compute(ssize_t length, ssize_t *start, ssize_t *stop, ssize_t *step,
ssize_t *slicelength) const {
return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr,
length, start,
stop, step,
slicelength) == 0;
}
}; };
class capsule : public object { class capsule : public object {

View File

@ -71,6 +71,25 @@ py::list test_random_access_iterator(PythonType x) {
} }
TEST_SUBMODULE(sequences_and_iterators, m) { TEST_SUBMODULE(sequences_and_iterators, m) {
// test_sliceable
class Sliceable{
public:
Sliceable(int n): size(n) {}
int start,stop,step;
int size;
};
py::class_<Sliceable>(m,"Sliceable")
.def(py::init<int>())
.def("__getitem__",[](const Sliceable &s, py::slice slice) {
ssize_t start, stop, step, slicelength;
if (!slice.compute(s.size, &start, &stop, &step, &slicelength))
throw py::error_already_set();
int istart = static_cast<int>(start);
int istop = static_cast<int>(stop);
int istep = static_cast<int>(step);
return std::make_tuple(istart,istop,istep);
})
;
// test_sequence // test_sequence
class Sequence { class Sequence {

View File

@ -33,6 +33,19 @@ def test_generalized_iterators():
next(it) next(it)
def test_sliceable():
sliceable = m.Sliceable(100)
assert sliceable[::] == (0, 100, 1)
assert sliceable[10::] == (10, 100, 1)
assert sliceable[:10:] == (0, 10, 1)
assert sliceable[::10] == (0, 100, 10)
assert sliceable[-10::] == (90, 100, 1)
assert sliceable[:-10:] == (0, 90, 1)
assert sliceable[::-10] == (99, -1, -10)
assert sliceable[50:60:1] == (50, 60, 1)
assert sliceable[50:60:-1] == (50, 60, -1)
def test_sequence(): def test_sequence():
cstats = ConstructorStats.get(m.Sequence) cstats = ConstructorStats.get(m.Sequence)