From 083a0219b5962a146a5e1556d7ce5e5187cb8bca Mon Sep 17 00:00:00 2001 From: uentity Date: Thu, 13 Apr 2017 21:41:55 +0500 Subject: [PATCH] array: implement array resize --- include/pybind11/numpy.h | 23 +++++++++++++++++++++++ tests/test_numpy_array.cpp | 23 ++++++++++++++++++++++- tests/test_numpy_array.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 44ea42721..ba9402b5c 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -129,6 +129,11 @@ struct npy_api { NPY_STRING_, NPY_UNICODE_, NPY_VOID_ }; + typedef struct { + Py_intptr_t *ptr; + int len; + } PyArray_Dims; + static npy_api& get() { static npy_api api = lookup(); return api; @@ -159,6 +164,7 @@ struct npy_api { Py_ssize_t *, PyObject **, PyObject *); PyObject *(*PyArray_Squeeze_)(PyObject *); int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); + PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); private: enum functions { API_PyArray_GetNDArrayCFeatureVersion = 211, @@ -168,6 +174,7 @@ private: API_PyArray_DescrFromType = 45, API_PyArray_DescrFromScalar = 57, API_PyArray_FromAny = 69, + API_PyArray_Resize = 80, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, API_PyArray_DescrNewFromType = 9, @@ -197,6 +204,7 @@ private: DECL_NPY_API(PyArray_DescrFromType); DECL_NPY_API(PyArray_DescrFromScalar); DECL_NPY_API(PyArray_FromAny); + DECL_NPY_API(PyArray_Resize); DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewFromDescr); DECL_NPY_API(PyArray_DescrNewFromType); @@ -652,6 +660,21 @@ public: return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); } + /// Resize array to given shape + /// If refcheck is true and more that one reference exist to this array + /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change + void resize(ShapeContainer new_shape, bool refcheck = true) { + detail::npy_api::PyArray_Dims d = { + new_shape->data(), int(new_shape->size()) + }; + // try to resize, set ordering param to -1 cause it's not used anyway + object new_array = reinterpret_steal( + detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) + ); + if (!new_array) throw error_already_set(); + if (isinstance(new_array)) { *this = std::move(new_array); } + } + /// Ensure that the argument is a NumPy array /// In case of an error, nullptr is returned and the Python error is cleared. static array ensure(handle h, int ExtraFlags = 0) { diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 269f18bbe..85c185f6b 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -273,4 +273,25 @@ test_initializer numpy_array([](py::module &m) { sm.def("array_initializer_list", []() { return py::array_t({ 1, 2 }); }); sm.def("array_initializer_list", []() { return py::array_t({ 1, 2, 3 }); }); sm.def("array_initializer_list", []() { return py::array_t({ 1, 2, 3, 4 }); }); -}); + + // reshape array to 2D without changing size + sm.def("array_reshape2", [](py::array_t a) { + const size_t dim_sz = (size_t)std::sqrt(a.size()); + if (dim_sz * dim_sz != a.size()) + throw std::domain_error("array_reshape2: input array total size is not a squared integer"); + a.resize({dim_sz, dim_sz}); + }); + + // resize to 3D array with each dimension = N + sm.def("array_resize3", [](py::array_t a, size_t N, bool refcheck) { + a.resize({N, N, N}, refcheck); + }); + + // return 2D array with Nrows = Ncols = N + sm.def("create_and_resize", [](size_t N) { + py::array_t a; + a.resize({N, N}); + std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.); + return a; + }); +}); \ No newline at end of file diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 6281fa478..10af7486a 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -389,3 +389,38 @@ def test_array_failure(): with pytest.raises(ValueError) as excinfo: array_t_fail_test() assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr' + + +def test_array_resize(msg): + from pybind11_tests.array import (array_reshape2, array_resize3) + + a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64') + array_reshape2(a) + assert(a.size == 9) + assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + # total size change should succced with refcheck off + array_resize3(a, 4, False) + assert(a.size == 64) + # ... and fail with refcheck on + try: + array_resize3(a, 3, True) + except ValueError as e: + assert(str(e).startswith("cannot resize an array")) + # transposed array doesn't own data + b = a.transpose() + try: + array_resize3(b, 3, False) + except ValueError as e: + assert(str(e).startswith("cannot resize this array: it does not own its data")) + # ... but reshape should be fine + array_reshape2(b) + assert(b.shape == (8, 8)) + + +@pytest.unsupported_on_pypy +def test_array_create_and_resize(msg): + from pybind11_tests.array import create_and_resize + a = create_and_resize(2) + assert(a.size == 4) + assert(np.all(a == 42.))