mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-17 06:00:51 +00:00
Add NumPy Scalar.
This commit is contained in:
parent
a224d0cca5
commit
e493242b3b
@ -37,6 +37,8 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
class array; // Forward declaration
|
||||
|
||||
template<typename> struct numpy_scalar; // Forward declaration
|
||||
|
||||
PYBIND11_NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <> struct handle_type_name<array> { static constexpr auto name = _("numpy.ndarray"); };
|
||||
@ -110,16 +112,12 @@ inline numpy_internals& get_numpy_internals() {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <typename T> struct same_size {
|
||||
template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
|
||||
};
|
||||
|
||||
template <typename Concrete> constexpr int platform_lookup() { return -1; }
|
||||
template <std::size_t> constexpr int platform_lookup() { return -1; }
|
||||
|
||||
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
|
||||
template <typename Concrete, typename T, typename... Ts, typename... Ints>
|
||||
template <std::size_t size, typename T, typename... Ts, typename... Ints>
|
||||
constexpr int platform_lookup(int I, Ints... Is) {
|
||||
return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
|
||||
return sizeof(size) == sizeof(T) ? I : platform_lookup<size, Ts...>(Is...);
|
||||
}
|
||||
|
||||
struct npy_api {
|
||||
@ -149,14 +147,23 @@ struct npy_api {
|
||||
// `npy_common.h` defines the integer aliases. In order, it checks:
|
||||
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
|
||||
// and assigns the alias to the first matching size, so we should check in this order.
|
||||
NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
|
||||
NPY_INT32_ = platform_lookup<4, long, int, short>(
|
||||
NPY_LONG_, NPY_INT_, NPY_SHORT_),
|
||||
NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
|
||||
NPY_UINT32_ = platform_lookup<4, unsigned long, unsigned int, unsigned short>(
|
||||
NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
|
||||
NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
|
||||
NPY_INT64_ = platform_lookup<8, long, long long, int>(
|
||||
NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
|
||||
NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
|
||||
NPY_UINT64_ = platform_lookup<8, unsigned long, unsigned long long, unsigned int>(
|
||||
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
|
||||
NPY_FLOAT32_ = platform_lookup<4, double, float, long double>(
|
||||
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
|
||||
NPY_FLOAT64_ = platform_lookup<8, double, float, long double>(
|
||||
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
|
||||
NPY_COMPLEX64_ = platform_lookup<8, double, float, long double>(
|
||||
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
|
||||
NPY_COMPLEX128_ = platform_lookup<8, double, float, long double>(
|
||||
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
|
||||
NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
|
||||
};
|
||||
|
||||
struct PyArray_Dims {
|
||||
@ -178,6 +185,7 @@ struct npy_api {
|
||||
|
||||
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
|
||||
PyObject *(*PyArray_DescrFromType_)(int);
|
||||
PyObject *(*PyArray_TypeObjectFromType_)(int);
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t const *,
|
||||
Py_intptr_t const *, void *, int, PyObject *);
|
||||
@ -189,6 +197,8 @@ struct npy_api {
|
||||
PyTypeObject *PyVoidArrType_Type_;
|
||||
PyTypeObject *PyArrayDescr_Type_;
|
||||
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
|
||||
PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
|
||||
void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
@ -208,7 +218,10 @@ private:
|
||||
API_PyArrayDescr_Type = 3,
|
||||
API_PyVoidArrType_Type = 39,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_TypeObjectFromType = 46,
|
||||
API_PyArray_DescrFromScalar = 57,
|
||||
API_PyArray_Scalar = 60,
|
||||
API_PyArray_ScalarAsCtype = 62,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_Resize = 80,
|
||||
API_PyArray_CopyInto = 82,
|
||||
@ -241,7 +254,10 @@ private:
|
||||
DECL_NPY_API(PyVoidArrType_Type);
|
||||
DECL_NPY_API(PyArrayDescr_Type);
|
||||
DECL_NPY_API(PyArray_DescrFromType);
|
||||
DECL_NPY_API(PyArray_TypeObjectFromType);
|
||||
DECL_NPY_API(PyArray_DescrFromScalar);
|
||||
DECL_NPY_API(PyArray_Scalar);
|
||||
DECL_NPY_API(PyArray_ScalarAsCtype);
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_Resize);
|
||||
DECL_NPY_API(PyArray_CopyInto);
|
||||
@ -261,6 +277,74 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct is_complex : std::false_type { };
|
||||
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct npy_format_descriptor_name;
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, bool>::value>(
|
||||
_("bool"), _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>()
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
|
||||
_("float") + _<sizeof(T)*8>(), _("longdouble")
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
|
||||
|| std::is_same<typename T::value_type, double>::value>(
|
||||
_("complex") + _<sizeof(typename T::value_type)*16>(), _("longcomplex")
|
||||
);
|
||||
};
|
||||
|
||||
template<typename T> struct numpy_scalar_info {};
|
||||
|
||||
#define DECL_NPY_SCALAR(ctype_, typenum_) \
|
||||
template<> struct numpy_scalar_info<ctype_> { \
|
||||
static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
|
||||
static constexpr int typenum = npy_api::typenum_##_; \
|
||||
}
|
||||
|
||||
// boolean type
|
||||
DECL_NPY_SCALAR(bool, NPY_BOOL);
|
||||
|
||||
// character types
|
||||
DECL_NPY_SCALAR(char, NPY_CHAR);
|
||||
DECL_NPY_SCALAR(signed char, NPY_BYTE);
|
||||
DECL_NPY_SCALAR(unsigned char, NPY_UBYTE);
|
||||
|
||||
// signed integer types
|
||||
DECL_NPY_SCALAR(short, NPY_SHORT);
|
||||
DECL_NPY_SCALAR(int, NPY_INT);
|
||||
DECL_NPY_SCALAR(long, NPY_LONG);
|
||||
DECL_NPY_SCALAR(long long, NPY_LONGLONG);
|
||||
|
||||
// unsigned integer types
|
||||
DECL_NPY_SCALAR(unsigned short, NPY_USHORT);
|
||||
DECL_NPY_SCALAR(unsigned int, NPY_UINT);
|
||||
DECL_NPY_SCALAR(unsigned long, NPY_ULONG);
|
||||
DECL_NPY_SCALAR(unsigned long long, NPY_ULONGLONG);
|
||||
|
||||
// floating point types
|
||||
DECL_NPY_SCALAR(float, NPY_FLOAT);
|
||||
DECL_NPY_SCALAR(double, NPY_DOUBLE);
|
||||
DECL_NPY_SCALAR(long double, NPY_LONGDOUBLE);
|
||||
|
||||
// complex types
|
||||
DECL_NPY_SCALAR(std::complex<float>, NPY_CFLOAT);
|
||||
DECL_NPY_SCALAR(std::complex<double>, NPY_CDOUBLE);
|
||||
DECL_NPY_SCALAR(std::complex<long double>, NPY_CLONGDOUBLE);
|
||||
|
||||
#undef DECL_NPY_SCALAR
|
||||
|
||||
inline PyArray_Proxy* array_proxy(void* ptr) {
|
||||
return reinterpret_cast<PyArray_Proxy*>(ptr);
|
||||
}
|
||||
@ -283,8 +367,6 @@ inline bool check_flags(const void* ptr, int flag) {
|
||||
|
||||
template <typename T> struct is_std_array : std::false_type { };
|
||||
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
||||
template <typename T> struct is_complex : std::false_type { };
|
||||
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
||||
|
||||
template <typename T> struct array_info_scalar {
|
||||
using type = T;
|
||||
@ -459,8 +541,56 @@ struct type_caster<unchecked_reference<T, Dim>> {
|
||||
template <typename T, ssize_t Dim>
|
||||
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
|
||||
|
||||
template<typename T>
|
||||
struct type_caster<numpy_scalar<T>> {
|
||||
using value_type = T;
|
||||
using type_info = numpy_scalar_info<T>;
|
||||
|
||||
PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
|
||||
|
||||
static handle& target_type() {
|
||||
static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
|
||||
return tp;
|
||||
}
|
||||
|
||||
static handle& target_dtype() {
|
||||
static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
|
||||
return tp;
|
||||
}
|
||||
|
||||
bool load(handle src, bool) {
|
||||
if (isinstance(src, target_type())) {
|
||||
npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
|
||||
return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
PYBIND11_NAMESPACE_END(detail)
|
||||
|
||||
template<typename T>
|
||||
struct numpy_scalar {
|
||||
using value_type = T;
|
||||
|
||||
value_type value;
|
||||
|
||||
numpy_scalar() = default;
|
||||
numpy_scalar(value_type value) : value(value) {}
|
||||
|
||||
operator value_type() { return value; }
|
||||
numpy_scalar& operator=(value_type value) { this->value = value; return *this; }
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
numpy_scalar<T> make_scalar(T value) {
|
||||
return numpy_scalar<T>(value);
|
||||
}
|
||||
|
||||
class dtype : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
||||
@ -1051,36 +1181,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct npy_format_descriptor_name;
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, bool>::value>(
|
||||
_("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<T, float>::value
|
||||
|| std::is_same<T, const float>::value
|
||||
|| std::is_same<T, double>::value
|
||||
|| std::is_same<T, const double>::value>(
|
||||
_("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
|
||||
static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
|
||||
|| std::is_same<typename T::value_type, const float>::value
|
||||
|| std::is_same<typename T::value_type, double>::value
|
||||
|| std::is_same<typename T::value_type, const double>::value>(
|
||||
_("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
|
||||
);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
|
||||
: npy_format_descriptor_name<T> {
|
||||
|
Loading…
Reference in New Issue
Block a user