Use correct itemsize when constructing a numpy dtype from a buffer_info

This commit is contained in:
Patrick Stewart 2016-11-22 14:56:52 +00:00 committed by Wenzel Jakob
parent 47681c183d
commit 5271576828
3 changed files with 26 additions and 7 deletions

View File

@ -228,7 +228,8 @@ public:
explicit dtype(const buffer_info &info) { explicit dtype(const buffer_info &info) {
dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
m_ptr = descr.strip_padding().release().ptr(); // If info.itemsize == 0, use the value calculated from the format string
m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
} }
explicit dtype(const std::string &format) { explicit dtype(const std::string &format) {
@ -281,7 +282,7 @@ private:
return reinterpret_borrow<object>(obj); return reinterpret_borrow<object>(obj);
} }
dtype strip_padding() { dtype strip_padding(size_t itemsize) {
// Recursively strip all void fields with empty names that are generated for // Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11). // padding fields (as of NumPy v1.11).
if (!has_fields()) if (!has_fields())
@ -297,7 +298,7 @@ private:
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>(); auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
if (!len(name) && format.kind() == 'V') if (!len(name) && format.kind() == 'V')
continue; continue;
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset}); field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
} }
std::sort(field_descriptors.begin(), field_descriptors.end(), std::sort(field_descriptors.begin(), field_descriptors.end(),
@ -311,7 +312,7 @@ private:
formats.append(descr.format); formats.append(descr.format);
offsets.append(descr.offset); offsets.append(descr.offset);
} }
return dtype(names, formats, offsets, itemsize()); return dtype(names, formats, offsets, itemsize);
} }
}; };

View File

@ -282,11 +282,24 @@ py::list test_dtype_ctors() {
dict["itemsize"] = py::int_(20); dict["itemsize"] = py::int_(20);
list.append(py::dtype::from_args(dict)); list.append(py::dtype::from_args(dict));
list.append(py::dtype(names, formats, offsets, 20)); list.append(py::dtype(names, formats, offsets, 20));
list.append(py::dtype(py::buffer_info((void *) 0, 1, "I", 1))); list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1)));
list.append(py::dtype(py::buffer_info((void *) 0, 1, "T{i:a:f:b:}", 1))); list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1)));
return list; return list;
} }
struct TrailingPaddingStruct {
int32_t a;
char b;
};
py::dtype trailing_padding_dtype() {
return py::dtype::of<TrailingPaddingStruct>();
}
py::dtype buffer_to_dtype(py::buffer& buf) {
return py::dtype(buf.request());
}
py::list test_dtype_methods() { py::list test_dtype_methods() {
py::list list; py::list list;
auto dt1 = py::dtype::of<int32_t>(); auto dt1 = py::dtype::of<int32_t>();
@ -314,6 +327,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a); PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
PYBIND11_NUMPY_DTYPE(StringStruct, a, b); PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
// ... or after // ... or after
py::class_<PackedStruct>(m, "PackedStruct"); py::class_<PackedStruct>(m, "PackedStruct");
@ -338,6 +352,8 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("test_array_ctors", &test_array_ctors); m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors); m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods); m.def("test_dtype_methods", &test_dtype_methods);
m.def("trailing_padding_dtype", &trailing_padding_dtype);
m.def("buffer_to_dtype", &buffer_to_dtype);
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; }); m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
m.def("f_packed", [](PackedStruct s) { return s.y * 10; }); m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; }); m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });

View File

@ -42,7 +42,7 @@ def test_format_descriptors():
@pytest.requires_numpy @pytest.requires_numpy
def test_dtype(simple_dtype): def test_dtype(simple_dtype):
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods, trailing_padding_dtype, buffer_to_dtype
assert print_dtypes() == [ assert print_dtypes() == [
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}", "{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}",
@ -64,6 +64,8 @@ def test_dtype(simple_dtype):
assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True, assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True,
np.dtype('int32').itemsize, simple_dtype.itemsize] np.dtype('int32').itemsize, simple_dtype.itemsize]
assert trailing_padding_dtype() == buffer_to_dtype(np.zeros(1, trailing_padding_dtype()))
@pytest.requires_numpy @pytest.requires_numpy
def test_recarray(simple_dtype, packed_dtype): def test_recarray(simple_dtype, packed_dtype):