mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-22 00:19:18 +00:00
Use py::detail::compare_buffer_info<T>::compare()
to validate the format_descriptor<T>::format()
strings.
This commit is contained in:
parent
d432ce75b3
commit
18e1bd2a89
@ -14,12 +14,18 @@
|
|||||||
#include "pybind11_tests.h"
|
#include "pybind11_tests.h"
|
||||||
|
|
||||||
TEST_SUBMODULE(buffers, m) {
|
TEST_SUBMODULE(buffers, m) {
|
||||||
m.def("format_descriptor_format", [](const std::string &cpp_name) {
|
m.attr("std_is_same_double_long_double") = std::is_same<double, long double>::value;
|
||||||
|
|
||||||
|
m.def("format_descriptor_format_compare",
|
||||||
|
[](const std::string &cpp_name, const py::buffer &buffer) {
|
||||||
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
|
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
|
||||||
static auto *table = new std::map<std::string, std::string>;
|
static auto *format_table = new std::map<std::string, std::string>;
|
||||||
if (table->empty()) {
|
static auto *compare_table
|
||||||
|
= new std::map<std::string, bool (*)(const py::buffer_info &)>;
|
||||||
|
if (format_table->empty()) {
|
||||||
#define PYBIND11_ASSIGN_HELPER(...) \
|
#define PYBIND11_ASSIGN_HELPER(...) \
|
||||||
(*table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format();
|
(*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \
|
||||||
|
(*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare;
|
||||||
PYBIND11_ASSIGN_HELPER(PyObject *)
|
PYBIND11_ASSIGN_HELPER(PyObject *)
|
||||||
PYBIND11_ASSIGN_HELPER(bool)
|
PYBIND11_ASSIGN_HELPER(bool)
|
||||||
PYBIND11_ASSIGN_HELPER(std::int8_t)
|
PYBIND11_ASSIGN_HELPER(std::int8_t)
|
||||||
@ -38,7 +44,8 @@ TEST_SUBMODULE(buffers, m) {
|
|||||||
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
|
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
|
||||||
#undef PYBIND11_ASSIGN_HELPER
|
#undef PYBIND11_ASSIGN_HELPER
|
||||||
}
|
}
|
||||||
return (*table)[cpp_name];
|
return std::pair<std::string, bool>((*format_table)[cpp_name],
|
||||||
|
(*compare_table)[cpp_name](buffer.request()));
|
||||||
});
|
});
|
||||||
|
|
||||||
// test_from_python / test_to_python:
|
// test_from_python / test_to_python:
|
||||||
|
@ -10,49 +10,55 @@ from pybind11_tests import buffers as m
|
|||||||
|
|
||||||
np = pytest.importorskip("numpy")
|
np = pytest.importorskip("numpy")
|
||||||
|
|
||||||
if env.WIN:
|
if m.std_is_same_double_long_double: # Windows.
|
||||||
# Windows does not have these (see e.g. #1908). But who knows, maybe later?
|
np_float128 = None
|
||||||
np_float128_or_none = getattr(np, "float128", None)
|
np_complex256 = None
|
||||||
np_complex256_or_none = getattr(np, "complex256", None)
|
|
||||||
else:
|
else:
|
||||||
np_float128_or_none = np.float128
|
np_float128 = np.float128
|
||||||
np_complex256_or_none = np.complex256
|
np_complex256 = np.complex256
|
||||||
|
|
||||||
|
CPP_NAME_FORMAT_NP_DTYPE_TABLE = [
|
||||||
|
item
|
||||||
|
for item in [
|
||||||
|
("PyObject *", "O", object),
|
||||||
|
("bool", "?", np.bool_),
|
||||||
|
("std::int8_t", "b", np.int8),
|
||||||
|
("std::uint8_t", "B", np.uint8),
|
||||||
|
("std::int16_t", "h", np.int16),
|
||||||
|
("std::uint16_t", "H", np.uint16),
|
||||||
|
("std::int32_t", "i", np.int32),
|
||||||
|
("std::uint32_t", "I", np.uint32),
|
||||||
|
("std::int64_t", "q", np.int64),
|
||||||
|
("std::uint64_t", "Q", np.uint64),
|
||||||
|
("float", "f", np.float32),
|
||||||
|
("double", "d", np.float64),
|
||||||
|
("long double", "g", np_float128),
|
||||||
|
("std::complex<float>", "Zf", np.complex64),
|
||||||
|
("std::complex<double>", "Zd", np.complex128),
|
||||||
|
("std::complex<long double>", "Zg", np_complex256),
|
||||||
|
]
|
||||||
|
if item[-1] is not None
|
||||||
|
]
|
||||||
|
CPP_NAME_FORMAT_TABLE = [
|
||||||
|
(cpp_name, format) for cpp_name, format, _ in CPP_NAME_FORMAT_NP_DTYPE_TABLE
|
||||||
|
]
|
||||||
|
CPP_NAME_NP_DTYPE_TABLE = [
|
||||||
|
(cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE)
|
||||||
("cpp_name", "expected_fmts", "np_array_dtype"),
|
def test_format_descriptor_format_compare(cpp_name, np_dtype):
|
||||||
[
|
np_array = np.array([], dtype=np_dtype)
|
||||||
("PyObject *", ["O"], object),
|
for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE:
|
||||||
("bool", ["?"], np.bool_),
|
format, np_array_is_matching = m.format_descriptor_format_compare(
|
||||||
("std::int8_t", ["b"], np.int8),
|
other_cpp_name, np_array
|
||||||
("std::uint8_t", ["B"], np.uint8),
|
)
|
||||||
("std::int16_t", ["h"], np.int16),
|
assert format == expected_format
|
||||||
("std::uint16_t", ["H"], np.uint16),
|
if other_cpp_name == cpp_name:
|
||||||
("std::int32_t", ["i"], np.int32),
|
assert np_array_is_matching
|
||||||
("std::uint32_t", ["I"], np.uint32),
|
else:
|
||||||
("std::int64_t", ["q"], np.int64),
|
assert not np_array_is_matching
|
||||||
("std::uint64_t", ["Q"], np.uint64),
|
|
||||||
("float", ["f"], np.float32),
|
|
||||||
("double", ["d"], np.float64),
|
|
||||||
("long double", ["g", "d"], np_float128_or_none),
|
|
||||||
("std::complex<float>", ["Zf"], np.complex64),
|
|
||||||
("std::complex<double>", ["Zd"], np.complex128),
|
|
||||||
("std::complex<long double>", ["Zg", "Zd"], np_complex256_or_none),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_format_descriptor_format(cpp_name, expected_fmts, np_array_dtype):
|
|
||||||
fmt = m.format_descriptor_format(cpp_name)
|
|
||||||
assert fmt in expected_fmts
|
|
||||||
|
|
||||||
if np_array_dtype is not None:
|
|
||||||
na = np.array([], dtype=np_array_dtype)
|
|
||||||
bi = m.get_buffer_info(na)
|
|
||||||
bif = bi.format
|
|
||||||
if bif == "l":
|
|
||||||
bif = "i" if bi.itemsize == 4 else "q"
|
|
||||||
elif bif == "L":
|
|
||||||
bif = "I" if bi.itemsize == 4 else "Q"
|
|
||||||
assert bif == fmt
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_python():
|
def test_from_python():
|
||||||
|
Loading…
Reference in New Issue
Block a user