mirror of
https://github.com/pybind/pybind11.git
synced 2024-12-01 17:37:15 +00:00
numpy.h replace macros with functions (#514)
This commit is contained in:
parent
7146d6299c
commit
b14f065fa9
@ -199,16 +199,28 @@ private:
|
|||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
NAMESPACE_END(detail)
|
|
||||||
|
|
||||||
#define PyArray_GET_(ptr, attr) \
|
inline PyArray_Proxy* array_proxy(void* ptr) {
|
||||||
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
return reinterpret_cast<PyArray_Proxy*>(ptr);
|
||||||
#define PyArrayDescr_GET_(ptr, attr) \
|
}
|
||||||
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
|
||||||
#define PyArray_FLAGS_(ptr) \
|
inline const PyArray_Proxy* array_proxy(const void* ptr) {
|
||||||
PyArray_GET_(ptr, flags)
|
return reinterpret_cast<const PyArray_Proxy*>(ptr);
|
||||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
}
|
||||||
(flag == (PyArray_FLAGS_(ptr) & flag))
|
|
||||||
|
inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
|
||||||
|
return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
|
||||||
|
return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool check_flags(const void* ptr, int flag) {
|
||||||
|
return (flag == (array_proxy(ptr)->flags & flag));
|
||||||
|
}
|
||||||
|
|
||||||
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
class dtype : public object {
|
class dtype : public object {
|
||||||
public:
|
public:
|
||||||
@ -249,17 +261,17 @@ public:
|
|||||||
|
|
||||||
/// Size of the data type in bytes.
|
/// Size of the data type in bytes.
|
||||||
size_t itemsize() const {
|
size_t itemsize() const {
|
||||||
return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
|
return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true for structured data types.
|
/// Returns true for structured data types.
|
||||||
bool has_fields() const {
|
bool has_fields() const {
|
||||||
return PyArrayDescr_GET_(m_ptr, names) != nullptr;
|
return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Single-character type code.
|
/// Single-character type code.
|
||||||
char kind() const {
|
char kind() const {
|
||||||
return PyArrayDescr_GET_(m_ptr, kind);
|
return detail::array_descriptor_proxy(m_ptr)->kind;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -341,7 +353,7 @@ public:
|
|||||||
pybind11_fail("NumPy: unable to create array!");
|
pybind11_fail("NumPy: unable to create array!");
|
||||||
if (ptr) {
|
if (ptr) {
|
||||||
if (base) {
|
if (base) {
|
||||||
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
|
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
|
||||||
} else {
|
} else {
|
||||||
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
||||||
}
|
}
|
||||||
@ -376,7 +388,7 @@ public:
|
|||||||
|
|
||||||
/// Array descriptor (dtype)
|
/// Array descriptor (dtype)
|
||||||
pybind11::dtype dtype() const {
|
pybind11::dtype dtype() const {
|
||||||
return reinterpret_borrow<pybind11::dtype>(PyArray_GET_(m_ptr, descr));
|
return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Total number of elements
|
/// Total number of elements
|
||||||
@ -386,7 +398,7 @@ public:
|
|||||||
|
|
||||||
/// Byte size of a single element
|
/// Byte size of a single element
|
||||||
size_t itemsize() const {
|
size_t itemsize() const {
|
||||||
return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
|
return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Total number of bytes
|
/// Total number of bytes
|
||||||
@ -396,17 +408,17 @@ public:
|
|||||||
|
|
||||||
/// Number of dimensions
|
/// Number of dimensions
|
||||||
size_t ndim() const {
|
size_t ndim() const {
|
||||||
return (size_t) PyArray_GET_(m_ptr, nd);
|
return (size_t) detail::array_proxy(m_ptr)->nd;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Base object
|
/// Base object
|
||||||
object base() const {
|
object base() const {
|
||||||
return reinterpret_borrow<object>(PyArray_GET_(m_ptr, base));
|
return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dimensions of the array
|
/// Dimensions of the array
|
||||||
const size_t* shape() const {
|
const size_t* shape() const {
|
||||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dimension along a given axis
|
/// Dimension along a given axis
|
||||||
@ -418,7 +430,7 @@ public:
|
|||||||
|
|
||||||
/// Strides of the array
|
/// Strides of the array
|
||||||
const size_t* strides() const {
|
const size_t* strides() const {
|
||||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stride along a given axis
|
/// Stride along a given axis
|
||||||
@ -430,23 +442,23 @@ public:
|
|||||||
|
|
||||||
/// Return the NumPy array flags
|
/// Return the NumPy array flags
|
||||||
int flags() const {
|
int flags() const {
|
||||||
return PyArray_FLAGS_(m_ptr);
|
return detail::array_proxy(m_ptr)->flags;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If set, the array is writeable (otherwise the buffer is read-only)
|
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||||
bool writeable() const {
|
bool writeable() const {
|
||||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If set, the array owns the data (will be freed when the array is deleted)
|
/// If set, the array owns the data (will be freed when the array is deleted)
|
||||||
bool owndata() const {
|
bool owndata() const {
|
||||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pointer to the contained data. If index is not provided, points to the
|
/// Pointer to the contained data. If index is not provided, points to the
|
||||||
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||||
template<typename... Ix> const void* data(Ix... index) const {
|
template<typename... Ix> const void* data(Ix... index) const {
|
||||||
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutable pointer to the contained data. If index is not provided, points to the
|
/// Mutable pointer to the contained data. If index is not provided, points to the
|
||||||
@ -454,7 +466,7 @@ public:
|
|||||||
/// May throw if the array is not writeable.
|
/// May throw if the array is not writeable.
|
||||||
template<typename... Ix> void* mutable_data(Ix... index) {
|
template<typename... Ix> void* mutable_data(Ix... index) {
|
||||||
check_writeable();
|
check_writeable();
|
||||||
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Byte offset from beginning of the array to a given index (full or partial).
|
/// Byte offset from beginning of the array to a given index (full or partial).
|
||||||
@ -620,7 +632,7 @@ public:
|
|||||||
static bool _check(handle h) {
|
static bool _check(handle h) {
|
||||||
const auto &api = detail::npy_api::get();
|
const auto &api = detail::npy_api::get();
|
||||||
return api.PyArray_Check_(h.ptr())
|
return api.PyArray_Check_(h.ptr())
|
||||||
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
|
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
Loading…
Reference in New Issue
Block a user