From 43398a854859ba9df3bc7cd1d7c135a18e066d75 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 28 Jul 2015 16:12:20 +0200 Subject: [PATCH] complex number support --- example/example10.cpp | 7 +++++- example/example10.py | 3 +++ include/pybind/cast.h | 17 +++++++++++++ include/pybind/common.h | 10 ++++---- include/pybind/numpy.h | 53 ++++++++++++++++++++++++++++++----------- include/pybind/pybind.h | 26 +++++++++++++------- 6 files changed, 88 insertions(+), 28 deletions(-) diff --git a/example/example10.cpp b/example/example10.cpp index b53c053ed..360066ae1 100644 --- a/example/example10.cpp +++ b/example/example10.cpp @@ -15,14 +15,19 @@ double my_func(int x, float y, double z) { return x*y*z; } +std::complex my_func3(std::complex c) { + return c * std::complex(2.f); +} + void init_ex10(py::module &m) { // Vectorize all arguments (though non-vector arguments are also allowed) m.def("vectorized_func", py::vectorize(my_func)); - // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization) m.def("vectorized_func2", [](py::array_dtype x, py::array_dtype y, float z) { return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y); } ); + // Vectorize all arguments (complex numbers) + m.def("vectorized_func3", py::vectorize(my_func3)); } diff --git a/example/example10.py b/example/example10.py index 337292670..401c5ccc7 100644 --- a/example/example10.py +++ b/example/example10.py @@ -7,6 +7,9 @@ import numpy as np from example import vectorized_func from example import vectorized_func2 +from example import vectorized_func3 + +print(vectorized_func3(np.array(3+7j))) for f in [vectorized_func, vectorized_func2]: print(f(1, 2, 3)) diff --git a/include/pybind/cast.h b/include/pybind/cast.h index 6c26f41ba..9da3fcaac 100644 --- a/include/pybind/cast.h +++ b/include/pybind/cast.h @@ -192,6 +192,23 @@ public: PYBIND_TYPE_CASTER(bool, "bool"); }; +template class type_caster> { +public: + bool load(PyObject *src, bool) { + Py_complex result = PyComplex_AsCComplex(src); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + value = std::complex((T) result.real, (T) result.imag); + return true; + } + static PyObject *cast(const std::complex &src, return_value_policy /* policy */, PyObject * /* parent */) { + return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); + } + PYBIND_TYPE_CASTER(std::complex, "complex"); +}; + template <> class type_caster { public: bool load(PyObject *src, bool) { diff --git a/include/pybind/common.h b/include/pybind/common.h index d1a6b47e7..e4a7d34ac 100644 --- a/include/pybind/common.h +++ b/include/pybind/common.h @@ -33,7 +33,7 @@ #include #include #include -#include +#include /// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode #if defined(_MSC_VER) @@ -82,7 +82,8 @@ template struct format_descriptor { }; #define DECL_FMT(t, n) template<> struct format_descriptor { static std::string value() { return n; }; }; DECL_FMT(int8_t, "b"); DECL_FMT(uint8_t, "B"); DECL_FMT(int16_t, "h"); DECL_FMT(uint16_t, "H"); DECL_FMT(int32_t, "i"); DECL_FMT(uint32_t, "I"); DECL_FMT(int64_t, "q"); DECL_FMT(uint64_t, "Q"); -DECL_FMT(float , "f"); DECL_FMT(double, "d"); +DECL_FMT(float, "f"); DECL_FMT(double, "d"); DECL_FMT(bool, "?"); +DECL_FMT(std::complex, "Zf"); DECL_FMT(std::complex, "Zd"); #undef DECL_FMT /// Information record describing a Python buffer object @@ -126,11 +127,12 @@ struct type_info { PyTypeObject *type; size_t type_size; void (*init_holder)(PyObject *); - std::function get_buffer; std::vector implicit_conversions; + buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; + void *get_buffer_data = nullptr; }; -/// Internal data struture used to track registered instances and types +/// Internal data struture used to track registered instances and types struct internals { std::unordered_map registered_types; std::unordered_map registered_instances; diff --git a/include/pybind/numpy.h b/include/pybind/numpy.h index f4a4a74e7..033679470 100644 --- a/include/pybind/numpy.h +++ b/include/pybind/numpy.h @@ -17,8 +17,10 @@ NAMESPACE_BEGIN(pybind) +template struct npy_format_descriptor { }; + class array : public buffer { -protected: +public: struct API { enum Entries { API_PyArray_Type = 2, @@ -26,10 +28,18 @@ protected: API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, - API_NPY_C_CONTIGUOUS = 0x0001, - API_NPY_F_CONTIGUOUS = 0x0002, - API_NPY_NPY_ARRAY_FORCECAST = 0x0010, - API_NPY_ENSURE_ARRAY = 0x0040 + NPY_C_CONTIGUOUS = 0x0001, + NPY_F_CONTIGUOUS = 0x0002, + NPY_NPY_ARRAY_FORCECAST = 0x0010, + NPY_ENSURE_ARRAY = 0x0040, + NPY_BOOL=0, + NPY_BYTE, NPY_UBYTE, + NPY_SHORT, NPY_USHORT, + NPY_INT, NPY_UINT, + NPY_LONG, NPY_ULONG, + NPY_LONGLONG, NPY_ULONGLONG, + NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, + NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE }; static API lookup() { @@ -59,13 +69,12 @@ protected: PyTypeObject *PyArray_Type; PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *); }; -public: + PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check) template array(size_t size, const Type *ptr) { API& api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType( - (int) format_descriptor::value()[0]); + PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format!"); Py_intptr_t shape = (Py_intptr_t) size; @@ -83,7 +92,12 @@ public: API& api = lookup_api(); if (info.format.size() != 1) throw std::runtime_error("Unsupported buffer format!"); - PyObject *descr = api.PyArray_DescrFromType(info.format[0]); + int fmt = (int) info.format[0]; + if (info.format == "Zd") + fmt = API::NPY_CDOUBLE; + else if (info.format == "Zf") + fmt = API::NPY_CFLOAT; + PyObject *descr = api.PyArray_DescrFromType(fmt); if (descr == nullptr) throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!"); PyObject *tmp = api.PyArray_NewFromDescr( @@ -109,12 +123,12 @@ public: PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr)); array_dtype() : array() { } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } - static PyObject *ensure(PyObject *ptr) { + PyObject *ensure(PyObject *ptr) { API &api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType(format_descriptor::value()[0]); + PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor::value); return api.PyArray_FromAny(ptr, descr, 0, 0, - API::API_NPY_C_CONTIGUOUS | API::API_NPY_ENSURE_ARRAY | - API::API_NPY_NPY_ARRAY_FORCECAST, nullptr); + API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY | + API::NPY_NPY_ARRAY_FORCECAST, nullptr); } }; @@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_ PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) PYBIND_TYPE_CASTER_PYTYPE(array_dtype) +PYBIND_TYPE_CASTER_PYTYPE(array_dtype>) +PYBIND_TYPE_CASTER_PYTYPE(array_dtype>) +PYBIND_TYPE_CASTER_PYTYPE(array_dtype) NAMESPACE_END(detail) +#define DECL_FMT(t, n) template<> struct npy_format_descriptor { enum { value = array::API::n }; } +DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT); +DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT); +DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT); +DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex, NPY_CFLOAT); +DECL_FMT(std::complex, NPY_CDOUBLE); +#undef DECL_FMT + template std::function...)> vectorize(func_type &&f, return_type (*) (args_type ...), @@ -171,7 +196,7 @@ template ::value(), ndim, shape, strides)); }; diff --git a/include/pybind/pybind.h b/include/pybind/pybind.h index e8c4ff5d7..570d1939c 100644 --- a/include/pybind/pybind.h +++ b/include/pybind/pybind.h @@ -393,22 +393,27 @@ protected: Py_TYPE(self)->tp_free((PyObject*) self); } - void install_buffer_funcs(const std::function &func) { + void install_buffer_funcs( + buffer_info *(*get_buffer)(PyObject *, void *), + void *get_buffer_data) { PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr; type->ht_type.tp_as_buffer = &type->as_buffer; type->as_buffer.bf_getbuffer = getbuffer; type->as_buffer.bf_releasebuffer = releasebuffer; - ((detail::type_info *) capsule(attr("__pybind__")))->get_buffer = func; + auto info = ((detail::type_info *) capsule(attr("__pybind__"))); + info->get_buffer = get_buffer; + info->get_buffer_data = get_buffer_data; } static int getbuffer(PyObject *obj, Py_buffer *view, int flags) { - auto const &info_func = ((detail::type_info *) capsule(handle(obj).attr("__pybind__")))->get_buffer; - if (view == nullptr || obj == nullptr || !info_func) { + auto const &typeinfo = ((detail::type_info *) capsule(handle(obj).attr("__pybind__"))); + + if (view == nullptr || obj == nullptr || !typeinfo || !typeinfo->get_buffer) { PyErr_SetString(PyExc_BufferError, "Internal error"); return -1; } memset(view, 0, sizeof(Py_buffer)); - buffer_info *info = info_func(obj); + buffer_info *info = typeinfo->get_buffer(obj, typeinfo->get_buffer_data); view->obj = obj; view->ndim = 1; view->internal = info; @@ -483,13 +488,16 @@ public: return *this; } - class_& def_buffer(const std::function &func) { - install_buffer_funcs([func](PyObject *obj) -> buffer_info* { + template + class_& def_buffer(Func &&func) { + struct capture { Func func; }; + capture *ptr = new capture { std::forward(func) }; + install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* { detail::type_caster caster; if (!caster.load(obj, false)) return nullptr; - return new buffer_info(func(caster)); - }); + return new buffer_info(((capture *) ptr)->func(caster)); + }, ptr); return *this; }