mirror of
https://github.com/pybind/pybind11.git
synced 2025-01-19 01:15:52 +00:00
Use correct itemsize when constructing a numpy dtype from a buffer_info
This commit is contained in:
parent
47681c183d
commit
5271576828
@ -228,7 +228,8 @@ public:
|
||||
|
||||
explicit dtype(const buffer_info &info) {
|
||||
dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
|
||||
m_ptr = descr.strip_padding().release().ptr();
|
||||
// If info.itemsize == 0, use the value calculated from the format string
|
||||
m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
|
||||
}
|
||||
|
||||
explicit dtype(const std::string &format) {
|
||||
@ -281,7 +282,7 @@ private:
|
||||
return reinterpret_borrow<object>(obj);
|
||||
}
|
||||
|
||||
dtype strip_padding() {
|
||||
dtype strip_padding(size_t itemsize) {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
if (!has_fields())
|
||||
@ -297,7 +298,7 @@ private:
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
|
||||
if (!len(name) && format.kind() == 'V')
|
||||
continue;
|
||||
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
|
||||
field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
|
||||
}
|
||||
|
||||
std::sort(field_descriptors.begin(), field_descriptors.end(),
|
||||
@ -311,7 +312,7 @@ private:
|
||||
formats.append(descr.format);
|
||||
offsets.append(descr.offset);
|
||||
}
|
||||
return dtype(names, formats, offsets, itemsize());
|
||||
return dtype(names, formats, offsets, itemsize);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -282,11 +282,24 @@ py::list test_dtype_ctors() {
|
||||
dict["itemsize"] = py::int_(20);
|
||||
list.append(py::dtype::from_args(dict));
|
||||
list.append(py::dtype(names, formats, offsets, 20));
|
||||
list.append(py::dtype(py::buffer_info((void *) 0, 1, "I", 1)));
|
||||
list.append(py::dtype(py::buffer_info((void *) 0, 1, "T{i:a:f:b:}", 1)));
|
||||
list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1)));
|
||||
list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1)));
|
||||
return list;
|
||||
}
|
||||
|
||||
struct TrailingPaddingStruct {
|
||||
int32_t a;
|
||||
char b;
|
||||
};
|
||||
|
||||
py::dtype trailing_padding_dtype() {
|
||||
return py::dtype::of<TrailingPaddingStruct>();
|
||||
}
|
||||
|
||||
py::dtype buffer_to_dtype(py::buffer& buf) {
|
||||
return py::dtype(buf.request());
|
||||
}
|
||||
|
||||
py::list test_dtype_methods() {
|
||||
py::list list;
|
||||
auto dt1 = py::dtype::of<int32_t>();
|
||||
@ -314,6 +327,7 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
|
||||
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
|
||||
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
|
||||
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
|
||||
|
||||
// ... or after
|
||||
py::class_<PackedStruct>(m, "PackedStruct");
|
||||
@ -338,6 +352,8 @@ test_initializer numpy_dtypes([](py::module &m) {
|
||||
m.def("test_array_ctors", &test_array_ctors);
|
||||
m.def("test_dtype_ctors", &test_dtype_ctors);
|
||||
m.def("test_dtype_methods", &test_dtype_methods);
|
||||
m.def("trailing_padding_dtype", &trailing_padding_dtype);
|
||||
m.def("buffer_to_dtype", &buffer_to_dtype);
|
||||
m.def("f_simple", [](SimpleStruct s) { return s.y * 10; });
|
||||
m.def("f_packed", [](PackedStruct s) { return s.y * 10; });
|
||||
m.def("f_nested", [](NestedStruct s) { return s.a.y * 10; });
|
||||
|
@ -42,7 +42,7 @@ def test_format_descriptors():
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_dtype(simple_dtype):
|
||||
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods
|
||||
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods, trailing_padding_dtype, buffer_to_dtype
|
||||
|
||||
assert print_dtypes() == [
|
||||
"{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}",
|
||||
@ -64,6 +64,8 @@ def test_dtype(simple_dtype):
|
||||
assert test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True,
|
||||
np.dtype('int32').itemsize, simple_dtype.itemsize]
|
||||
|
||||
assert trailing_padding_dtype() == buffer_to_dtype(np.zeros(1, trailing_padding_dtype()))
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_recarray(simple_dtype, packed_dtype):
|
||||
|
Loading…
Reference in New Issue
Block a user