Cleanup: move numpy API bindings out of py::array

This commit is contained in:
Ivan Smirnov 2016-07-20 00:54:57 +01:00
parent afb07e7e92
commit 05cb58ade2
1 changed files with 90 additions and 88 deletions

View File

@ -28,87 +28,94 @@ NAMESPACE_BEGIN(pybind11)
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
template <typename type> struct is_pod_struct;
struct npy_api {
enum constants {
NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ENSURE_ARRAY_ = 0x0040,
NPY_BOOL_ = 0,
NPY_BYTE_, NPY_UBYTE_,
NPY_SHORT_, NPY_USHORT_,
NPY_INT_, NPY_UINT_,
NPY_LONG_, NPY_ULONG_,
NPY_LONGLONG_, NPY_ULONGLONG_,
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};
static npy_api& get() {
static npy_api api = lookup();
return api;
}
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
private:
enum functions {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
};
static npy_api lookup() {
module m = module::import("numpy.core.multiarray");
object c = (object) m.attr("_ARRAY_API");
#if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
#else
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
#endif
npy_api api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
}
};
}
class array : public buffer {
public:
struct API {
enum Entries {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ENSURE_ARRAY_ = 0x0040,
NPY_BOOL_ = 0,
NPY_BYTE_, NPY_UBYTE_,
NPY_SHORT_, NPY_USHORT_,
NPY_INT_, NPY_UINT_,
NPY_LONG_, NPY_ULONG_,
NPY_LONGLONG_, NPY_ULONGLONG_,
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};
static API lookup() {
module m = module::import("numpy.core.multiarray");
object c = (object) m.attr("_ARRAY_API");
#if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
#else
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
#endif
API api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
}
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
enum {
c_style = API::NPY_C_CONTIGUOUS_,
f_style = API::NPY_F_CONTIGUOUS_,
forcecast = API::NPY_ARRAY_FORCECAST_
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};
template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api();
auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
@ -121,7 +128,7 @@ public:
}
array(const buffer_info &info) {
auto& api = lookup_api();
auto& api = detail::npy_api::get();
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
auto numpy_internal = module::import("numpy.core._internal");
@ -139,11 +146,6 @@ public:
}
protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static object strip_padding_fields(object dtype) {
@ -183,7 +185,7 @@ protected:
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
PyObject *descr = nullptr;
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
pybind11_fail("NumPy: failed to create structured dtype");
return object(descr, false);
}
@ -198,10 +200,10 @@ public:
static PyObject *ensure(PyObject *ptr) {
if (ptr == nullptr)
return nullptr;
API &api = lookup_api();
auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result)
PyErr_Clear();
Py_DECREF(ptr);
@ -246,12 +248,12 @@ struct is_pod_struct {
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
private:
constexpr static const int values[8] = {
array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_,
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
public:
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
static object dtype() {
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value))
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
return object(ptr, true);
pybind11_fail("Unsupported buffer format!");
}
@ -264,9 +266,9 @@ template <typename T> constexpr const int npy_format_descriptor<
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value = array::API::NumPyName }; \
enum { value = npy_api::NumPyName }; \
static object dtype() { \
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
} \
@ -281,7 +283,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
auto& api = array::lookup_api(); \
auto& api = npy_api::get(); \
PyObject *descr = nullptr; \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
@ -319,7 +321,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
auto& api = array::lookup_api();
auto& api = npy_api::get();
auto args = dict();
list names { }, offsets { }, formats { };
for (auto field : fields) {