mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-26 07:02:11 +00:00
numpy.h replace macros with functions (#514)
This commit is contained in:
parent
7146d6299c
commit
b14f065fa9
@ -199,16 +199,28 @@ private:
|
||||
return api;
|
||||
}
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
#define PyArray_GET_(ptr, attr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||
#define PyArrayDescr_GET_(ptr, attr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||
#define PyArray_FLAGS_(ptr) \
|
||||
PyArray_GET_(ptr, flags)
|
||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||
(flag == (PyArray_FLAGS_(ptr) & flag))
|
||||
inline PyArray_Proxy* array_proxy(void* ptr) {
|
||||
return reinterpret_cast<PyArray_Proxy*>(ptr);
|
||||
}
|
||||
|
||||
inline const PyArray_Proxy* array_proxy(const void* ptr) {
|
||||
return reinterpret_cast<const PyArray_Proxy*>(ptr);
|
||||
}
|
||||
|
||||
inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
|
||||
return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
|
||||
}
|
||||
|
||||
inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
|
||||
return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
|
||||
}
|
||||
|
||||
inline bool check_flags(const void* ptr, int flag) {
|
||||
return (flag == (array_proxy(ptr)->flags & flag));
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
class dtype : public object {
|
||||
public:
|
||||
@ -249,17 +261,17 @@ public:
|
||||
|
||||
/// Size of the data type in bytes.
|
||||
size_t itemsize() const {
|
||||
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
|
||||
return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
|
||||
}
|
||||
|
||||
/// Returns true for structured data types.
|
||||
bool has_fields() const {
|
||||
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
|
||||
return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
|
||||
}
|
||||
|
||||
/// Single-character type code.
|
||||
char kind() const {
|
||||
return PyArrayDescr_GET_(m_ptr, kind);
|
||||
return detail::array_descriptor_proxy(m_ptr)->kind;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -341,7 +353,7 @@ public:
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (ptr) {
|
||||
if (base) {
|
||||
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
|
||||
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
|
||||
} else {
|
||||
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
||||
}
|
||||
@ -376,7 +388,7 @@ public:
|
||||
|
||||
/// Array descriptor (dtype)
|
||||
pybind11::dtype dtype() const {
|
||||
return reinterpret_borrow<pybind11::dtype>(PyArray_GET_(m_ptr, descr));
|
||||
return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
|
||||
}
|
||||
|
||||
/// Total number of elements
|
||||
@ -386,7 +398,7 @@ public:
|
||||
|
||||
/// Byte size of a single element
|
||||
size_t itemsize() const {
|
||||
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
|
||||
return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
|
||||
}
|
||||
|
||||
/// Total number of bytes
|
||||
@ -396,17 +408,17 @@ public:
|
||||
|
||||
/// Number of dimensions
|
||||
size_t ndim() const {
|
||||
return (size_t) PyArray_GET_(m_ptr, nd);
|
||||
return (size_t) detail::array_proxy(m_ptr)->nd;
|
||||
}
|
||||
|
||||
/// Base object
|
||||
object base() const {
|
||||
return reinterpret_borrow<object>(PyArray_GET_(m_ptr, base));
|
||||
return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
|
||||
}
|
||||
|
||||
/// Dimensions of the array
|
||||
const size_t* shape() const {
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
|
||||
}
|
||||
|
||||
/// Dimension along a given axis
|
||||
@ -418,7 +430,7 @@ public:
|
||||
|
||||
/// Strides of the array
|
||||
const size_t* strides() const {
|
||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
||||
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
|
||||
}
|
||||
|
||||
/// Stride along a given axis
|
||||
@ -430,23 +442,23 @@ public:
|
||||
|
||||
/// Return the NumPy array flags
|
||||
int flags() const {
|
||||
return PyArray_FLAGS_(m_ptr);
|
||||
return detail::array_proxy(m_ptr)->flags;
|
||||
}
|
||||
|
||||
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||
bool writeable() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||
}
|
||||
|
||||
/// If set, the array owns the data (will be freed when the array is deleted)
|
||||
bool owndata() const {
|
||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||
}
|
||||
|
||||
/// Pointer to the contained data. If index is not provided, points to the
|
||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||
template<typename... Ix> const void* data(Ix... index) const {
|
||||
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Mutable pointer to the contained data. If index is not provided, points to the
|
||||
@ -454,7 +466,7 @@ public:
|
||||
/// May throw if the array is not writeable.
|
||||
template<typename... Ix> void* mutable_data(Ix... index) {
|
||||
check_writeable();
|
||||
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||
return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
||||
}
|
||||
|
||||
/// Byte offset from beginning of the array to a given index (full or partial).
|
||||
@ -620,7 +632,7 @@ public:
|
||||
static bool _check(handle h) {
|
||||
const auto &api = detail::npy_api::get();
|
||||
return api.PyArray_Check_(h.ptr())
|
||||
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
|
||||
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
Loading…
Reference in New Issue
Block a user