mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
reshape for numpy arrays (#984)
* reshape * more tests * Update numpy.h * Update test_numpy_array.py * Update numpy.h * Update numpy.h * Update test_numpy_array.cpp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge bug * Make clang-tidy happy * Add xfail for PyPy * Fix casting issue * Address reviews on additional tests * Fix ordering * Do a little more reordering * Fix typo * Try improving tests * Fix error in reshape * Add one more reshape test * streamlining new tests; removing a few stray msg Co-authored-by: ncullen93 <ncullen.th@dartmouth.edu> Co-authored-by: NC Cullen <nicholas.c.cullen.th@dartmouth.edu> Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf Grosse-Kunstleve <rwgk@google.com>
This commit is contained in:
parent
031a700dfd
commit
59ad1e7d05
@ -198,6 +198,8 @@ struct npy_api {
|
||||
// Unused. Not removed because that affects ABI of the class.
|
||||
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
|
||||
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
|
||||
|
||||
private:
|
||||
enum functions {
|
||||
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
||||
@ -212,10 +214,11 @@ private:
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrNewFromType = 96,
|
||||
API_PyArray_Newshape = 135,
|
||||
API_PyArray_Squeeze = 136,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
API_PyArray_Squeeze = 136,
|
||||
API_PyArray_SetBaseObject = 282
|
||||
};
|
||||
|
||||
@ -243,11 +246,13 @@ private:
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||
DECL_NPY_API(PyArray_Newshape);
|
||||
DECL_NPY_API(PyArray_Squeeze);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
DECL_NPY_API(PyArray_Squeeze);
|
||||
DECL_NPY_API(PyArray_SetBaseObject);
|
||||
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
}
|
||||
@ -785,6 +790,18 @@ public:
|
||||
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
|
||||
}
|
||||
|
||||
/// Optional `order` parameter omitted, to be added as needed.
|
||||
array reshape(ShapeContainer new_shape) {
|
||||
detail::npy_api::PyArray_Dims d
|
||||
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
|
||||
auto new_array
|
||||
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
|
||||
if (!new_array) {
|
||||
throw error_already_set();
|
||||
}
|
||||
return new_array;
|
||||
}
|
||||
|
||||
/// Ensure that the argument is a NumPy array
|
||||
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||
static array ensure(handle h, int ExtraFlags = 0) {
|
||||
|
@ -405,6 +405,13 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
return a;
|
||||
});
|
||||
|
||||
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
|
||||
return a.reshape({N, M, O});
|
||||
});
|
||||
sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
|
||||
return a.reshape(new_shape);
|
||||
});
|
||||
|
||||
sm.def("index_using_ellipsis",
|
||||
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
|
||||
|
||||
|
@ -411,7 +411,7 @@ def test_array_unchecked_fixed_dims(msg):
|
||||
assert m.proxy_auxiliaries2_const_ref(z1)
|
||||
|
||||
|
||||
def test_array_unchecked_dyn_dims(msg):
|
||||
def test_array_unchecked_dyn_dims():
|
||||
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
|
||||
m.proxy_add2_dyn(z1, 10)
|
||||
assert np.all(z1 == [[11, 12], [13, 14]])
|
||||
@ -444,7 +444,7 @@ def test_initializer_list():
|
||||
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
|
||||
|
||||
|
||||
def test_array_resize(msg):
|
||||
def test_array_resize():
|
||||
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
|
||||
m.array_reshape2(a)
|
||||
assert a.size == 9
|
||||
@ -470,12 +470,37 @@ def test_array_resize(msg):
|
||||
|
||||
|
||||
@pytest.mark.xfail("env.PYPY")
|
||||
def test_array_create_and_resize(msg):
|
||||
def test_array_create_and_resize():
|
||||
a = m.create_and_resize(2)
|
||||
assert a.size == 4
|
||||
assert np.all(a == 42.0)
|
||||
|
||||
|
||||
def test_reshape_initializer_list():
|
||||
a = np.arange(2 * 7 * 3) + 1
|
||||
x = m.reshape_initializer_list(a, 2, 7, 3)
|
||||
assert x.shape == (2, 7, 3)
|
||||
assert list(x[1][4]) == [34, 35, 36]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
m.reshape_initializer_list(a, 1, 7, 3)
|
||||
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"
|
||||
|
||||
|
||||
def test_reshape_tuple():
|
||||
a = np.arange(3 * 7 * 2) + 1
|
||||
x = m.reshape_tuple(a, (3, 7, 2))
|
||||
assert x.shape == (3, 7, 2)
|
||||
assert list(x[1][4]) == [23, 24]
|
||||
y = m.reshape_tuple(x, (x.size,))
|
||||
assert y.shape == (42,)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
m.reshape_tuple(a, (3, 7, 1))
|
||||
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
m.reshape_tuple(a, ())
|
||||
assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"
|
||||
|
||||
|
||||
def test_index_using_ellipsis():
|
||||
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
|
||||
assert a.shape == (6,)
|
||||
|
Loading…
Reference in New Issue
Block a user