Add numpy wrappers for char[] and std::array<char>

This commit is contained in:
Ivan Smirnov 2016-07-20 00:19:24 +01:00
parent 103d5eadc3
commit f9c0defed7
4 changed files with 84 additions and 5 deletions

View File

@ -65,6 +65,19 @@ struct PartialNestedStruct {
struct UnboundStruct { };
struct StringStruct {
char a[3];
std::array<char, 3> b;
};
std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
os << "a='";
for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
os << "',b='";
for (size_t i = 0; i < 3 && v.b[i]; i++) os << v.b[i];
return os << "'";
}
template <typename T>
py::array mkarray_via_buffer(size_t n) {
return py::array(py::buffer_info(nullptr, sizeof(T),
@ -108,6 +121,25 @@ py::array_t<PartialNestedStruct, 0> create_partial_nested(size_t n) {
return arr;
}
py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
auto arr = mkarray_via_buffer<StringStruct>(non_empty ? 4 : 0);
if (non_empty) {
auto req = arr.request();
auto ptr = static_cast<StringStruct*>(req.ptr);
for (size_t i = 0; i < req.size * req.itemsize; i++)
static_cast<char*>(req.ptr)[i] = 0;
ptr[1].a[0] = 'a'; ptr[1].b[0] = 'a';
ptr[2].a[0] = 'a'; ptr[2].b[0] = 'a';
ptr[3].a[0] = 'a'; ptr[3].b[0] = 'a';
ptr[2].a[1] = 'b'; ptr[2].b[1] = 'b';
ptr[3].a[1] = 'b'; ptr[3].b[1] = 'b';
ptr[3].a[2] = 'c'; ptr[3].b[2] = 'c';
}
return arr;
}
template <typename S>
void print_recarray(py::array_t<S, 0> arr) {
auto req = arr.request();
@ -122,6 +154,7 @@ void print_format_descriptors() {
std::cout << py::format_descriptor<NestedStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialNestedStruct>::format() << std::endl;
std::cout << py::format_descriptor<StringStruct>::format() << std::endl;
}
void print_dtypes() {
@ -133,6 +166,7 @@ void print_dtypes() {
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<StringStruct>()) << std::endl;
}
void init_ex20(py::module &m) {
@ -141,6 +175,7 @@ void init_ex20(py::module &m) {
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
@ -153,6 +188,8 @@ void init_ex20(py::module &m) {
m.def("print_rec_nested", &print_recarray<NestedStruct>);
m.def("print_dtypes", &print_dtypes);
m.def("get_format_unbound", &get_format_unbound);
m.def("create_string_array", &create_string_array);
m.def("print_string_array", &print_recarray<StringStruct>);
}
#undef PYBIND11_PACKED

View File

@ -6,7 +6,7 @@ import numpy as np
from example import (
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound,
create_rec_partial, create_rec_partial_nested
create_rec_partial, create_rec_partial_nested, create_string_array, print_string_array
)
@ -72,3 +72,12 @@ check_eq(arr, [((False, 0, 0.0), (True, 1, 1.5)),
print_rec_nested(arr)
assert create_rec_nested.__doc__.strip().endswith('numpy.ndarray[dtype=NestedStruct]')
arr = create_string_array(True)
print(arr.dtype)
print_string_array(arr)
dtype = arr.dtype
assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
arr = create_string_array(False)
assert dtype == arr.dtype

View File

@ -3,11 +3,13 @@ T{=?:x:=I:y:=f:z:}
T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}
T{=?:x:3x=I:y:=f:z:12x}
T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}
T{=3s:a:=3s:b:}
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}
[('x', '?'), ('y', '<u4'), ('z', '<f4')]
[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
[('a', 'S3'), ('b', 'S3')]
s:0,0,0
s:1,1,1.5
s:0,2,3
@ -18,4 +20,9 @@ p:0,2,3
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
n:a=s:0,0,0;b=p:1,1,1.5
n:a=s:1,1,1.5;b=p:0,2,3
n:a=s:0,2,3;b=p:1,3,4.5
n:a=s:0,2,3;b=p:1,3,4.5
[('a', 'S3'), ('b', 'S3')]
a='',b=''
a='a',b='a'
a='ab',b='ab'
a='abc',b='abc'

View File

@ -13,6 +13,7 @@
#include "complex.h"
#include <numeric>
#include <algorithm>
#include <array>
#include <cstdlib>
#include <cstring>
#include <sstream>
@ -27,10 +28,14 @@ NAMESPACE_BEGIN(pybind11)
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T>
struct is_pod_struct {
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
!std::is_array<T>::value &&
!is_std_array<T>::value &&
!std::is_integral<T>::value &&
!std::is_same<T, float>::value &&
!std::is_same<T, double>::value &&
@ -221,9 +226,14 @@ public:
template <typename T>
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
static const char *format() {
return detail::npy_format_descriptor<T>::format();
}
static const char *format() { return detail::npy_format_descriptor<T>::format(); }
};
template <size_t N> struct format_descriptor<char[N]> {
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
};
template <typename T>
@ -268,6 +278,22 @@ DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#undef DECL_FMT
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
auto& api = array::lookup_api(); \
PyObject *descr = nullptr; \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
pybind11_fail("NumPy: failed to create string dtype"); \
return object(descr, false); \
} \
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
#undef DECL_CHAR_FMT
struct field_descriptor {
const char *name;
size_t offset;