Initial implementation of py::dtype

This commit is contained in:
Ivan Smirnov 2016-07-23 21:55:37 +01:00
parent 05cb58ade2
commit 01f7409550
2 changed files with 112 additions and 85 deletions

View File

@ -158,15 +158,12 @@ void print_format_descriptors() {
}
void print_dtypes() {
auto to_str = [](py::object obj) {
return (std::string) (py::str) ((py::object) obj.attr("__str__"))();
};
std::cout << to_str(py::dtype_of<SimpleStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PackedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<StringStruct>()) << std::endl;
std::cout << (std::string) py::dtype::of<SimpleStruct>().str() << std::endl;
std::cout << (std::string) py::dtype::of<PackedStruct>().str() << std::endl;
std::cout << (std::string) py::dtype::of<NestedStruct>().str() << std::endl;
std::cout << (std::string) py::dtype::of<PartialStruct>().str() << std::endl;
std::cout << (std::string) py::dtype::of<PartialNestedStruct>().str() << std::endl;
std::cout << (std::string) py::dtype::of<StringStruct>().str() << std::endl;
}
void init_ex_numpy_dtypes(py::module &m) {

View File

@ -52,7 +52,12 @@ struct npy_api {
return api;
}
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
bool PyArray_Check_(PyObject *obj) const {
return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
}
bool PyArrayDescr_Check_(PyObject *obj) const {
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
}
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)
@ -61,6 +66,7 @@ struct npy_api {
PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyTypeObject *PyArrayDescr_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
@ -69,6 +75,7 @@ struct npy_api {
private:
enum functions {
API_PyArray_Type = 2,
API_PyArrayDescr_Type = 3,
API_PyArray_DescrFromType = 45,
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
@ -90,6 +97,7 @@ private:
npy_api api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
@ -104,6 +112,86 @@ private:
};
}
class dtype : public object {
public:
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
dtype(const buffer_info &info) {
dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
m_ptr = descr.strip_padding().release().ptr();
}
dtype(std::string format) {
m_ptr = from_args(pybind11::str(format)).release().ptr();
}
static dtype from_args(object args) {
// This is essentially the same as calling np.dtype() constructor in Python
PyObject *ptr = nullptr;
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
pybind11_fail("NumPy: failed to create structured dtype");
return object(ptr, false);
}
template <typename T> static dtype of() {
return detail::npy_format_descriptor<T>::dtype();
}
size_t itemsize() const {
return (size_t) attr("itemsize").cast<int_>();
}
bool has_fields() const {
return attr("fields").cast<object>().ptr() != Py_None;
}
std::string kind() const {
return (std::string) attr("kind").cast<pybind11::str>();
}
private:
static object& _dtype_from_pep3118() {
static object obj = module::import("numpy.core._internal").attr("_dtype_from_pep3118");
return obj;
}
dtype strip_padding() {
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto fields = attr("fields").cast<object>();
if (fields.ptr() == Py_None)
return *this;
struct field_descr { pybind11::str name; object format; int_ offset; };
std::vector<field_descr> field_descriptors;
auto items = fields.attr("items").cast<object>();
for (auto field : items()) {
auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
if (!len(name) && format.kind() == "V")
continue;
field_descriptors.push_back({name, format.strip_padding(), offset});
}
std::sort(field_descriptors.begin(), field_descriptors.end(),
[](const field_descr& a, const field_descr& b) {
return (int) a.offset < (int) b.offset;
});
list names, formats, offsets;
for (auto& descr : field_descriptors) {
names.append(descr.name); formats.append(descr.format); offsets.append(descr.offset);
}
auto args = dict();
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
args["itemsize"] = (int_) itemsize();
return dtype::from_args(args);
}
};
class array : public buffer {
public:
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
@ -116,7 +204,7 @@ public:
template <typename Type> array(size_t size, const Type *ptr) {
auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
auto descr = pybind11::dtype::of<Type>().release().ptr();
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
@ -129,14 +217,9 @@ public:
array(const buffer_info &info) {
auto& api = detail::npy_api::get();
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
auto numpy_internal = module::import("numpy.core._internal");
auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
auto dtype = strip_padding_fields(dtype_from_fmt(pybind11::str(info.format)));
auto descr = pybind11::dtype(info).release().ptr();
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, dtype.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
@ -145,50 +228,12 @@ public:
m_ptr = tmp.release().ptr();
}
pybind11::dtype dtype() {
return attr("dtype").cast<pybind11::dtype>();
}
protected:
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static object strip_padding_fields(object dtype) {
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto fields = dtype.attr("fields").cast<object>();
if (fields.ptr() == Py_None)
return dtype;
struct field_descr { pybind11::str name; object format; int_ offset; };
std::vector<field_descr> field_descriptors;
auto items = fields.attr("items").cast<object>();
for (auto field : items()) {
auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<object>();
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
continue;
field_descriptors.push_back({name, strip_padding_fields(format), offset});
}
std::sort(field_descriptors.begin(), field_descriptors.end(),
[](const field_descr& a, const field_descr& b) {
return (int) a.offset < (int) b.offset;
});
list names, formats, offsets;
for (auto& descr : field_descriptors) {
names.append(descr.name);
formats.append(descr.format);
offsets.append(descr.offset);
}
auto args = dict();
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
PyObject *descr = nullptr;
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
pybind11_fail("NumPy: failed to create structured dtype");
return object(descr, false);
}
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
@ -201,8 +246,7 @@ public:
if (ptr == nullptr)
return nullptr;
auto& api = detail::npy_api::get();
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result)
PyErr_Clear();
@ -223,11 +267,6 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
};
template <typename T>
object dtype_of() {
return detail::npy_format_descriptor<T>::dtype();
}
NAMESPACE_BEGIN(detail)
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
@ -252,7 +291,7 @@ private:
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
public:
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
static object dtype() {
static pybind11::dtype dtype() {
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
return object(ptr, true);
pybind11_fail("Unsupported buffer format!");
@ -267,7 +306,7 @@ template <typename T> constexpr const int npy_format_descriptor<
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
enum { value = npy_api::NumPyName }; \
static object dtype() { \
static pybind11::dtype dtype() { \
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
@ -282,14 +321,9 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#define DECL_CHAR_FMT \
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
static object dtype() { \
auto& api = npy_api::get(); \
PyObject *descr = nullptr; \
static pybind11::dtype dtype() { \
PYBIND11_DESCR fmt = _("S") + _<N>(); \
pybind11::str py_fmt(fmt.text()); \
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
pybind11_fail("NumPy: failed to create string dtype"); \
return object(descr, false); \
return pybind11::dtype::from_args(pybind11::str(fmt.text())); \
} \
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
@ -301,14 +335,14 @@ struct field_descriptor {
size_t offset;
size_t size;
const char *format;
object descr;
dtype descr;
};
template <typename T>
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
static PYBIND11_DESCR name() { return _("struct"); }
static object dtype() {
static pybind11::dtype dtype() {
if (!dtype_())
pybind11_fail("NumPy: unsupported buffer format!");
return object(dtype_(), true);
@ -321,7 +355,6 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
auto& api = npy_api::get();
auto args = dict();
list names { }, offsets { }, formats { };
for (auto field : fields) {
@ -333,10 +366,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
args["itemsize"] = int_(sizeof(T));
// This is essentially the same as calling np.dtype() constructor in Python and passing
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
pybind11_fail("NumPy: failed to create structured dtype");
dtype_() = pybind11::dtype::from_args(args).release().ptr();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
@ -366,9 +396,9 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
format_() = oss.str();
// Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
if (!api.PyArray_EquivTypes_(dtype_(), arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!");
}