mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 13:15:12 +00:00
reshape for numpy arrays (#984)
* reshape * more tests * Update numpy.h * Update test_numpy_array.py * Update numpy.h * Update numpy.h * Update test_numpy_array.cpp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge bug * Make clang-tidy happy * Add xfail for PyPy * Fix casting issue * Address reviews on additional tests * Fix ordering * Do a little more reordering * Fix typo * Try improving tests * Fix error in reshape * Add one more reshape test * streamlining new tests; removing a few stray msg Co-authored-by: ncullen93 <ncullen.th@dartmouth.edu> Co-authored-by: NC Cullen <nicholas.c.cullen.th@dartmouth.edu> Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf Grosse-Kunstleve <rwgk@google.com>
This commit is contained in:
parent
031a700dfd
commit
59ad1e7d05
@ -198,6 +198,8 @@ struct npy_api {
|
|||||||
// Unused. Not removed because that affects ABI of the class.
|
// Unused. Not removed because that affects ABI of the class.
|
||||||
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||||
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
|
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
|
||||||
|
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum functions {
|
enum functions {
|
||||||
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
||||||
@ -212,10 +214,11 @@ private:
|
|||||||
API_PyArray_NewCopy = 85,
|
API_PyArray_NewCopy = 85,
|
||||||
API_PyArray_NewFromDescr = 94,
|
API_PyArray_NewFromDescr = 94,
|
||||||
API_PyArray_DescrNewFromType = 96,
|
API_PyArray_DescrNewFromType = 96,
|
||||||
|
API_PyArray_Newshape = 135,
|
||||||
|
API_PyArray_Squeeze = 136,
|
||||||
API_PyArray_DescrConverter = 174,
|
API_PyArray_DescrConverter = 174,
|
||||||
API_PyArray_EquivTypes = 182,
|
API_PyArray_EquivTypes = 182,
|
||||||
API_PyArray_GetArrayParamsFromObject = 278,
|
API_PyArray_GetArrayParamsFromObject = 278,
|
||||||
API_PyArray_Squeeze = 136,
|
|
||||||
API_PyArray_SetBaseObject = 282
|
API_PyArray_SetBaseObject = 282
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -243,11 +246,13 @@ private:
|
|||||||
DECL_NPY_API(PyArray_NewCopy);
|
DECL_NPY_API(PyArray_NewCopy);
|
||||||
DECL_NPY_API(PyArray_NewFromDescr);
|
DECL_NPY_API(PyArray_NewFromDescr);
|
||||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||||
|
DECL_NPY_API(PyArray_Newshape);
|
||||||
|
DECL_NPY_API(PyArray_Squeeze);
|
||||||
DECL_NPY_API(PyArray_DescrConverter);
|
DECL_NPY_API(PyArray_DescrConverter);
|
||||||
DECL_NPY_API(PyArray_EquivTypes);
|
DECL_NPY_API(PyArray_EquivTypes);
|
||||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||||
DECL_NPY_API(PyArray_Squeeze);
|
|
||||||
DECL_NPY_API(PyArray_SetBaseObject);
|
DECL_NPY_API(PyArray_SetBaseObject);
|
||||||
|
|
||||||
#undef DECL_NPY_API
|
#undef DECL_NPY_API
|
||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
@ -785,6 +790,18 @@ public:
|
|||||||
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
|
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Optional `order` parameter omitted, to be added as needed.
|
||||||
|
array reshape(ShapeContainer new_shape) {
|
||||||
|
detail::npy_api::PyArray_Dims d
|
||||||
|
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
|
||||||
|
auto new_array
|
||||||
|
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
|
||||||
|
if (!new_array) {
|
||||||
|
throw error_already_set();
|
||||||
|
}
|
||||||
|
return new_array;
|
||||||
|
}
|
||||||
|
|
||||||
/// Ensure that the argument is a NumPy array
|
/// Ensure that the argument is a NumPy array
|
||||||
/// In case of an error, nullptr is returned and the Python error is cleared.
|
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
static array ensure(handle h, int ExtraFlags = 0) {
|
static array ensure(handle h, int ExtraFlags = 0) {
|
||||||
|
@ -405,6 +405,13 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
return a;
|
return a;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
|
||||||
|
return a.reshape({N, M, O});
|
||||||
|
});
|
||||||
|
sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
|
||||||
|
return a.reshape(new_shape);
|
||||||
|
});
|
||||||
|
|
||||||
sm.def("index_using_ellipsis",
|
sm.def("index_using_ellipsis",
|
||||||
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
|
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
|
||||||
|
|
||||||
|
@ -411,7 +411,7 @@ def test_array_unchecked_fixed_dims(msg):
|
|||||||
assert m.proxy_auxiliaries2_const_ref(z1)
|
assert m.proxy_auxiliaries2_const_ref(z1)
|
||||||
|
|
||||||
|
|
||||||
def test_array_unchecked_dyn_dims(msg):
|
def test_array_unchecked_dyn_dims():
|
||||||
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
|
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
|
||||||
m.proxy_add2_dyn(z1, 10)
|
m.proxy_add2_dyn(z1, 10)
|
||||||
assert np.all(z1 == [[11, 12], [13, 14]])
|
assert np.all(z1 == [[11, 12], [13, 14]])
|
||||||
@ -444,7 +444,7 @@ def test_initializer_list():
|
|||||||
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
|
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
|
||||||
|
|
||||||
|
|
||||||
def test_array_resize(msg):
|
def test_array_resize():
|
||||||
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
|
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
|
||||||
m.array_reshape2(a)
|
m.array_reshape2(a)
|
||||||
assert a.size == 9
|
assert a.size == 9
|
||||||
@ -470,12 +470,37 @@ def test_array_resize(msg):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail("env.PYPY")
|
@pytest.mark.xfail("env.PYPY")
|
||||||
def test_array_create_and_resize(msg):
|
def test_array_create_and_resize():
|
||||||
a = m.create_and_resize(2)
|
a = m.create_and_resize(2)
|
||||||
assert a.size == 4
|
assert a.size == 4
|
||||||
assert np.all(a == 42.0)
|
assert np.all(a == 42.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape_initializer_list():
|
||||||
|
a = np.arange(2 * 7 * 3) + 1
|
||||||
|
x = m.reshape_initializer_list(a, 2, 7, 3)
|
||||||
|
assert x.shape == (2, 7, 3)
|
||||||
|
assert list(x[1][4]) == [34, 35, 36]
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
m.reshape_initializer_list(a, 1, 7, 3)
|
||||||
|
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reshape_tuple():
|
||||||
|
a = np.arange(3 * 7 * 2) + 1
|
||||||
|
x = m.reshape_tuple(a, (3, 7, 2))
|
||||||
|
assert x.shape == (3, 7, 2)
|
||||||
|
assert list(x[1][4]) == [23, 24]
|
||||||
|
y = m.reshape_tuple(x, (x.size,))
|
||||||
|
assert y.shape == (42,)
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
m.reshape_tuple(a, (3, 7, 1))
|
||||||
|
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
m.reshape_tuple(a, ())
|
||||||
|
assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"
|
||||||
|
|
||||||
|
|
||||||
def test_index_using_ellipsis():
|
def test_index_using_ellipsis():
|
||||||
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
|
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
|
||||||
assert a.shape == (6,)
|
assert a.shape == (6,)
|
||||||
|
Loading…
Reference in New Issue
Block a user