Add a direct converter for numpy scalars

This commit is contained in:
Ivan Smirnov 2016-10-20 16:11:08 +01:00
parent c275ee6b46
commit 7bf90e8008

View File

@ -63,6 +63,14 @@ struct PyArray_Proxy {
int flags;
};
struct PyVoidScalarObject_Proxy {
PyObject_VAR_HEAD
char *obval;
PyArrayDescr_Proxy *descr;
int flags;
PyObject *base;
};
struct npy_api {
enum constants {
NPY_C_CONTIGUOUS_ = 0x0001,
@ -702,11 +710,29 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!");
register_direct_converter();
}
private:
static std::string format_str;
static PyObject* dtype_ptr;
static void register_direct_converter() {
auto converter = [=](PyObject *obj, void*& value) {
auto& api = npy_api::get();
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
return false;
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
return true;
}
}
return false;
};
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(converter);
}
};
template <typename T>