mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Add numpy wrappers for char[] and std::array<char>
This commit is contained in:
parent
103d5eadc3
commit
f9c0defed7
@ -65,6 +65,19 @@ struct PartialNestedStruct {
|
||||
|
||||
struct UnboundStruct { };
|
||||
|
||||
struct StringStruct {
|
||||
char a[3];
|
||||
std::array<char, 3> b;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
|
||||
os << "a='";
|
||||
for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
|
||||
os << "',b='";
|
||||
for (size_t i = 0; i < 3 && v.b[i]; i++) os << v.b[i];
|
||||
return os << "'";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::array mkarray_via_buffer(size_t n) {
|
||||
return py::array(py::buffer_info(nullptr, sizeof(T),
|
||||
@ -108,6 +121,25 @@ py::array_t<PartialNestedStruct, 0> create_partial_nested(size_t n) {
|
||||
return arr;
|
||||
}
|
||||
|
||||
py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
|
||||
auto arr = mkarray_via_buffer<StringStruct>(non_empty ? 4 : 0);
|
||||
if (non_empty) {
|
||||
auto req = arr.request();
|
||||
auto ptr = static_cast<StringStruct*>(req.ptr);
|
||||
for (size_t i = 0; i < req.size * req.itemsize; i++)
|
||||
static_cast<char*>(req.ptr)[i] = 0;
|
||||
ptr[1].a[0] = 'a'; ptr[1].b[0] = 'a';
|
||||
ptr[2].a[0] = 'a'; ptr[2].b[0] = 'a';
|
||||
ptr[3].a[0] = 'a'; ptr[3].b[0] = 'a';
|
||||
|
||||
ptr[2].a[1] = 'b'; ptr[2].b[1] = 'b';
|
||||
ptr[3].a[1] = 'b'; ptr[3].b[1] = 'b';
|
||||
|
||||
ptr[3].a[2] = 'c'; ptr[3].b[2] = 'c';
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
void print_recarray(py::array_t<S, 0> arr) {
|
||||
auto req = arr.request();
|
||||
@ -122,6 +154,7 @@ void print_format_descriptors() {
|
||||
std::cout << py::format_descriptor<NestedStruct>::format() << std::endl;
|
||||
std::cout << py::format_descriptor<PartialStruct>::format() << std::endl;
|
||||
std::cout << py::format_descriptor<PartialNestedStruct>::format() << std::endl;
|
||||
std::cout << py::format_descriptor<StringStruct>::format() << std::endl;
|
||||
}
|
||||
|
||||
void print_dtypes() {
|
||||
@ -133,6 +166,7 @@ void print_dtypes() {
|
||||
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<StringStruct>()) << std::endl;
|
||||
}
|
||||
|
||||
void init_ex20(py::module &m) {
|
||||
@ -141,6 +175,7 @@ void init_ex20(py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
|
||||
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
|
||||
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
|
||||
m.def("create_rec_packed", &create_recarray<PackedStruct>);
|
||||
@ -153,6 +188,8 @@ void init_ex20(py::module &m) {
|
||||
m.def("print_rec_nested", &print_recarray<NestedStruct>);
|
||||
m.def("print_dtypes", &print_dtypes);
|
||||
m.def("get_format_unbound", &get_format_unbound);
|
||||
m.def("create_string_array", &create_string_array);
|
||||
m.def("print_string_array", &print_recarray<StringStruct>);
|
||||
}
|
||||
|
||||
#undef PYBIND11_PACKED
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from example import (
|
||||
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
|
||||
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound,
|
||||
create_rec_partial, create_rec_partial_nested
|
||||
create_rec_partial, create_rec_partial_nested, create_string_array, print_string_array
|
||||
)
|
||||
|
||||
|
||||
@ -72,3 +72,12 @@ check_eq(arr, [((False, 0, 0.0), (True, 1, 1.5)),
|
||||
print_rec_nested(arr)
|
||||
|
||||
assert create_rec_nested.__doc__.strip().endswith('numpy.ndarray[dtype=NestedStruct]')
|
||||
|
||||
arr = create_string_array(True)
|
||||
print(arr.dtype)
|
||||
print_string_array(arr)
|
||||
dtype = arr.dtype
|
||||
assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
|
||||
assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
|
||||
arr = create_string_array(False)
|
||||
assert dtype == arr.dtype
|
||||
|
@ -3,11 +3,13 @@ T{=?:x:=I:y:=f:z:}
|
||||
T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}
|
||||
T{=?:x:3x=I:y:=f:z:12x}
|
||||
T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}
|
||||
T{=3s:a:=3s:b:}
|
||||
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}
|
||||
[('x', '?'), ('y', '<u4'), ('z', '<f4')]
|
||||
[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]
|
||||
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
|
||||
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
|
||||
[('a', 'S3'), ('b', 'S3')]
|
||||
s:0,0,0
|
||||
s:1,1,1.5
|
||||
s:0,2,3
|
||||
@ -18,4 +20,9 @@ p:0,2,3
|
||||
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
|
||||
n:a=s:0,0,0;b=p:1,1,1.5
|
||||
n:a=s:1,1,1.5;b=p:0,2,3
|
||||
n:a=s:0,2,3;b=p:1,3,4.5
|
||||
n:a=s:0,2,3;b=p:1,3,4.5
|
||||
[('a', 'S3'), ('b', 'S3')]
|
||||
a='',b=''
|
||||
a='a',b='a'
|
||||
a='ab',b='ab'
|
||||
a='abc',b='abc'
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "complex.h"
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
@ -27,10 +28,14 @@ NAMESPACE_BEGIN(pybind11)
|
||||
namespace detail {
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
|
||||
template <typename T> struct is_std_array : std::false_type { };
|
||||
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
||||
|
||||
template <typename T>
|
||||
struct is_pod_struct {
|
||||
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
|
||||
!std::is_array<T>::value &&
|
||||
!is_std_array<T>::value &&
|
||||
!std::is_integral<T>::value &&
|
||||
!std::is_same<T, float>::value &&
|
||||
!std::is_same<T, double>::value &&
|
||||
@ -221,9 +226,14 @@ public:
|
||||
|
||||
template <typename T>
|
||||
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
|
||||
static const char *format() {
|
||||
return detail::npy_format_descriptor<T>::format();
|
||||
}
|
||||
static const char *format() { return detail::npy_format_descriptor<T>::format(); }
|
||||
};
|
||||
|
||||
template <size_t N> struct format_descriptor<char[N]> {
|
||||
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
|
||||
};
|
||||
template <size_t N> struct format_descriptor<std::array<char, N>> {
|
||||
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -268,6 +278,22 @@ DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
|
||||
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
#undef DECL_FMT
|
||||
|
||||
#define DECL_CHAR_FMT \
|
||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||
static object dtype() { \
|
||||
auto& api = array::lookup_api(); \
|
||||
PyObject *descr = nullptr; \
|
||||
PYBIND11_DESCR fmt = _("S") + _<N>(); \
|
||||
pybind11::str py_fmt(fmt.text()); \
|
||||
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
|
||||
pybind11_fail("NumPy: failed to create string dtype"); \
|
||||
return object(descr, false); \
|
||||
} \
|
||||
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
|
||||
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
|
||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
|
||||
#undef DECL_CHAR_FMT
|
||||
|
||||
struct field_descriptor {
|
||||
const char *name;
|
||||
size_t offset;
|
||||
|
Loading…
Reference in New Issue
Block a user