diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index 71917ce6e..458f99e97 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -364,3 +364,23 @@ uses of ``py::array``: The file :file:`tests/test_numpy_array.cpp` contains additional examples demonstrating the use of this feature. + +Ellipsis +======== + +Python 3 provides a convenient ``...`` ellipsis notation that is often used to +slice multidimensional arrays. For instance, the following snippet extracts the +middle dimensions of a tensor with the first and last index set to zero. + +.. code-block:: python + + a = # a NumPy array + b = a[0, ..., 0] + +The function ``py::ellipsis()`` function can be used to perform the same +operation on the C++ side: + +.. code-block:: cpp + + py::array a = /* A NumPy array */; + py::array b = a[py::make_tuple(0, py::ellipsis(), 0)]; diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index bcee8b5b8..976abf86e 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -693,6 +693,9 @@ inline bool PyIterable_Check(PyObject *obj) { } inline bool PyNone_Check(PyObject *o) { return o == Py_None; } +#if PY_MAJOR_VERSION >= 3 +inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; } +#endif inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); } @@ -967,6 +970,14 @@ public: none() : object(Py_None, borrowed_t{}) { } }; +#if PY_MAJOR_VERSION >= 3 +class ellipsis : public object { +public: + PYBIND11_OBJECT(ellipsis, object, detail::PyEllipsis_Check) + ellipsis() : object(Py_Ellipsis, borrowed_t{}) { } +}; +#endif + class bool_ : public object { public: PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 79a157e60..570259493 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -295,4 +295,10 @@ TEST_SUBMODULE(numpy_array, sm) { std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.); return a; }); + +#if PY_MAJOR_VERSION >= 3 + sm.def("index_using_ellipsis", [](py::array a) { + return a[py::make_tuple(0, py::ellipsis(), 0)]; + }); +#endif } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 1e83135bb..8ac0e66fb 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -408,3 +408,9 @@ def test_array_create_and_resize(msg): a = m.create_and_resize(2) assert(a.size == 4) assert(np.all(a == 42.)) + + +@pytest.unsupported_on_py2 +def test_index_using_ellipsis(): + a = m.index_using_ellipsis(np.zeros((5, 6, 7))) + assert a.shape == (6,)