mirror of https://github.com/pybind/pybind11.git
Cleanup: move numpy API bindings out of py::array
This commit is contained in:
parent
afb07e7e92
commit
05cb58ade2
|
@ -28,22 +28,9 @@ NAMESPACE_BEGIN(pybind11)
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||||
template <typename type> struct is_pod_struct;
|
template <typename type> struct is_pod_struct;
|
||||||
}
|
|
||||||
|
|
||||||
class array : public buffer {
|
|
||||||
public:
|
|
||||||
struct API {
|
|
||||||
enum Entries {
|
|
||||||
API_PyArray_Type = 2,
|
|
||||||
API_PyArray_DescrFromType = 45,
|
|
||||||
API_PyArray_FromAny = 69,
|
|
||||||
API_PyArray_NewCopy = 85,
|
|
||||||
API_PyArray_NewFromDescr = 94,
|
|
||||||
API_PyArray_DescrNewFromType = 9,
|
|
||||||
API_PyArray_DescrConverter = 174,
|
|
||||||
API_PyArray_EquivTypes = 182,
|
|
||||||
API_PyArray_GetArrayParamsFromObject = 278,
|
|
||||||
|
|
||||||
|
struct npy_api {
|
||||||
|
enum constants {
|
||||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||||
|
@ -60,26 +47,8 @@ public:
|
||||||
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
||||||
};
|
};
|
||||||
|
|
||||||
static API lookup() {
|
static npy_api& get() {
|
||||||
module m = module::import("numpy.core.multiarray");
|
static npy_api api = lookup();
|
||||||
object c = (object) m.attr("_ARRAY_API");
|
|
||||||
#if PY_MAJOR_VERSION >= 3
|
|
||||||
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
|
|
||||||
#else
|
|
||||||
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
|
|
||||||
#endif
|
|
||||||
API api;
|
|
||||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
|
||||||
DECL_NPY_API(PyArray_Type);
|
|
||||||
DECL_NPY_API(PyArray_DescrFromType);
|
|
||||||
DECL_NPY_API(PyArray_FromAny);
|
|
||||||
DECL_NPY_API(PyArray_NewCopy);
|
|
||||||
DECL_NPY_API(PyArray_NewFromDescr);
|
|
||||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
|
||||||
DECL_NPY_API(PyArray_DescrConverter);
|
|
||||||
DECL_NPY_API(PyArray_EquivTypes);
|
|
||||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
|
||||||
#undef DECL_NPY_API
|
|
||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,18 +66,56 @@ public:
|
||||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||||
Py_ssize_t *, PyObject **, PyObject *);
|
Py_ssize_t *, PyObject **, PyObject *);
|
||||||
|
private:
|
||||||
|
enum functions {
|
||||||
|
API_PyArray_Type = 2,
|
||||||
|
API_PyArray_DescrFromType = 45,
|
||||||
|
API_PyArray_FromAny = 69,
|
||||||
|
API_PyArray_NewCopy = 85,
|
||||||
|
API_PyArray_NewFromDescr = 94,
|
||||||
|
API_PyArray_DescrNewFromType = 9,
|
||||||
|
API_PyArray_DescrConverter = 174,
|
||||||
|
API_PyArray_EquivTypes = 182,
|
||||||
|
API_PyArray_GetArrayParamsFromObject = 278,
|
||||||
};
|
};
|
||||||
|
|
||||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
static npy_api lookup() {
|
||||||
|
module m = module::import("numpy.core.multiarray");
|
||||||
|
object c = (object) m.attr("_ARRAY_API");
|
||||||
|
#if PY_MAJOR_VERSION >= 3
|
||||||
|
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
|
||||||
|
#else
|
||||||
|
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
|
||||||
|
#endif
|
||||||
|
npy_api api;
|
||||||
|
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||||
|
DECL_NPY_API(PyArray_Type);
|
||||||
|
DECL_NPY_API(PyArray_DescrFromType);
|
||||||
|
DECL_NPY_API(PyArray_FromAny);
|
||||||
|
DECL_NPY_API(PyArray_NewCopy);
|
||||||
|
DECL_NPY_API(PyArray_NewFromDescr);
|
||||||
|
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||||
|
DECL_NPY_API(PyArray_DescrConverter);
|
||||||
|
DECL_NPY_API(PyArray_EquivTypes);
|
||||||
|
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||||
|
#undef DECL_NPY_API
|
||||||
|
return api;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
class array : public buffer {
|
||||||
|
public:
|
||||||
|
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
c_style = API::NPY_C_CONTIGUOUS_,
|
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
|
||||||
f_style = API::NPY_F_CONTIGUOUS_,
|
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
|
||||||
forcecast = API::NPY_ARRAY_FORCECAST_
|
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Type> array(size_t size, const Type *ptr) {
|
template <typename Type> array(size_t size, const Type *ptr) {
|
||||||
API& api = lookup_api();
|
auto& api = detail::npy_api::get();
|
||||||
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
|
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
|
||||||
Py_intptr_t shape = (Py_intptr_t) size;
|
Py_intptr_t shape = (Py_intptr_t) size;
|
||||||
object tmp = object(api.PyArray_NewFromDescr_(
|
object tmp = object(api.PyArray_NewFromDescr_(
|
||||||
|
@ -121,7 +128,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
array(const buffer_info &info) {
|
array(const buffer_info &info) {
|
||||||
auto& api = lookup_api();
|
auto& api = detail::npy_api::get();
|
||||||
|
|
||||||
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
|
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
|
||||||
auto numpy_internal = module::import("numpy.core._internal");
|
auto numpy_internal = module::import("numpy.core._internal");
|
||||||
|
@ -139,11 +146,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static API &lookup_api() {
|
|
||||||
static API api = API::lookup();
|
|
||||||
return api;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||||
|
|
||||||
static object strip_padding_fields(object dtype) {
|
static object strip_padding_fields(object dtype) {
|
||||||
|
@ -183,7 +185,7 @@ protected:
|
||||||
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
|
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
|
||||||
|
|
||||||
PyObject *descr = nullptr;
|
PyObject *descr = nullptr;
|
||||||
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
||||||
pybind11_fail("NumPy: failed to create structured dtype");
|
pybind11_fail("NumPy: failed to create structured dtype");
|
||||||
return object(descr, false);
|
return object(descr, false);
|
||||||
}
|
}
|
||||||
|
@ -198,10 +200,10 @@ public:
|
||||||
static PyObject *ensure(PyObject *ptr) {
|
static PyObject *ensure(PyObject *ptr) {
|
||||||
if (ptr == nullptr)
|
if (ptr == nullptr)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
API &api = lookup_api();
|
auto& api = detail::npy_api::get();
|
||||||
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
|
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
|
||||||
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
|
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
|
||||||
API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||||
if (!result)
|
if (!result)
|
||||||
PyErr_Clear();
|
PyErr_Clear();
|
||||||
Py_DECREF(ptr);
|
Py_DECREF(ptr);
|
||||||
|
@ -246,12 +248,12 @@ struct is_pod_struct {
|
||||||
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
|
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
|
||||||
private:
|
private:
|
||||||
constexpr static const int values[8] = {
|
constexpr static const int values[8] = {
|
||||||
array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_,
|
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
|
||||||
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
|
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
|
||||||
public:
|
public:
|
||||||
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
|
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
|
||||||
static object dtype() {
|
static object dtype() {
|
||||||
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value))
|
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
|
||||||
return object(ptr, true);
|
return object(ptr, true);
|
||||||
pybind11_fail("Unsupported buffer format!");
|
pybind11_fail("Unsupported buffer format!");
|
||||||
}
|
}
|
||||||
|
@ -264,9 +266,9 @@ template <typename T> constexpr const int npy_format_descriptor<
|
||||||
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
|
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
|
||||||
|
|
||||||
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
||||||
enum { value = array::API::NumPyName }; \
|
enum { value = npy_api::NumPyName }; \
|
||||||
static object dtype() { \
|
static object dtype() { \
|
||||||
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \
|
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
|
||||||
return object(ptr, true); \
|
return object(ptr, true); \
|
||||||
pybind11_fail("Unsupported buffer format!"); \
|
pybind11_fail("Unsupported buffer format!"); \
|
||||||
} \
|
} \
|
||||||
|
@ -281,7 +283,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||||
#define DECL_CHAR_FMT \
|
#define DECL_CHAR_FMT \
|
||||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||||
static object dtype() { \
|
static object dtype() { \
|
||||||
auto& api = array::lookup_api(); \
|
auto& api = npy_api::get(); \
|
||||||
PyObject *descr = nullptr; \
|
PyObject *descr = nullptr; \
|
||||||
PYBIND11_DESCR fmt = _("S") + _<N>(); \
|
PYBIND11_DESCR fmt = _("S") + _<N>(); \
|
||||||
pybind11::str py_fmt(fmt.text()); \
|
pybind11::str py_fmt(fmt.text()); \
|
||||||
|
@ -319,7 +321,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||||
}
|
}
|
||||||
|
|
||||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||||
auto& api = array::lookup_api();
|
auto& api = npy_api::get();
|
||||||
auto args = dict();
|
auto args = dict();
|
||||||
list names { }, offsets { }, formats { };
|
list names { }, offsets { }, formats { };
|
||||||
for (auto field : fields) {
|
for (auto field : fields) {
|
||||||
|
|
Loading…
Reference in New Issue