Strip padding fields in dtypes, update the tests

This commit is contained in:
Ivan Smirnov 2016-07-06 00:28:12 +01:00
parent 13022f1b8c
commit 8fa09cb871
4 changed files with 178 additions and 56 deletions

View File

@ -44,6 +44,19 @@ std::ostream& operator<<(std::ostream& os, const NestedStruct& v) {
return os << "n:a=" << v.a << ";b=" << v.b;
}
struct PartialStruct {
bool x;
uint32_t y;
float z;
long dummy2;
};
struct PartialNestedStruct {
long dummy1;
PartialStruct a;
long dummy2;
};
struct UnboundStruct { };
template <typename T>
@ -54,7 +67,7 @@ py::array mkarray_via_buffer(size_t n) {
}
template <typename S>
py::array_t<S> create_recarray(size_t n) {
py::array_t<S, 0> create_recarray(size_t n) {
auto arr = mkarray_via_buffer<S>(n);
auto ptr = static_cast<S*>(arr.request().ptr);
for (size_t i = 0; i < n; i++) {
@ -67,7 +80,7 @@ std::string get_format_unbound() {
return py::format_descriptor<UnboundStruct>::format();
}
py::array_t<NestedStruct> create_nested(size_t n) {
py::array_t<NestedStruct, 0> create_nested(size_t n) {
auto arr = mkarray_via_buffer<NestedStruct>(n);
auto ptr = static_cast<NestedStruct*>(arr.request().ptr);
for (size_t i = 0; i < n; i++) {
@ -77,8 +90,17 @@ py::array_t<NestedStruct> create_nested(size_t n) {
return arr;
}
py::array_t<PartialNestedStruct, 0> create_partial_nested(size_t n) {
auto arr = mkarray_via_buffer<PartialNestedStruct>(n);
auto ptr = static_cast<PartialNestedStruct*>(arr.request().ptr);
for (size_t i = 0; i < n; i++) {
ptr[i].a.x = i % 2; ptr[i].a.y = (uint32_t) i; ptr[i].a.z = (float) i * 1.5f;
}
return arr;
}
template <typename S>
void print_recarray(py::array_t<S> arr) {
void print_recarray(py::array_t<S, 0> arr) {
auto buf = arr.request();
auto ptr = static_cast<S*>(buf.ptr);
for (size_t i = 0; i < buf.size; i++)
@ -89,6 +111,8 @@ void print_format_descriptors() {
std::cout << py::format_descriptor<SimpleStruct>::format() << std::endl;
std::cout << py::format_descriptor<PackedStruct>::format() << std::endl;
std::cout << py::format_descriptor<NestedStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialNestedStruct>::format() << std::endl;
}
void print_dtypes() {
@ -98,16 +122,22 @@ void print_dtypes() {
std::cout << to_str(py::dtype_of<SimpleStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PackedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
}
void init_ex20(py::module &m) {
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
m.def("create_rec_nested", &create_nested);
m.def("create_rec_partial", &create_recarray<PartialStruct>);
m.def("create_rec_partial_nested", &create_partial_nested);
m.def("print_format_descriptors", &print_format_descriptors);
m.def("print_rec_simple", &print_recarray<SimpleStruct>);
m.def("print_rec_packed", &print_recarray<PackedStruct>);

View File

@ -5,7 +5,8 @@ import unittest
import numpy as np
from example import (
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound,
create_rec_partial, create_rec_partial_nested
)
@ -23,6 +24,8 @@ simple_dtype = np.dtype({'names': ['x', 'y', 'z'],
'offsets': [0, 4, 8]})
packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
elements = [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)]
for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]:
arr = func(0)
assert arr.dtype == dtype
@ -31,14 +34,30 @@ for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packe
arr = func(3)
assert arr.dtype == dtype
check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], simple_dtype)
check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], packed_dtype)
check_eq(arr, elements, simple_dtype)
check_eq(arr, elements, packed_dtype)
if dtype == simple_dtype:
print_rec_simple(arr)
else:
print_rec_packed(arr)
arr = create_rec_partial(3)
print(arr.dtype)
partial_dtype = arr.dtype
assert '' not in arr.dtype.fields
assert partial_dtype.itemsize > simple_dtype.itemsize
check_eq(arr, elements, simple_dtype)
check_eq(arr, elements, packed_dtype)
arr = create_rec_partial_nested(3)
print(arr.dtype)
assert '' not in arr.dtype.fields
assert '' not in arr.dtype.fields['a'][0].fields
assert arr.dtype.itemsize > partial_dtype.itemsize
np.testing.assert_equal(arr['a'], create_rec_partial(3))
nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
arr = create_rec_nested(0)

View File

@ -1,15 +1,21 @@
T{?:x:xxxI:y:f:z:}
T{?:x:=I:y:f:z:}
T{T{?:x:xxxI:y:f:z:}:a:T{?:x:=I:y:f:z:}:b:}
T{=?:x:3x=I:y:=f:z:}
T{=?:x:=I:y:=f:z:}
T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}
T{=?:x:3x=I:y:=f:z:12x}
T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}
[('x', '?'), ('y', '<u4'), ('z', '<f4')]
[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
s:0,0,0
s:1,1,1.5
s:0,2,3
p:0,0,0
p:1,1,1.5
p:0,2,3
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
n:a=s:0,0,0;b=p:1,1,1.5
n:a=s:1,1,1.5;b=p:0,2,3
n:a=s:0,2,3;b=p:1,3,4.5

View File

@ -15,6 +15,7 @@
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <initializer_list>
#if defined(_MSC_VER)
@ -26,6 +27,8 @@ NAMESPACE_BEGIN(pybind11)
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
object fix_dtype(object);
template <typename T>
struct is_pod_struct {
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
@ -47,7 +50,9 @@ public:
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
@ -61,7 +66,9 @@ public:
NPY_LONG_, NPY_ULONG_,
NPY_LONGLONG_, NPY_ULONGLONG_,
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};
static API lookup() {
@ -79,7 +86,9 @@ public:
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
@ -91,10 +100,12 @@ public:
PyObject *(*PyArray_NewFromDescr_)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
@ -113,52 +124,83 @@ public:
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
if (ptr && tmp)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (ptr)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
}
array(const buffer_info &info) {
PyObject *arr = nullptr, *descr = nullptr;
int ndim = 0;
Py_ssize_t dims[32];
API& api = lookup_api();
auto& api = lookup_api();
// Allocate non-zeroed memory if it hasn't been provided by the caller.
// Normally, we could leave this null for NumPy to allocate memory for us, but
// since we need a memoryview, the data pointer has to be non-null. NumPy uses
// malloc if NPY_NEEDS_INIT is not set (in which case it uses calloc); however,
// we don't have a desriptor yet (only a buffer format string), so we can't
// access the flags. As long as we're not dealing with object dtypes/fields
// though, the memory doesn't have to be zeroed so we use malloc.
auto buf_info = info;
if (!buf_info.ptr)
// always allocate at least 1 element, same way as NumPy does it
buf_info.ptr = std::malloc(std::max(info.size, (size_t) 1) * info.itemsize);
if (!buf_info.ptr)
pybind11_fail("NumPy: failed to allocate memory for buffer");
// _dtype_from_pep3118 returns dtypes with padding fields in, however the array
// constructor seems to then consume them, so we don't need to strip them ourselves
auto numpy_internal = module::import("numpy.core._internal");
auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
auto dtype = dtype_from_fmt(pybind11::str(info.format));
auto dtype2 = strip_padding_fields(dtype);
// PyArray_GetArrayParamsFromObject seems to be the only low-level API function
// that will accept arbitrary buffers (including structured types)
auto view = memoryview(buf_info);
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
&ndim, dims, &arr, nullptr);
if (res < 0 || !arr || descr)
// We expect arr to have a pointer to a newly created array, in which case all
// other parameters like descr would be set to null, according to the API.
pybind11_fail("NumPy: unable to convert buffer to an array");
m_ptr = arr;
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, dtype2.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (info.ptr)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
auto d = (object) this->attr("dtype");
}
protected:
// protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static object strip_padding_fields(object dtype) {
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto fields = dtype.attr("fields").cast<object>();
if (fields.ptr() == Py_None)
return dtype;
struct field_descr { pybind11::str name; object format; int_ offset; };
std::vector<field_descr> field_descriptors;
auto items = fields.attr("items").cast<object>();
for (auto field : items()) {
auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<object>();
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
continue;
field_descriptors.push_back({name, strip_padding_fields(format), offset});
}
std::sort(field_descriptors.begin(), field_descriptors.end(),
[](const field_descr& a, const field_descr& b) {
return (int) a.offset < (int) b.offset;
});
list names, formats, offsets;
for (auto& descr : field_descriptors) {
names.append(descr.name);
formats.append(descr.format);
offsets.append(descr.offset);
}
auto args = dict();
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
PyObject *descr = nullptr;
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
pybind11_fail("NumPy: failed to create structured dtype");
return object(descr, false);
}
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
@ -233,9 +275,12 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
struct field_descriptor {
const char *name;
size_t offset;
size_t size;
const char *format;
object descr;
};
template <typename T>
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
static PYBIND11_DESCR name() { return _("user-defined"); }
@ -253,7 +298,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
array::API& api = array::lookup_api();
auto& api = array::lookup_api();
auto args = dict();
list names { }, offsets { }, formats { };
for (auto field : fields) {
@ -263,26 +308,47 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
offsets.append(int_(field.offset));
formats.append(field.descr);
}
args["names"] = names;
args["offsets"] = offsets;
args["formats"] = formats;
args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
args["itemsize"] = int_(sizeof(T));
// This is essentially the same as calling np.dtype() constructor in Python and passing
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
pybind11_fail("NumPy: failed to create structured dtype");
// Let NumPy figure the buffer format string for us: memoryview(np.empty(0, dtype)).format
auto np = module::import("numpy");
auto empty = (object) np.attr("empty");
if (auto arr = (object) empty(int_(0), dtype())) {
if (auto view = PyMemoryView_FromObject(arr.ptr())) {
if (auto info = PyMemoryView_GET_BUFFER(view)) {
std::strncpy(format_(), info->format, 4096);
return;
}
}
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor& a, const field_descriptor &b) {
return a.offset < b.offset;
});
size_t offset = 0;
std::ostringstream oss;
oss << "T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// note that '=' is required to cover the case of unaligned fields
oss << '=' << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
pybind11_fail("NumPy: failed to extract buffer format");
if (sizeof(T) > offset)
oss << (sizeof(T) - offset) << 'x';
oss << '}';
std::strncpy(format_(), oss.str().c_str(), 4096);
// Sanity check: verify that NumPy properly parses our buffer format string
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
auto dtype = (object) arr.attr("dtype");
auto fixed_dtype = dtype;
// auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
// if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
// pybind11_fail("NumPy: invalid buffer descriptor!");
}
private:
@ -293,7 +359,8 @@ private:
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \
#Field, offsetof(Type, Field), \
#Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
}