reshape for numpy arrays (#984)

* reshape

* more tests

* Update numpy.h

* Update test_numpy_array.py

* Update numpy.h

* Update numpy.h

* Update test_numpy_array.cpp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix merge bug

* Make clang-tidy happy

* Add xfail for PyPy

* Fix casting issue

* Address reviews on additional tests

* Fix ordering

* Do a little more reordering

* Fix typo

* Try improving tests

* Fix error in reshape

* Add one more reshape test

* streamlining new tests; removing a few stray msg

Co-authored-by: ncullen93 <ncullen.th@dartmouth.edu>
Co-authored-by: NC Cullen <nicholas.c.cullen.th@dartmouth.edu>
Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ralf Grosse-Kunstleve <rwgk@google.com>
This commit is contained in:
Nick Cullen 2021-08-26 17:12:35 +02:00 committed by GitHub
parent 031a700dfd
commit 59ad1e7d05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 5 deletions

View File

@ -198,6 +198,8 @@ struct npy_api {
// Unused. Not removed because that affects ABI of the class.
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
private:
enum functions {
API_PyArray_GetNDArrayCFeatureVersion = 211,
@ -212,10 +214,11 @@ private:
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 96,
API_PyArray_Newshape = 135,
API_PyArray_Squeeze = 136,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136,
API_PyArray_SetBaseObject = 282
};
@ -243,11 +246,13 @@ private:
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_Newshape);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);
#undef DECL_NPY_API
return api;
}
@ -785,6 +790,18 @@ public:
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
}
/// Optional `order` parameter omitted, to be added as needed.
array reshape(ShapeContainer new_shape) {
detail::npy_api::PyArray_Dims d
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
auto new_array
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
if (!new_array) {
throw error_already_set();
}
return new_array;
}
/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {

View File

@ -405,6 +405,13 @@ TEST_SUBMODULE(numpy_array, sm) {
return a;
});
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
return a.reshape({N, M, O});
});
sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
return a.reshape(new_shape);
});
sm.def("index_using_ellipsis",
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });

View File

@ -411,7 +411,7 @@ def test_array_unchecked_fixed_dims(msg):
assert m.proxy_auxiliaries2_const_ref(z1)
def test_array_unchecked_dyn_dims(msg):
def test_array_unchecked_dyn_dims():
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
m.proxy_add2_dyn(z1, 10)
assert np.all(z1 == [[11, 12], [13, 14]])
@ -444,7 +444,7 @@ def test_initializer_list():
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
def test_array_resize(msg):
def test_array_resize():
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
m.array_reshape2(a)
assert a.size == 9
@ -470,12 +470,37 @@ def test_array_resize(msg):
@pytest.mark.xfail("env.PYPY")
def test_array_create_and_resize(msg):
def test_array_create_and_resize():
a = m.create_and_resize(2)
assert a.size == 4
assert np.all(a == 42.0)
def test_reshape_initializer_list():
a = np.arange(2 * 7 * 3) + 1
x = m.reshape_initializer_list(a, 2, 7, 3)
assert x.shape == (2, 7, 3)
assert list(x[1][4]) == [34, 35, 36]
with pytest.raises(ValueError) as excinfo:
m.reshape_initializer_list(a, 1, 7, 3)
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"
def test_reshape_tuple():
a = np.arange(3 * 7 * 2) + 1
x = m.reshape_tuple(a, (3, 7, 2))
assert x.shape == (3, 7, 2)
assert list(x[1][4]) == [23, 24]
y = m.reshape_tuple(x, (x.size,))
assert y.shape == (42,)
with pytest.raises(ValueError) as excinfo:
m.reshape_tuple(a, (3, 7, 1))
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
with pytest.raises(ValueError) as excinfo:
m.reshape_tuple(a, ())
assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"
def test_index_using_ellipsis():
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
assert a.shape == (6,)