From 0b6d08a0081dc50a1bec7937a67043399755ce99 Mon Sep 17 00:00:00 2001 From: Patrick Stewart Date: Mon, 21 Nov 2016 17:40:43 +0000 Subject: [PATCH] Add function for comparing buffer_info formats to types Allows equivalent integral types and numpy dtypes --- include/pybind11/common.h | 18 ++++++++++++++++++ include/pybind11/numpy.h | 7 +++++++ tests/test_numpy_dtypes.cpp | 18 ++++++++++++++++++ tests/test_numpy_dtypes.py | 6 ++++++ tests/test_stl_binders.py | 6 +++++- 5 files changed, 54 insertions(+), 1 deletion(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index ef94f3854..44b7008e4 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -623,6 +623,24 @@ template struct format_descriptor constexpr const char format_descriptor< T, detail::enable_if_t::value>>::value[2]; +NAMESPACE_BEGIN(detail) + +template struct compare_buffer_info { + static bool compare(const buffer_info& b) { + return b.format == format_descriptor::format() && b.itemsize == sizeof(T); + } +}; + +template struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return b.itemsize == sizeof(T) && (b.format == format_descriptor::value || + ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || + ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); + } +}; + +NAMESPACE_END(detail) + /// RAII wrapper that temporarily clears any Python error state struct error_scope { PyObject *type, *value, *trace; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0faed31d5..5a766d490 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -703,6 +703,13 @@ struct pyobject_caster> { PYBIND11_TYPE_CASTER(type, handle_type_name::name()); }; +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return npy_api::get().PyArray_EquivTypes_(dtype::of().ptr(), dtype(b).ptr()); + } +}; + template struct npy_format_descriptor::value>> { private: // NB: the order here must match the one in common.h diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index d74ecc59e..1f6c85704 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -319,6 +319,22 @@ py::list test_dtype_methods() { return list; } +struct CompareStruct { + bool x; + uint32_t y; + float z; +}; + +py::list test_compare_buffer_info() { + py::list list; + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(float), "f", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(int), "I", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), "l", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1)))); + return list; +} + test_initializer numpy_dtypes([](py::module &m) { try { py::module::import("numpy"); @@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(StringStruct, a, b); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b); + PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z); // ... or after py::class_(m, "PackedStruct"); @@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("test_array_ctors", &test_array_ctors); m.def("test_dtype_ctors", &test_dtype_ctors); m.def("test_dtype_methods", &test_dtype_methods); + m.def("compare_buffer_info", &test_compare_buffer_info); m.def("trailing_padding_dtype", &trailing_padding_dtype); m.def("buffer_to_dtype", &buffer_to_dtype); m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; }); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 0ef4e939a..f63814f9d 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -264,3 +264,9 @@ def test_register_dtype(): with pytest.raises(RuntimeError) as excinfo: register_dtype() assert 'dtype is already registered' in str(excinfo.value) + + +@pytest.requires_numpy +def test_compare_buffer_info(): + from pybind11_tests import compare_buffer_info + assert all(compare_buffer_info()) diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index f8f817e33..0edf9e26e 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -55,7 +55,7 @@ def test_vector_buffer(): @pytest.requires_numpy def test_vector_buffer_numpy(): - from pybind11_tests import VectorInt, get_vectorstruct + from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct a = np.array([1, 2, 3, 4], dtype=np.int32) with pytest.raises(TypeError): @@ -79,6 +79,10 @@ def test_vector_buffer_numpy(): m[1]['x'] = 99 assert v[1].x == 99 + v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'), + ('y', 'float64'), ('z', 'bool')], align=True))) + assert len(v) == 3 + def test_vector_custom(): from pybind11_tests import El, VectorEl, VectorVectorEl