mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 06:35:12 +00:00
Better NumPy support
This commit is contained in:
parent
bd4a529319
commit
2ac80e77aa
@ -17,7 +17,7 @@ become an excessively large and unnecessary dependency.
|
|||||||
|
|
||||||
Think of this library as a tiny self-contained version of Boost.Python with
|
Think of this library as a tiny self-contained version of Boost.Python with
|
||||||
everything stripped away that isn't relevant for binding generation. The whole
|
everything stripped away that isn't relevant for binding generation. The whole
|
||||||
codebase requires less than 2000 lines of code and just depends on Python and
|
codebase requires just over 2000 lines of code and just depends on Python and
|
||||||
the C++ standard library. This compact implementation was possible thanks to
|
the C++ standard library. This compact implementation was possible thanks to
|
||||||
some of the new C++11 language features (tuples, lambda functions and variadic
|
some of the new C++11 language features (tuples, lambda functions and variadic
|
||||||
templates), and by only targeting Python 3.x and higher.
|
templates), and by only targeting Python 3.x and higher.
|
||||||
|
@ -206,6 +206,22 @@ public:
|
|||||||
TYPE_CASTER(std::string, "str");
|
TYPE_CASTER(std::string, "str");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef HAVE_WCHAR_H
|
||||||
|
template <> class type_caster<std::wstring> {
|
||||||
|
public:
|
||||||
|
bool load(PyObject *src, bool) {
|
||||||
|
const wchar_t *ptr = PyUnicode_AsWideCharString(src, nullptr);
|
||||||
|
if (!ptr) { PyErr_Clear(); return false; }
|
||||||
|
value = std::wstring(ptr);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static PyObject *cast(const std::wstring &src, return_value_policy /* policy */, PyObject * /* parent */) {
|
||||||
|
return PyUnicode_FromWideChar(src.c_str(), src.length());
|
||||||
|
}
|
||||||
|
TYPE_CASTER(std::wstring, "wstr");
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
template <> class type_caster<char> {
|
template <> class type_caster<char> {
|
||||||
public:
|
public:
|
||||||
bool load(PyObject *src, bool) {
|
bool load(PyObject *src, bool) {
|
||||||
@ -474,6 +490,7 @@ TYPE_CASTER_PYTYPE(list)
|
|||||||
TYPE_CASTER_PYTYPE(slice)
|
TYPE_CASTER_PYTYPE(slice)
|
||||||
TYPE_CASTER_PYTYPE(tuple)
|
TYPE_CASTER_PYTYPE(tuple)
|
||||||
TYPE_CASTER_PYTYPE(function)
|
TYPE_CASTER_PYTYPE(function)
|
||||||
|
TYPE_CASTER_PYTYPE(array)
|
||||||
|
|
||||||
#undef TYPE_CASTER
|
#undef TYPE_CASTER
|
||||||
#undef TYPE_CASTER_PYTYPE
|
#undef TYPE_CASTER_PYTYPE
|
||||||
|
@ -132,7 +132,10 @@ private:
|
|||||||
entry = backup;
|
entry = backup;
|
||||||
}
|
}
|
||||||
std::string signatures;
|
std::string signatures;
|
||||||
|
int it = 0;
|
||||||
while (entry) { /* Create pydoc entry */
|
while (entry) { /* Create pydoc entry */
|
||||||
|
if (sibling.ptr())
|
||||||
|
signatures += std::to_string(++it) + ". ";
|
||||||
signatures += "Signature : " + std::string(entry->signature) + "\n";
|
signatures += "Signature : " + std::string(entry->signature) + "\n";
|
||||||
if (!entry->doc.empty())
|
if (!entry->doc.empty())
|
||||||
signatures += "\n" + std::string(entry->doc) + "\n";
|
signatures += "\n" + std::string(entry->doc) + "\n";
|
||||||
|
@ -322,6 +322,88 @@ private:
|
|||||||
Py_buffer *view = nullptr;
|
Py_buffer *view = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class array : public buffer {
|
||||||
|
protected:
|
||||||
|
struct API {
|
||||||
|
enum Entries {
|
||||||
|
API_PyArray_Type = 2,
|
||||||
|
API_PyArray_DescrFromType = 45,
|
||||||
|
API_PyArray_NewCopy = 85,
|
||||||
|
API_PyArray_NewFromDescr = 94
|
||||||
|
};
|
||||||
|
|
||||||
|
static API lookup() {
|
||||||
|
PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
|
||||||
|
PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
|
||||||
|
void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
|
||||||
|
Py_XDECREF(capsule);
|
||||||
|
Py_XDECREF(numpy);
|
||||||
|
if (api_ptr == nullptr)
|
||||||
|
throw std::runtime_error("Could not acquire pointer to NumPy API!");
|
||||||
|
API api;
|
||||||
|
api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
|
||||||
|
api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
|
||||||
|
api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
|
||||||
|
api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
|
||||||
|
return api;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyArray_Check(PyObject *obj) const {
|
||||||
|
return (bool) PyObject_TypeCheck(obj, PyArray_Type);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *(*PyArray_DescrFromType)(int);
|
||||||
|
PyObject *(*PyArray_NewFromDescr)
|
||||||
|
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||||
|
Py_intptr_t *, void *, int, PyObject *);
|
||||||
|
PyObject *(*PyArray_NewCopy)(PyObject *, int);
|
||||||
|
PyTypeObject *PyArray_Type;
|
||||||
|
};
|
||||||
|
public:
|
||||||
|
PYTHON_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
|
||||||
|
|
||||||
|
template <typename Type> array(size_t size, const Type *ptr) {
|
||||||
|
API& api = lookup_api();
|
||||||
|
PyObject *descr = api.PyArray_DescrFromType(
|
||||||
|
(int) format_descriptor<Type>::value()[0]);
|
||||||
|
if (descr == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unsupported buffer format!");
|
||||||
|
Py_intptr_t shape = (Py_intptr_t) size;
|
||||||
|
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||||
|
api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
|
||||||
|
if (tmp == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unable to create array!");
|
||||||
|
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||||
|
Py_DECREF(tmp);
|
||||||
|
if (m_ptr == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unable to copy array!");
|
||||||
|
}
|
||||||
|
|
||||||
|
array(const buffer_info &info) {
|
||||||
|
API& api = lookup_api();
|
||||||
|
if (info.format.size() != 1)
|
||||||
|
throw std::runtime_error("Unsupported buffer format!");
|
||||||
|
PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
|
||||||
|
if (descr == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
|
||||||
|
PyObject *tmp = api.PyArray_NewFromDescr(
|
||||||
|
api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||||
|
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
|
||||||
|
if (tmp == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unable to create array!");
|
||||||
|
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
|
||||||
|
Py_DECREF(tmp);
|
||||||
|
if (m_ptr == nullptr)
|
||||||
|
throw std::runtime_error("NumPy: unable to copy array!");
|
||||||
|
}
|
||||||
|
protected:
|
||||||
|
static API &lookup_api() {
|
||||||
|
static API api = API::lookup();
|
||||||
|
return api;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
NAMESPACE_BEGIN(detail)
|
NAMESPACE_BEGIN(detail)
|
||||||
inline internals &get_internals() {
|
inline internals &get_internals() {
|
||||||
static internals *internals_ptr = nullptr;
|
static internals *internals_ptr = nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user