diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 019f5688e..d1ddc4a9d 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -507,11 +507,21 @@ public: return detail::array_descriptor_proxy(m_ptr)->names != nullptr; } - /// Single-character type code. + /// Single-character code for dtype's kind. + /// For example, floating point types are 'f' and integral types are 'i'. char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; } + /// Single-character for dtype's type. + /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'd'. + char char_() const { + // Note: The signature, `dtype::char_` follows the naming of NumPy's + // public Python API (i.e., ``dtype.char``), rather than its internal + // C API (``PyArray_Descr::type``). + return detail::array_descriptor_proxy(m_ptr)->type; + } + private: static object _dtype_from_pep3118() { static PyObject *obj = module_::import("numpy.core._internal") diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index b2e5e6075..9dece73ec 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -358,6 +358,14 @@ TEST_SUBMODULE(numpy_dtypes, m) { }); // test_dtype + std::vector dtype_names{ + "byte", "short", "intc", "int_", "longlong", + "ubyte", "ushort", "uintc", "uint", "ulonglong", + "half", "single", "double", "longdouble", + "csingle", "cdouble", "clongdouble", + "bool_", "datetime64", "timedelta64", "object_" + }; + m.def("print_dtypes", []() { py::list l; for (const py::handle &d : { @@ -376,6 +384,18 @@ TEST_SUBMODULE(numpy_dtypes, m) { return l; }); m.def("test_dtype_ctors", &test_dtype_ctors); + m.def("test_dtype_kind", [dtype_names]() { + py::list list; + for (auto& dt_name : dtype_names) + list.append(py::dtype(dt_name).kind()); + return list; + }); + m.def("test_dtype_char_", [dtype_names]() { + py::list list; + for (auto& dt_name : dtype_names) + list.append(py::dtype(dt_name).char_()); + return list; + }); m.def("test_dtype_methods", []() { py::list list; auto dt1 = py::dtype::of(); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index f56b776a4..0a5881e49 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -169,6 +169,9 @@ def test_dtype(simple_dtype): np.zeros(1, m.trailing_padding_dtype()) ) + assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO") + assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO") + def test_recarray(simple_dtype, packed_dtype): elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]