From 7bf90e8008fcaba6b03dfbc999610928197174fa Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Thu, 20 Oct 2016 16:11:08 +0100 Subject: [PATCH] Add a direct converter for numpy scalars --- include/pybind11/numpy.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index aa93c5500..dba8b7a2b 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -63,6 +63,14 @@ struct PyArray_Proxy { int flags; }; +struct PyVoidScalarObject_Proxy { + PyObject_VAR_HEAD + char *obval; + PyArrayDescr_Proxy *descr; + int flags; + PyObject *base; +}; + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, @@ -702,11 +710,29 @@ struct npy_format_descriptor::value>> { auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) pybind11_fail("NumPy: invalid buffer descriptor!"); + + register_direct_converter(); } private: static std::string format_str; static PyObject* dtype_ptr; + + static void register_direct_converter() { + auto converter = [=](PyObject *obj, void*& value) { + auto& api = npy_api::get(); + if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) + return false; + if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { + if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { + value = ((PyVoidScalarObject_Proxy *) obj)->obval; + return true; + } + } + return false; + }; + get_internals().direct_conversions[std::type_index(typeid(T))].push_back(converter); + } }; template