2015-07-26 14:33:49 +00:00
|
|
|
/*
|
2016-05-05 18:33:54 +00:00
|
|
|
pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
|
2015-07-26 14:33:49 +00:00
|
|
|
|
2016-04-17 18:21:41 +00:00
|
|
|
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
2015-07-26 14:33:49 +00:00
|
|
|
|
|
|
|
All rights reserved. Use of this source code is governed by a
|
|
|
|
BSD-style license that can be found in the LICENSE file.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
2015-10-15 16:13:33 +00:00
|
|
|
#include "pybind11.h"
|
|
|
|
#include "complex.h"
|
2016-02-11 09:47:11 +00:00
|
|
|
#include <numeric>
|
|
|
|
#include <algorithm>
|
2016-07-19 23:19:24 +00:00
|
|
|
#include <array>
|
2016-06-19 13:50:06 +00:00
|
|
|
#include <cstdlib>
|
2016-06-19 14:48:55 +00:00
|
|
|
#include <cstring>
|
2016-07-05 23:28:12 +00:00
|
|
|
#include <sstream>
|
2016-08-14 12:45:49 +00:00
|
|
|
#include <string>
|
2016-06-19 14:48:55 +00:00
|
|
|
#include <initializer_list>
|
2016-08-29 01:41:05 +00:00
|
|
|
#include <functional>
|
2016-11-01 13:27:35 +00:00
|
|
|
#include <utility>
|
2016-10-31 13:52:32 +00:00
|
|
|
#include <typeindex>
|
2015-07-29 15:51:54 +00:00
|
|
|
|
2015-07-26 14:33:49 +00:00
|
|
|
#if defined(_MSC_VER)
|
2016-10-12 22:57:42 +00:00
|
|
|
# pragma warning(push)
|
|
|
|
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
2015-07-26 14:33:49 +00:00
|
|
|
#endif
|
|
|
|
|
2016-09-08 20:48:14 +00:00
|
|
|
/* This will be true on all flat address space platforms and allows us to reduce the
|
|
|
|
whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
|
|
|
|
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
|
|
|
|
upon the library user. */
|
|
|
|
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
|
|
|
|
2015-10-15 16:13:33 +00:00
|
|
|
NAMESPACE_BEGIN(pybind11)
|
2016-09-08 20:48:14 +00:00
|
|
|
NAMESPACE_BEGIN(detail)
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
2015-07-28 14:12:20 +00:00
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
struct PyArrayDescr_Proxy {
|
|
|
|
PyObject_HEAD
|
|
|
|
PyObject *typeobj;
|
|
|
|
char kind;
|
|
|
|
char type;
|
|
|
|
char byteorder;
|
|
|
|
char flags;
|
|
|
|
int type_num;
|
|
|
|
int elsize;
|
|
|
|
int alignment;
|
|
|
|
char *subarray;
|
|
|
|
PyObject *fields;
|
|
|
|
PyObject *names;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct PyArray_Proxy {
|
|
|
|
PyObject_HEAD
|
|
|
|
char *data;
|
|
|
|
int nd;
|
|
|
|
ssize_t *dimensions;
|
|
|
|
ssize_t *strides;
|
|
|
|
PyObject *base;
|
|
|
|
PyObject *descr;
|
|
|
|
int flags;
|
|
|
|
};
|
|
|
|
|
2016-10-20 15:11:08 +00:00
|
|
|
struct PyVoidScalarObject_Proxy {
|
|
|
|
PyObject_VAR_HEAD
|
|
|
|
char *obval;
|
|
|
|
PyArrayDescr_Proxy *descr;
|
|
|
|
int flags;
|
|
|
|
PyObject *base;
|
|
|
|
};
|
|
|
|
|
2016-10-31 13:52:32 +00:00
|
|
|
struct numpy_type_info {
|
|
|
|
PyObject* dtype_ptr;
|
|
|
|
std::string format_str;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct numpy_internals {
|
|
|
|
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
|
|
|
|
|
2016-10-31 16:16:47 +00:00
|
|
|
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
|
|
|
|
auto it = registered_dtypes.find(std::type_index(tinfo));
|
2016-10-31 13:52:32 +00:00
|
|
|
if (it != registered_dtypes.end())
|
|
|
|
return &(it->second);
|
|
|
|
if (throw_if_missing)
|
2016-10-31 16:16:47 +00:00
|
|
|
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
|
2016-10-31 13:52:32 +00:00
|
|
|
return nullptr;
|
|
|
|
}
|
2016-10-31 16:16:47 +00:00
|
|
|
|
|
|
|
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
|
|
|
|
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
|
|
|
|
}
|
2016-10-31 13:52:32 +00:00
|
|
|
};
|
|
|
|
|
2016-10-31 14:11:10 +00:00
|
|
|
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
|
|
|
|
ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
|
2016-10-31 13:52:32 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
inline numpy_internals& get_numpy_internals() {
|
2016-10-31 14:11:10 +00:00
|
|
|
static numpy_internals* ptr = nullptr;
|
|
|
|
if (!ptr)
|
|
|
|
load_numpy_internals(ptr);
|
2016-10-31 13:52:32 +00:00
|
|
|
return *ptr;
|
|
|
|
}
|
|
|
|
|
2016-07-19 23:54:57 +00:00
|
|
|
struct npy_api {
|
|
|
|
enum constants {
|
2017-01-17 01:15:42 +00:00
|
|
|
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
|
|
|
|
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
|
2016-08-29 01:41:05 +00:00
|
|
|
NPY_ARRAY_OWNDATA_ = 0x0004,
|
2016-07-19 23:54:57 +00:00
|
|
|
NPY_ARRAY_FORCECAST_ = 0x0010,
|
2017-01-17 01:15:42 +00:00
|
|
|
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
|
2016-08-29 01:41:05 +00:00
|
|
|
NPY_ARRAY_ALIGNED_ = 0x0100,
|
|
|
|
NPY_ARRAY_WRITEABLE_ = 0x0400,
|
2016-07-19 23:54:57 +00:00
|
|
|
NPY_BOOL_ = 0,
|
|
|
|
NPY_BYTE_, NPY_UBYTE_,
|
|
|
|
NPY_SHORT_, NPY_USHORT_,
|
|
|
|
NPY_INT_, NPY_UINT_,
|
|
|
|
NPY_LONG_, NPY_ULONG_,
|
|
|
|
NPY_LONGLONG_, NPY_ULONGLONG_,
|
|
|
|
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
|
|
|
|
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
|
|
|
|
NPY_OBJECT_ = 17,
|
|
|
|
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
|
|
|
};
|
|
|
|
|
|
|
|
static npy_api& get() {
|
|
|
|
static npy_api api = lookup();
|
|
|
|
return api;
|
|
|
|
}
|
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
bool PyArray_Check_(PyObject *obj) const {
|
|
|
|
return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
|
|
|
|
}
|
|
|
|
bool PyArrayDescr_Check_(PyObject *obj) const {
|
|
|
|
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
|
|
|
|
}
|
2016-07-19 23:54:57 +00:00
|
|
|
|
|
|
|
PyObject *(*PyArray_DescrFromType_)(int);
|
|
|
|
PyObject *(*PyArray_NewFromDescr_)
|
|
|
|
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
|
|
|
Py_intptr_t *, void *, int, PyObject *);
|
|
|
|
PyObject *(*PyArray_DescrNewFromType_)(int);
|
|
|
|
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
|
|
|
PyTypeObject *PyArray_Type_;
|
2016-10-20 15:09:10 +00:00
|
|
|
PyTypeObject *PyVoidArrType_Type_;
|
2016-07-23 20:55:37 +00:00
|
|
|
PyTypeObject *PyArrayDescr_Type_;
|
2016-10-20 15:09:10 +00:00
|
|
|
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
|
2016-07-19 23:54:57 +00:00
|
|
|
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
|
|
|
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
|
|
|
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
|
|
|
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
|
|
|
Py_ssize_t *, PyObject **, PyObject *);
|
2016-10-07 09:19:25 +00:00
|
|
|
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
2017-01-17 01:22:00 +00:00
|
|
|
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
2016-07-19 23:54:57 +00:00
|
|
|
private:
|
|
|
|
enum functions {
|
|
|
|
API_PyArray_Type = 2,
|
2016-07-23 20:55:37 +00:00
|
|
|
API_PyArrayDescr_Type = 3,
|
2016-10-20 15:09:10 +00:00
|
|
|
API_PyVoidArrType_Type = 39,
|
2016-07-19 23:54:57 +00:00
|
|
|
API_PyArray_DescrFromType = 45,
|
2016-10-20 15:09:10 +00:00
|
|
|
API_PyArray_DescrFromScalar = 57,
|
2016-07-19 23:54:57 +00:00
|
|
|
API_PyArray_FromAny = 69,
|
|
|
|
API_PyArray_NewCopy = 85,
|
|
|
|
API_PyArray_NewFromDescr = 94,
|
|
|
|
API_PyArray_DescrNewFromType = 9,
|
|
|
|
API_PyArray_DescrConverter = 174,
|
|
|
|
API_PyArray_EquivTypes = 182,
|
|
|
|
API_PyArray_GetArrayParamsFromObject = 278,
|
2017-01-17 01:22:00 +00:00
|
|
|
API_PyArray_Squeeze = 136,
|
|
|
|
API_PyArray_SetBaseObject = 282
|
2016-07-19 23:54:57 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
static npy_api lookup() {
|
|
|
|
module m = module::import("numpy.core.multiarray");
|
2016-09-08 15:02:04 +00:00
|
|
|
auto c = m.attr("_ARRAY_API");
|
2015-09-04 21:42:12 +00:00
|
|
|
#if PY_MAJOR_VERSION >= 3
|
2016-09-20 23:06:32 +00:00
|
|
|
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
|
2015-09-04 21:42:12 +00:00
|
|
|
#else
|
2016-09-20 23:06:32 +00:00
|
|
|
void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
|
2015-09-04 21:42:12 +00:00
|
|
|
#endif
|
2016-07-19 23:54:57 +00:00
|
|
|
npy_api api;
|
2016-06-19 13:44:20 +00:00
|
|
|
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
2016-07-19 23:54:57 +00:00
|
|
|
DECL_NPY_API(PyArray_Type);
|
2016-10-20 15:09:10 +00:00
|
|
|
DECL_NPY_API(PyVoidArrType_Type);
|
2016-07-23 20:55:37 +00:00
|
|
|
DECL_NPY_API(PyArrayDescr_Type);
|
2016-07-19 23:54:57 +00:00
|
|
|
DECL_NPY_API(PyArray_DescrFromType);
|
2016-10-20 15:09:10 +00:00
|
|
|
DECL_NPY_API(PyArray_DescrFromScalar);
|
2016-07-19 23:54:57 +00:00
|
|
|
DECL_NPY_API(PyArray_FromAny);
|
|
|
|
DECL_NPY_API(PyArray_NewCopy);
|
|
|
|
DECL_NPY_API(PyArray_NewFromDescr);
|
|
|
|
DECL_NPY_API(PyArray_DescrNewFromType);
|
|
|
|
DECL_NPY_API(PyArray_DescrConverter);
|
|
|
|
DECL_NPY_API(PyArray_EquivTypes);
|
|
|
|
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
2016-10-07 09:19:25 +00:00
|
|
|
DECL_NPY_API(PyArray_Squeeze);
|
2017-01-17 01:22:00 +00:00
|
|
|
DECL_NPY_API(PyArray_SetBaseObject);
|
2016-06-19 13:44:20 +00:00
|
|
|
#undef DECL_NPY_API
|
2016-07-19 23:54:57 +00:00
|
|
|
return api;
|
|
|
|
}
|
|
|
|
};
|
2015-07-28 14:12:20 +00:00
|
|
|
|
2016-11-22 10:29:55 +00:00
|
|
|
inline PyArray_Proxy* array_proxy(void* ptr) {
|
|
|
|
return reinterpret_cast<PyArray_Proxy*>(ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
inline const PyArray_Proxy* array_proxy(const void* ptr) {
|
|
|
|
return reinterpret_cast<const PyArray_Proxy*>(ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
|
|
|
|
return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
|
|
|
|
return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
inline bool check_flags(const void* ptr, int flag) {
|
|
|
|
return (flag == (array_proxy(ptr)->flags & flag));
|
|
|
|
}
|
|
|
|
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <typename T> struct is_std_array : std::false_type { };
|
|
|
|
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
|
|
|
template <typename T> struct is_complex : std::false_type { };
|
|
|
|
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
|
|
|
|
|
|
|
template <typename T> using is_pod_struct = all_of<
|
|
|
|
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
|
|
|
|
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
|
|
|
>;
|
|
|
|
|
2016-11-22 10:29:55 +00:00
|
|
|
NAMESPACE_END(detail)
|
2016-08-29 01:41:05 +00:00
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
class dtype : public object {
|
2016-07-19 23:54:57 +00:00
|
|
|
public:
|
2016-07-23 20:55:37 +00:00
|
|
|
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
2015-07-26 14:33:49 +00:00
|
|
|
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit dtype(const buffer_info &info) {
|
2016-08-13 11:42:02 +00:00
|
|
|
dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
|
2016-11-22 14:56:52 +00:00
|
|
|
// If info.itemsize == 0, use the value calculated from the format string
|
|
|
|
m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
2016-05-05 08:00:00 +00:00
|
|
|
|
2016-10-22 09:50:39 +00:00
|
|
|
explicit dtype(const std::string &format) {
|
2016-07-23 20:55:37 +00:00
|
|
|
m_ptr = from_args(pybind11::str(format)).release().ptr();
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2016-10-22 09:51:19 +00:00
|
|
|
dtype(const char *format) : dtype(std::string(format)) { }
|
2016-07-24 23:58:17 +00:00
|
|
|
|
2016-07-24 16:51:35 +00:00
|
|
|
dtype(list names, list formats, list offsets, size_t itemsize) {
|
|
|
|
dict args;
|
|
|
|
args["names"] = names;
|
|
|
|
args["formats"] = formats;
|
|
|
|
args["offsets"] = offsets;
|
2016-08-25 00:16:47 +00:00
|
|
|
args["itemsize"] = pybind11::int_(itemsize);
|
2016-07-24 16:51:35 +00:00
|
|
|
m_ptr = from_args(args).release().ptr();
|
|
|
|
}
|
|
|
|
|
2016-10-22 09:51:04 +00:00
|
|
|
/// This is essentially the same as calling numpy.dtype(args) in Python.
|
2016-07-23 20:55:37 +00:00
|
|
|
static dtype from_args(object args) {
|
|
|
|
PyObject *ptr = nullptr;
|
|
|
|
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
|
2016-10-22 09:52:05 +00:00
|
|
|
throw error_already_set();
|
2016-10-28 01:08:15 +00:00
|
|
|
return reinterpret_steal<dtype>(ptr);
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
2016-07-05 23:28:12 +00:00
|
|
|
|
2016-10-22 09:51:04 +00:00
|
|
|
/// Return dtype associated with a C++ type.
|
2016-07-23 20:55:37 +00:00
|
|
|
template <typename T> static dtype of() {
|
2016-08-15 00:24:28 +00:00
|
|
|
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
2016-07-05 23:28:12 +00:00
|
|
|
|
2016-10-22 09:51:04 +00:00
|
|
|
/// Size of the data type in bytes.
|
2016-07-23 20:55:37 +00:00
|
|
|
size_t itemsize() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2016-10-22 09:51:04 +00:00
|
|
|
/// Returns true for structured data types.
|
2016-07-23 20:55:37 +00:00
|
|
|
bool has_fields() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
|
|
|
|
2016-10-22 09:51:04 +00:00
|
|
|
/// Single-character type code.
|
2016-08-29 01:41:05 +00:00
|
|
|
char kind() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return detail::array_descriptor_proxy(m_ptr)->kind;
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2016-07-24 16:34:53 +00:00
|
|
|
static object _dtype_from_pep3118() {
|
|
|
|
static PyObject *obj = module::import("numpy.core._internal")
|
|
|
|
.attr("_dtype_from_pep3118").cast<object>().release().ptr();
|
2016-10-28 01:08:15 +00:00
|
|
|
return reinterpret_borrow<object>(obj);
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
2016-07-05 23:28:12 +00:00
|
|
|
|
2016-11-22 14:56:52 +00:00
|
|
|
dtype strip_padding(size_t itemsize) {
|
2016-07-05 23:28:12 +00:00
|
|
|
// Recursively strip all void fields with empty names that are generated for
|
|
|
|
// padding fields (as of NumPy v1.11).
|
2016-08-29 01:41:05 +00:00
|
|
|
if (!has_fields())
|
2016-07-23 20:55:37 +00:00
|
|
|
return *this;
|
2016-07-05 23:28:12 +00:00
|
|
|
|
2016-08-25 00:16:47 +00:00
|
|
|
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
|
2016-07-05 23:28:12 +00:00
|
|
|
std::vector<field_descr> field_descriptors;
|
|
|
|
|
2016-09-08 15:02:04 +00:00
|
|
|
for (auto field : attr("fields").attr("items")()) {
|
2016-10-28 01:08:15 +00:00
|
|
|
auto spec = field.cast<tuple>();
|
2016-07-05 23:28:12 +00:00
|
|
|
auto name = spec[0].cast<pybind11::str>();
|
2016-07-23 20:55:37 +00:00
|
|
|
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
2016-08-25 00:16:47 +00:00
|
|
|
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
2016-08-29 01:41:05 +00:00
|
|
|
if (!len(name) && format.kind() == 'V')
|
2016-07-18 21:37:42 +00:00
|
|
|
continue;
|
2016-11-22 14:56:52 +00:00
|
|
|
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
|
2016-07-05 23:28:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
std::sort(field_descriptors.begin(), field_descriptors.end(),
|
|
|
|
[](const field_descr& a, const field_descr& b) {
|
2016-08-25 20:52:52 +00:00
|
|
|
return a.offset.cast<int>() < b.offset.cast<int>();
|
2016-07-05 23:28:12 +00:00
|
|
|
});
|
|
|
|
|
|
|
|
list names, formats, offsets;
|
|
|
|
for (auto& descr : field_descriptors) {
|
2016-07-24 16:51:35 +00:00
|
|
|
names.append(descr.name);
|
|
|
|
formats.append(descr.format);
|
|
|
|
offsets.append(descr.offset);
|
2016-07-05 23:28:12 +00:00
|
|
|
}
|
2016-11-22 14:56:52 +00:00
|
|
|
return dtype(names, formats, offsets, itemsize);
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
|
|
|
};
|
2016-07-05 23:28:12 +00:00
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
class array : public buffer {
|
|
|
|
public:
|
2016-11-16 00:35:22 +00:00
|
|
|
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
|
2016-07-23 20:55:37 +00:00
|
|
|
|
|
|
|
enum {
|
2017-01-17 01:15:42 +00:00
|
|
|
c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
|
|
|
|
f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
|
2016-07-23 20:55:37 +00:00
|
|
|
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
|
|
|
};
|
|
|
|
|
2016-11-16 00:35:22 +00:00
|
|
|
array() : array(0, static_cast<const double *>(nullptr)) {}
|
|
|
|
|
2017-02-14 10:25:47 +00:00
|
|
|
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
|
|
|
const std::vector<size_t> &strides, const void *ptr = nullptr,
|
2016-10-12 22:57:42 +00:00
|
|
|
handle base = handle()) {
|
2016-07-23 20:55:37 +00:00
|
|
|
auto& api = detail::npy_api::get();
|
2016-07-24 17:35:14 +00:00
|
|
|
auto ndim = shape.size();
|
|
|
|
if (shape.size() != strides.size())
|
|
|
|
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
|
|
|
auto descr = dt;
|
2016-10-12 22:57:42 +00:00
|
|
|
|
|
|
|
int flags = 0;
|
|
|
|
if (base && ptr) {
|
2016-10-23 12:50:08 +00:00
|
|
|
if (isinstance<array>(base))
|
2017-02-06 23:06:04 +00:00
|
|
|
/* Copy flags from base (except ownership bit) */
|
2016-10-28 01:08:15 +00:00
|
|
|
flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
|
2016-10-12 22:57:42 +00:00
|
|
|
else
|
|
|
|
/* Writable by default, easy to downgrade later on if needed */
|
|
|
|
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
|
|
}
|
|
|
|
|
2016-10-28 01:08:15 +00:00
|
|
|
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
|
2017-02-08 22:43:08 +00:00
|
|
|
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
|
|
|
|
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
|
|
|
|
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
|
|
|
|
const_cast<void *>(ptr), flags, nullptr));
|
2016-07-23 20:55:37 +00:00
|
|
|
if (!tmp)
|
|
|
|
pybind11_fail("NumPy: unable to create array!");
|
2016-10-12 22:57:42 +00:00
|
|
|
if (ptr) {
|
|
|
|
if (base) {
|
2017-01-17 01:22:00 +00:00
|
|
|
api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
|
2016-10-12 22:57:42 +00:00
|
|
|
} else {
|
2016-10-28 01:08:15 +00:00
|
|
|
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
2016-10-12 22:57:42 +00:00
|
|
|
}
|
|
|
|
}
|
2016-07-23 20:55:37 +00:00
|
|
|
m_ptr = tmp.release().ptr();
|
|
|
|
}
|
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
|
|
|
const void *ptr = nullptr, handle base = handle())
|
|
|
|
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
|
2016-07-24 17:35:14 +00:00
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
|
|
|
|
handle base = handle())
|
|
|
|
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
|
2016-07-24 17:35:14 +00:00
|
|
|
|
|
|
|
template<typename T> array(const std::vector<size_t>& shape,
|
2016-10-12 22:57:42 +00:00
|
|
|
const std::vector<size_t>& strides,
|
|
|
|
const T* ptr, handle base = handle())
|
2017-02-08 22:43:08 +00:00
|
|
|
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
|
2016-07-24 17:35:14 +00:00
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
template <typename T>
|
|
|
|
array(const std::vector<size_t> &shape, const T *ptr,
|
|
|
|
handle base = handle())
|
|
|
|
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
|
2016-07-24 17:35:14 +00:00
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
template <typename T>
|
|
|
|
array(size_t count, const T *ptr, handle base = handle())
|
|
|
|
: array(std::vector<size_t>{ count }, ptr, base) { }
|
2016-07-24 17:35:14 +00:00
|
|
|
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit array(const buffer_info &info)
|
2016-07-24 23:46:39 +00:00
|
|
|
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
2016-07-23 20:55:37 +00:00
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
/// Array descriptor (dtype)
|
|
|
|
pybind11::dtype dtype() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Total number of elements
|
|
|
|
size_t size() const {
|
|
|
|
return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Byte size of a single element
|
|
|
|
size_t itemsize() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Total number of bytes
|
|
|
|
size_t nbytes() const {
|
|
|
|
return size() * itemsize();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Number of dimensions
|
|
|
|
size_t ndim() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return (size_t) detail::array_proxy(m_ptr)->nd;
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
/// Base object
|
|
|
|
object base() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
|
2016-10-12 22:57:42 +00:00
|
|
|
}
|
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
/// Dimensions of the array
|
|
|
|
const size_t* shape() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Dimension along a given axis
|
|
|
|
size_t shape(size_t dim) const {
|
|
|
|
if (dim >= ndim())
|
2016-09-08 20:48:14 +00:00
|
|
|
fail_dim_check(dim, "invalid axis");
|
2016-08-29 01:41:05 +00:00
|
|
|
return shape()[dim];
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Strides of the array
|
|
|
|
const size_t* strides() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Stride along a given axis
|
|
|
|
size_t strides(size_t dim) const {
|
|
|
|
if (dim >= ndim())
|
2016-09-08 20:48:14 +00:00
|
|
|
fail_dim_check(dim, "invalid axis");
|
2016-08-29 01:41:05 +00:00
|
|
|
return strides()[dim];
|
|
|
|
}
|
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
/// Return the NumPy array flags
|
|
|
|
int flags() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return detail::array_proxy(m_ptr)->flags;
|
2016-10-12 22:57:42 +00:00
|
|
|
}
|
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
/// If set, the array is writeable (otherwise the buffer is read-only)
|
|
|
|
bool writeable() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// If set, the array owns the data (will be freed when the array is deleted)
|
|
|
|
bool owndata() const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
2016-09-08 20:48:14 +00:00
|
|
|
/// Pointer to the contained data. If index is not provided, points to the
|
|
|
|
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> const void* data(Ix... index) const {
|
2016-11-22 10:29:55 +00:00
|
|
|
return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
2016-09-08 20:48:14 +00:00
|
|
|
/// Mutable pointer to the contained data. If index is not provided, points to the
|
|
|
|
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
|
|
|
/// May throw if the array is not writeable.
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> void* mutable_data(Ix... index) {
|
2016-09-08 20:48:14 +00:00
|
|
|
check_writeable();
|
2016-11-22 10:29:55 +00:00
|
|
|
return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
2016-09-08 20:48:14 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Byte offset from beginning of the array to a given index (full or partial).
|
|
|
|
/// May throw if the index would lead to out of bounds access.
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> size_t offset_at(Ix... index) const {
|
2016-09-08 20:48:14 +00:00
|
|
|
if (sizeof...(index) > ndim())
|
|
|
|
fail_dim_check(sizeof...(index), "too many indices for an array");
|
2016-11-16 16:53:37 +00:00
|
|
|
return byte_offset(size_t(index)...);
|
2016-09-08 20:48:14 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
size_t offset_at() const { return 0; }
|
|
|
|
|
|
|
|
/// Item count from beginning of the array to a given index (full or partial).
|
|
|
|
/// May throw if the index would lead to out of bounds access.
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> size_t index_at(Ix... index) const {
|
2016-09-08 20:48:14 +00:00
|
|
|
return offset_at(index...) / itemsize();
|
2016-07-23 20:55:37 +00:00
|
|
|
}
|
|
|
|
|
2016-10-07 09:19:25 +00:00
|
|
|
/// Return a new view with all of the dimensions of length 1 removed
|
|
|
|
array squeeze() {
|
|
|
|
auto& api = detail::npy_api::get();
|
2016-10-28 01:08:15 +00:00
|
|
|
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
|
2016-10-07 09:19:25 +00:00
|
|
|
}
|
|
|
|
|
2016-10-13 23:08:03 +00:00
|
|
|
/// Ensure that the argument is a NumPy array
|
2016-11-16 00:35:22 +00:00
|
|
|
/// In case of an error, nullptr is returned and the Python error is cleared.
|
|
|
|
static array ensure(handle h, int ExtraFlags = 0) {
|
|
|
|
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
|
|
|
|
if (!result)
|
|
|
|
PyErr_Clear();
|
|
|
|
return result;
|
2016-10-13 23:08:03 +00:00
|
|
|
}
|
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
protected:
|
2016-09-08 20:48:14 +00:00
|
|
|
template<typename, typename> friend struct detail::npy_format_descriptor;
|
|
|
|
|
|
|
|
void fail_dim_check(size_t dim, const std::string& msg) const {
|
|
|
|
throw index_error(msg + ": " + std::to_string(dim) +
|
|
|
|
" (ndim = " + std::to_string(ndim()) + ")");
|
|
|
|
}
|
|
|
|
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
|
|
|
check_dimensions(index...);
|
|
|
|
return byte_offset_unsafe(index...);
|
|
|
|
}
|
|
|
|
|
|
|
|
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
|
|
|
|
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
|
2016-09-08 20:48:14 +00:00
|
|
|
}
|
|
|
|
|
2016-11-16 16:53:37 +00:00
|
|
|
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
|
2016-09-08 20:48:14 +00:00
|
|
|
|
|
|
|
void check_writeable() const {
|
|
|
|
if (!writeable())
|
|
|
|
throw std::runtime_error("array is not writeable");
|
|
|
|
}
|
2016-07-24 17:35:14 +00:00
|
|
|
|
2017-02-14 10:25:47 +00:00
|
|
|
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
2016-07-24 17:35:14 +00:00
|
|
|
auto ndim = shape.size();
|
|
|
|
std::vector<size_t> strides(ndim);
|
|
|
|
if (ndim) {
|
|
|
|
std::fill(strides.begin(), strides.end(), itemsize);
|
|
|
|
for (size_t i = 0; i < ndim - 1; i++)
|
|
|
|
for (size_t j = 0; j < ndim - 1 - i; j++)
|
|
|
|
strides[j] *= shape[ndim - 1 - i];
|
|
|
|
}
|
|
|
|
return strides;
|
|
|
|
}
|
2016-11-16 16:53:37 +00:00
|
|
|
|
|
|
|
template<typename... Ix> void check_dimensions(Ix... index) const {
|
|
|
|
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
|
|
|
|
}
|
|
|
|
|
|
|
|
void check_dimensions_impl(size_t, const size_t*) const { }
|
|
|
|
|
|
|
|
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
|
|
|
|
if (i >= *shape) {
|
|
|
|
throw index_error(std::string("index ") + std::to_string(i) +
|
|
|
|
" is out of bounds for axis " + std::to_string(axis) +
|
|
|
|
" with size " + std::to_string(*shape));
|
|
|
|
}
|
|
|
|
check_dimensions_impl(axis + 1, shape + 1, index...);
|
|
|
|
}
|
2016-11-16 00:35:22 +00:00
|
|
|
|
|
|
|
/// Create array from any object -- always returns a new reference
|
|
|
|
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
|
|
|
|
if (ptr == nullptr)
|
|
|
|
return nullptr;
|
|
|
|
return detail::npy_api::get().PyArray_FromAny_(
|
2017-01-17 01:15:42 +00:00
|
|
|
ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
|
2016-11-16 00:35:22 +00:00
|
|
|
}
|
2015-07-26 14:33:49 +00:00
|
|
|
};
|
|
|
|
|
2016-05-19 14:02:09 +00:00
|
|
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
2015-07-26 14:33:49 +00:00
|
|
|
public:
|
2016-11-16 00:35:22 +00:00
|
|
|
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
|
|
|
array_t(handle h, borrowed_t) : array(h, borrowed) { }
|
|
|
|
array_t(handle h, stolen_t) : array(h, stolen) { }
|
2016-07-24 17:54:53 +00:00
|
|
|
|
2016-11-16 00:35:22 +00:00
|
|
|
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
|
|
|
|
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
|
|
|
|
if (!m_ptr) PyErr_Clear();
|
|
|
|
if (!is_borrowed) Py_XDECREF(h.ptr());
|
|
|
|
}
|
2016-10-25 20:12:39 +00:00
|
|
|
|
2016-11-16 00:35:22 +00:00
|
|
|
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
|
|
|
|
if (!m_ptr) throw error_already_set();
|
|
|
|
}
|
2016-10-25 20:12:39 +00:00
|
|
|
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit array_t(const buffer_info& info) : array(info) { }
|
2016-07-24 17:54:53 +00:00
|
|
|
|
2016-10-12 22:57:42 +00:00
|
|
|
array_t(const std::vector<size_t> &shape,
|
|
|
|
const std::vector<size_t> &strides, const T *ptr = nullptr,
|
|
|
|
handle base = handle())
|
|
|
|
: array(shape, strides, ptr, base) { }
|
2016-07-24 17:54:53 +00:00
|
|
|
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
|
2016-10-12 22:57:42 +00:00
|
|
|
handle base = handle())
|
|
|
|
: array(shape, ptr, base) { }
|
2016-07-24 17:54:53 +00:00
|
|
|
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
2016-10-12 22:57:42 +00:00
|
|
|
: array(count, ptr, base) { }
|
2016-08-29 01:41:05 +00:00
|
|
|
|
2016-09-08 20:48:14 +00:00
|
|
|
constexpr size_t itemsize() const {
|
|
|
|
return sizeof(T);
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
|
|
|
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> size_t index_at(Ix... index) const {
|
2016-09-08 20:48:14 +00:00
|
|
|
return offset_at(index...) / itemsize();
|
|
|
|
}
|
|
|
|
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> const T* data(Ix... index) const {
|
2016-09-08 20:48:14 +00:00
|
|
|
return static_cast<const T*>(array::data(index...));
|
|
|
|
}
|
|
|
|
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> T* mutable_data(Ix... index) {
|
2016-09-08 20:48:14 +00:00
|
|
|
return static_cast<T*>(array::mutable_data(index...));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Reference to element at a given index
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> const T& at(Ix... index) const {
|
2016-09-08 20:48:14 +00:00
|
|
|
if (sizeof...(index) != ndim())
|
|
|
|
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
2016-11-16 16:53:37 +00:00
|
|
|
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
|
2016-09-08 20:48:14 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Mutable reference to element at a given index
|
2016-11-16 16:53:37 +00:00
|
|
|
template<typename... Ix> T& mutable_at(Ix... index) {
|
2016-09-08 20:48:14 +00:00
|
|
|
if (sizeof...(index) != ndim())
|
|
|
|
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
2016-11-16 16:53:37 +00:00
|
|
|
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
2016-08-29 01:41:05 +00:00
|
|
|
}
|
2016-07-24 17:54:53 +00:00
|
|
|
|
2017-01-17 01:22:00 +00:00
|
|
|
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
|
|
|
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
2016-11-16 00:35:22 +00:00
|
|
|
static array_t ensure(handle h) {
|
|
|
|
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
|
2016-05-19 14:02:09 +00:00
|
|
|
if (!result)
|
|
|
|
PyErr_Clear();
|
2016-01-17 21:36:41 +00:00
|
|
|
return result;
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
2016-11-16 00:35:22 +00:00
|
|
|
|
2017-02-06 23:06:04 +00:00
|
|
|
static bool check_(handle h) {
|
2016-11-16 00:35:22 +00:00
|
|
|
const auto &api = detail::npy_api::get();
|
|
|
|
return api.PyArray_Check_(h.ptr())
|
2016-11-22 10:29:55 +00:00
|
|
|
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
|
2016-11-16 00:35:22 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
protected:
|
|
|
|
/// Create array from any object -- always returns a new reference
|
|
|
|
static PyObject *raw_array_t(PyObject *ptr) {
|
|
|
|
if (ptr == nullptr)
|
|
|
|
return nullptr;
|
|
|
|
return detail::npy_api::get().PyArray_FromAny_(
|
|
|
|
ptr, dtype::of<T>().release().ptr(), 0, 0,
|
2017-01-17 01:15:42 +00:00
|
|
|
detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
|
2016-11-16 00:35:22 +00:00
|
|
|
}
|
2015-07-26 14:33:49 +00:00
|
|
|
};
|
|
|
|
|
2016-06-26 15:46:40 +00:00
|
|
|
template <typename T>
|
2016-09-12 15:36:43 +00:00
|
|
|
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
|
2016-08-15 00:24:28 +00:00
|
|
|
static std::string format() {
|
|
|
|
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
|
|
|
|
}
|
2016-07-19 23:19:24 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
template <size_t N> struct format_descriptor<char[N]> {
|
2016-08-14 12:45:49 +00:00
|
|
|
static std::string format() { return std::to_string(N) + "s"; }
|
2016-07-19 23:19:24 +00:00
|
|
|
};
|
|
|
|
template <size_t N> struct format_descriptor<std::array<char, N>> {
|
2016-08-14 12:45:49 +00:00
|
|
|
static std::string format() { return std::to_string(N) + "s"; }
|
2016-06-19 14:48:55 +00:00
|
|
|
};
|
|
|
|
|
2016-10-20 11:28:08 +00:00
|
|
|
template <typename T>
|
|
|
|
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
|
|
|
|
static std::string format() {
|
|
|
|
return format_descriptor<
|
|
|
|
typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2016-05-05 18:33:54 +00:00
|
|
|
NAMESPACE_BEGIN(detail)
|
2016-10-23 12:50:08 +00:00
|
|
|
template <typename T, int ExtraFlags>
|
|
|
|
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
|
|
|
using type = array_t<T, ExtraFlags>;
|
|
|
|
|
|
|
|
bool load(handle src, bool /* convert */) {
|
2016-11-16 00:35:22 +00:00
|
|
|
value = type::ensure(src);
|
2016-10-23 12:50:08 +00:00
|
|
|
return static_cast<bool>(value);
|
|
|
|
}
|
|
|
|
|
|
|
|
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
|
|
|
|
return src.inc_ref();
|
|
|
|
}
|
|
|
|
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
|
|
|
|
};
|
|
|
|
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
|
2016-05-04 20:22:48 +00:00
|
|
|
private:
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
// NB: the order here must match the one in common.h
|
|
|
|
constexpr static const int values[15] = {
|
|
|
|
npy_api::NPY_BOOL_,
|
|
|
|
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
|
|
|
|
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
|
|
|
|
npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
|
|
|
|
npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
|
|
|
|
};
|
|
|
|
|
2016-05-04 20:22:48 +00:00
|
|
|
public:
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
static constexpr int value = values[detail::is_fmt_numeric<T>::index];
|
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
static pybind11::dtype dtype() {
|
2016-07-19 23:54:57 +00:00
|
|
|
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
|
2016-10-28 01:08:15 +00:00
|
|
|
return reinterpret_borrow<pybind11::dtype>(ptr);
|
2016-06-26 15:19:18 +00:00
|
|
|
pybind11_fail("Unsupported buffer format!");
|
2016-06-19 13:53:20 +00:00
|
|
|
}
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
|
|
|
|
static PYBIND11_DESCR name() {
|
|
|
|
return _<std::is_same<T, bool>::value>(_("bool"),
|
|
|
|
_<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
|
|
|
|
}
|
|
|
|
template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
|
|
|
|
static PYBIND11_DESCR name() {
|
|
|
|
return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
|
|
|
|
_("float") + _<sizeof(T)*8>(), _("longdouble"));
|
|
|
|
}
|
|
|
|
template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
|
|
|
|
static PYBIND11_DESCR name() {
|
|
|
|
return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
|
|
|
|
_("complex") + _<sizeof(T2::value_type)*16>(), _("longcomplex"));
|
|
|
|
}
|
2016-05-04 20:22:48 +00:00
|
|
|
};
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
|
|
|
|
#define PYBIND11_DECL_CHAR_FMT \
|
2016-07-19 23:19:24 +00:00
|
|
|
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
2016-10-16 20:27:42 +00:00
|
|
|
static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
|
|
|
|
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
|
|
|
#undef PYBIND11_DECL_CHAR_FMT
|
2016-07-19 23:19:24 +00:00
|
|
|
|
2016-10-20 11:28:08 +00:00
|
|
|
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
|
|
|
private:
|
|
|
|
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
|
|
|
public:
|
|
|
|
static PYBIND11_DESCR name() { return base_descr::name(); }
|
|
|
|
static pybind11::dtype dtype() { return base_descr::dtype(); }
|
|
|
|
};
|
|
|
|
|
2016-06-19 14:48:55 +00:00
|
|
|
struct field_descriptor {
|
|
|
|
const char *name;
|
2016-07-03 09:22:10 +00:00
|
|
|
size_t offset;
|
2016-07-05 23:28:12 +00:00
|
|
|
size_t size;
|
2016-11-22 11:17:07 +00:00
|
|
|
size_t alignment;
|
2016-08-14 12:45:49 +00:00
|
|
|
std::string format;
|
2016-07-23 20:55:37 +00:00
|
|
|
dtype descr;
|
2016-06-19 14:48:55 +00:00
|
|
|
};
|
|
|
|
|
2016-10-31 16:16:47 +00:00
|
|
|
inline PYBIND11_NOINLINE void register_structured_dtype(
|
|
|
|
const std::initializer_list<field_descriptor>& fields,
|
|
|
|
const std::type_info& tinfo, size_t itemsize,
|
2016-11-08 09:53:30 +00:00
|
|
|
bool (*direct_converter)(PyObject *, void *&)) {
|
|
|
|
|
2016-10-31 16:16:47 +00:00
|
|
|
auto& numpy_internals = get_numpy_internals();
|
|
|
|
if (numpy_internals.get_type_info(tinfo, false))
|
|
|
|
pybind11_fail("NumPy: dtype is already registered");
|
|
|
|
|
|
|
|
list names, formats, offsets;
|
|
|
|
for (auto field : fields) {
|
|
|
|
if (!field.descr)
|
|
|
|
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
|
|
|
field.name + "` @ " + tinfo.name());
|
|
|
|
names.append(PYBIND11_STR_TYPE(field.name));
|
|
|
|
formats.append(field.descr);
|
|
|
|
offsets.append(pybind11::int_(field.offset));
|
|
|
|
}
|
|
|
|
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
|
|
|
|
|
|
|
|
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
|
|
|
// not encoded explicitly into the format string. This will supposedly
|
|
|
|
// get fixed in v1.12; for further details, see these:
|
|
|
|
// - https://github.com/numpy/numpy/issues/7797
|
|
|
|
// - https://github.com/numpy/numpy/pull/7798
|
|
|
|
// Because of this, we won't use numpy's logic to generate buffer format
|
|
|
|
// strings and will just do it ourselves.
|
|
|
|
std::vector<field_descriptor> ordered_fields(fields);
|
|
|
|
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
|
|
|
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
|
|
|
|
size_t offset = 0;
|
|
|
|
std::ostringstream oss;
|
|
|
|
oss << "T{";
|
|
|
|
for (auto& field : ordered_fields) {
|
|
|
|
if (field.offset > offset)
|
|
|
|
oss << (field.offset - offset) << 'x';
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
// mark unaligned fields with '^' (unaligned native type)
|
2016-11-22 11:17:07 +00:00
|
|
|
if (field.offset % field.alignment)
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
oss << '^';
|
2016-11-22 11:17:07 +00:00
|
|
|
oss << field.format << ':' << field.name << ':';
|
2016-10-31 16:16:47 +00:00
|
|
|
offset = field.offset + field.size;
|
|
|
|
}
|
|
|
|
if (itemsize > offset)
|
|
|
|
oss << (itemsize - offset) << 'x';
|
|
|
|
oss << '}';
|
|
|
|
auto format_str = oss.str();
|
|
|
|
|
|
|
|
// Sanity check: verify that NumPy properly parses our buffer format string
|
|
|
|
auto& api = npy_api::get();
|
|
|
|
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
|
|
|
|
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
|
|
|
pybind11_fail("NumPy: invalid buffer descriptor!");
|
|
|
|
|
|
|
|
auto tindex = std::type_index(tinfo);
|
|
|
|
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
|
|
|
|
get_internals().direct_conversions[tindex].push_back(direct_converter);
|
|
|
|
}
|
|
|
|
|
Numpy: better compilation errors, long double support (#619)
* Clarify PYBIND11_NUMPY_DTYPE documentation
The current documentation and example reads as though
PYBIND11_NUMPY_DTYPE is a declarative macro along the same lines as
PYBIND11_DECLARE_HOLDER_TYPE, but it isn't. The changes the
documentation and docs example to make it clear that you need to "call"
the macro.
* Add satisfies_{all,any,none}_of<T, Preds>
`satisfies_all_of<T, Pred1, Pred2, Pred3>` is a nice legibility-enhanced
shortcut for `is_all<Pred1<T>, Pred2<T>, Pred3<T>>`.
* Give better error message for non-POD dtype attempts
If you try to use a non-POD data type, you get difficult-to-interpret
compilation errors (about ::name() not being a member of an internal
pybind11 struct, among others), for which isn't at all obvious what the
problem is.
This adds a static_assert for such cases.
It also changes the base case from an empty struct to the is_pod_struct
case by no longer using `enable_if<is_pod_struct>` but instead using a
static_assert: thus specializations avoid the base class, POD types
work, and non-POD types (and unimplemented POD types like std::array)
get a more informative static_assert failure.
* Prefix macros with PYBIND11_
numpy.h uses unprefixed macros, which seems undesirable. This prefixes
them with PYBIND11_ to match all the other macros in numpy.h (and
elsewhere).
* Add long double support
This adds long double and std::complex<long double> support for numpy
arrays.
This allows some simplification of the code used to generate format
descriptors; the new code uses fewer macros, instead putting the code as
different templated options; the template conditions end up simpler with
this because we are now supporting all basic C++ arithmetic types (and
so can use is_arithmetic instead of is_integral + multiple
different specializations).
In addition to testing that it is indeed working in the test script, it
also adds various offset and size calculations there, which
fixes the test failures under x86 compilations.
2017-01-31 16:00:15 +00:00
|
|
|
template <typename T, typename SFINAE> struct npy_format_descriptor {
|
|
|
|
static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
|
|
|
|
2016-07-18 18:58:20 +00:00
|
|
|
static PYBIND11_DESCR name() { return _("struct"); }
|
2016-06-19 14:48:55 +00:00
|
|
|
|
2016-07-23 20:55:37 +00:00
|
|
|
static pybind11::dtype dtype() {
|
2016-10-28 01:08:15 +00:00
|
|
|
return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
|
2016-06-19 14:48:55 +00:00
|
|
|
}
|
|
|
|
|
2016-08-14 12:45:49 +00:00
|
|
|
static std::string format() {
|
2016-10-31 13:52:32 +00:00
|
|
|
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
|
2016-08-14 12:45:49 +00:00
|
|
|
return format_str;
|
2016-06-19 14:48:55 +00:00
|
|
|
}
|
|
|
|
|
2016-10-31 16:16:47 +00:00
|
|
|
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
|
|
|
|
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
|
|
|
|
sizeof(T), &direct_converter);
|
2016-06-19 14:48:55 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2016-10-31 13:52:32 +00:00
|
|
|
static PyObject* dtype_ptr() {
|
|
|
|
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
|
|
|
|
return ptr;
|
|
|
|
}
|
2016-10-20 15:11:08 +00:00
|
|
|
|
2016-10-23 14:27:13 +00:00
|
|
|
static bool direct_converter(PyObject *obj, void*& value) {
|
|
|
|
auto& api = npy_api::get();
|
|
|
|
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
|
2016-10-20 15:11:08 +00:00
|
|
|
return false;
|
2016-10-28 01:08:15 +00:00
|
|
|
if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
|
2016-10-31 13:52:32 +00:00
|
|
|
if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
|
2016-10-23 14:27:13 +00:00
|
|
|
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2016-06-19 14:48:55 +00:00
|
|
|
};
|
|
|
|
|
2016-11-01 13:27:35 +00:00
|
|
|
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
|
|
|
|
::pybind11::detail::field_descriptor { \
|
|
|
|
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
|
2016-11-22 11:17:07 +00:00
|
|
|
alignof(decltype(std::declval<T>().Field)), \
|
2016-11-01 13:27:35 +00:00
|
|
|
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
|
|
|
|
::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
|
2016-06-21 23:42:10 +00:00
|
|
|
}
|
2016-06-19 14:48:55 +00:00
|
|
|
|
2016-11-01 13:27:35 +00:00
|
|
|
// Extract name, offset and format descriptor for a struct field
|
|
|
|
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
|
|
|
|
|
2016-06-19 14:48:55 +00:00
|
|
|
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
|
|
|
|
// (C) William Swanson, Paul Fultz
|
2016-06-27 22:02:21 +00:00
|
|
|
#define PYBIND11_EVAL0(...) __VA_ARGS__
|
|
|
|
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
|
|
|
|
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
|
|
|
|
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
|
|
|
|
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
|
|
|
|
#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
|
|
|
|
#define PYBIND11_MAP_END(...)
|
|
|
|
#define PYBIND11_MAP_OUT
|
|
|
|
#define PYBIND11_MAP_COMMA ,
|
|
|
|
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
|
|
|
|
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
|
|
|
|
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
|
|
|
|
#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
|
2016-06-29 14:21:51 +00:00
|
|
|
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
|
2016-06-27 22:02:21 +00:00
|
|
|
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
|
|
|
|
PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
|
2016-06-27 16:01:22 +00:00
|
|
|
#else
|
2016-06-27 22:02:21 +00:00
|
|
|
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
|
|
|
|
PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
|
2016-06-27 16:01:22 +00:00
|
|
|
#endif
|
2016-06-27 22:02:21 +00:00
|
|
|
#define PYBIND11_MAP_LIST_NEXT(test, next) \
|
|
|
|
PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
|
|
|
|
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
|
|
|
|
f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
|
|
|
|
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
|
|
|
|
f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
|
2016-06-29 14:21:51 +00:00
|
|
|
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
|
2016-06-27 22:02:21 +00:00
|
|
|
#define PYBIND11_MAP_LIST(f, t, ...) \
|
|
|
|
PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
|
2016-06-19 14:48:55 +00:00
|
|
|
|
2016-07-02 15:18:42 +00:00
|
|
|
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
|
2016-06-21 23:48:36 +00:00
|
|
|
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
2016-06-27 22:02:21 +00:00
|
|
|
({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
2016-06-19 14:48:55 +00:00
|
|
|
|
2016-11-01 13:27:35 +00:00
|
|
|
#ifdef _MSC_VER
|
|
|
|
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
|
|
|
PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
|
|
|
|
#else
|
|
|
|
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
|
|
|
PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
|
|
|
|
#endif
|
|
|
|
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
|
|
|
|
PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
|
|
|
|
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
|
|
|
|
f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
|
|
|
|
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
|
|
|
|
f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
|
|
|
|
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
|
|
|
|
#define PYBIND11_MAP2_LIST(f, t, ...) \
|
|
|
|
PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
|
|
|
|
|
|
|
|
#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
|
|
|
|
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
|
|
|
|
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
|
|
|
|
2016-02-11 09:47:11 +00:00
|
|
|
template <class T>
|
|
|
|
using array_iterator = typename std::add_pointer<T>::type;
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
array_iterator<T> array_begin(const buffer_info& buffer) {
|
|
|
|
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
array_iterator<T> array_end(const buffer_info& buffer) {
|
|
|
|
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
|
|
|
|
}
|
|
|
|
|
|
|
|
class common_iterator {
|
|
|
|
public:
|
|
|
|
using container_type = std::vector<size_t>;
|
|
|
|
using value_type = container_type::value_type;
|
|
|
|
using size_type = container_type::size_type;
|
|
|
|
|
|
|
|
common_iterator() : p_ptr(0), m_strides() {}
|
2016-02-20 11:17:17 +00:00
|
|
|
|
2016-02-11 09:47:11 +00:00
|
|
|
common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
|
|
|
|
: p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
|
|
|
|
m_strides.back() = static_cast<value_type>(strides.back());
|
|
|
|
for (size_type i = m_strides.size() - 1; i != 0; --i) {
|
|
|
|
size_type j = i - 1;
|
|
|
|
value_type s = static_cast<value_type>(shape[i]);
|
|
|
|
m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void increment(size_type dim) {
|
|
|
|
p_ptr += m_strides[dim];
|
|
|
|
}
|
|
|
|
|
|
|
|
void* data() const {
|
|
|
|
return p_ptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
char* p_ptr;
|
|
|
|
container_type m_strides;
|
|
|
|
};
|
|
|
|
|
2016-02-20 11:17:17 +00:00
|
|
|
template <size_t N> class multi_array_iterator {
|
2016-02-11 09:47:11 +00:00
|
|
|
public:
|
|
|
|
using container_type = std::vector<size_t>;
|
|
|
|
|
2016-02-20 11:17:17 +00:00
|
|
|
multi_array_iterator(const std::array<buffer_info, N> &buffers,
|
|
|
|
const std::vector<size_t> &shape)
|
|
|
|
: m_shape(shape.size()), m_index(shape.size(), 0),
|
|
|
|
m_common_iterator() {
|
|
|
|
|
2016-02-11 09:47:11 +00:00
|
|
|
// Manual copy to avoid conversion warning if using std::copy
|
2016-02-20 11:17:17 +00:00
|
|
|
for (size_t i = 0; i < shape.size(); ++i)
|
2016-02-11 09:47:11 +00:00
|
|
|
m_shape[i] = static_cast<container_type::value_type>(shape[i]);
|
|
|
|
|
|
|
|
container_type strides(shape.size());
|
2016-02-20 11:17:17 +00:00
|
|
|
for (size_t i = 0; i < N; ++i)
|
2016-02-11 09:47:11 +00:00
|
|
|
init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
|
|
|
|
}
|
|
|
|
|
|
|
|
multi_array_iterator& operator++() {
|
|
|
|
for (size_t j = m_index.size(); j != 0; --j) {
|
|
|
|
size_t i = j - 1;
|
|
|
|
if (++m_index[i] != m_shape[i]) {
|
|
|
|
increment_common_iterator(i);
|
|
|
|
break;
|
2016-02-20 11:17:17 +00:00
|
|
|
} else {
|
2016-02-11 09:47:11 +00:00
|
|
|
m_index[i] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
2016-02-20 11:17:17 +00:00
|
|
|
template <size_t K, class T> const T& data() const {
|
2016-02-11 09:47:11 +00:00
|
|
|
return *reinterpret_cast<T*>(m_common_iterator[K].data());
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
|
|
|
using common_iter = common_iterator;
|
|
|
|
|
2016-02-20 11:17:17 +00:00
|
|
|
void init_common_iterator(const buffer_info &buffer,
|
|
|
|
const std::vector<size_t> &shape,
|
|
|
|
common_iter &iterator, container_type &strides) {
|
2016-02-11 09:47:11 +00:00
|
|
|
auto buffer_shape_iter = buffer.shape.rbegin();
|
|
|
|
auto buffer_strides_iter = buffer.strides.rbegin();
|
|
|
|
auto shape_iter = shape.rbegin();
|
|
|
|
auto strides_iter = strides.rbegin();
|
|
|
|
|
|
|
|
while (buffer_shape_iter != buffer.shape.rend()) {
|
|
|
|
if (*shape_iter == *buffer_shape_iter)
|
2016-05-29 11:40:40 +00:00
|
|
|
*strides_iter = static_cast<size_t>(*buffer_strides_iter);
|
2016-02-11 09:47:11 +00:00
|
|
|
else
|
|
|
|
*strides_iter = 0;
|
|
|
|
|
|
|
|
++buffer_shape_iter;
|
|
|
|
++buffer_strides_iter;
|
|
|
|
++shape_iter;
|
|
|
|
++strides_iter;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::fill(strides_iter, strides.rend(), 0);
|
|
|
|
iterator = common_iter(buffer.ptr, strides, shape);
|
|
|
|
}
|
|
|
|
|
|
|
|
void increment_common_iterator(size_t dim) {
|
2016-02-20 11:17:17 +00:00
|
|
|
for (auto &iter : m_common_iterator)
|
2016-02-11 09:47:11 +00:00
|
|
|
iter.increment(dim);
|
|
|
|
}
|
|
|
|
|
|
|
|
container_type m_shape;
|
|
|
|
container_type m_index;
|
|
|
|
std::array<common_iter, N> m_common_iterator;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <size_t N>
|
2016-05-29 11:40:40 +00:00
|
|
|
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
|
|
|
|
ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
|
2016-02-11 09:47:11 +00:00
|
|
|
return std::max(res, buf.ndim);
|
|
|
|
});
|
|
|
|
|
2016-05-29 11:40:40 +00:00
|
|
|
shape = std::vector<size_t>(ndim, 1);
|
2016-02-11 09:47:11 +00:00
|
|
|
bool trivial_broadcast = true;
|
|
|
|
for (size_t i = 0; i < N; ++i) {
|
|
|
|
auto res_iter = shape.rbegin();
|
|
|
|
bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
|
2016-02-20 11:17:17 +00:00
|
|
|
for (auto shape_iter = buffers[i].shape.rbegin();
|
|
|
|
shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {
|
|
|
|
|
|
|
|
if (*res_iter == 1)
|
2016-02-11 09:47:11 +00:00
|
|
|
*res_iter = *shape_iter;
|
2016-02-20 11:17:17 +00:00
|
|
|
else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
|
2016-02-11 09:47:11 +00:00
|
|
|
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
|
2016-02-20 11:17:17 +00:00
|
|
|
|
2016-02-11 09:47:11 +00:00
|
|
|
i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
|
|
|
|
}
|
|
|
|
trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
|
|
|
|
}
|
|
|
|
return trivial_broadcast;
|
|
|
|
}
|
|
|
|
|
2015-07-29 15:51:54 +00:00
|
|
|
template <typename Func, typename Return, typename... Args>
|
|
|
|
struct vectorize_helper {
|
|
|
|
typename std::remove_reference<Func>::type f;
|
|
|
|
|
2015-07-30 13:29:00 +00:00
|
|
|
template <typename T>
|
2016-10-16 20:27:42 +00:00
|
|
|
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
|
2015-07-28 14:12:20 +00:00
|
|
|
|
2016-05-19 14:02:09 +00:00
|
|
|
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
|
2016-11-27 19:56:04 +00:00
|
|
|
return run(args..., make_index_sequence<sizeof...(Args)>());
|
2015-07-29 15:51:54 +00:00
|
|
|
}
|
2015-07-26 14:33:49 +00:00
|
|
|
|
2016-05-19 14:02:09 +00:00
|
|
|
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
|
2015-07-26 14:33:49 +00:00
|
|
|
/* Request buffers from all parameters */
|
2015-07-29 15:51:54 +00:00
|
|
|
const size_t N = sizeof...(Args);
|
2016-02-11 09:47:11 +00:00
|
|
|
|
2015-07-26 14:33:49 +00:00
|
|
|
std::array<buffer_info, N> buffers {{ args.request()... }};
|
|
|
|
|
|
|
|
/* Determine dimensions parameters of output array */
|
2016-05-29 11:40:40 +00:00
|
|
|
size_t ndim = 0;
|
2016-02-11 09:47:11 +00:00
|
|
|
std::vector<size_t> shape(0);
|
|
|
|
bool trivial_broadcast = broadcast(buffers, ndim, shape);
|
2016-05-19 14:02:09 +00:00
|
|
|
|
2016-02-11 09:47:11 +00:00
|
|
|
size_t size = 1;
|
2015-07-26 14:33:49 +00:00
|
|
|
std::vector<size_t> strides(ndim);
|
|
|
|
if (ndim > 0) {
|
2015-07-29 15:51:54 +00:00
|
|
|
strides[ndim-1] = sizeof(Return);
|
2016-05-29 11:40:40 +00:00
|
|
|
for (size_t i = ndim - 1; i > 0; --i) {
|
2016-02-11 09:47:11 +00:00
|
|
|
strides[i - 1] = strides[i] * shape[i];
|
|
|
|
size *= shape[i];
|
|
|
|
}
|
|
|
|
size *= shape[0];
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2016-01-17 21:36:40 +00:00
|
|
|
if (size == 1)
|
2016-01-17 21:36:39 +00:00
|
|
|
return cast(f(*((Args *) buffers[Index].ptr)...));
|
2015-07-26 14:33:49 +00:00
|
|
|
|
2016-08-29 01:41:05 +00:00
|
|
|
array_t<Return> result(shape, strides);
|
|
|
|
auto buf = result.request();
|
|
|
|
auto output = (Return *) buf.ptr;
|
2016-01-17 21:36:39 +00:00
|
|
|
|
2016-05-05 18:33:54 +00:00
|
|
|
if (trivial_broadcast) {
|
2016-02-11 09:47:11 +00:00
|
|
|
/* Call the function */
|
2016-08-29 01:41:05 +00:00
|
|
|
for (size_t i = 0; i < size; ++i) {
|
2016-02-11 09:47:11 +00:00
|
|
|
output[i] = f((buffers[Index].size == 1
|
2016-01-17 21:36:39 +00:00
|
|
|
? *((Args *) buffers[Index].ptr)
|
|
|
|
: ((Args *) buffers[Index].ptr)[i])...);
|
2016-02-11 09:47:11 +00:00
|
|
|
}
|
2016-02-20 11:17:17 +00:00
|
|
|
} else {
|
2016-02-11 09:47:11 +00:00
|
|
|
apply_broadcast<N, Index...>(buffers, buf, index);
|
|
|
|
}
|
2016-01-17 21:36:39 +00:00
|
|
|
|
|
|
|
return result;
|
2015-07-29 15:51:54 +00:00
|
|
|
}
|
2016-02-11 09:47:11 +00:00
|
|
|
|
|
|
|
template <size_t N, size_t... Index>
|
2016-02-20 11:17:17 +00:00
|
|
|
void apply_broadcast(const std::array<buffer_info, N> &buffers,
|
|
|
|
buffer_info &output, index_sequence<Index...>) {
|
2016-02-11 09:47:11 +00:00
|
|
|
using input_iterator = multi_array_iterator<N>;
|
|
|
|
using output_iterator = array_iterator<Return>;
|
|
|
|
|
|
|
|
input_iterator input_iter(buffers, output.shape);
|
|
|
|
output_iterator output_end = array_end<Return>(output);
|
|
|
|
|
2016-02-20 11:17:17 +00:00
|
|
|
for (output_iterator iter = array_begin<Return>(output);
|
|
|
|
iter != output_end; ++iter, ++input_iter) {
|
2016-02-11 09:47:11 +00:00
|
|
|
*iter = f((input_iter.template data<Index, Args>())...);
|
|
|
|
}
|
|
|
|
}
|
2015-07-29 15:51:54 +00:00
|
|
|
};
|
|
|
|
|
2016-05-19 14:02:09 +00:00
|
|
|
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
|
2017-01-03 10:52:05 +00:00
|
|
|
static PYBIND11_DESCR name() { return _("numpy.ndarray[") + make_caster<T>::name() + _("]"); }
|
2016-02-20 11:17:17 +00:00
|
|
|
};
|
|
|
|
|
2015-07-29 15:51:54 +00:00
|
|
|
NAMESPACE_END(detail)
|
2015-07-26 14:33:49 +00:00
|
|
|
|
2017-02-17 11:56:41 +00:00
|
|
|
template <typename Func, typename Return, typename... Args>
|
2016-12-14 01:06:41 +00:00
|
|
|
detail::vectorize_helper<Func, Return, Args...>
|
2017-02-17 11:56:41 +00:00
|
|
|
vectorize(const Func &f, Return (*) (Args ...)) {
|
2015-07-29 15:51:54 +00:00
|
|
|
return detail::vectorize_helper<Func, Return, Args...>(f);
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2017-02-17 11:56:41 +00:00
|
|
|
template <typename Return, typename... Args>
|
|
|
|
detail::vectorize_helper<Return (*) (Args ...), Return, Args...>
|
|
|
|
vectorize(Return (*f) (Args ...)) {
|
2015-07-29 15:51:54 +00:00
|
|
|
return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2017-02-17 11:56:41 +00:00
|
|
|
template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type>
|
2016-09-22 21:44:11 +00:00
|
|
|
auto vectorize(Func &&f) -> decltype(
|
2017-02-17 11:56:41 +00:00
|
|
|
vectorize(std::forward<Func>(f), (FuncType *) nullptr)) {
|
|
|
|
return vectorize(std::forward<Func>(f), (FuncType *) nullptr);
|
2015-07-26 14:33:49 +00:00
|
|
|
}
|
|
|
|
|
2015-10-15 16:13:33 +00:00
|
|
|
NAMESPACE_END(pybind11)
|
2015-07-26 14:33:49 +00:00
|
|
|
|
|
|
|
#if defined(_MSC_VER)
|
|
|
|
#pragma warning(pop)
|
|
|
|
#endif
|