numpy.h replace macros with functions (#514)

This commit is contained in:
Sylvain Corlay 2016-11-22 01:29:55 -09:00 committed by Wenzel Jakob
parent 7146d6299c
commit b14f065fa9

View File

@ -199,16 +199,28 @@ private:
return api; return api;
} }
}; };
NAMESPACE_END(detail)
#define PyArray_GET_(ptr, attr) \ inline PyArray_Proxy* array_proxy(void* ptr) {
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) return reinterpret_cast<PyArray_Proxy*>(ptr);
#define PyArrayDescr_GET_(ptr, attr) \ }
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \ inline const PyArray_Proxy* array_proxy(const void* ptr) {
PyArray_GET_(ptr, flags) return reinterpret_cast<const PyArray_Proxy*>(ptr);
#define PyArray_CHKFLAGS_(ptr, flag) \ }
(flag == (PyArray_FLAGS_(ptr) & flag))
inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}
inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}
inline bool check_flags(const void* ptr, int flag) {
return (flag == (array_proxy(ptr)->flags & flag));
}
NAMESPACE_END(detail)
class dtype : public object { class dtype : public object {
public: public:
@ -249,17 +261,17 @@ public:
/// Size of the data type in bytes. /// Size of the data type in bytes.
size_t itemsize() const { size_t itemsize() const {
return (size_t) PyArrayDescr_GET_(m_ptr, elsize); return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
} }
/// Returns true for structured data types. /// Returns true for structured data types.
bool has_fields() const { bool has_fields() const {
return PyArrayDescr_GET_(m_ptr, names) != nullptr; return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
} }
/// Single-character type code. /// Single-character type code.
char kind() const { char kind() const {
return PyArrayDescr_GET_(m_ptr, kind); return detail::array_descriptor_proxy(m_ptr)->kind;
} }
private: private:
@ -341,7 +353,7 @@ public:
pybind11_fail("NumPy: unable to create array!"); pybind11_fail("NumPy: unable to create array!");
if (ptr) { if (ptr) {
if (base) { if (base) {
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr(); detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
} else { } else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
} }
@ -376,7 +388,7 @@ public:
/// Array descriptor (dtype) /// Array descriptor (dtype)
pybind11::dtype dtype() const { pybind11::dtype dtype() const {
return reinterpret_borrow<pybind11::dtype>(PyArray_GET_(m_ptr, descr)); return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
} }
/// Total number of elements /// Total number of elements
@ -386,7 +398,7 @@ public:
/// Byte size of a single element /// Byte size of a single element
size_t itemsize() const { size_t itemsize() const {
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize); return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
} }
/// Total number of bytes /// Total number of bytes
@ -396,17 +408,17 @@ public:
/// Number of dimensions /// Number of dimensions
size_t ndim() const { size_t ndim() const {
return (size_t) PyArray_GET_(m_ptr, nd); return (size_t) detail::array_proxy(m_ptr)->nd;
} }
/// Base object /// Base object
object base() const { object base() const {
return reinterpret_borrow<object>(PyArray_GET_(m_ptr, base)); return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
} }
/// Dimensions of the array /// Dimensions of the array
const size_t* shape() const { const size_t* shape() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions)); return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
} }
/// Dimension along a given axis /// Dimension along a given axis
@ -418,7 +430,7 @@ public:
/// Strides of the array /// Strides of the array
const size_t* strides() const { const size_t* strides() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides)); return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
} }
/// Stride along a given axis /// Stride along a given axis
@ -430,23 +442,23 @@ public:
/// Return the NumPy array flags /// Return the NumPy array flags
int flags() const { int flags() const {
return PyArray_FLAGS_(m_ptr); return detail::array_proxy(m_ptr)->flags;
} }
/// If set, the array is writeable (otherwise the buffer is read-only) /// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const { bool writeable() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
} }
/// If set, the array owns the data (will be freed when the array is deleted) /// If set, the array owns the data (will be freed when the array is deleted)
bool owndata() const { bool owndata() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
} }
/// Pointer to the contained data. If index is not provided, points to the /// Pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
template<typename... Ix> const void* data(Ix... index) const { template<typename... Ix> const void* data(Ix... index) const {
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...)); return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
} }
/// Mutable pointer to the contained data. If index is not provided, points to the /// Mutable pointer to the contained data. If index is not provided, points to the
@ -454,7 +466,7 @@ public:
/// May throw if the array is not writeable. /// May throw if the array is not writeable.
template<typename... Ix> void* mutable_data(Ix... index) { template<typename... Ix> void* mutable_data(Ix... index) {
check_writeable(); check_writeable();
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...)); return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
} }
/// Byte offset from beginning of the array to a given index (full or partial). /// Byte offset from beginning of the array to a given index (full or partial).
@ -620,7 +632,7 @@ public:
static bool _check(handle h) { static bool _check(handle h) {
const auto &api = detail::npy_api::get(); const auto &api = detail::npy_api::get();
return api.PyArray_Check_(h.ptr()) return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr()); && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
} }
protected: protected: