mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-24 22:25:10 +00:00
Better NumPy support
This commit is contained in:
parent
bd4a529319
commit
2ac80e77aa
@ -17,7 +17,7 @@ become an excessively large and unnecessary dependency.
|
||||
|
||||
Think of this library as a tiny self-contained version of Boost.Python with
|
||||
everything stripped away that isn't relevant for binding generation. The whole
|
||||
codebase requires less than 2000 lines of code and just depends on Python and
|
||||
codebase requires just over 2000 lines of code and just depends on Python and
|
||||
the C++ standard library. This compact implementation was possible thanks to
|
||||
some of the new C++11 language features (tuples, lambda functions and variadic
|
||||
templates), and by only targeting Python 3.x and higher.
|
||||
|
@ -206,6 +206,22 @@ public:
|
||||
TYPE_CASTER(std::string, "str");
|
||||
};
|
||||
|
||||
#ifdef HAVE_WCHAR_H
|
||||
template <> class type_caster<std::wstring> {
|
||||
public:
|
||||
bool load(PyObject *src, bool) {
|
||||
const wchar_t *ptr = PyUnicode_AsWideCharString(src, nullptr);
|
||||
if (!ptr) { PyErr_Clear(); return false; }
|
||||
value = std::wstring(ptr);
|
||||
return true;
|
||||
}
|
||||
static PyObject *cast(const std::wstring &src, return_value_policy /* policy */, PyObject * /* parent */) {
|
||||
return PyUnicode_FromWideChar(src.c_str(), src.length());
|
||||
}
|
||||
TYPE_CASTER(std::wstring, "wstr");
|
||||
};
|
||||
#endif
|
||||
|
||||
template <> class type_caster<char> {
|
||||
public:
|
||||
bool load(PyObject *src, bool) {
|
||||
@ -474,6 +490,7 @@ TYPE_CASTER_PYTYPE(list)
|
||||
TYPE_CASTER_PYTYPE(slice)
|
||||
TYPE_CASTER_PYTYPE(tuple)
|
||||
TYPE_CASTER_PYTYPE(function)
|
||||
TYPE_CASTER_PYTYPE(array)
|
||||
|
||||
#undef TYPE_CASTER
|
||||
#undef TYPE_CASTER_PYTYPE
|
||||
|
@ -132,7 +132,10 @@ private:
|
||||
entry = backup;
|
||||
}
|
||||
std::string signatures;
|
||||
int it = 0;
|
||||
while (entry) { /* Create pydoc entry */
|
||||
if (sibling.ptr())
|
||||
signatures += std::to_string(++it) + ". ";
|
||||
signatures += "Signature : " + std::string(entry->signature) + "\n";
|
||||
if (!entry->doc.empty())
|
||||
signatures += "\n" + std::string(entry->doc) + "\n";
|
||||
|
@ -322,6 +322,88 @@ private:
|
||||
Py_buffer *view = nullptr;
|
||||
};
|
||||
|
||||
class array : public buffer {
|
||||
protected:
|
||||
struct API {
|
||||
enum Entries {
|
||||
API_PyArray_Type = 2,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94
|
||||
};
|
||||
|
||||
static API lookup() {
|
||||
PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
|
||||
PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
|
||||
void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
|
||||
Py_XDECREF(capsule);
|
||||
Py_XDECREF(numpy);
|
||||
if (api_ptr == nullptr)
|
||||
throw std::runtime_error("Could not acquire pointer to NumPy API!");
|
||||
API api;
|
||||
api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
|
||||
api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
|
||||
api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
|
||||
api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
|
||||
return api;
|
||||
}
|
||||
|
||||
bool PyArray_Check(PyObject *obj) const {
|
||||
return (bool) PyObject_TypeCheck(obj, PyArray_Type);
|
||||
}
|
||||
|
||||
PyObject *(*PyArray_DescrFromType)(int);
|
||||
PyObject *(*PyArray_NewFromDescr)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||
Py_intptr_t *, void *, int, PyObject *);
|
||||
PyObject *(*PyArray_NewCopy)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type;
|
||||
};
|
||||
public:
|
||||
PYTHON_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
API& api = lookup_api();
|
||||
PyObject *descr = api.PyArray_DescrFromType(
|
||||
(int) format_descriptor<Type>::value()[0]);
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||
api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
API& api = lookup_api();
|
||||
if (info.format.size() != 1)
|
||||
throw std::runtime_error("Unsupported buffer format!");
|
||||
PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
|
||||
if (descr == nullptr)
|
||||
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||
api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
|
||||
if (tmp == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to create array!");
|
||||
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||
Py_DECREF(tmp);
|
||||
if (m_ptr == nullptr)
|
||||
throw std::runtime_error("NumPy: unable to copy array!");
|
||||
}
|
||||
protected:
|
||||
static API &lookup_api() {
|
||||
static API api = API::lookup();
|
||||
return api;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
inline internals &get_internals() {
|
||||
static internals *internals_ptr = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user