diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index dd7f43d16..ec4c53b43 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1029,7 +1029,10 @@ struct npy_format_descriptor_name::value>> { template struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value || std::is_same::value>( + static constexpr auto name = _::value + || std::is_same::value + || std::is_same::value + || std::is_same::value>( _("numpy.float") + _(), _("numpy.longdouble") ); }; @@ -1037,7 +1040,9 @@ struct npy_format_descriptor_name::valu template struct npy_format_descriptor_name::value>> { static constexpr auto name = _::value - || std::is_same::value>( + || std::is_same::value + || std::is_same::value + || std::is_same::value>( _("numpy.complex") + _(), _("numpy.longcomplex") ); }; diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index dca7145f9..204ea8367 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -437,4 +437,10 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("accept_double_f_style_forcecast_noconvert", [](py::array_t) {}, "a"_a.noconvert()); + + // Check that types returns correct npy format descriptor + sm.def("test_fmt_desc_float", [](py::array_t) {}); + sm.def("test_fmt_desc_double", [](py::array_t) {}); + sm.def("test_fmt_desc_const_float", [](py::array_t) {}); + sm.def("test_fmt_desc_const_double", [](py::array_t) {}); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 02f3ecfc0..548c84bab 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -482,6 +482,19 @@ def test_index_using_ellipsis(): assert a.shape == (6,) +@pytest.mark.parametrize( + "test_func", + [ + m.test_fmt_desc_float, + m.test_fmt_desc_double, + m.test_fmt_desc_const_float, + m.test_fmt_desc_const_double, + ], +) +def test_format_descriptors_for_floating_point_types(test_func): + assert "numpy.ndarray[numpy.float" in test_func.__doc__ + + @pytest.mark.parametrize("forcecast", [False, True]) @pytest.mark.parametrize("contiguity", [None, "C", "F"]) @pytest.mark.parametrize("noconvert", [False, True])