diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 77006c887..72dd4b371 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -199,16 +199,28 @@ private: return api; } }; -NAMESPACE_END(detail) -#define PyArray_GET_(ptr, attr) \ - (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) -#define PyArrayDescr_GET_(ptr, attr) \ - (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) -#define PyArray_FLAGS_(ptr) \ - PyArray_GET_(ptr, flags) -#define PyArray_CHKFLAGS_(ptr, flag) \ - (flag == (PyArray_FLAGS_(ptr) & flag)) +inline PyArray_Proxy* array_proxy(void* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArray_Proxy* array_proxy(const void* ptr) { + return reinterpret_cast(ptr); +} + +inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline bool check_flags(const void* ptr, int flag) { + return (flag == (array_proxy(ptr)->flags & flag)); +} + +NAMESPACE_END(detail) class dtype : public object { public: @@ -249,17 +261,17 @@ public: /// Size of the data type in bytes. size_t itemsize() const { - return (size_t) PyArrayDescr_GET_(m_ptr, elsize); + return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize; } /// Returns true for structured data types. bool has_fields() const { - return PyArrayDescr_GET_(m_ptr, names) != nullptr; + return detail::array_descriptor_proxy(m_ptr)->names != nullptr; } /// Single-character type code. char kind() const { - return PyArrayDescr_GET_(m_ptr, kind); + return detail::array_descriptor_proxy(m_ptr)->kind; } private: @@ -341,7 +353,7 @@ public: pybind11_fail("NumPy: unable to create array!"); if (ptr) { if (base) { - PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr(); + detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr(); } else { tmp = reinterpret_steal(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); } @@ -376,7 +388,7 @@ public: /// Array descriptor (dtype) pybind11::dtype dtype() const { - return reinterpret_borrow(PyArray_GET_(m_ptr, descr)); + return reinterpret_borrow(detail::array_proxy(m_ptr)->descr); } /// Total number of elements @@ -386,7 +398,7 @@ public: /// Byte size of a single element size_t itemsize() const { - return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize); + return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize; } /// Total number of bytes @@ -396,17 +408,17 @@ public: /// Number of dimensions size_t ndim() const { - return (size_t) PyArray_GET_(m_ptr, nd); + return (size_t) detail::array_proxy(m_ptr)->nd; } /// Base object object base() const { - return reinterpret_borrow(PyArray_GET_(m_ptr, base)); + return reinterpret_borrow(detail::array_proxy(m_ptr)->base); } /// Dimensions of the array const size_t* shape() const { - return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); + return reinterpret_cast(detail::array_proxy(m_ptr)->dimensions); } /// Dimension along a given axis @@ -418,7 +430,7 @@ public: /// Strides of the array const size_t* strides() const { - return reinterpret_cast(PyArray_GET_(m_ptr, strides)); + return reinterpret_cast(detail::array_proxy(m_ptr)->strides); } /// Stride along a given axis @@ -430,23 +442,23 @@ public: /// Return the NumPy array flags int flags() const { - return PyArray_FLAGS_(m_ptr); + return detail::array_proxy(m_ptr)->flags; } /// If set, the array is writeable (otherwise the buffer is read-only) bool writeable() const { - return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); } /// If set, the array owns the data (will be freed when the array is deleted) bool owndata() const { - return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); } /// Pointer to the contained data. If index is not provided, points to the /// beginning of the buffer. May throw if the index would lead to out of bounds access. template const void* data(Ix... index) const { - return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); } /// Mutable pointer to the contained data. If index is not provided, points to the @@ -454,7 +466,7 @@ public: /// May throw if the array is not writeable. template void* mutable_data(Ix... index) { check_writeable(); - return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); } /// Byte offset from beginning of the array to a given index (full or partial). @@ -620,7 +632,7 @@ public: static bool _check(handle h) { const auto &api = detail::npy_api::get(); return api.PyArray_Check_(h.ptr()) - && api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of().ptr()); + && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()); } protected: