From 18e1bd2a8969a4cfaa41ba2faedff9d897d1fe15 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Thu, 18 May 2023 23:09:31 -0700 Subject: [PATCH] Use `py::detail::compare_buffer_info::compare()` to validate the `format_descriptor::format()` strings. --- tests/test_buffers.cpp | 55 +++++++++++++++------------ tests/test_buffers.py | 86 ++++++++++++++++++++++-------------------- 2 files changed, 77 insertions(+), 64 deletions(-) diff --git a/tests/test_buffers.cpp b/tests/test_buffers.cpp index daf36a794..ed9013ae7 100644 --- a/tests/test_buffers.cpp +++ b/tests/test_buffers.cpp @@ -14,32 +14,39 @@ #include "pybind11_tests.h" TEST_SUBMODULE(buffers, m) { - m.def("format_descriptor_format", [](const std::string &cpp_name) { - // https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables - static auto *table = new std::map; - if (table->empty()) { + m.attr("std_is_same_double_long_double") = std::is_same::value; + + m.def("format_descriptor_format_compare", + [](const std::string &cpp_name, const py::buffer &buffer) { + // https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables + static auto *format_table = new std::map; + static auto *compare_table + = new std::map; + if (format_table->empty()) { #define PYBIND11_ASSIGN_HELPER(...) \ - (*table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); - PYBIND11_ASSIGN_HELPER(PyObject *) - PYBIND11_ASSIGN_HELPER(bool) - PYBIND11_ASSIGN_HELPER(std::int8_t) - PYBIND11_ASSIGN_HELPER(std::uint8_t) - PYBIND11_ASSIGN_HELPER(std::int16_t) - PYBIND11_ASSIGN_HELPER(std::uint16_t) - PYBIND11_ASSIGN_HELPER(std::int32_t) - PYBIND11_ASSIGN_HELPER(std::uint32_t) - PYBIND11_ASSIGN_HELPER(std::int64_t) - PYBIND11_ASSIGN_HELPER(std::uint64_t) - PYBIND11_ASSIGN_HELPER(float) - PYBIND11_ASSIGN_HELPER(double) - PYBIND11_ASSIGN_HELPER(long double) - PYBIND11_ASSIGN_HELPER(std::complex) - PYBIND11_ASSIGN_HELPER(std::complex) - PYBIND11_ASSIGN_HELPER(std::complex) + (*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \ + (*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare; + PYBIND11_ASSIGN_HELPER(PyObject *) + PYBIND11_ASSIGN_HELPER(bool) + PYBIND11_ASSIGN_HELPER(std::int8_t) + PYBIND11_ASSIGN_HELPER(std::uint8_t) + PYBIND11_ASSIGN_HELPER(std::int16_t) + PYBIND11_ASSIGN_HELPER(std::uint16_t) + PYBIND11_ASSIGN_HELPER(std::int32_t) + PYBIND11_ASSIGN_HELPER(std::uint32_t) + PYBIND11_ASSIGN_HELPER(std::int64_t) + PYBIND11_ASSIGN_HELPER(std::uint64_t) + PYBIND11_ASSIGN_HELPER(float) + PYBIND11_ASSIGN_HELPER(double) + PYBIND11_ASSIGN_HELPER(long double) + PYBIND11_ASSIGN_HELPER(std::complex) + PYBIND11_ASSIGN_HELPER(std::complex) + PYBIND11_ASSIGN_HELPER(std::complex) #undef PYBIND11_ASSIGN_HELPER - } - return (*table)[cpp_name]; - }); + } + return std::pair((*format_table)[cpp_name], + (*compare_table)[cpp_name](buffer.request())); + }); // test_from_python / test_to_python: class Matrix { diff --git a/tests/test_buffers.py b/tests/test_buffers.py index afa5fc571..108154231 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -10,49 +10,55 @@ from pybind11_tests import buffers as m np = pytest.importorskip("numpy") -if env.WIN: - # Windows does not have these (see e.g. #1908). But who knows, maybe later? - np_float128_or_none = getattr(np, "float128", None) - np_complex256_or_none = getattr(np, "complex256", None) +if m.std_is_same_double_long_double: # Windows. + np_float128 = None + np_complex256 = None else: - np_float128_or_none = np.float128 - np_complex256_or_none = np.complex256 + np_float128 = np.float128 + np_complex256 = np.complex256 + +CPP_NAME_FORMAT_NP_DTYPE_TABLE = [ + item + for item in [ + ("PyObject *", "O", object), + ("bool", "?", np.bool_), + ("std::int8_t", "b", np.int8), + ("std::uint8_t", "B", np.uint8), + ("std::int16_t", "h", np.int16), + ("std::uint16_t", "H", np.uint16), + ("std::int32_t", "i", np.int32), + ("std::uint32_t", "I", np.uint32), + ("std::int64_t", "q", np.int64), + ("std::uint64_t", "Q", np.uint64), + ("float", "f", np.float32), + ("double", "d", np.float64), + ("long double", "g", np_float128), + ("std::complex", "Zf", np.complex64), + ("std::complex", "Zd", np.complex128), + ("std::complex", "Zg", np_complex256), + ] + if item[-1] is not None +] +CPP_NAME_FORMAT_TABLE = [ + (cpp_name, format) for cpp_name, format, _ in CPP_NAME_FORMAT_NP_DTYPE_TABLE +] +CPP_NAME_NP_DTYPE_TABLE = [ + (cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE +] -@pytest.mark.parametrize( - ("cpp_name", "expected_fmts", "np_array_dtype"), - [ - ("PyObject *", ["O"], object), - ("bool", ["?"], np.bool_), - ("std::int8_t", ["b"], np.int8), - ("std::uint8_t", ["B"], np.uint8), - ("std::int16_t", ["h"], np.int16), - ("std::uint16_t", ["H"], np.uint16), - ("std::int32_t", ["i"], np.int32), - ("std::uint32_t", ["I"], np.uint32), - ("std::int64_t", ["q"], np.int64), - ("std::uint64_t", ["Q"], np.uint64), - ("float", ["f"], np.float32), - ("double", ["d"], np.float64), - ("long double", ["g", "d"], np_float128_or_none), - ("std::complex", ["Zf"], np.complex64), - ("std::complex", ["Zd"], np.complex128), - ("std::complex", ["Zg", "Zd"], np_complex256_or_none), - ], -) -def test_format_descriptor_format(cpp_name, expected_fmts, np_array_dtype): - fmt = m.format_descriptor_format(cpp_name) - assert fmt in expected_fmts - - if np_array_dtype is not None: - na = np.array([], dtype=np_array_dtype) - bi = m.get_buffer_info(na) - bif = bi.format - if bif == "l": - bif = "i" if bi.itemsize == 4 else "q" - elif bif == "L": - bif = "I" if bi.itemsize == 4 else "Q" - assert bif == fmt +@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE) +def test_format_descriptor_format_compare(cpp_name, np_dtype): + np_array = np.array([], dtype=np_dtype) + for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE: + format, np_array_is_matching = m.format_descriptor_format_compare( + other_cpp_name, np_array + ) + assert format == expected_format + if other_cpp_name == cpp_name: + assert np_array_is_matching + else: + assert not np_array_is_matching def test_from_python():