From 21c3911bd3eef2d763783ff2122d0fddb9958c01 Mon Sep 17 00:00:00 2001 From: sizmailov Date: Fri, 11 May 2018 21:30:15 +0300 Subject: [PATCH] add signed overload for `py::slice::compute` --- include/pybind11/pytypes.h | 7 +++++++ tests/test_sequences_and_iterators.cpp | 19 +++++++++++++++++++ tests/test_sequences_and_iterators.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index db7dfec26..2d573dfad 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1133,6 +1133,13 @@ public: (ssize_t *) stop, (ssize_t *) step, (ssize_t *) slicelength) == 0; } + bool compute(ssize_t length, ssize_t *start, ssize_t *stop, ssize_t *step, + ssize_t *slicelength) const { + return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, + length, start, + stop, step, + slicelength) == 0; + } }; class capsule : public object { diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp index a45521256..87ccf99d6 100644 --- a/tests/test_sequences_and_iterators.cpp +++ b/tests/test_sequences_and_iterators.cpp @@ -71,6 +71,25 @@ py::list test_random_access_iterator(PythonType x) { } TEST_SUBMODULE(sequences_and_iterators, m) { + // test_sliceable + class Sliceable{ + public: + Sliceable(int n): size(n) {} + int start,stop,step; + int size; + }; + py::class_(m,"Sliceable") + .def(py::init()) + .def("__getitem__",[](const Sliceable &s, py::slice slice) { + ssize_t start, stop, step, slicelength; + if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + int istart = static_cast(start); + int istop = static_cast(stop); + int istep = static_cast(step); + return std::make_tuple(istart,istop,istep); + }) + ; // test_sequence class Sequence { diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py index f6c062094..6bd160640 100644 --- a/tests/test_sequences_and_iterators.py +++ b/tests/test_sequences_and_iterators.py @@ -33,6 +33,19 @@ def test_generalized_iterators(): next(it) +def test_sliceable(): + sliceable = m.Sliceable(100) + assert sliceable[::] == (0, 100, 1) + assert sliceable[10::] == (10, 100, 1) + assert sliceable[:10:] == (0, 10, 1) + assert sliceable[::10] == (0, 100, 10) + assert sliceable[-10::] == (90, 100, 1) + assert sliceable[:-10:] == (0, 90, 1) + assert sliceable[::-10] == (99, -1, -10) + assert sliceable[50:60:1] == (50, 60, 1) + assert sliceable[50:60:-1] == (50, 60, -1) + + def test_sequence(): cstats = ConstructorStats.get(m.Sequence)