mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-18 17:05:53 +00:00
view for numpy arrays (#987)
* reshape * more tests * Update numpy.h * Update test_numpy_array.py * array view * test * Update test_numpy_array.cpp * Update numpy.h * Update numpy.h * Update test_numpy_array.cpp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge bug * Make clang-tidy happy * Add xfail for PyPy * Fix casting issue * Fix formatting * Apply clang-tidy * Address reviews on additional tests * Fix ordering * Do a little more reordering * Fix typo * Try improving tests * Fix error in reshape * Add one more reshape test * Fix bugs and add test * Relax test * streamlining new tests; removing a few stray msg * Fix style revert * Fix clang-tidy * Misc tweaks: * Comment: matching style in file (///), responsibility sentence, consistent punctuation. * Replacing `unsigned char` with `uint8_t` for max consistency. * Removing `1` from `array_view1` because there is only one. * Partial clang-format-diff. Co-authored-by: ncullen93 <ncullen.th@dartmouth.edu> Co-authored-by: NC Cullen <nicholas.c.cullen.th@dartmouth.edu> Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf Grosse-Kunstleve <rwgk@google.com>
This commit is contained in:
parent
db44afa33b
commit
503ff2a6fb
@ -199,6 +199,7 @@ struct npy_api {
|
||||
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
|
||||
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
|
||||
PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*);
|
||||
|
||||
private:
|
||||
enum functions {
|
||||
@ -216,6 +217,7 @@ private:
|
||||
API_PyArray_DescrNewFromType = 96,
|
||||
API_PyArray_Newshape = 135,
|
||||
API_PyArray_Squeeze = 136,
|
||||
API_PyArray_View = 137,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
@ -248,6 +250,7 @@ private:
|
||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||
DECL_NPY_API(PyArray_Newshape);
|
||||
DECL_NPY_API(PyArray_Squeeze);
|
||||
DECL_NPY_API(PyArray_View);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
@ -802,6 +805,21 @@ public:
|
||||
return new_array;
|
||||
}
|
||||
|
||||
/// Create a view of an array in a different data type.
|
||||
/// This function may fundamentally reinterpret the data in the array.
|
||||
/// It is the responsibility of the caller to ensure that this is safe.
|
||||
/// Only supports the `dtype` argument, the `type` argument is omitted,
|
||||
/// to be added as needed.
|
||||
array view(const std::string &dtype) {
|
||||
auto &api = detail::npy_api::get();
|
||||
auto new_view = reinterpret_steal<array>(api.PyArray_View_(
|
||||
m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
|
||||
if (!new_view) {
|
||||
throw error_already_set();
|
||||
}
|
||||
return new_view;
|
||||
}
|
||||
|
||||
/// Ensure that the argument is a NumPy array
|
||||
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||
static array ensure(handle h, int ExtraFlags = 0) {
|
||||
|
@ -405,6 +405,9 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
return a;
|
||||
});
|
||||
|
||||
sm.def("array_view",
|
||||
[](py::array_t<uint8_t> a, const std::string &dtype) { return a.view(dtype); });
|
||||
|
||||
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
|
||||
return a.reshape({N, M, O});
|
||||
});
|
||||
|
@ -476,6 +476,21 @@ def test_array_create_and_resize():
|
||||
assert np.all(a == 42.0)
|
||||
|
||||
|
||||
def test_array_view():
|
||||
a = np.ones(100 * 4).astype("uint8")
|
||||
a_float_view = m.array_view(a, "float32")
|
||||
assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32
|
||||
|
||||
a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32
|
||||
assert a_int16_view.shape == (100 * 2,)
|
||||
|
||||
|
||||
def test_array_view_invalid():
|
||||
a = np.ones(100 * 4).astype("uint8")
|
||||
with pytest.raises(TypeError):
|
||||
m.array_view(a, "deadly_dtype")
|
||||
|
||||
|
||||
def test_reshape_initializer_list():
|
||||
a = np.arange(2 * 7 * 3) + 1
|
||||
x = m.reshape_initializer_list(a, 2, 7, 3)
|
||||
|
Loading…
Reference in New Issue
Block a user