mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
Cleanup: move numpy API bindings out of py::array
This commit is contained in:
parent
afb07e7e92
commit
05cb58ade2
@ -28,87 +28,94 @@ NAMESPACE_BEGIN(pybind11)
|
||||
namespace detail {
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
template <typename type> struct is_pod_struct;
|
||||
|
||||
struct npy_api {
|
||||
enum constants {
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||
NPY_BOOL_ = 0,
|
||||
NPY_BYTE_, NPY_UBYTE_,
|
||||
NPY_SHORT_, NPY_USHORT_,
|
||||
NPY_INT_, NPY_UINT_,
|
||||
NPY_LONG_, NPY_ULONG_,
|
||||
NPY_LONGLONG_, NPY_ULONGLONG_,
|
||||
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
|
||||
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
|
||||
NPY_OBJECT_ = 17,
|
||||
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
||||
};
|
||||
|
||||
static npy_api& get() {
|
||||
static npy_api api = lookup();
|
||||
return api;
|
||||
}
|
||||
|
||||
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
|
||||
|
||||
PyObject *(*PyArray_DescrFromType_)(int);
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||
Py_intptr_t *, void *, int, PyObject *);
|
||||
PyObject *(*PyArray_DescrNewFromType_)(int);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||
Py_ssize_t *, PyObject **, PyObject *);
|
||||
private:
|
||||
enum functions {
|
||||
API_PyArray_Type = 2,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrNewFromType = 9,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
};
|
||||
|
||||
static npy_api lookup() {
|
||||
module m = module::import("numpy.core.multiarray");
|
||||
object c = (object) m.attr("_ARRAY_API");
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
|
||||
#else
|
||||
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
|
||||
#endif
|
||||
npy_api api;
|
||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||
DECL_NPY_API(PyArray_Type);
|
||||
DECL_NPY_API(PyArray_DescrFromType);
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
class array : public buffer {
|
||||
public:
|
||||
struct API {
|
||||
enum Entries {
|
||||
API_PyArray_Type = 2,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrNewFromType = 9,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||
NPY_BOOL_ = 0,
|
||||
NPY_BYTE_, NPY_UBYTE_,
|
||||
NPY_SHORT_, NPY_USHORT_,
|
||||
NPY_INT_, NPY_UINT_,
|
||||
NPY_LONG_, NPY_ULONG_,
|
||||
NPY_LONGLONG_, NPY_ULONGLONG_,
|
||||
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
|
||||
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
|
||||
NPY_OBJECT_ = 17,
|
||||
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
||||
};
|
||||
|
||||
static API lookup() {
|
||||
module m = module::import("numpy.core.multiarray");
|
||||
object c = (object) m.attr("_ARRAY_API");
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
|
||||
#else
|
||||
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
|
||||
#endif
|
||||
API api;
|
||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||
DECL_NPY_API(PyArray_Type);
|
||||
DECL_NPY_API(PyArray_DescrFromType);
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
}
|
||||
|
||||
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
|
||||
|
||||
PyObject *(*PyArray_DescrFromType_)(int);
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||
Py_intptr_t *, void *, int, PyObject *);
|
||||
PyObject *(*PyArray_DescrNewFromType_)(int);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||
Py_ssize_t *, PyObject **, PyObject *);
|
||||
};
|
||||
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
|
||||
|
||||
enum {
|
||||
c_style = API::NPY_C_CONTIGUOUS_,
|
||||
f_style = API::NPY_F_CONTIGUOUS_,
|
||||
forcecast = API::NPY_ARRAY_FORCECAST_
|
||||
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
|
||||
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
|
||||
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
||||
};
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
API& api = lookup_api();
|
||||
auto& api = detail::npy_api::get();
|
||||
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
object tmp = object(api.PyArray_NewFromDescr_(
|
||||
@ -121,7 +128,7 @@ public:
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
auto& api = lookup_api();
|
||||
auto& api = detail::npy_api::get();
|
||||
|
||||
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
|
||||
auto numpy_internal = module::import("numpy.core._internal");
|
||||
@ -139,11 +146,6 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
static API &lookup_api() {
|
||||
static API api = API::lookup();
|
||||
return api;
|
||||
}
|
||||
|
||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||
|
||||
static object strip_padding_fields(object dtype) {
|
||||
@ -183,7 +185,7 @@ protected:
|
||||
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
|
||||
|
||||
PyObject *descr = nullptr;
|
||||
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
||||
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
return object(descr, false);
|
||||
}
|
||||
@ -198,10 +200,10 @@ public:
|
||||
static PyObject *ensure(PyObject *ptr) {
|
||||
if (ptr == nullptr)
|
||||
return nullptr;
|
||||
API &api = lookup_api();
|
||||
auto& api = detail::npy_api::get();
|
||||
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
|
||||
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
|
||||
API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||
if (!result)
|
||||
PyErr_Clear();
|
||||
Py_DECREF(ptr);
|
||||
@ -246,12 +248,12 @@ struct is_pod_struct {
|
||||
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
|
||||
private:
|
||||
constexpr static const int values[8] = {
|
||||
array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_,
|
||||
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
|
||||
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
|
||||
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
|
||||
public:
|
||||
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
|
||||
static object dtype() {
|
||||
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value))
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
|
||||
return object(ptr, true);
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
}
|
||||
@ -264,9 +266,9 @@ template <typename T> constexpr const int npy_format_descriptor<
|
||||
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
|
||||
|
||||
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
||||
enum { value = array::API::NumPyName }; \
|
||||
enum { value = npy_api::NumPyName }; \
|
||||
static object dtype() { \
|
||||
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
|
||||
return object(ptr, true); \
|
||||
pybind11_fail("Unsupported buffer format!"); \
|
||||
} \
|
||||
@ -281,7 +283,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
#define DECL_CHAR_FMT \
|
||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||
static object dtype() { \
|
||||
auto& api = array::lookup_api(); \
|
||||
auto& api = npy_api::get(); \
|
||||
PyObject *descr = nullptr; \
|
||||
PYBIND11_DESCR fmt = _("S") + _<N>(); \
|
||||
pybind11::str py_fmt(fmt.text()); \
|
||||
@ -319,7 +321,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
}
|
||||
|
||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||
auto& api = array::lookup_api();
|
||||
auto& api = npy_api::get();
|
||||
auto args = dict();
|
||||
list names { }, offsets { }, formats { };
|
||||
for (auto field : fields) {
|
||||
|
Loading…
Reference in New Issue
Block a user