numpy.h replace macros with functions (#514)

This commit is contained in:
Sylvain Corlay 2016-11-22 01:29:55 -09:00 committed by Wenzel Jakob
parent 7146d6299c
commit b14f065fa9

View File

@ -199,16 +199,28 @@ private:
return api;
}
};
NAMESPACE_END(detail)
#define PyArray_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \
PyArray_GET_(ptr, flags)
#define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (PyArray_FLAGS_(ptr) & flag))
inline PyArray_Proxy* array_proxy(void* ptr) {
return reinterpret_cast<PyArray_Proxy*>(ptr);
}
inline const PyArray_Proxy* array_proxy(const void* ptr) {
return reinterpret_cast<const PyArray_Proxy*>(ptr);
}
inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}
inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}
inline bool check_flags(const void* ptr, int flag) {
return (flag == (array_proxy(ptr)->flags & flag));
}
NAMESPACE_END(detail)
class dtype : public object {
public:
@ -249,17 +261,17 @@ public:
/// Size of the data type in bytes.
size_t itemsize() const {
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
}
/// Returns true for structured data types.
bool has_fields() const {
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
}
/// Single-character type code.
char kind() const {
return PyArrayDescr_GET_(m_ptr, kind);
return detail::array_descriptor_proxy(m_ptr)->kind;
}
private:
@ -341,7 +353,7 @@ public:
pybind11_fail("NumPy: unable to create array!");
if (ptr) {
if (base) {
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
} else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
}
@ -376,7 +388,7 @@ public:
/// Array descriptor (dtype)
pybind11::dtype dtype() const {
return reinterpret_borrow<pybind11::dtype>(PyArray_GET_(m_ptr, descr));
return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
}
/// Total number of elements
@ -386,7 +398,7 @@ public:
/// Byte size of a single element
size_t itemsize() const {
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
}
/// Total number of bytes
@ -396,17 +408,17 @@ public:
/// Number of dimensions
size_t ndim() const {
return (size_t) PyArray_GET_(m_ptr, nd);
return (size_t) detail::array_proxy(m_ptr)->nd;
}
/// Base object
object base() const {
return reinterpret_borrow<object>(PyArray_GET_(m_ptr, base));
return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
}
/// Dimensions of the array
const size_t* shape() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
}
/// Dimension along a given axis
@ -418,7 +430,7 @@ public:
/// Strides of the array
const size_t* strides() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
}
/// Stride along a given axis
@ -430,23 +442,23 @@ public:
/// Return the NumPy array flags
int flags() const {
return PyArray_FLAGS_(m_ptr);
return detail::array_proxy(m_ptr)->flags;
}
/// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
}
/// If set, the array owns the data (will be freed when the array is deleted)
bool owndata() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
}
/// Pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
template<typename... Ix> const void* data(Ix... index) const {
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
}
/// Mutable pointer to the contained data. If index is not provided, points to the
@ -454,7 +466,7 @@ public:
/// May throw if the array is not writeable.
template<typename... Ix> void* mutable_data(Ix... index) {
check_writeable();
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
}
/// Byte offset from beginning of the array to a given index (full or partial).
@ -620,7 +632,7 @@ public:
static bool _check(handle h) {
const auto &api = detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
}
protected: