mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 06:35:12 +00:00
complex number support
This commit is contained in:
parent
d4258bafef
commit
43398a8548
@ -15,14 +15,19 @@ double my_func(int x, float y, double z) {
|
|||||||
return x*y*z;
|
return x*y*z;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::complex<double> my_func3(std::complex<double> c) {
|
||||||
|
return c * std::complex<double>(2.f);
|
||||||
|
}
|
||||||
|
|
||||||
void init_ex10(py::module &m) {
|
void init_ex10(py::module &m) {
|
||||||
// Vectorize all arguments (though non-vector arguments are also allowed)
|
// Vectorize all arguments (though non-vector arguments are also allowed)
|
||||||
m.def("vectorized_func", py::vectorize(my_func));
|
m.def("vectorized_func", py::vectorize(my_func));
|
||||||
|
|
||||||
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
|
// Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
|
||||||
m.def("vectorized_func2",
|
m.def("vectorized_func2",
|
||||||
[](py::array_dtype<int> x, py::array_dtype<float> y, float z) {
|
[](py::array_dtype<int> x, py::array_dtype<float> y, float z) {
|
||||||
return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y);
|
return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
// Vectorize all arguments (complex numbers)
|
||||||
|
m.def("vectorized_func3", py::vectorize(my_func3));
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,9 @@ import numpy as np
|
|||||||
|
|
||||||
from example import vectorized_func
|
from example import vectorized_func
|
||||||
from example import vectorized_func2
|
from example import vectorized_func2
|
||||||
|
from example import vectorized_func3
|
||||||
|
|
||||||
|
print(vectorized_func3(np.array(3+7j)))
|
||||||
|
|
||||||
for f in [vectorized_func, vectorized_func2]:
|
for f in [vectorized_func, vectorized_func2]:
|
||||||
print(f(1, 2, 3))
|
print(f(1, 2, 3))
|
||||||
|
@ -192,6 +192,23 @@ public:
|
|||||||
PYBIND_TYPE_CASTER(bool, "bool");
|
PYBIND_TYPE_CASTER(bool, "bool");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T> class type_caster<std::complex<T>> {
|
||||||
|
public:
|
||||||
|
bool load(PyObject *src, bool) {
|
||||||
|
Py_complex result = PyComplex_AsCComplex(src);
|
||||||
|
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||||
|
PyErr_Clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
value = std::complex<T>((T) result.real, (T) result.imag);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static PyObject *cast(const std::complex<T> &src, return_value_policy /* policy */, PyObject * /* parent */) {
|
||||||
|
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
|
||||||
|
}
|
||||||
|
PYBIND_TYPE_CASTER(std::complex<T>, "complex");
|
||||||
|
};
|
||||||
|
|
||||||
template <> class type_caster<std::string> {
|
template <> class type_caster<std::string> {
|
||||||
public:
|
public:
|
||||||
bool load(PyObject *src, bool) {
|
bool load(PyObject *src, bool) {
|
||||||
|
@ -33,7 +33,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <complex>
|
||||||
|
|
||||||
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
|
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
@ -82,7 +82,8 @@ template <typename type> struct format_descriptor { };
|
|||||||
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
|
#define DECL_FMT(t, n) template<> struct format_descriptor<t> { static std::string value() { return n; }; };
|
||||||
DECL_FMT(int8_t, "b"); DECL_FMT(uint8_t, "B"); DECL_FMT(int16_t, "h"); DECL_FMT(uint16_t, "H");
|
DECL_FMT(int8_t, "b"); DECL_FMT(uint8_t, "B"); DECL_FMT(int16_t, "h"); DECL_FMT(uint16_t, "H");
|
||||||
DECL_FMT(int32_t, "i"); DECL_FMT(uint32_t, "I"); DECL_FMT(int64_t, "q"); DECL_FMT(uint64_t, "Q");
|
DECL_FMT(int32_t, "i"); DECL_FMT(uint32_t, "I"); DECL_FMT(int64_t, "q"); DECL_FMT(uint64_t, "Q");
|
||||||
DECL_FMT(float , "f"); DECL_FMT(double, "d");
|
DECL_FMT(float, "f"); DECL_FMT(double, "d"); DECL_FMT(bool, "?");
|
||||||
|
DECL_FMT(std::complex<float>, "Zf"); DECL_FMT(std::complex<double>, "Zd");
|
||||||
#undef DECL_FMT
|
#undef DECL_FMT
|
||||||
|
|
||||||
/// Information record describing a Python buffer object
|
/// Information record describing a Python buffer object
|
||||||
@ -126,11 +127,12 @@ struct type_info {
|
|||||||
PyTypeObject *type;
|
PyTypeObject *type;
|
||||||
size_t type_size;
|
size_t type_size;
|
||||||
void (*init_holder)(PyObject *);
|
void (*init_holder)(PyObject *);
|
||||||
std::function<buffer_info *(PyObject *)> get_buffer;
|
|
||||||
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
||||||
|
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
||||||
|
void *get_buffer_data = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Internal data struture used to track registered instances and types
|
/// Internal data struture used to track registered instances and types
|
||||||
struct internals {
|
struct internals {
|
||||||
std::unordered_map<std::string, type_info> registered_types;
|
std::unordered_map<std::string, type_info> registered_types;
|
||||||
std::unordered_map<void *, PyObject *> registered_instances;
|
std::unordered_map<void *, PyObject *> registered_instances;
|
||||||
|
@ -17,8 +17,10 @@
|
|||||||
|
|
||||||
NAMESPACE_BEGIN(pybind)
|
NAMESPACE_BEGIN(pybind)
|
||||||
|
|
||||||
|
template <typename type> struct npy_format_descriptor { };
|
||||||
|
|
||||||
class array : public buffer {
|
class array : public buffer {
|
||||||
protected:
|
public:
|
||||||
struct API {
|
struct API {
|
||||||
enum Entries {
|
enum Entries {
|
||||||
API_PyArray_Type = 2,
|
API_PyArray_Type = 2,
|
||||||
@ -26,10 +28,18 @@ protected:
|
|||||||
API_PyArray_FromAny = 69,
|
API_PyArray_FromAny = 69,
|
||||||
API_PyArray_NewCopy = 85,
|
API_PyArray_NewCopy = 85,
|
||||||
API_PyArray_NewFromDescr = 94,
|
API_PyArray_NewFromDescr = 94,
|
||||||
API_NPY_C_CONTIGUOUS = 0x0001,
|
NPY_C_CONTIGUOUS = 0x0001,
|
||||||
API_NPY_F_CONTIGUOUS = 0x0002,
|
NPY_F_CONTIGUOUS = 0x0002,
|
||||||
API_NPY_NPY_ARRAY_FORCECAST = 0x0010,
|
NPY_NPY_ARRAY_FORCECAST = 0x0010,
|
||||||
API_NPY_ENSURE_ARRAY = 0x0040
|
NPY_ENSURE_ARRAY = 0x0040,
|
||||||
|
NPY_BOOL=0,
|
||||||
|
NPY_BYTE, NPY_UBYTE,
|
||||||
|
NPY_SHORT, NPY_USHORT,
|
||||||
|
NPY_INT, NPY_UINT,
|
||||||
|
NPY_LONG, NPY_ULONG,
|
||||||
|
NPY_LONGLONG, NPY_ULONGLONG,
|
||||||
|
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
|
||||||
|
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE
|
||||||
};
|
};
|
||||||
|
|
||||||
static API lookup() {
|
static API lookup() {
|
||||||
@ -59,13 +69,12 @@ protected:
|
|||||||
PyTypeObject *PyArray_Type;
|
PyTypeObject *PyArray_Type;
|
||||||
PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
|
PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||||
};
|
};
|
||||||
public:
|
|
||||||
PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
|
PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
|
||||||
|
|
||||||
template <typename Type> array(size_t size, const Type *ptr) {
|
template <typename Type> array(size_t size, const Type *ptr) {
|
||||||
API& api = lookup_api();
|
API& api = lookup_api();
|
||||||
PyObject *descr = api.PyArray_DescrFromType(
|
PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value);
|
||||||
(int) format_descriptor<Type>::value()[0]);
|
|
||||||
if (descr == nullptr)
|
if (descr == nullptr)
|
||||||
throw std::runtime_error("NumPy: unsupported buffer format!");
|
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||||
Py_intptr_t shape = (Py_intptr_t) size;
|
Py_intptr_t shape = (Py_intptr_t) size;
|
||||||
@ -83,7 +92,12 @@ public:
|
|||||||
API& api = lookup_api();
|
API& api = lookup_api();
|
||||||
if (info.format.size() != 1)
|
if (info.format.size() != 1)
|
||||||
throw std::runtime_error("Unsupported buffer format!");
|
throw std::runtime_error("Unsupported buffer format!");
|
||||||
PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
|
int fmt = (int) info.format[0];
|
||||||
|
if (info.format == "Zd")
|
||||||
|
fmt = API::NPY_CDOUBLE;
|
||||||
|
else if (info.format == "Zf")
|
||||||
|
fmt = API::NPY_CFLOAT;
|
||||||
|
PyObject *descr = api.PyArray_DescrFromType(fmt);
|
||||||
if (descr == nullptr)
|
if (descr == nullptr)
|
||||||
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||||
PyObject *tmp = api.PyArray_NewFromDescr(
|
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||||
@ -109,12 +123,12 @@ public:
|
|||||||
PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
|
PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
|
||||||
array_dtype() : array() { }
|
array_dtype() : array() { }
|
||||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||||
static PyObject *ensure(PyObject *ptr) {
|
PyObject *ensure(PyObject *ptr) {
|
||||||
API &api = lookup_api();
|
API &api = lookup_api();
|
||||||
PyObject *descr = api.PyArray_DescrFromType(format_descriptor<T>::value()[0]);
|
PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value);
|
||||||
return api.PyArray_FromAny(ptr, descr, 0, 0,
|
return api.PyArray_FromAny(ptr, descr, 0, 0,
|
||||||
API::API_NPY_C_CONTIGUOUS | API::API_NPY_ENSURE_ARRAY |
|
API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY |
|
||||||
API::API_NPY_NPY_ARRAY_FORCECAST, nullptr);
|
API::NPY_NPY_ARRAY_FORCECAST, nullptr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -125,8 +139,19 @@ PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_
|
|||||||
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
|
||||||
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
|
||||||
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
|
||||||
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>)
|
||||||
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>)
|
||||||
|
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>)
|
||||||
NAMESPACE_END(detail)
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
|
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
|
||||||
|
DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
|
||||||
|
DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
|
||||||
|
DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
|
||||||
|
DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
|
||||||
|
DECL_FMT(std::complex<double>, NPY_CDOUBLE);
|
||||||
|
#undef DECL_FMT
|
||||||
|
|
||||||
template <typename func_type, typename return_type, typename... args_type, size_t... Index>
|
template <typename func_type, typename return_type, typename... args_type, size_t... Index>
|
||||||
std::function<object(array_dtype<args_type>...)>
|
std::function<object(array_dtype<args_type>...)>
|
||||||
vectorize(func_type &&f, return_type (*) (args_type ...),
|
vectorize(func_type &&f, return_type (*) (args_type ...),
|
||||||
@ -171,7 +196,7 @@ template <typename func_type, typename return_type, typename... args_type, size_
|
|||||||
return cast(result[0]);
|
return cast(result[0]);
|
||||||
|
|
||||||
/* Return the result */
|
/* Return the result */
|
||||||
return array(buffer_info(result.data(), sizeof(return_type),
|
return array(buffer_info(result.data(), sizeof(return_type),
|
||||||
format_descriptor<return_type>::value(),
|
format_descriptor<return_type>::value(),
|
||||||
ndim, shape, strides));
|
ndim, shape, strides));
|
||||||
};
|
};
|
||||||
|
@ -393,22 +393,27 @@ protected:
|
|||||||
Py_TYPE(self)->tp_free((PyObject*) self);
|
Py_TYPE(self)->tp_free((PyObject*) self);
|
||||||
}
|
}
|
||||||
|
|
||||||
void install_buffer_funcs(const std::function<buffer_info *(PyObject *)> &func) {
|
void install_buffer_funcs(
|
||||||
|
buffer_info *(*get_buffer)(PyObject *, void *),
|
||||||
|
void *get_buffer_data) {
|
||||||
PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr;
|
PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr;
|
||||||
type->ht_type.tp_as_buffer = &type->as_buffer;
|
type->ht_type.tp_as_buffer = &type->as_buffer;
|
||||||
type->as_buffer.bf_getbuffer = getbuffer;
|
type->as_buffer.bf_getbuffer = getbuffer;
|
||||||
type->as_buffer.bf_releasebuffer = releasebuffer;
|
type->as_buffer.bf_releasebuffer = releasebuffer;
|
||||||
((detail::type_info *) capsule(attr("__pybind__")))->get_buffer = func;
|
auto info = ((detail::type_info *) capsule(attr("__pybind__")));
|
||||||
|
info->get_buffer = get_buffer;
|
||||||
|
info->get_buffer_data = get_buffer_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int getbuffer(PyObject *obj, Py_buffer *view, int flags) {
|
static int getbuffer(PyObject *obj, Py_buffer *view, int flags) {
|
||||||
auto const &info_func = ((detail::type_info *) capsule(handle(obj).attr("__pybind__")))->get_buffer;
|
auto const &typeinfo = ((detail::type_info *) capsule(handle(obj).attr("__pybind__")));
|
||||||
if (view == nullptr || obj == nullptr || !info_func) {
|
|
||||||
|
if (view == nullptr || obj == nullptr || !typeinfo || !typeinfo->get_buffer) {
|
||||||
PyErr_SetString(PyExc_BufferError, "Internal error");
|
PyErr_SetString(PyExc_BufferError, "Internal error");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
memset(view, 0, sizeof(Py_buffer));
|
memset(view, 0, sizeof(Py_buffer));
|
||||||
buffer_info *info = info_func(obj);
|
buffer_info *info = typeinfo->get_buffer(obj, typeinfo->get_buffer_data);
|
||||||
view->obj = obj;
|
view->obj = obj;
|
||||||
view->ndim = 1;
|
view->ndim = 1;
|
||||||
view->internal = info;
|
view->internal = info;
|
||||||
@ -483,13 +488,16 @@ public:
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
class_& def_buffer(const std::function<buffer_info(type&)> &func) {
|
template <typename Func>
|
||||||
install_buffer_funcs([func](PyObject *obj) -> buffer_info* {
|
class_& def_buffer(Func &&func) {
|
||||||
|
struct capture { Func func; };
|
||||||
|
capture *ptr = new capture { std::forward<Func>(func) };
|
||||||
|
install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* {
|
||||||
detail::type_caster<type> caster;
|
detail::type_caster<type> caster;
|
||||||
if (!caster.load(obj, false))
|
if (!caster.load(obj, false))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return new buffer_info(func(caster));
|
return new buffer_info(((capture *) ptr)->func(caster));
|
||||||
});
|
}, ptr);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user