diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index a0441efa3..e67d37151 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -108,6 +109,18 @@ inline numpy_internals& get_numpy_internals() { return *ptr; } +template struct same_size { + template using as = bool_constant; +}; + +// Lookup a type according to its size, and return a value corresponding to the NumPy typenum. +template +constexpr int platform_lookup(Int... codes) { + using code_index = std::integral_constant::template as, Check...>()>; + static_assert(code_index::value != sizeof...(Check), "Unable to match type on this platform"); + return std::get(std::make_tuple(codes...)); +} + struct npy_api { enum constants { NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, @@ -126,7 +139,23 @@ struct npy_api { NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, NPY_OBJECT_ = 17, - NPY_STRING_, NPY_UNICODE_, NPY_VOID_ + NPY_STRING_, NPY_UNICODE_, NPY_VOID_, + // Platform-dependent normalization + NPY_INT8_ = NPY_BYTE_, + NPY_UINT8_ = NPY_UBYTE_, + NPY_INT16_ = NPY_SHORT_, + NPY_UINT16_ = NPY_USHORT_, + // `npy_common.h` defines the integer aliases. In order, it checks: + // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR + // and assigns the alias to the first matching size, so we should check in this order. + NPY_INT32_ = platform_lookup( + NPY_LONG_, NPY_INT_, NPY_SHORT_), + NPY_UINT32_ = platform_lookup( + NPY_ULONG_, NPY_UINT_, NPY_USHORT_), + NPY_INT64_ = platform_lookup( + NPY_LONG_, NPY_LONGLONG_, NPY_INT_), + NPY_UINT64_ = platform_lookup( + NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_), }; typedef struct { @@ -1004,8 +1033,8 @@ private: // NB: the order here must match the one in common.h constexpr static const int values[15] = { npy_api::NPY_BOOL_, - npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, - npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_, + npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_INT16_, npy_api::NPY_UINT16_, + npy_api::NPY_INT32_, npy_api::NPY_UINT32_, npy_api::NPY_INT64_, npy_api::NPY_UINT64_, npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ };