From fb74df50c9165e768feac4a46f7e41e115996d62 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Thu, 20 Oct 2016 12:28:08 +0100 Subject: [PATCH 1/2] Implement format/numpy descriptors for enums --- include/pybind11/numpy.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index cee40c817..04001d6c5 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -552,6 +552,14 @@ template struct format_descriptor> { static std::string format() { return std::to_string(N) + "s"; } }; +template +struct format_descriptor::value>> { + static std::string format() { + return format_descriptor< + typename std::remove_cv::type>::type>::format(); + } +}; + NAMESPACE_BEGIN(detail) template struct is_std_array : std::false_type { }; template struct is_std_array> : std::true_type { }; @@ -563,6 +571,7 @@ struct is_pod_struct { !std::is_array::value && !is_std_array::value && !std::is_integral::value && + !std::is_enum::value && !std::is_same::type, float>::value && !std::is_same::type, double>::value && !std::is_same::type, bool>::value && @@ -612,6 +621,14 @@ template struct npy_format_descriptor { DECL_CHAR_FMT }; template struct npy_format_descriptor> { DECL_CHAR_FMT }; #undef DECL_CHAR_FMT +template struct npy_format_descriptor::value>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static PYBIND11_DESCR name() { return base_descr::name(); } + static pybind11::dtype dtype() { return base_descr::dtype(); } +}; + struct field_descriptor { const char *name; size_t offset; From 2f3f3687dc6efd282a500ec6253862d1f075beba Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Thu, 20 Oct 2016 12:28:47 +0100 Subject: [PATCH 2/2] Add tests for numpy enum descriptors --- tests/test_numpy_dtypes.cpp | 31 +++++++++++++++++++++++++++++-- tests/test_numpy_dtypes.py | 23 +++++++++++++++++++++-- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index 3041e55d8..86e6e68cc 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -67,6 +67,14 @@ struct StringStruct { std::array b; }; +enum class E1 : int64_t { A = -1, B = 1 }; +enum E2 : uint8_t { X = 1, Y = 2 }; + +PYBIND11_PACKED(struct EnumStruct { + E1 e1; + E2 e2; +}); + std::ostream& operator<<(std::ostream& os, const StringStruct& v) { os << "a='"; for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i]; @@ -75,6 +83,10 @@ std::ostream& operator<<(std::ostream& os, const StringStruct& v) { return os << "'"; } +std::ostream& operator<<(std::ostream& os, const EnumStruct& v) { + return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y"); +} + template py::array mkarray_via_buffer(size_t n) { return py::array(py::buffer_info(nullptr, sizeof(T), @@ -137,6 +149,16 @@ py::array_t create_string_array(bool non_empty) { return arr; } +py::array_t create_enum_array(size_t n) { + auto arr = mkarray_via_buffer(n); + auto ptr = (EnumStruct *) arr.mutable_data(); + for (size_t i = 0; i < n; i++) { + ptr[i].e1 = static_cast(-1 + ((int) i % 2) * 2); + ptr[i].e2 = static_cast(1 + (i % 2)); + } + return arr; +} + template py::list print_recarray(py::array_t arr) { const auto req = arr.request(); @@ -157,7 +179,8 @@ py::list print_format_descriptors() { py::format_descriptor::format(), py::format_descriptor::format(), py::format_descriptor::format(), - py::format_descriptor::format() + py::format_descriptor::format(), + py::format_descriptor::format() }; auto l = py::list(); for (const auto &fmt : fmts) { @@ -173,7 +196,8 @@ py::list print_dtypes() { py::dtype::of().str(), py::dtype::of().str(), py::dtype::of().str(), - py::dtype::of().str() + py::dtype::of().str(), + py::dtype::of().str() }; auto l = py::list(); for (const auto &s : dtypes) { @@ -280,6 +304,7 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z); PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a); PYBIND11_NUMPY_DTYPE(StringStruct, a, b); + PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); m.def("create_rec_simple", &create_recarray); m.def("create_rec_packed", &create_recarray); @@ -294,6 +319,8 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("get_format_unbound", &get_format_unbound); m.def("create_string_array", &create_string_array); m.def("print_string_array", &print_recarray); + m.def("create_enum_array", &create_enum_array); + m.def("print_enum_array", &print_recarray); m.def("test_array_ctors", &test_array_ctors); m.def("test_dtype_ctors", &test_dtype_ctors); m.def("test_dtype_methods", &test_dtype_methods); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 2f4cab0f0..22f5c662f 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -26,7 +26,8 @@ def test_format_descriptors(): "T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}", "T{=?:x:3x=I:y:=f:z:12x}", "T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}", - "T{=3s:a:=3s:b:}" + "T{=3s:a:=3s:b:}", + 'T{=q:e1:=B:e2:}' ] @@ -40,7 +41,8 @@ def test_dtype(): "[('a', {'names':['x','y','z'], 'formats':['?','