Cleanup: move numpy API bindings out of py::array

This commit is contained in:
Ivan Smirnov 2016-07-20 00:54:57 +01:00
parent afb07e7e92
commit 05cb58ade2
1 changed files with 90 additions and 88 deletions

View File

@ -28,22 +28,9 @@ NAMESPACE_BEGIN(pybind11)
namespace detail { namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { }; template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
template <typename type> struct is_pod_struct; template <typename type> struct is_pod_struct;
}
class array : public buffer {
public:
struct API {
enum Entries {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
struct npy_api {
enum constants {
NPY_C_CONTIGUOUS_ = 0x0001, NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002, NPY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_FORCECAST_ = 0x0010, NPY_ARRAY_FORCECAST_ = 0x0010,
@ -60,26 +47,8 @@ public:
NPY_STRING_, NPY_UNICODE_, NPY_VOID_ NPY_STRING_, NPY_UNICODE_, NPY_VOID_
}; };
static API lookup() { static npy_api& get() {
module m = module::import("numpy.core.multiarray"); static npy_api api = lookup();
object c = (object) m.attr("_ARRAY_API");
#if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
#else
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
#endif
API api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api; return api;
} }
@ -97,18 +66,56 @@ public:
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *); Py_ssize_t *, PyObject **, PyObject *);
private:
enum functions {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
}; };
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_) static npy_api lookup() {
module m = module::import("numpy.core.multiarray");
object c = (object) m.attr("_ARRAY_API");
#if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
#else
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
#endif
npy_api api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
}
};
}
class array : public buffer {
public:
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
enum { enum {
c_style = API::NPY_C_CONTIGUOUS_, c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
f_style = API::NPY_F_CONTIGUOUS_, f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
forcecast = API::NPY_ARRAY_FORCECAST_ forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
}; };
template <typename Type> array(size_t size, const Type *ptr) { template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api(); auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr(); PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
Py_intptr_t shape = (Py_intptr_t) size; Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_( object tmp = object(api.PyArray_NewFromDescr_(
@ -121,7 +128,7 @@ public:
} }
array(const buffer_info &info) { array(const buffer_info &info) {
auto& api = lookup_api(); auto& api = detail::npy_api::get();
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them // _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
auto numpy_internal = module::import("numpy.core._internal"); auto numpy_internal = module::import("numpy.core._internal");
@ -139,11 +146,6 @@ public:
} }
protected: protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor; template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static object strip_padding_fields(object dtype) { static object strip_padding_fields(object dtype) {
@ -183,7 +185,7 @@ protected:
args["itemsize"] = dtype.attr("itemsize").cast<int_>(); args["itemsize"] = dtype.attr("itemsize").cast<int_>();
PyObject *descr = nullptr; PyObject *descr = nullptr;
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr) if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
pybind11_fail("NumPy: failed to create structured dtype"); pybind11_fail("NumPy: failed to create structured dtype");
return object(descr, false); return object(descr, false);
} }
@ -198,10 +200,10 @@ public:
static PyObject *ensure(PyObject *ptr) { static PyObject *ensure(PyObject *ptr) {
if (ptr == nullptr) if (ptr == nullptr)
return nullptr; return nullptr;
API &api = lookup_api(); auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr(); PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result) if (!result)
PyErr_Clear(); PyErr_Clear();
Py_DECREF(ptr); Py_DECREF(ptr);
@ -246,12 +248,12 @@ struct is_pod_struct {
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> { template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
private: private:
constexpr static const int values[8] = { constexpr static const int values[8] = {
array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_, npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ }; npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
public: public:
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] }; enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
static object dtype() { static object dtype() {
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
return object(ptr, true); return object(ptr, true);
pybind11_fail("Unsupported buffer format!"); pybind11_fail("Unsupported buffer format!");
} }
@ -264,9 +266,9 @@ template <typename T> constexpr const int npy_format_descriptor<
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8]; T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \ #define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value = array::API::NumPyName }; \ enum { value = npy_api::NumPyName }; \
static object dtype() { \ static object dtype() { \
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \ if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
return object(ptr, true); \ return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \ pybind11_fail("Unsupported buffer format!"); \
} \ } \
@ -281,7 +283,7 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \ #define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \ static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \ static object dtype() { \
auto& api = array::lookup_api(); \ auto& api = npy_api::get(); \
PyObject *descr = nullptr; \ PyObject *descr = nullptr; \
PYBIND11_DESCR fmt = _("S") + _<N>(); \ PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \ pybind11::str py_fmt(fmt.text()); \
@ -319,7 +321,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
} }
static void register_dtype(std::initializer_list<field_descriptor> fields) { static void register_dtype(std::initializer_list<field_descriptor> fields) {
auto& api = array::lookup_api(); auto& api = npy_api::get();
auto args = dict(); auto args = dict();
list names { }, offsets { }, formats { }; list names { }, offsets { }, formats { };
for (auto field : fields) { for (auto field : fields) {