mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-26 15:12:01 +00:00
Add PYBIND11_DTYPE macro for registering dtypes
This commit is contained in:
parent
fab02efb10
commit
2488b32066
@ -14,6 +14,8 @@
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <initializer_list>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
@ -32,6 +34,7 @@ public:
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
@ -63,6 +66,7 @@ public:
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
@ -77,6 +81,7 @@ public:
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||
Py_ssize_t *, PyObject **, PyObject *);
|
||||
};
|
||||
@ -149,6 +154,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct format_descriptor
|
||||
<T, typename std::enable_if<std::is_pod<T>::value &&
|
||||
!std::is_integral<T>::value &&
|
||||
!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, bool>::value &&
|
||||
!std::is_same<T, std::complex<float>>::value &&
|
||||
!std::is_same<T, std::complex<double>>::value>::type>
|
||||
{
|
||||
static const char *value() {
|
||||
return detail::npy_format_descriptor<T>::format_str();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
|
||||
@ -184,6 +202,95 @@ DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
|
||||
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
#undef DECL_FMT
|
||||
|
||||
struct field_descriptor {
|
||||
const char *name;
|
||||
int offset;
|
||||
PyObject *descr;
|
||||
};
|
||||
|
||||
template <typename T> struct npy_format_descriptor
|
||||
<T, typename std::enable_if<std::is_pod<T>::value && // offsetof only works correctly for POD types
|
||||
!std::is_integral<T>::value &&
|
||||
!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, bool>::value &&
|
||||
!std::is_same<T, std::complex<float>>::value &&
|
||||
!std::is_same<T, std::complex<double>>::value>::type>
|
||||
{
|
||||
static PYBIND11_DESCR name() { return _("user-defined"); }
|
||||
|
||||
static PyObject* descr() {
|
||||
if (!descr_())
|
||||
pybind11_fail("NumPy: unsupported buffer format!");
|
||||
return descr_();
|
||||
}
|
||||
|
||||
static const char* format_str() {
|
||||
return format_str_();
|
||||
}
|
||||
|
||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||
array::API& api = array::lookup_api();
|
||||
auto args = py::dict();
|
||||
py::list names { }, offsets { }, formats { };
|
||||
std::vector<py::object> dtypes;
|
||||
for (auto field : fields) {
|
||||
names.append(py::str(field.name));
|
||||
offsets.append(py::int_(field.offset));
|
||||
if (!field.descr)
|
||||
pybind11_fail("NumPy: unsupported field dtype");
|
||||
dtypes.emplace_back(field.descr, false);
|
||||
formats.append(dtypes.back());
|
||||
}
|
||||
args["names"] = names;
|
||||
args["offsets"] = offsets;
|
||||
args["formats"] = formats;
|
||||
if (!api.PyArray_DescrConverter_(args.ptr(), &descr_()) || !descr_())
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
auto np = module::import("numpy");
|
||||
auto empty = (object) np.attr("empty");
|
||||
if (auto arr = (object) empty(py::int_(0), object(descr(), true)))
|
||||
if (auto view = PyMemoryView_FromObject(arr.ptr()))
|
||||
if (auto info = PyMemoryView_GET_BUFFER(view)) {
|
||||
std::strncpy(format_str_(), info->format, 4096);
|
||||
return;
|
||||
}
|
||||
pybind11_fail("NumPy: failed to extract buffer format");
|
||||
}
|
||||
|
||||
private:
|
||||
static inline PyObject*& descr_() { static PyObject *ptr = nullptr; return ptr; }
|
||||
static inline char* format_str_() { static char s[4096]; return s; }
|
||||
};
|
||||
|
||||
#define FIELD_DESCRIPTOR(Type, Field) \
|
||||
::pybind11::detail::field_descriptor { \
|
||||
#Field, offsetof(Type, Field), \
|
||||
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::descr() }
|
||||
|
||||
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
|
||||
// (C) William Swanson, Paul Fultz
|
||||
#define EVAL0(...) __VA_ARGS__
|
||||
#define EVAL1(...) EVAL0 (EVAL0 (EVAL0 (__VA_ARGS__)))
|
||||
#define EVAL2(...) EVAL1 (EVAL1 (EVAL1 (__VA_ARGS__)))
|
||||
#define EVAL3(...) EVAL2 (EVAL2 (EVAL2 (__VA_ARGS__)))
|
||||
#define EVAL4(...) EVAL3 (EVAL3 (EVAL3 (__VA_ARGS__)))
|
||||
#define EVAL(...) EVAL4 (EVAL4 (EVAL4 (__VA_ARGS__)))
|
||||
#define MAP_END(...)
|
||||
#define MAP_OUT
|
||||
#define MAP_COMMA ,
|
||||
#define MAP_GET_END() 0, MAP_END
|
||||
#define MAP_NEXT0(test, next, ...) next MAP_OUT
|
||||
#define MAP_NEXT1(test, next) MAP_NEXT0 (test, next, 0)
|
||||
#define MAP_NEXT(test, next) MAP_NEXT1 (MAP_GET_END test, next)
|
||||
#define MAP_LIST_NEXT1(test, next) MAP_NEXT0 (test, MAP_COMMA next, 0)
|
||||
#define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1 (MAP_GET_END test, next)
|
||||
#define MAP_LIST0(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST1) (f, t, peek, __VA_ARGS__)
|
||||
#define MAP_LIST1(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST0) (f, t, peek, __VA_ARGS__)
|
||||
#define MAP_LIST(f, t, ...) EVAL (MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
|
||||
|
||||
#define PYBIND11_DTYPE(Type, ...) \
|
||||
::pybind11::detail::npy_format_descriptor<Type>::register_dtype({MAP_LIST(FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
||||
|
||||
template <class T>
|
||||
using array_iterator = typename std::add_pointer<T>::type;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user