numpy.h: added array::squeeze() method

This commit is contained in:
Wenzel Jakob 2016-10-07 11:19:25 +02:00
parent 68a9989298
commit ba7678016c

View File

@ -109,6 +109,7 @@ struct npy_api {
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *); Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *);
private: private:
enum functions { enum functions {
API_PyArray_Type = 2, API_PyArray_Type = 2,
@ -121,6 +122,7 @@ private:
API_PyArray_DescrConverter = 174, API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182, API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278, API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136
}; };
static npy_api lookup() { static npy_api lookup() {
@ -143,6 +145,7 @@ private:
DECL_NPY_API(PyArray_DescrConverter); DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject); DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
#undef DECL_NPY_API #undef DECL_NPY_API
return api; return api;
} }
@ -380,6 +383,12 @@ public:
return offset_at(index...) / itemsize(); return offset_at(index...) / itemsize();
} }
/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
auto& api = detail::npy_api::get();
return array(api.PyArray_Squeeze_(m_ptr), false);
}
protected: protected:
template<typename, typename> friend struct detail::npy_format_descriptor; template<typename, typename> friend struct detail::npy_format_descriptor;
@ -601,7 +610,7 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
// strings and will just do it ourselves. // strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields); std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(), std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor& a, const field_descriptor &b) { [](const field_descriptor &a, const field_descriptor &b) {
return a.offset < b.offset; return a.offset < b.offset;
}); });
size_t offset = 0; size_t offset = 0;