mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-13 09:03:54 +00:00
Add buffer_info::compare<T>
to make detail::compare_buffer_info<T>::compare
more visible & accessible.
This commit is contained in:
parent
18e1bd2a89
commit
029b157540
@ -37,6 +37,9 @@ inline std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t
|
||||
return strides;
|
||||
}
|
||||
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct compare_buffer_info;
|
||||
|
||||
PYBIND11_NAMESPACE_END(detail)
|
||||
|
||||
/// Information record describing a Python buffer object
|
||||
@ -150,6 +153,11 @@ struct buffer_info {
|
||||
Py_buffer *view() const { return m_view; }
|
||||
Py_buffer *&view() { return m_view; }
|
||||
|
||||
template <typename T>
|
||||
static bool compare(const buffer_info &b) {
|
||||
return detail::compare_buffer_info<T>::compare(b);
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_ctr_tag {};
|
||||
|
||||
@ -170,7 +178,7 @@ private:
|
||||
|
||||
PYBIND11_NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T, typename SFINAE = void>
|
||||
template <typename T, typename SFINAE>
|
||||
struct compare_buffer_info {
|
||||
static bool compare(const buffer_info &b) {
|
||||
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
|
||||
|
@ -16,7 +16,7 @@
|
||||
TEST_SUBMODULE(buffers, m) {
|
||||
m.attr("std_is_same_double_long_double") = std::is_same<double, long double>::value;
|
||||
|
||||
m.def("format_descriptor_format_compare",
|
||||
m.def("format_descriptor_format_buffer_info_compare",
|
||||
[](const std::string &cpp_name, const py::buffer &buffer) {
|
||||
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
|
||||
static auto *format_table = new std::map<std::string, std::string>;
|
||||
@ -25,7 +25,7 @@ TEST_SUBMODULE(buffers, m) {
|
||||
if (format_table->empty()) {
|
||||
#define PYBIND11_ASSIGN_HELPER(...) \
|
||||
(*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \
|
||||
(*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare;
|
||||
(*compare_table)[#__VA_ARGS__] = py::buffer_info::compare<__VA_ARGS__>;
|
||||
PYBIND11_ASSIGN_HELPER(PyObject *)
|
||||
PYBIND11_ASSIGN_HELPER(bool)
|
||||
PYBIND11_ASSIGN_HELPER(std::int8_t)
|
||||
|
@ -48,10 +48,10 @@ CPP_NAME_NP_DTYPE_TABLE = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE)
|
||||
def test_format_descriptor_format_compare(cpp_name, np_dtype):
|
||||
def test_format_descriptor_format_buffer_info_compare(cpp_name, np_dtype):
|
||||
np_array = np.array([], dtype=np_dtype)
|
||||
for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE:
|
||||
format, np_array_is_matching = m.format_descriptor_format_compare(
|
||||
format, np_array_is_matching = m.format_descriptor_format_buffer_info_compare(
|
||||
other_cpp_name, np_array
|
||||
)
|
||||
assert format == expected_format
|
||||
|
Loading…
Reference in New Issue
Block a user