mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-21 07:59:17 +00:00
Switch away from typenums for numpy descriptors
This commit is contained in:
parent
a67c2b52e4
commit
fab02efb10
@ -91,9 +91,7 @@ public:
|
|||||||
|
|
||||||
template <typename Type> array(size_t size, const Type *ptr) {
|
template <typename Type> array(size_t size, const Type *ptr) {
|
||||||
API& api = lookup_api();
|
API& api = lookup_api();
|
||||||
PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor<Type>::typenum());
|
PyObject *descr = detail::npy_format_descriptor<Type>::descr();
|
||||||
if (descr == nullptr)
|
|
||||||
pybind11_fail("NumPy: unsupported buffer format!");
|
|
||||||
Py_intptr_t shape = (Py_intptr_t) size;
|
Py_intptr_t shape = (Py_intptr_t) size;
|
||||||
object tmp = object(api.PyArray_NewFromDescr_(
|
object tmp = object(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||||
@ -128,6 +126,8 @@ protected:
|
|||||||
static API api = API::lookup();
|
static API api = API::lookup();
|
||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||||
@ -140,7 +140,7 @@ public:
|
|||||||
if (ptr == nullptr)
|
if (ptr == nullptr)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
API &api = lookup_api();
|
API &api = lookup_api();
|
||||||
PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor<T>::typenum());
|
PyObject *descr = detail::npy_format_descriptor<T>::descr();
|
||||||
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||||
if (!result)
|
if (!result)
|
||||||
PyErr_Clear();
|
PyErr_Clear();
|
||||||
@ -158,6 +158,10 @@ private:
|
|||||||
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
|
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
|
||||||
public:
|
public:
|
||||||
static int typenum() { return values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)]; }
|
static int typenum() { return values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)]; }
|
||||||
|
static PyObject* descr() {
|
||||||
|
if (auto obj = array::lookup_api().PyArray_DescrFromType_(typenum())) return obj;
|
||||||
|
else pybind11_fail("Unsupported buffer format!");
|
||||||
|
}
|
||||||
template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0>
|
template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0>
|
||||||
static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
|
static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
|
||||||
template <typename T2 = T, typename std::enable_if<!std::is_signed<T2>::value, int>::type = 0>
|
template <typename T2 = T, typename std::enable_if<!std::is_signed<T2>::value, int>::type = 0>
|
||||||
@ -167,7 +171,11 @@ template <typename T> constexpr const int npy_format_descriptor<
|
|||||||
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
|
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
|
||||||
|
|
||||||
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
||||||
static int typenum() { return array::API::NumPyName; } \
|
static int typenum() { return array::API::NumPyName; } \
|
||||||
|
static PyObject* descr() { \
|
||||||
|
if (auto obj = array::lookup_api().PyArray_DescrFromType_(typenum())) return obj; \
|
||||||
|
else pybind11_fail("Unsupported buffer format!"); \
|
||||||
|
} \
|
||||||
static PYBIND11_DESCR name() { return _(Name); } }
|
static PYBIND11_DESCR name() { return _(Name); } }
|
||||||
DECL_FMT(float, NPY_FLOAT_, "float32");
|
DECL_FMT(float, NPY_FLOAT_, "float32");
|
||||||
DECL_FMT(double, NPY_DOUBLE_, "float64");
|
DECL_FMT(double, NPY_DOUBLE_, "float64");
|
||||||
|
Loading…
Reference in New Issue
Block a user