mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
array: implement array resize
This commit is contained in:
parent
4ffa76ec56
commit
083a0219b5
@ -129,6 +129,11 @@ struct npy_api {
|
|||||||
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
Py_intptr_t *ptr;
|
||||||
|
int len;
|
||||||
|
} PyArray_Dims;
|
||||||
|
|
||||||
static npy_api& get() {
|
static npy_api& get() {
|
||||||
static npy_api api = lookup();
|
static npy_api api = lookup();
|
||||||
return api;
|
return api;
|
||||||
@ -159,6 +164,7 @@ struct npy_api {
|
|||||||
Py_ssize_t *, PyObject **, PyObject *);
|
Py_ssize_t *, PyObject **, PyObject *);
|
||||||
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
||||||
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||||
|
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
|
||||||
private:
|
private:
|
||||||
enum functions {
|
enum functions {
|
||||||
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
||||||
@ -168,6 +174,7 @@ private:
|
|||||||
API_PyArray_DescrFromType = 45,
|
API_PyArray_DescrFromType = 45,
|
||||||
API_PyArray_DescrFromScalar = 57,
|
API_PyArray_DescrFromScalar = 57,
|
||||||
API_PyArray_FromAny = 69,
|
API_PyArray_FromAny = 69,
|
||||||
|
API_PyArray_Resize = 80,
|
||||||
API_PyArray_NewCopy = 85,
|
API_PyArray_NewCopy = 85,
|
||||||
API_PyArray_NewFromDescr = 94,
|
API_PyArray_NewFromDescr = 94,
|
||||||
API_PyArray_DescrNewFromType = 9,
|
API_PyArray_DescrNewFromType = 9,
|
||||||
@ -197,6 +204,7 @@ private:
|
|||||||
DECL_NPY_API(PyArray_DescrFromType);
|
DECL_NPY_API(PyArray_DescrFromType);
|
||||||
DECL_NPY_API(PyArray_DescrFromScalar);
|
DECL_NPY_API(PyArray_DescrFromScalar);
|
||||||
DECL_NPY_API(PyArray_FromAny);
|
DECL_NPY_API(PyArray_FromAny);
|
||||||
|
DECL_NPY_API(PyArray_Resize);
|
||||||
DECL_NPY_API(PyArray_NewCopy);
|
DECL_NPY_API(PyArray_NewCopy);
|
||||||
DECL_NPY_API(PyArray_NewFromDescr);
|
DECL_NPY_API(PyArray_NewFromDescr);
|
||||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||||
@ -652,6 +660,21 @@ public:
|
|||||||
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
|
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resize array to given shape
|
||||||
|
/// If refcheck is true and more that one reference exist to this array
|
||||||
|
/// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
|
||||||
|
void resize(ShapeContainer new_shape, bool refcheck = true) {
|
||||||
|
detail::npy_api::PyArray_Dims d = {
|
||||||
|
new_shape->data(), int(new_shape->size())
|
||||||
|
};
|
||||||
|
// try to resize, set ordering param to -1 cause it's not used anyway
|
||||||
|
object new_array = reinterpret_steal<object>(
|
||||||
|
detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
|
||||||
|
);
|
||||||
|
if (!new_array) throw error_already_set();
|
||||||
|
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
|
||||||
|
}
|
||||||
|
|
||||||
/// Ensure that the argument is a NumPy array
|
/// Ensure that the argument is a NumPy array
|
||||||
/// In case of an error, nullptr is returned and the Python error is cleared.
|
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
static array ensure(handle h, int ExtraFlags = 0) {
|
static array ensure(handle h, int ExtraFlags = 0) {
|
||||||
|
@ -273,4 +273,25 @@ test_initializer numpy_array([](py::module &m) {
|
|||||||
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2 }); });
|
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2 }); });
|
||||||
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3 }); });
|
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3 }); });
|
||||||
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3, 4 }); });
|
sm.def("array_initializer_list", []() { return py::array_t<float>({ 1, 2, 3, 4 }); });
|
||||||
});
|
|
||||||
|
// reshape array to 2D without changing size
|
||||||
|
sm.def("array_reshape2", [](py::array_t<double> a) {
|
||||||
|
const size_t dim_sz = (size_t)std::sqrt(a.size());
|
||||||
|
if (dim_sz * dim_sz != a.size())
|
||||||
|
throw std::domain_error("array_reshape2: input array total size is not a squared integer");
|
||||||
|
a.resize({dim_sz, dim_sz});
|
||||||
|
});
|
||||||
|
|
||||||
|
// resize to 3D array with each dimension = N
|
||||||
|
sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
|
||||||
|
a.resize({N, N, N}, refcheck);
|
||||||
|
});
|
||||||
|
|
||||||
|
// return 2D array with Nrows = Ncols = N
|
||||||
|
sm.def("create_and_resize", [](size_t N) {
|
||||||
|
py::array_t<double> a;
|
||||||
|
a.resize({N, N});
|
||||||
|
std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
|
||||||
|
return a;
|
||||||
|
});
|
||||||
|
});
|
@ -389,3 +389,38 @@ def test_array_failure():
|
|||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
array_t_fail_test()
|
array_t_fail_test()
|
||||||
assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr'
|
assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr'
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_resize(msg):
|
||||||
|
from pybind11_tests.array import (array_reshape2, array_resize3)
|
||||||
|
|
||||||
|
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64')
|
||||||
|
array_reshape2(a)
|
||||||
|
assert(a.size == 9)
|
||||||
|
assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||||
|
|
||||||
|
# total size change should succced with refcheck off
|
||||||
|
array_resize3(a, 4, False)
|
||||||
|
assert(a.size == 64)
|
||||||
|
# ... and fail with refcheck on
|
||||||
|
try:
|
||||||
|
array_resize3(a, 3, True)
|
||||||
|
except ValueError as e:
|
||||||
|
assert(str(e).startswith("cannot resize an array"))
|
||||||
|
# transposed array doesn't own data
|
||||||
|
b = a.transpose()
|
||||||
|
try:
|
||||||
|
array_resize3(b, 3, False)
|
||||||
|
except ValueError as e:
|
||||||
|
assert(str(e).startswith("cannot resize this array: it does not own its data"))
|
||||||
|
# ... but reshape should be fine
|
||||||
|
array_reshape2(b)
|
||||||
|
assert(b.shape == (8, 8))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.unsupported_on_pypy
|
||||||
|
def test_array_create_and_resize(msg):
|
||||||
|
from pybind11_tests.array import create_and_resize
|
||||||
|
a = create_and_resize(2)
|
||||||
|
assert(a.size == 4)
|
||||||
|
assert(np.all(a == 42.))
|
||||||
|
Loading…
Reference in New Issue
Block a user