mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-11 08:03:55 +00:00
avoid naming clashes with numpy (fixes #36)
This commit is contained in:
parent
4177ed4336
commit
87dfad6544
@ -30,65 +30,62 @@ public:
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
NPY_C_CONTIGUOUS = 0x0001,
|
||||
NPY_F_CONTIGUOUS = 0x0002,
|
||||
NPY_ARRAY_FORCECAST = 0x0010,
|
||||
NPY_ENSURE_ARRAY = 0x0040,
|
||||
NPY_BOOL = 0,
|
||||
NPY_BYTE, NPY_UBYTE,
|
||||
NPY_SHORT, NPY_USHORT,
|
||||
NPY_INT, NPY_UINT,
|
||||
NPY_LONG, NPY_ULONG,
|
||||
NPY_LONGLONG, NPY_ULONGLONG,
|
||||
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
|
||||
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE
|
||||
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||
NPY_ARRAY_FORCECAST_ = 0x0010,
|
||||
NPY_ENSURE_ARRAY_ = 0x0040,
|
||||
NPY_BOOL_ = 0,
|
||||
NPY_BYTE_, NPY_UBYTE_,
|
||||
NPY_SHORT_, NPY_USHORT_,
|
||||
NPY_INT_, NPY_UINT_,
|
||||
NPY_LONG_, NPY_ULONG_,
|
||||
NPY_LONGLONG_, NPY_ULONGLONG_,
|
||||
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
|
||||
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
|
||||
};
|
||||
|
||||
static API lookup() {
|
||||
PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
|
||||
PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
|
||||
module m = module::import("numpy.core.multiarray");
|
||||
object c = (object) m.attr("_ARRAY_API");
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
|
||||
void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
|
||||
#else
|
||||
void **api_ptr = (void **) (capsule ? PyCObject_AsVoidPtr(capsule) : nullptr);
|
||||
void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
|
||||
#endif
|
||||
Py_XDECREF(capsule);
|
||||
Py_XDECREF(numpy);
|
||||
if (api_ptr == nullptr)
|
||||
throw std::runtime_error("Could not acquire pointer to NumPy API!");
|
||||
API api;
|
||||
api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
|
||||
api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
|
||||
api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny];
|
||||
api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
|
||||
api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
|
||||
api.PyArray_Type_ = (decltype(api.PyArray_Type_)) api_ptr[API_PyArray_Type];
|
||||
api.PyArray_DescrFromType_ = (decltype(api.PyArray_DescrFromType_)) api_ptr[API_PyArray_DescrFromType];
|
||||
api.PyArray_FromAny_ = (decltype(api.PyArray_FromAny_)) api_ptr[API_PyArray_FromAny];
|
||||
api.PyArray_NewCopy_ = (decltype(api.PyArray_NewCopy_)) api_ptr[API_PyArray_NewCopy];
|
||||
api.PyArray_NewFromDescr_ = (decltype(api.PyArray_NewFromDescr_)) api_ptr[API_PyArray_NewFromDescr];
|
||||
return api;
|
||||
}
|
||||
|
||||
bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); }
|
||||
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
|
||||
|
||||
PyObject *(*PyArray_DescrFromType)(int);
|
||||
PyObject *(*PyArray_NewFromDescr)
|
||||
PyObject *(*PyArray_DescrFromType_)(int);
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||
Py_intptr_t *, void *, int, PyObject *);
|
||||
PyObject *(*PyArray_NewCopy)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type;
|
||||
PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
};
|
||||
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
API& api = lookup_api();
|
||||
PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value);
|
||||
PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor<Type>::value);
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||
api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
@ -99,19 +96,18 @@ public:
|
||||
if ((info.format.size() < 1) || (info.format.size() > 2))
|
||||
throw std::runtime_error("Unsupported buffer format!");
|
||||
int fmt = (int) info.format[0];
|
||||
if (info.format == "Zd")
|
||||
fmt = API::NPY_CDOUBLE;
|
||||
else if (info.format == "Zf")
|
||||
fmt = API::NPY_CFLOAT;
|
||||
PyObject *descr = api.PyArray_DescrFromType(fmt);
|
||||
if (info.format == "Zd") fmt = API::NPY_CDOUBLE_;
|
||||
else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;
|
||||
|
||||
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||
api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
PyObject *tmp = api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||
m_ptr = api.PyArray_NewCopy_(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
@ -133,19 +129,19 @@ public:
|
||||
if (ptr == nullptr)
|
||||
return nullptr;
|
||||
API &api = lookup_api();
|
||||
PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value);
|
||||
return api.PyArray_FromAny(ptr, descr, 0, 0,
|
||||
API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY |
|
||||
API::NPY_ARRAY_FORCECAST, nullptr);
|
||||
PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor<T>::value);
|
||||
return api.PyArray_FromAny_(ptr, descr, 0, 0,
|
||||
API::NPY_C_CONTIGUOUS_ | API::NPY_ENSURE_ARRAY_ |
|
||||
API::NPY_ARRAY_FORCECAST_, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
|
||||
DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
|
||||
DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
|
||||
DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
|
||||
DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
|
||||
DECL_FMT(std::complex<double>, NPY_CDOUBLE);
|
||||
DECL_FMT(int8_t, NPY_BYTE_); DECL_FMT(uint8_t, NPY_UBYTE_); DECL_FMT(int16_t, NPY_SHORT_);
|
||||
DECL_FMT(uint16_t, NPY_USHORT_); DECL_FMT(int32_t, NPY_INT_); DECL_FMT(uint32_t, NPY_UINT_);
|
||||
DECL_FMT(int64_t, NPY_LONGLONG_); DECL_FMT(uint64_t, NPY_ULONGLONG_); DECL_FMT(float, NPY_FLOAT_);
|
||||
DECL_FMT(double, NPY_DOUBLE_); DECL_FMT(bool, NPY_BOOL_); DECL_FMT(std::complex<float>, NPY_CFLOAT_);
|
||||
DECL_FMT(std::complex<double>, NPY_CDOUBLE_);
|
||||
#undef DECL_FMT
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
Loading…
Reference in New Issue
Block a user