numpy.h: added array::squeeze() method

This commit is contained in:
Wenzel Jakob 2016-10-07 11:19:25 +02:00
parent 68a9989298
commit ba7678016c

View File

@ -109,6 +109,7 @@ struct npy_api {
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *);
private:
enum functions {
API_PyArray_Type = 2,
@ -121,6 +122,7 @@ private:
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136
};
static npy_api lookup() {
@ -143,6 +145,7 @@ private:
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
#undef DECL_NPY_API
return api;
}
@ -380,6 +383,12 @@ public:
return offset_at(index...) / itemsize();
}
/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
auto& api = detail::npy_api::get();
return array(api.PyArray_Squeeze_(m_ptr), false);
}
protected:
template<typename, typename> friend struct detail::npy_format_descriptor;
@ -601,7 +610,7 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor& a, const field_descriptor &b) {
[](const field_descriptor &a, const field_descriptor &b) {
return a.offset < b.offset;
});
size_t offset = 0;