Better NumPy support

This commit is contained in:
Wenzel Jakob 2015-07-22 00:59:01 +02:00
parent bd4a529319
commit 2ac80e77aa
4 changed files with 103 additions and 1 deletions

View File

@ -17,7 +17,7 @@ become an excessively large and unnecessary dependency.
Think of this library as a tiny self-contained version of Boost.Python with
everything stripped away that isn't relevant for binding generation. The whole
codebase requires less than 2000 lines of code and just depends on Python and
codebase requires just over 2000 lines of code and just depends on Python and
the C++ standard library. This compact implementation was possible thanks to
some of the new C++11 language features (tuples, lambda functions and variadic
templates), and by only targeting Python 3.x and higher.

View File

@ -206,6 +206,22 @@ public:
TYPE_CASTER(std::string, "str");
};
#ifdef HAVE_WCHAR_H
template <> class type_caster<std::wstring> {
public:
bool load(PyObject *src, bool) {
const wchar_t *ptr = PyUnicode_AsWideCharString(src, nullptr);
if (!ptr) { PyErr_Clear(); return false; }
value = std::wstring(ptr);
return true;
}
static PyObject *cast(const std::wstring &src, return_value_policy /* policy */, PyObject * /* parent */) {
return PyUnicode_FromWideChar(src.c_str(), src.length());
}
TYPE_CASTER(std::wstring, "wstr");
};
#endif
template <> class type_caster<char> {
public:
bool load(PyObject *src, bool) {
@ -474,6 +490,7 @@ TYPE_CASTER_PYTYPE(list)
TYPE_CASTER_PYTYPE(slice)
TYPE_CASTER_PYTYPE(tuple)
TYPE_CASTER_PYTYPE(function)
TYPE_CASTER_PYTYPE(array)
#undef TYPE_CASTER
#undef TYPE_CASTER_PYTYPE

View File

@ -132,7 +132,10 @@ private:
entry = backup;
}
std::string signatures;
int it = 0;
while (entry) { /* Create pydoc entry */
if (sibling.ptr())
signatures += std::to_string(++it) + ". ";
signatures += "Signature : " + std::string(entry->signature) + "\n";
if (!entry->doc.empty())
signatures += "\n" + std::string(entry->doc) + "\n";

View File

@ -322,6 +322,88 @@ private:
Py_buffer *view = nullptr;
};
class array : public buffer {
protected:
struct API {
enum Entries {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94
};
static API lookup() {
PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
Py_XDECREF(capsule);
Py_XDECREF(numpy);
if (api_ptr == nullptr)
throw std::runtime_error("Could not acquire pointer to NumPy API!");
API api;
api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
return api;
}
bool PyArray_Check(PyObject *obj) const {
return (bool) PyObject_TypeCheck(obj, PyArray_Type);
}
PyObject *(*PyArray_DescrFromType)(int);
PyObject *(*PyArray_NewFromDescr)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_NewCopy)(PyObject *, int);
PyTypeObject *PyArray_Type;
};
public:
PYTHON_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api();
PyObject *descr = api.PyArray_DescrFromType(
(int) format_descriptor<Type>::value()[0]);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format!");
Py_intptr_t shape = (Py_intptr_t) size;
PyObject *tmp = api.PyArray_NewFromDescr(
api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
if (tmp == nullptr)
throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
}
array(const buffer_info &info) {
API& api = lookup_api();
if (info.format.size() != 1)
throw std::runtime_error("Unsupported buffer format!");
PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
PyObject *tmp = api.PyArray_NewFromDescr(
api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
if (tmp == nullptr)
throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
}
protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
};
NAMESPACE_BEGIN(detail)
inline internals &get_internals() {
static internals *internals_ptr = nullptr;