From 2488b32066e6a7199719bd1f056037bbffc39b52 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Sun, 19 Jun 2016 15:48:55 +0100 Subject: [PATCH] Add PYBIND11_DTYPE macro for registering dtypes --- include/pybind11/numpy.h | 107 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 49489ceb1..186b82f95 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -14,6 +14,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(push) @@ -32,6 +34,7 @@ public: API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, + API_PyArray_DescrConverter = 174, API_PyArray_GetArrayParamsFromObject = 278, NPY_C_CONTIGUOUS_ = 0x0001, @@ -63,6 +66,7 @@ public: DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewFromDescr); + DECL_NPY_API(PyArray_DescrConverter); DECL_NPY_API(PyArray_GetArrayParamsFromObject); #undef DECL_NPY_API return api; @@ -77,6 +81,7 @@ public: PyObject *(*PyArray_NewCopy_)(PyObject *, int); PyTypeObject *PyArray_Type_; PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); + int (*PyArray_DescrConverter_) (PyObject *, PyObject **); int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, Py_ssize_t *, PyObject **, PyObject *); }; @@ -149,6 +154,19 @@ public: } }; +template struct format_descriptor +::value && + !std::is_integral::value && + !std::is_same::value && + !std::is_same::value && + !std::is_same>::value && + !std::is_same>::value>::type> +{ + static const char *value() { + return detail::npy_format_descriptor::format_str(); + } +}; + NAMESPACE_BEGIN(detail) template struct npy_format_descriptor::value>::type> { @@ -184,6 +202,95 @@ DECL_FMT(std::complex, NPY_CFLOAT_, "complex64"); DECL_FMT(std::complex, NPY_CDOUBLE_, "complex128"); #undef DECL_FMT +struct field_descriptor { + const char *name; + int offset; + PyObject *descr; +}; + +template struct npy_format_descriptor +::value && // offsetof only works correctly for POD types + !std::is_integral::value && + !std::is_same::value && + !std::is_same::value && + !std::is_same>::value && + !std::is_same>::value>::type> +{ + static PYBIND11_DESCR name() { return _("user-defined"); } + + static PyObject* descr() { + if (!descr_()) + pybind11_fail("NumPy: unsupported buffer format!"); + return descr_(); + } + + static const char* format_str() { + return format_str_(); + } + + static void register_dtype(std::initializer_list fields) { + array::API& api = array::lookup_api(); + auto args = py::dict(); + py::list names { }, offsets { }, formats { }; + std::vector dtypes; + for (auto field : fields) { + names.append(py::str(field.name)); + offsets.append(py::int_(field.offset)); + if (!field.descr) + pybind11_fail("NumPy: unsupported field dtype"); + dtypes.emplace_back(field.descr, false); + formats.append(dtypes.back()); + } + args["names"] = names; + args["offsets"] = offsets; + args["formats"] = formats; + if (!api.PyArray_DescrConverter_(args.ptr(), &descr_()) || !descr_()) + pybind11_fail("NumPy: failed to create structured dtype"); + auto np = module::import("numpy"); + auto empty = (object) np.attr("empty"); + if (auto arr = (object) empty(py::int_(0), object(descr(), true))) + if (auto view = PyMemoryView_FromObject(arr.ptr())) + if (auto info = PyMemoryView_GET_BUFFER(view)) { + std::strncpy(format_str_(), info->format, 4096); + return; + } + pybind11_fail("NumPy: failed to extract buffer format"); + } + +private: + static inline PyObject*& descr_() { static PyObject *ptr = nullptr; return ptr; } + static inline char* format_str_() { static char s[4096]; return s; } +}; + +#define FIELD_DESCRIPTOR(Type, Field) \ + ::pybind11::detail::field_descriptor { \ + #Field, offsetof(Type, Field), \ + ::pybind11::detail::npy_format_descriptor(0)->Field)>::descr() } + +// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro +// (C) William Swanson, Paul Fultz +#define EVAL0(...) __VA_ARGS__ +#define EVAL1(...) EVAL0 (EVAL0 (EVAL0 (__VA_ARGS__))) +#define EVAL2(...) EVAL1 (EVAL1 (EVAL1 (__VA_ARGS__))) +#define EVAL3(...) EVAL2 (EVAL2 (EVAL2 (__VA_ARGS__))) +#define EVAL4(...) EVAL3 (EVAL3 (EVAL3 (__VA_ARGS__))) +#define EVAL(...) EVAL4 (EVAL4 (EVAL4 (__VA_ARGS__))) +#define MAP_END(...) +#define MAP_OUT +#define MAP_COMMA , +#define MAP_GET_END() 0, MAP_END +#define MAP_NEXT0(test, next, ...) next MAP_OUT +#define MAP_NEXT1(test, next) MAP_NEXT0 (test, next, 0) +#define MAP_NEXT(test, next) MAP_NEXT1 (MAP_GET_END test, next) +#define MAP_LIST_NEXT1(test, next) MAP_NEXT0 (test, MAP_COMMA next, 0) +#define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1 (MAP_GET_END test, next) +#define MAP_LIST0(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST1) (f, t, peek, __VA_ARGS__) +#define MAP_LIST1(f, t, x, peek, ...) f(t, x) MAP_LIST_NEXT (peek, MAP_LIST0) (f, t, peek, __VA_ARGS__) +#define MAP_LIST(f, t, ...) EVAL (MAP_LIST1 (f, t, __VA_ARGS__, (), 0)) + +#define PYBIND11_DTYPE(Type, ...) \ + ::pybind11::detail::npy_format_descriptor::register_dtype({MAP_LIST(FIELD_DESCRIPTOR, Type, __VA_ARGS__)}) + template using array_iterator = typename std::add_pointer::type;