Add function for comparing buffer_info formats to types

Allows equivalent integral types and numpy dtypes
This commit is contained in:
Patrick Stewart 2016-11-21 17:40:43 +00:00 committed by Wenzel Jakob
parent 5467979588
commit 0b6d08a008
5 changed files with 54 additions and 1 deletions

View File

@ -623,6 +623,24 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;

View File

@ -703,6 +703,13 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
static bool compare(const buffer_info& b) {
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
}
};
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
private:
// NB: the order here must match the one in common.h

View File

@ -319,6 +319,22 @@ py::list test_dtype_methods() {
return list;
}
struct CompareStruct {
bool x;
uint32_t y;
float z;
};
py::list test_compare_buffer_info() {
py::list list;
list.append(py::bool_(py::detail::compare_buffer_info<float>::compare(py::buffer_info(nullptr, sizeof(float), "f", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<unsigned>::compare(py::buffer_info(nullptr, sizeof(int), "I", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), "l", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<CompareStruct>::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1))));
return list;
}
test_initializer numpy_dtypes([](py::module &m) {
try {
py::module::import("numpy");
@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
// ... or after
py::class_<PackedStruct>(m, "PackedStruct");
@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
m.def("compare_buffer_info", &test_compare_buffer_info);
m.def("trailing_padding_dtype", &trailing_padding_dtype);
m.def("buffer_to_dtype", &buffer_to_dtype);
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });

View File

@ -264,3 +264,9 @@ def test_register_dtype():
with pytest.raises(RuntimeError) as excinfo:
register_dtype()
assert 'dtype is already registered' in str(excinfo.value)
@pytest.requires_numpy
def test_compare_buffer_info():
from pybind11_tests import compare_buffer_info
assert all(compare_buffer_info())

View File

@ -55,7 +55,7 @@ def test_vector_buffer():
@pytest.requires_numpy
def test_vector_buffer_numpy():
from pybind11_tests import VectorInt, get_vectorstruct
from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct
a = np.array([1, 2, 3, 4], dtype=np.int32)
with pytest.raises(TypeError):
@ -79,6 +79,10 @@ def test_vector_buffer_numpy():
m[1]['x'] = 99
assert v[1].x == 99
v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'),
('y', 'float64'), ('z', 'bool')], align=True)))
assert len(v) == 3
def test_vector_custom():
from pybind11_tests import El, VectorEl, VectorVectorEl