diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index d0cf1f6b7..77f7e6f72 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -28,87 +28,94 @@ NAMESPACE_BEGIN(pybind11) namespace detail { template struct npy_format_descriptor { }; template struct is_pod_struct; + +struct npy_api { + enum constants { + NPY_C_CONTIGUOUS_ = 0x0001, + NPY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_FORCECAST_ = 0x0010, + NPY_ENSURE_ARRAY_ = 0x0040, + NPY_BOOL_ = 0, + NPY_BYTE_, NPY_UBYTE_, + NPY_SHORT_, NPY_USHORT_, + NPY_INT_, NPY_UINT_, + NPY_LONG_, NPY_ULONG_, + NPY_LONGLONG_, NPY_ULONGLONG_, + NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, + NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, + NPY_OBJECT_ = 17, + NPY_STRING_, NPY_UNICODE_, NPY_VOID_ + }; + + static npy_api& get() { + static npy_api api = lookup(); + return api; + } + + bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); } + + PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_NewFromDescr_) + (PyTypeObject *, PyObject *, int, Py_intptr_t *, + Py_intptr_t *, void *, int, PyObject *); + PyObject *(*PyArray_DescrNewFromType_)(int); + PyObject *(*PyArray_NewCopy_)(PyObject *, int); + PyTypeObject *PyArray_Type_; + PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); + int (*PyArray_DescrConverter_) (PyObject *, PyObject **); + bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); + int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, + Py_ssize_t *, PyObject **, PyObject *); +private: + enum functions { + API_PyArray_Type = 2, + API_PyArray_DescrFromType = 45, + API_PyArray_FromAny = 69, + API_PyArray_NewCopy = 85, + API_PyArray_NewFromDescr = 94, + API_PyArray_DescrNewFromType = 9, + API_PyArray_DescrConverter = 174, + API_PyArray_EquivTypes = 182, + API_PyArray_GetArrayParamsFromObject = 278, + }; + + static npy_api lookup() { + module m = module::import("numpy.core.multiarray"); + object c = (object) m.attr("_ARRAY_API"); +#if PY_MAJOR_VERSION >= 3 + void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr); +#else + void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr); +#endif + npy_api api; +#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; + DECL_NPY_API(PyArray_Type); + DECL_NPY_API(PyArray_DescrFromType); + DECL_NPY_API(PyArray_FromAny); + DECL_NPY_API(PyArray_NewCopy); + DECL_NPY_API(PyArray_NewFromDescr); + DECL_NPY_API(PyArray_DescrNewFromType); + DECL_NPY_API(PyArray_DescrConverter); + DECL_NPY_API(PyArray_EquivTypes); + DECL_NPY_API(PyArray_GetArrayParamsFromObject); +#undef DECL_NPY_API + return api; + } +}; } class array : public buffer { public: - struct API { - enum Entries { - API_PyArray_Type = 2, - API_PyArray_DescrFromType = 45, - API_PyArray_FromAny = 69, - API_PyArray_NewCopy = 85, - API_PyArray_NewFromDescr = 94, - API_PyArray_DescrNewFromType = 9, - API_PyArray_DescrConverter = 174, - API_PyArray_EquivTypes = 182, - API_PyArray_GetArrayParamsFromObject = 278, - - NPY_C_CONTIGUOUS_ = 0x0001, - NPY_F_CONTIGUOUS_ = 0x0002, - NPY_ARRAY_FORCECAST_ = 0x0010, - NPY_ENSURE_ARRAY_ = 0x0040, - NPY_BOOL_ = 0, - NPY_BYTE_, NPY_UBYTE_, - NPY_SHORT_, NPY_USHORT_, - NPY_INT_, NPY_UINT_, - NPY_LONG_, NPY_ULONG_, - NPY_LONGLONG_, NPY_ULONGLONG_, - NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, - NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, - NPY_OBJECT_ = 17, - NPY_STRING_, NPY_UNICODE_, NPY_VOID_ - }; - - static API lookup() { - module m = module::import("numpy.core.multiarray"); - object c = (object) m.attr("_ARRAY_API"); -#if PY_MAJOR_VERSION >= 3 - void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr); -#else - void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr); -#endif - API api; -#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; - DECL_NPY_API(PyArray_Type); - DECL_NPY_API(PyArray_DescrFromType); - DECL_NPY_API(PyArray_FromAny); - DECL_NPY_API(PyArray_NewCopy); - DECL_NPY_API(PyArray_NewFromDescr); - DECL_NPY_API(PyArray_DescrNewFromType); - DECL_NPY_API(PyArray_DescrConverter); - DECL_NPY_API(PyArray_EquivTypes); - DECL_NPY_API(PyArray_GetArrayParamsFromObject); -#undef DECL_NPY_API - return api; - } - - bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); } - - PyObject *(*PyArray_DescrFromType_)(int); - PyObject *(*PyArray_NewFromDescr_) - (PyTypeObject *, PyObject *, int, Py_intptr_t *, - Py_intptr_t *, void *, int, PyObject *); - PyObject *(*PyArray_DescrNewFromType_)(int); - PyObject *(*PyArray_NewCopy_)(PyObject *, int); - PyTypeObject *PyArray_Type_; - PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); - int (*PyArray_DescrConverter_) (PyObject *, PyObject **); - bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); - int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, - Py_ssize_t *, PyObject **, PyObject *); - }; - - PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_) + PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_) enum { - c_style = API::NPY_C_CONTIGUOUS_, - f_style = API::NPY_F_CONTIGUOUS_, - forcecast = API::NPY_ARRAY_FORCECAST_ + c_style = detail::npy_api::NPY_C_CONTIGUOUS_, + f_style = detail::npy_api::NPY_F_CONTIGUOUS_, + forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; template array(size_t size, const Type *ptr) { - API& api = lookup_api(); + auto& api = detail::npy_api::get(); PyObject *descr = detail::npy_format_descriptor::dtype().release().ptr(); Py_intptr_t shape = (Py_intptr_t) size; object tmp = object(api.PyArray_NewFromDescr_( @@ -121,7 +128,7 @@ public: } array(const buffer_info &info) { - auto& api = lookup_api(); + auto& api = detail::npy_api::get(); // _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them auto numpy_internal = module::import("numpy.core._internal"); @@ -139,11 +146,6 @@ public: } protected: - static API &lookup_api() { - static API api = API::lookup(); - return api; - } - template friend struct detail::npy_format_descriptor; static object strip_padding_fields(object dtype) { @@ -183,7 +185,7 @@ protected: args["itemsize"] = dtype.attr("itemsize").cast(); PyObject *descr = nullptr; - if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr) + if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr) pybind11_fail("NumPy: failed to create structured dtype"); return object(descr, false); } @@ -198,10 +200,10 @@ public: static PyObject *ensure(PyObject *ptr) { if (ptr == nullptr) return nullptr; - API &api = lookup_api(); + auto& api = detail::npy_api::get(); PyObject *descr = detail::npy_format_descriptor::dtype().release().ptr(); PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, - API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); if (!result) PyErr_Clear(); Py_DECREF(ptr); @@ -246,12 +248,12 @@ struct is_pod_struct { template struct npy_format_descriptor::value>::type> { private: constexpr static const int values[8] = { - array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_, - array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ }; + npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, + npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ }; public: enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned::value ? 1 : 0)] }; static object dtype() { - if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) + if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) return object(ptr, true); pybind11_fail("Unsupported buffer format!"); } @@ -264,9 +266,9 @@ template constexpr const int npy_format_descriptor< T, typename std::enable_if::value>::type>::values[8]; #define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor { \ - enum { value = array::API::NumPyName }; \ + enum { value = npy_api::NumPyName }; \ static object dtype() { \ - if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \ + if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \ return object(ptr, true); \ pybind11_fail("Unsupported buffer format!"); \ } \ @@ -281,7 +283,7 @@ DECL_FMT(std::complex, NPY_CDOUBLE_, "complex128"); #define DECL_CHAR_FMT \ static PYBIND11_DESCR name() { return _("S") + _(); } \ static object dtype() { \ - auto& api = array::lookup_api(); \ + auto& api = npy_api::get(); \ PyObject *descr = nullptr; \ PYBIND11_DESCR fmt = _("S") + _(); \ pybind11::str py_fmt(fmt.text()); \ @@ -319,7 +321,7 @@ struct npy_format_descriptor::value> } static void register_dtype(std::initializer_list fields) { - auto& api = array::lookup_api(); + auto& api = npy_api::get(); auto args = dict(); list names { }, offsets { }, formats { }; for (auto field : fields) {