diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 228e02c3d..ab224e1f1 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -212,6 +212,7 @@ constexpr int platform_lookup(int I, Ints... Is) { } struct npy_api { + // If you change this code, please review `normalized_dtype_num` below. enum constants { NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, @@ -384,6 +385,74 @@ private: } }; +// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ... +// This is needed to correctly handle situations where multiple typenums map to the same type, +// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different +// typenum. The normalized typenum should always match the values used in npy_format_descriptor. +// If you change this code, please review `enum constants` above. +static constexpr int normalized_dtype_num[npy_api::NPY_VOID_ + 1] = { + // NPY_BOOL_ => + npy_api::NPY_BOOL_, + // NPY_BYTE_ => + npy_api::NPY_BYTE_, + // NPY_UBYTE_ => + npy_api::NPY_UBYTE_, + // NPY_SHORT_ => + npy_api::NPY_INT16_, + // NPY_USHORT_ => + npy_api::NPY_UINT16_, + // NPY_INT_ => + sizeof(int) == sizeof(std::int16_t) ? npy_api::NPY_INT16_ + : sizeof(int) == sizeof(std::int32_t) ? npy_api::NPY_INT32_ + : sizeof(int) == sizeof(std::int64_t) ? npy_api::NPY_INT64_ + : npy_api::NPY_INT_, + // NPY_UINT_ => + sizeof(unsigned int) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_ + : sizeof(unsigned int) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_ + : sizeof(unsigned int) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_ + : npy_api::NPY_UINT_, + // NPY_LONG_ => + sizeof(long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_ + : sizeof(long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_ + : sizeof(long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_ + : npy_api::NPY_LONG_, + // NPY_ULONG_ => + sizeof(unsigned long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_ + : sizeof(unsigned long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_ + : sizeof(unsigned long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_ + : npy_api::NPY_ULONG_, + // NPY_LONGLONG_ => + sizeof(long long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_ + : sizeof(long long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_ + : sizeof(long long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_ + : npy_api::NPY_LONGLONG_, + // NPY_ULONGLONG_ => + sizeof(unsigned long long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_ + : sizeof(unsigned long long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_ + : sizeof(unsigned long long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_ + : npy_api::NPY_ULONGLONG_, + // NPY_FLOAT_ => + npy_api::NPY_FLOAT_, + // NPY_DOUBLE_ => + npy_api::NPY_DOUBLE_, + // NPY_LONGDOUBLE_ => + npy_api::NPY_LONGDOUBLE_, + // NPY_CFLOAT_ => + npy_api::NPY_CFLOAT_, + // NPY_CDOUBLE_ => + npy_api::NPY_CDOUBLE_, + // NPY_CLONGDOUBLE_ => + npy_api::NPY_CLONGDOUBLE_, + // NPY_OBJECT_ => + npy_api::NPY_OBJECT_, + // NPY_STRING_ => + npy_api::NPY_STRING_, + // NPY_UNICODE_ => + npy_api::NPY_UNICODE_, + // NPY_VOID_ => + npy_api::NPY_VOID_, +}; + inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast(ptr); } inline const PyArray_Proxy *array_proxy(const void *ptr) { @@ -684,6 +753,13 @@ public: return detail::npy_format_descriptor::type>::dtype(); } + /// Return the type number associated with a C++ type. + /// This is the constexpr equivalent of `dtype::of().num()`. + template + static constexpr int num_of() { + return detail::npy_format_descriptor::type>::value; + } + /// Size of the data type in bytes. #ifdef PYBIND11_NUMPY_1_ONLY ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; } @@ -725,7 +801,9 @@ public: return detail::array_descriptor_proxy(m_ptr)->type; } - /// type number of dtype. + /// Type number of dtype. Note that different values may be returned for equivalent types, + /// e.g. even though ``long`` may be equivalent to ``int`` or ``long long``, they still have + /// different type numbers. Consider using `normalized_num` to avoid this. int num() const { // Note: The signature, `dtype::num` follows the naming of NumPy's public // Python API (i.e., ``dtype.num``), rather than its internal @@ -733,6 +811,17 @@ public: return detail::array_descriptor_proxy(m_ptr)->type_num; } + /// Type number of dtype, normalized to match the return value of `num_of` for equivalent + /// types. This function can be used to write switch statements that correctly handle + /// equivalent types with different type numbers. + int normalized_num() const { + int value = num(); + if (value >= 0 && value <= detail::npy_api::NPY_VOID_) { + return detail::normalized_dtype_num[value]; + } + return value; + } + /// Single character for byteorder char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; } diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index 596d90274..b6db439d9 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -11,6 +11,9 @@ #include "pybind11_tests.h" +#include +#include + #ifdef __GNUC__ # define PYBIND11_PACKED(cls) cls __attribute__((__packed__)) #else @@ -297,6 +300,15 @@ py::list test_dtype_ctors() { return list; } +template +py::array_t dispatch_array_increment(py::array_t arr) { + py::array_t res(arr.shape(0)); + for (py::ssize_t i = 0; i < arr.shape(0); ++i) { + res.mutable_at(i) = T(arr.at(i) + 1); + } + return res; +} + struct A {}; struct B {}; @@ -496,6 +508,98 @@ TEST_SUBMODULE(numpy_dtypes, m) { } return list; }); + m.def("test_dtype_num_of", []() -> py::list { + py::list res; +#define TEST_DTYPE(T) res.append(py::make_tuple(py::dtype::of().num(), py::dtype::num_of())); + TEST_DTYPE(bool) + TEST_DTYPE(char) + TEST_DTYPE(unsigned char) + TEST_DTYPE(short) + TEST_DTYPE(unsigned short) + TEST_DTYPE(int) + TEST_DTYPE(unsigned int) + TEST_DTYPE(long) + TEST_DTYPE(unsigned long) + TEST_DTYPE(long long) + TEST_DTYPE(unsigned long long) + TEST_DTYPE(float) + TEST_DTYPE(double) + TEST_DTYPE(long double) + TEST_DTYPE(std::complex) + TEST_DTYPE(std::complex) + TEST_DTYPE(std::complex) + TEST_DTYPE(int8_t) + TEST_DTYPE(uint8_t) + TEST_DTYPE(int16_t) + TEST_DTYPE(uint16_t) + TEST_DTYPE(int32_t) + TEST_DTYPE(uint32_t) + TEST_DTYPE(int64_t) + TEST_DTYPE(uint64_t) +#undef TEST_DTYPE + return res; + }); + m.def("test_dtype_normalized_num", []() -> py::list { + py::list res; +#define TEST_DTYPE(NT, T) \ + res.append(py::make_tuple(py::dtype(py::detail::npy_api::NT).normalized_num(), \ + py::dtype::num_of())); + TEST_DTYPE(NPY_BOOL_, bool) + TEST_DTYPE(NPY_BYTE_, char); + TEST_DTYPE(NPY_UBYTE_, unsigned char); + TEST_DTYPE(NPY_SHORT_, short); + TEST_DTYPE(NPY_USHORT_, unsigned short); + TEST_DTYPE(NPY_INT_, int); + TEST_DTYPE(NPY_UINT_, unsigned int); + TEST_DTYPE(NPY_LONG_, long); + TEST_DTYPE(NPY_ULONG_, unsigned long); + TEST_DTYPE(NPY_LONGLONG_, long long); + TEST_DTYPE(NPY_ULONGLONG_, unsigned long long); + TEST_DTYPE(NPY_FLOAT_, float); + TEST_DTYPE(NPY_DOUBLE_, double); + TEST_DTYPE(NPY_LONGDOUBLE_, long double); + TEST_DTYPE(NPY_CFLOAT_, std::complex); + TEST_DTYPE(NPY_CDOUBLE_, std::complex); + TEST_DTYPE(NPY_CLONGDOUBLE_, std::complex); + TEST_DTYPE(NPY_INT8_, int8_t); + TEST_DTYPE(NPY_UINT8_, uint8_t); + TEST_DTYPE(NPY_INT16_, int16_t); + TEST_DTYPE(NPY_UINT16_, uint16_t); + TEST_DTYPE(NPY_INT32_, int32_t); + TEST_DTYPE(NPY_UINT32_, uint32_t); + TEST_DTYPE(NPY_INT64_, int64_t); + TEST_DTYPE(NPY_UINT64_, uint64_t); +#undef TEST_DTYPE + return res; + }); + m.def("test_dtype_switch", [](const py::array &arr) -> py::array { + switch (arr.dtype().normalized_num()) { + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + case py::dtype::num_of(): + return dispatch_array_increment(arr); + default: + throw std::runtime_error("Unsupported dtype"); + } + }); m.def("test_dtype_methods", []() { py::list list; auto dt1 = py::dtype::of(); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 8ae239ed8..5d839933c 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -188,6 +188,28 @@ def test_dtype(simple_dtype): chr(np.dtype(ch).flags) for ch in expected_chars ] + for a, b in m.test_dtype_num_of(): + assert a == b + + for a, b in m.test_dtype_normalized_num(): + assert a == b + + arr = np.array([4, 84, 21, 36]) + # Note: "ulong" does not work in NumPy 1.x, so we use "L" + assert (m.test_dtype_switch(arr.astype("byte")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("ubyte")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("short")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("ushort")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("intc")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("uintc")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("long")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("L")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("longlong")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("ulonglong")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("single")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("double")) == arr + 1).all() + assert (m.test_dtype_switch(arr.astype("longdouble")) == arr + 1).all() + def test_recarray(simple_dtype, packed_dtype): elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]