diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 4b105323a..27ebe7d69 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -30,65 +30,62 @@ public: API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, - NPY_C_CONTIGUOUS = 0x0001, - NPY_F_CONTIGUOUS = 0x0002, - NPY_ARRAY_FORCECAST = 0x0010, - NPY_ENSURE_ARRAY = 0x0040, - NPY_BOOL = 0, - NPY_BYTE, NPY_UBYTE, - NPY_SHORT, NPY_USHORT, - NPY_INT, NPY_UINT, - NPY_LONG, NPY_ULONG, - NPY_LONGLONG, NPY_ULONGLONG, - NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, - NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE + + NPY_C_CONTIGUOUS_ = 0x0001, + NPY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_FORCECAST_ = 0x0010, + NPY_ENSURE_ARRAY_ = 0x0040, + NPY_BOOL_ = 0, + NPY_BYTE_, NPY_UBYTE_, + NPY_SHORT_, NPY_USHORT_, + NPY_INT_, NPY_UINT_, + NPY_LONG_, NPY_ULONG_, + NPY_LONGLONG_, NPY_ULONGLONG_, + NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, + NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_ }; static API lookup() { - PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray"); - PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr; + module m = module::import("numpy.core.multiarray"); + object c = (object) m.attr("_ARRAY_API"); #if PY_MAJOR_VERSION >= 3 - void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr); + void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr); #else - void **api_ptr = (void **) (capsule ? PyCObject_AsVoidPtr(capsule) : nullptr); + void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr); #endif - Py_XDECREF(capsule); - Py_XDECREF(numpy); - if (api_ptr == nullptr) - throw std::runtime_error("Could not acquire pointer to NumPy API!"); API api; - api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type]; - api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType]; - api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny]; - api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy]; - api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr]; + api.PyArray_Type_ = (decltype(api.PyArray_Type_)) api_ptr[API_PyArray_Type]; + api.PyArray_DescrFromType_ = (decltype(api.PyArray_DescrFromType_)) api_ptr[API_PyArray_DescrFromType]; + api.PyArray_FromAny_ = (decltype(api.PyArray_FromAny_)) api_ptr[API_PyArray_FromAny]; + api.PyArray_NewCopy_ = (decltype(api.PyArray_NewCopy_)) api_ptr[API_PyArray_NewCopy]; + api.PyArray_NewFromDescr_ = (decltype(api.PyArray_NewFromDescr_)) api_ptr[API_PyArray_NewFromDescr]; return api; } - bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); } + bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); } - PyObject *(*PyArray_DescrFromType)(int); - PyObject *(*PyArray_NewFromDescr) + PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_NewFromDescr_) (PyTypeObject *, PyObject *, int, Py_intptr_t *, Py_intptr_t *, void *, int, PyObject *); - PyObject *(*PyArray_NewCopy)(PyObject *, int); - PyTypeObject *PyArray_Type; - PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *); + PyObject *(*PyArray_NewCopy_)(PyObject *, int); + PyTypeObject *PyArray_Type_; + PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); }; - PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check) + PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_) template array(size_t size, const Type *ptr) { API& api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); + PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor::value); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format!"); Py_intptr_t shape = (Py_intptr_t) size; - PyObject *tmp = api.PyArray_NewFromDescr( - api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr); + PyObject *tmp = api.PyArray_NewFromDescr_( + api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr); if (tmp == nullptr) throw std::runtime_error("NumPy: unable to create array!"); - m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); + m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */); Py_DECREF(tmp); if (m_ptr == nullptr) throw std::runtime_error("NumPy: unable to copy array!"); @@ -99,19 +96,18 @@ public: if ((info.format.size() < 1) || (info.format.size() > 2)) throw std::runtime_error("Unsupported buffer format!"); int fmt = (int) info.format[0]; - if (info.format == "Zd") - fmt = API::NPY_CDOUBLE; - else if (info.format == "Zf") - fmt = API::NPY_CFLOAT; - PyObject *descr = api.PyArray_DescrFromType(fmt); + if (info.format == "Zd") fmt = API::NPY_CDOUBLE_; + else if (info.format == "Zf") fmt = API::NPY_CFLOAT_; + + PyObject *descr = api.PyArray_DescrFromType_(fmt); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!"); - PyObject *tmp = api.PyArray_NewFromDescr( - api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0], + PyObject *tmp = api.PyArray_NewFromDescr_( + api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0], (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr); if (tmp == nullptr) throw std::runtime_error("NumPy: unable to create array!"); - m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); + m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */); Py_DECREF(tmp); if (m_ptr == nullptr) throw std::runtime_error("NumPy: unable to copy array!"); @@ -133,19 +129,19 @@ public: if (ptr == nullptr) return nullptr; API &api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); - return api.PyArray_FromAny(ptr, descr, 0, 0, - API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY | - API::NPY_ARRAY_FORCECAST, nullptr); + PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor::value); + return api.PyArray_FromAny_(ptr, descr, 0, 0, + API::NPY_C_CONTIGUOUS_ | API::NPY_ENSURE_ARRAY_ | + API::NPY_ARRAY_FORCECAST_, nullptr); } }; #define DECL_FMT(t, n) template<> struct npy_format_descriptor { enum { value = array::API::n }; } -DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT); -DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT); -DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT); -DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex, NPY_CFLOAT); -DECL_FMT(std::complex, NPY_CDOUBLE); +DECL_FMT(int8_t, NPY_BYTE_); DECL_FMT(uint8_t, NPY_UBYTE_); DECL_FMT(int16_t, NPY_SHORT_); +DECL_FMT(uint16_t, NPY_USHORT_); DECL_FMT(int32_t, NPY_INT_); DECL_FMT(uint32_t, NPY_UINT_); +DECL_FMT(int64_t, NPY_LONGLONG_); DECL_FMT(uint64_t, NPY_ULONGLONG_); DECL_FMT(float, NPY_FLOAT_); +DECL_FMT(double, NPY_DOUBLE_); DECL_FMT(bool, NPY_BOOL_); DECL_FMT(std::complex, NPY_CFLOAT_); +DECL_FMT(std::complex, NPY_CDOUBLE_); #undef DECL_FMT NAMESPACE_BEGIN(detail)