mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-20 23:49:18 +00:00
Add function for comparing buffer_info formats to types
Allows equivalent integral types and numpy dtypes
This commit is contained in:
parent
5467979588
commit
0b6d08a008
@ -623,6 +623,24 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T, typename SFINAE = void> struct compare_buffer_info {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return b.format == format_descriptor<T>::format() && b.itemsize == sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
|
||||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
|
||||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// RAII wrapper that temporarily clears any Python error state
|
||||
struct error_scope {
|
||||
PyObject *type, *value, *trace;
|
||||
|
@ -703,6 +703,13 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
|
||||
private:
|
||||
// NB: the order here must match the one in common.h
|
||||
|
@ -319,6 +319,22 @@ py::list test_dtype_methods() {
|
||||
return list;
|
||||
}
|
||||
|
||||
struct CompareStruct {
|
||||
bool x;
|
||||
uint32_t y;
|
||||
float z;
|
||||
};
|
||||
|
||||
py::list test_compare_buffer_info() {
|
||||
py::list list;
|
||||
list.append(py::bool_(py::detail::compare_buffer_info<float>::compare(py::buffer_info(nullptr, sizeof(float), "f", 1))));
|
||||
list.append(py::bool_(py::detail::compare_buffer_info<unsigned>::compare(py::buffer_info(nullptr, sizeof(int), "I", 1))));
|
||||
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), "l", 1))));
|
||||
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1))));
|
||||
list.append(py::bool_(py::detail::compare_buffer_info<CompareStruct>::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1))));
|
||||
return list;
|
||||
}
|
||||
|
||||
test_initializer numpy_dtypes([](py::module &m) {
|
||||
try {
|
||||
py::module::import("numpy");
|
||||
@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
|
||||
|
||||
// ... or after
|
||||
py::class_<PackedStruct>(m, "PackedStruct");
|
||||
@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("test_array_ctors", &test_array_ctors);
|
||||
m.def("test_dtype_ctors", &test_dtype_ctors);
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
m.def("compare_buffer_info", &test_compare_buffer_info);
|
||||
m.def("trailing_padding_dtype", &trailing_padding_dtype);
|
||||
m.def("buffer_to_dtype", &buffer_to_dtype);
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
|
||||
|
@ -264,3 +264,9 @@ def test_register_dtype():
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
register_dtype()
|
||||
assert 'dtype is already registered' in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_compare_buffer_info():
|
||||
from pybind11_tests import compare_buffer_info
|
||||
assert all(compare_buffer_info())
|
||||
|
@ -55,7 +55,7 @@ def test_vector_buffer():
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_vector_buffer_numpy():
|
||||
from pybind11_tests import VectorInt, get_vectorstruct
|
||||
from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct
|
||||
|
||||
a = np.array([1, 2, 3, 4], dtype=np.int32)
|
||||
with pytest.raises(TypeError):
|
||||
@ -79,6 +79,10 @@ def test_vector_buffer_numpy():
|
||||
m[1]['x'] = 99
|
||||
assert v[1].x == 99
|
||||
|
||||
v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'),
|
||||
('y', 'float64'), ('z', 'bool')], align=True)))
|
||||
assert len(v) == 3
|
||||
|
||||
|
||||
def test_vector_custom():
|
||||
from pybind11_tests import El, VectorEl, VectorVectorEl
|
||||
|
Loading…
Reference in New Issue
Block a user