Use py::detail::compare_buffer_info<T>::compare() to validate the format_descriptor<T>::format() strings.

This commit is contained in:
Ralf W. Grosse-Kunstleve 2023-05-18 23:09:31 -07:00
parent d432ce75b3
commit 18e1bd2a89
2 changed files with 77 additions and 64 deletions

View File

@ -14,32 +14,39 @@
#include "pybind11_tests.h"
TEST_SUBMODULE(buffers, m) {
m.def("format_descriptor_format", [](const std::string &cpp_name) {
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
static auto *table = new std::map<std::string, std::string>;
if (table->empty()) {
m.attr("std_is_same_double_long_double") = std::is_same<double, long double>::value;
m.def("format_descriptor_format_compare",
[](const std::string &cpp_name, const py::buffer &buffer) {
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
static auto *format_table = new std::map<std::string, std::string>;
static auto *compare_table
= new std::map<std::string, bool (*)(const py::buffer_info &)>;
if (format_table->empty()) {
#define PYBIND11_ASSIGN_HELPER(...) \
(*table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format();
PYBIND11_ASSIGN_HELPER(PyObject *)
PYBIND11_ASSIGN_HELPER(bool)
PYBIND11_ASSIGN_HELPER(std::int8_t)
PYBIND11_ASSIGN_HELPER(std::uint8_t)
PYBIND11_ASSIGN_HELPER(std::int16_t)
PYBIND11_ASSIGN_HELPER(std::uint16_t)
PYBIND11_ASSIGN_HELPER(std::int32_t)
PYBIND11_ASSIGN_HELPER(std::uint32_t)
PYBIND11_ASSIGN_HELPER(std::int64_t)
PYBIND11_ASSIGN_HELPER(std::uint64_t)
PYBIND11_ASSIGN_HELPER(float)
PYBIND11_ASSIGN_HELPER(double)
PYBIND11_ASSIGN_HELPER(long double)
PYBIND11_ASSIGN_HELPER(std::complex<float>)
PYBIND11_ASSIGN_HELPER(std::complex<double>)
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
(*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \
(*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare;
PYBIND11_ASSIGN_HELPER(PyObject *)
PYBIND11_ASSIGN_HELPER(bool)
PYBIND11_ASSIGN_HELPER(std::int8_t)
PYBIND11_ASSIGN_HELPER(std::uint8_t)
PYBIND11_ASSIGN_HELPER(std::int16_t)
PYBIND11_ASSIGN_HELPER(std::uint16_t)
PYBIND11_ASSIGN_HELPER(std::int32_t)
PYBIND11_ASSIGN_HELPER(std::uint32_t)
PYBIND11_ASSIGN_HELPER(std::int64_t)
PYBIND11_ASSIGN_HELPER(std::uint64_t)
PYBIND11_ASSIGN_HELPER(float)
PYBIND11_ASSIGN_HELPER(double)
PYBIND11_ASSIGN_HELPER(long double)
PYBIND11_ASSIGN_HELPER(std::complex<float>)
PYBIND11_ASSIGN_HELPER(std::complex<double>)
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
#undef PYBIND11_ASSIGN_HELPER
}
return (*table)[cpp_name];
});
}
return std::pair<std::string, bool>((*format_table)[cpp_name],
(*compare_table)[cpp_name](buffer.request()));
});
// test_from_python / test_to_python:
class Matrix {

View File

@ -10,49 +10,55 @@ from pybind11_tests import buffers as m
np = pytest.importorskip("numpy")
if env.WIN:
# Windows does not have these (see e.g. #1908). But who knows, maybe later?
np_float128_or_none = getattr(np, "float128", None)
np_complex256_or_none = getattr(np, "complex256", None)
if m.std_is_same_double_long_double: # Windows.
np_float128 = None
np_complex256 = None
else:
np_float128_or_none = np.float128
np_complex256_or_none = np.complex256
np_float128 = np.float128
np_complex256 = np.complex256
CPP_NAME_FORMAT_NP_DTYPE_TABLE = [
item
for item in [
("PyObject *", "O", object),
("bool", "?", np.bool_),
("std::int8_t", "b", np.int8),
("std::uint8_t", "B", np.uint8),
("std::int16_t", "h", np.int16),
("std::uint16_t", "H", np.uint16),
("std::int32_t", "i", np.int32),
("std::uint32_t", "I", np.uint32),
("std::int64_t", "q", np.int64),
("std::uint64_t", "Q", np.uint64),
("float", "f", np.float32),
("double", "d", np.float64),
("long double", "g", np_float128),
("std::complex<float>", "Zf", np.complex64),
("std::complex<double>", "Zd", np.complex128),
("std::complex<long double>", "Zg", np_complex256),
]
if item[-1] is not None
]
CPP_NAME_FORMAT_TABLE = [
(cpp_name, format) for cpp_name, format, _ in CPP_NAME_FORMAT_NP_DTYPE_TABLE
]
CPP_NAME_NP_DTYPE_TABLE = [
(cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE
]
@pytest.mark.parametrize(
("cpp_name", "expected_fmts", "np_array_dtype"),
[
("PyObject *", ["O"], object),
("bool", ["?"], np.bool_),
("std::int8_t", ["b"], np.int8),
("std::uint8_t", ["B"], np.uint8),
("std::int16_t", ["h"], np.int16),
("std::uint16_t", ["H"], np.uint16),
("std::int32_t", ["i"], np.int32),
("std::uint32_t", ["I"], np.uint32),
("std::int64_t", ["q"], np.int64),
("std::uint64_t", ["Q"], np.uint64),
("float", ["f"], np.float32),
("double", ["d"], np.float64),
("long double", ["g", "d"], np_float128_or_none),
("std::complex<float>", ["Zf"], np.complex64),
("std::complex<double>", ["Zd"], np.complex128),
("std::complex<long double>", ["Zg", "Zd"], np_complex256_or_none),
],
)
def test_format_descriptor_format(cpp_name, expected_fmts, np_array_dtype):
fmt = m.format_descriptor_format(cpp_name)
assert fmt in expected_fmts
if np_array_dtype is not None:
na = np.array([], dtype=np_array_dtype)
bi = m.get_buffer_info(na)
bif = bi.format
if bif == "l":
bif = "i" if bi.itemsize == 4 else "q"
elif bif == "L":
bif = "I" if bi.itemsize == 4 else "Q"
assert bif == fmt
@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE)
def test_format_descriptor_format_compare(cpp_name, np_dtype):
np_array = np.array([], dtype=np_dtype)
for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE:
format, np_array_is_matching = m.format_descriptor_format_compare(
other_cpp_name, np_array
)
assert format == expected_format
if other_cpp_name == cpp_name:
assert np_array_is_matching
else:
assert not np_array_is_matching
def test_from_python():