Add PYBIND11_DTYPE macro for registering dtypes

This commit is contained in:
Ivan Smirnov 2016-06-19 15:48:55 +01:00
parent fab02efb10
commit 2488b32066

View File

@ -14,6 +14,8 @@
#include <numeric>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <initializer_list>
#if defined(_MSC_VER)
#pragma warning(push)
@ -32,6 +34,7 @@ public:
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrConverter = 174,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
@ -63,6 +66,7 @@ public:
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
@ -77,6 +81,7 @@ public:
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
@ -149,6 +154,19 @@ public:
}
};
template <typename T> struct format_descriptor
<T, typename std::enable_if<std::is_pod<T>::value &&
!std::is_integral<T>::value &&
!std::is_same<T, float>::value &&
!std::is_same<T, bool>::value &&
!std::is_same<T, std::complex<float>>::value &&
!std::is_same<T, std::complex<double>>::value>::type>
{
static const char *value() {
return detail::npy_format_descriptor<T>::format_str();
}
};
NAMESPACE_BEGIN(detail)
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
@ -184,6 +202,95 @@ DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#undef DECL_FMT
struct field_descriptor {
const char *name;
int offset;
PyObject *descr;
};
template <typename T> struct npy_format_descriptor
<T, typename std::enable_if<std::is_pod<T>::value && // offsetof only works correctly for POD types
!std::is_integral<T>::value &&
!std::is_same<T, float>::value &&
!std::is_same<T, bool>::value &&
!std::is_same<T, std::complex<float>>::value &&
!std::is_same<T, std::complex<double>>::value>::type>
{
static PYBIND11_DESCR name() { return _("user-defined"); }
static PyObject* descr() {
if (!descr_())
pybind11_fail("NumPy: unsupported buffer format!");
return descr_();
}
static const char* format_str() {
return format_str_();
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
array::API& api = array::lookup_api();
auto args = py::dict();
py::list names { }, offsets { }, formats { };
std::vector<py::object> dtypes;
for (auto field : fields) {
names.append(py::str(field.name));
offsets.append(py::int_(field.offset));
if (!field.descr)
pybind11_fail("NumPy: unsupported field dtype");
dtypes.emplace_back(field.descr, false);
formats.append(dtypes.back());
}
args["names"] = names;
args["offsets"] = offsets;
args["formats"] = formats;
if (!api.PyArray_DescrConverter_(args.ptr(), &descr_()) || !descr_())
pybind11_fail("NumPy: failed to create structured dtype");
auto np = module::import("numpy");
auto empty = (object) np.attr("empty");
if (auto arr = (object) empty(py::int_(0), object(descr(), true)))
if (auto view = PyMemoryView_FromObject(arr.ptr()))
if (auto info = PyMemoryView_GET_BUFFER(view)) {
std::strncpy(format_str_(), info->format, 4096);
return;
}
pybind11_fail("NumPy: failed to extract buffer format");
}
private:
static inline PyObject*& descr_() { static PyObject *ptr = nullptr; return ptr; }
static inline char* format_str_() { static char s[4096]; return s; }
};
#define FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \
#Field, offsetof(Type, Field), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::descr() }
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
#define EVAL0(...) __VA_ARGS__
#define EVAL1(...) EVAL0 (EVAL0 (EVAL0 (__VA_ARGS__)))
#define EVAL2(...) EVAL1 (EVAL1 (EVAL1 (__VA_ARGS__)))
#define EVAL3(...) EVAL2 (EVAL2 (EVAL2 (__VA_ARGS__)))
#define EVAL4(...) EVAL3 (EVAL3 (EVAL3 (__VA_ARGS__)))
#define EVAL(...) EVAL4 (EVAL4 (EVAL4 (__VA_ARGS__)))
#define MAP_END(...)
#define MAP_OUT
#define MAP_COMMA ,
#define MAP_GET_END() 0, MAP_END
#define MAP_NEXT0(test, next, ...) next MAP_OUT
#define MAP_NEXT1(test, next) MAP_NEXT0 (test, next, 0)
#define MAP_NEXT(test, next) MAP_NEXT1 (MAP_GET_END test, next)
#define MAP_LIST_NEXT1(test, next) MAP_NEXT0 (test, MAP_COMMA next, 0)
#define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1 (MAP_GET_END test, next)
#define MAP_LIST0(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define MAP_LIST1(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST0) (f, t, peek, __VA_ARGS__)
#define MAP_LIST(f, t, ...) EVAL (MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
#define PYBIND11_DTYPE(Type, ...) \
::pybind11::detail::npy_format_descriptor<Type>::register_dtype({MAP_LIST(FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
template <class T>
using array_iterator = typename std::add_pointer<T>::type;