mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-21 07:59:17 +00:00
Initial implementation of py::dtype
This commit is contained in:
parent
05cb58ade2
commit
01f7409550
@ -158,15 +158,12 @@ void print_format_descriptors() {
|
||||
}
|
||||
|
||||
void print_dtypes() {
|
||||
auto to_str = [](py::object obj) {
|
||||
return (std::string) (py::str) ((py::object) obj.attr("__str__"))();
|
||||
};
|
||||
std::cout << to_str(py::dtype_of<SimpleStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<PackedStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
|
||||
std::cout << to_str(py::dtype_of<StringStruct>()) << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<SimpleStruct>().str() << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<PackedStruct>().str() << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<NestedStruct>().str() << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<PartialStruct>().str() << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<PartialNestedStruct>().str() << std::endl;
|
||||
std::cout << (std::string) py::dtype::of<StringStruct>().str() << std::endl;
|
||||
}
|
||||
|
||||
void init_ex_numpy_dtypes(py::module &m) {
|
||||
|
@ -52,7 +52,12 @@ struct npy_api {
|
||||
return api;
|
||||
}
|
||||
|
||||
bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
|
||||
bool PyArray_Check_(PyObject *obj) const {
|
||||
return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
|
||||
}
|
||||
bool PyArrayDescr_Check_(PyObject *obj) const {
|
||||
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
|
||||
}
|
||||
|
||||
PyObject *(*PyArray_DescrFromType_)(int);
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
@ -61,6 +66,7 @@ struct npy_api {
|
||||
PyObject *(*PyArray_DescrNewFromType_)(int);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyTypeObject *PyArrayDescr_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
@ -69,6 +75,7 @@ struct npy_api {
|
||||
private:
|
||||
enum functions {
|
||||
API_PyArray_Type = 2,
|
||||
API_PyArrayDescr_Type = 3,
|
||||
API_PyArray_DescrFromType = 45,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
@ -90,6 +97,7 @@ private:
|
||||
npy_api api;
|
||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||
DECL_NPY_API(PyArray_Type);
|
||||
DECL_NPY_API(PyArrayDescr_Type);
|
||||
DECL_NPY_API(PyArray_DescrFromType);
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
@ -104,6 +112,86 @@ private:
|
||||
};
|
||||
}
|
||||
|
||||
class dtype : public object {
|
||||
public:
|
||||
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
|
||||
|
||||
dtype(const buffer_info &info) {
|
||||
dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
|
||||
m_ptr = descr.strip_padding().release().ptr();
|
||||
}
|
||||
|
||||
dtype(std::string format) {
|
||||
m_ptr = from_args(pybind11::str(format)).release().ptr();
|
||||
}
|
||||
|
||||
static dtype from_args(object args) {
|
||||
// This is essentially the same as calling np.dtype() constructor in Python
|
||||
PyObject *ptr = nullptr;
|
||||
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
return object(ptr, false);
|
||||
}
|
||||
|
||||
template <typename T> static dtype of() {
|
||||
return detail::npy_format_descriptor<T>::dtype();
|
||||
}
|
||||
|
||||
size_t itemsize() const {
|
||||
return (size_t) attr("itemsize").cast<int_>();
|
||||
}
|
||||
|
||||
bool has_fields() const {
|
||||
return attr("fields").cast<object>().ptr() != Py_None;
|
||||
}
|
||||
|
||||
std::string kind() const {
|
||||
return (std::string) attr("kind").cast<pybind11::str>();
|
||||
}
|
||||
|
||||
private:
|
||||
static object& _dtype_from_pep3118() {
|
||||
static object obj = module::import("numpy.core._internal").attr("_dtype_from_pep3118");
|
||||
return obj;
|
||||
}
|
||||
|
||||
dtype strip_padding() {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
auto fields = attr("fields").cast<object>();
|
||||
if (fields.ptr() == Py_None)
|
||||
return *this;
|
||||
|
||||
struct field_descr { pybind11::str name; object format; int_ offset; };
|
||||
std::vector<field_descr> field_descriptors;
|
||||
|
||||
auto items = fields.attr("items").cast<object>();
|
||||
for (auto field : items()) {
|
||||
auto spec = object(field, true).cast<tuple>();
|
||||
auto name = spec[0].cast<pybind11::str>();
|
||||
auto format = spec[1].cast<tuple>()[0].cast<dtype>();
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
|
||||
if (!len(name) && format.kind() == "V")
|
||||
continue;
|
||||
field_descriptors.push_back({name, format.strip_padding(), offset});
|
||||
}
|
||||
|
||||
std::sort(field_descriptors.begin(), field_descriptors.end(),
|
||||
[](const field_descr& a, const field_descr& b) {
|
||||
return (int) a.offset < (int) b.offset;
|
||||
});
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto& descr : field_descriptors) {
|
||||
names.append(descr.name); formats.append(descr.format); offsets.append(descr.offset);
|
||||
}
|
||||
auto args = dict();
|
||||
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
|
||||
args["itemsize"] = (int_) itemsize();
|
||||
return dtype::from_args(args);
|
||||
}
|
||||
};
|
||||
|
||||
class array : public buffer {
|
||||
public:
|
||||
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
|
||||
@ -116,7 +204,7 @@ public:
|
||||
|
||||
template <typename Type> array(size_t size, const Type *ptr) {
|
||||
auto& api = detail::npy_api::get();
|
||||
PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
|
||||
auto descr = pybind11::dtype::of<Type>().release().ptr();
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
object tmp = object(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||
@ -129,14 +217,9 @@ public:
|
||||
|
||||
array(const buffer_info &info) {
|
||||
auto& api = detail::npy_api::get();
|
||||
|
||||
// _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them
|
||||
auto numpy_internal = module::import("numpy.core._internal");
|
||||
auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
|
||||
auto dtype = strip_padding_fields(dtype_from_fmt(pybind11::str(info.format)));
|
||||
|
||||
auto descr = pybind11::dtype(info).release().ptr();
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, dtype.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
@ -145,50 +228,12 @@ public:
|
||||
m_ptr = tmp.release().ptr();
|
||||
}
|
||||
|
||||
pybind11::dtype dtype() {
|
||||
return attr("dtype").cast<pybind11::dtype>();
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||
|
||||
static object strip_padding_fields(object dtype) {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
auto fields = dtype.attr("fields").cast<object>();
|
||||
if (fields.ptr() == Py_None)
|
||||
return dtype;
|
||||
|
||||
struct field_descr { pybind11::str name; object format; int_ offset; };
|
||||
std::vector<field_descr> field_descriptors;
|
||||
|
||||
auto items = fields.attr("items").cast<object>();
|
||||
for (auto field : items()) {
|
||||
auto spec = object(field, true).cast<tuple>();
|
||||
auto name = spec[0].cast<pybind11::str>();
|
||||
auto format = spec[1].cast<tuple>()[0].cast<object>();
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
|
||||
if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
|
||||
continue;
|
||||
field_descriptors.push_back({name, strip_padding_fields(format), offset});
|
||||
}
|
||||
|
||||
std::sort(field_descriptors.begin(), field_descriptors.end(),
|
||||
[](const field_descr& a, const field_descr& b) {
|
||||
return (int) a.offset < (int) b.offset;
|
||||
});
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto& descr : field_descriptors) {
|
||||
names.append(descr.name);
|
||||
formats.append(descr.format);
|
||||
offsets.append(descr.offset);
|
||||
}
|
||||
auto args = dict();
|
||||
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
|
||||
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
|
||||
|
||||
PyObject *descr = nullptr;
|
||||
if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
return object(descr, false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
@ -201,8 +246,7 @@ public:
|
||||
if (ptr == nullptr)
|
||||
return nullptr;
|
||||
auto& api = detail::npy_api::get();
|
||||
PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
|
||||
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
|
||||
PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
|
||||
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||
if (!result)
|
||||
PyErr_Clear();
|
||||
@ -223,11 +267,6 @@ template <size_t N> struct format_descriptor<std::array<char, N>> {
|
||||
static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
object dtype_of() {
|
||||
return detail::npy_format_descriptor<T>::dtype();
|
||||
}
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename T> struct is_std_array : std::false_type { };
|
||||
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
|
||||
@ -252,7 +291,7 @@ private:
|
||||
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
|
||||
public:
|
||||
enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
|
||||
static object dtype() {
|
||||
static pybind11::dtype dtype() {
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
|
||||
return object(ptr, true);
|
||||
pybind11_fail("Unsupported buffer format!");
|
||||
@ -267,7 +306,7 @@ template <typename T> constexpr const int npy_format_descriptor<
|
||||
|
||||
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
|
||||
enum { value = npy_api::NumPyName }; \
|
||||
static object dtype() { \
|
||||
static pybind11::dtype dtype() { \
|
||||
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
|
||||
return object(ptr, true); \
|
||||
pybind11_fail("Unsupported buffer format!"); \
|
||||
@ -282,14 +321,9 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
|
||||
#define DECL_CHAR_FMT \
|
||||
static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
|
||||
static object dtype() { \
|
||||
auto& api = npy_api::get(); \
|
||||
PyObject *descr = nullptr; \
|
||||
static pybind11::dtype dtype() { \
|
||||
PYBIND11_DESCR fmt = _("S") + _<N>(); \
|
||||
pybind11::str py_fmt(fmt.text()); \
|
||||
if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
|
||||
pybind11_fail("NumPy: failed to create string dtype"); \
|
||||
return object(descr, false); \
|
||||
return pybind11::dtype::from_args(pybind11::str(fmt.text())); \
|
||||
} \
|
||||
static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
|
||||
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
|
||||
@ -301,14 +335,14 @@ struct field_descriptor {
|
||||
size_t offset;
|
||||
size_t size;
|
||||
const char *format;
|
||||
object descr;
|
||||
dtype descr;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
|
||||
static PYBIND11_DESCR name() { return _("struct"); }
|
||||
|
||||
static object dtype() {
|
||||
static pybind11::dtype dtype() {
|
||||
if (!dtype_())
|
||||
pybind11_fail("NumPy: unsupported buffer format!");
|
||||
return object(dtype_(), true);
|
||||
@ -321,7 +355,6 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
}
|
||||
|
||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||
auto& api = npy_api::get();
|
||||
auto args = dict();
|
||||
list names { }, offsets { }, formats { };
|
||||
for (auto field : fields) {
|
||||
@ -333,10 +366,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
}
|
||||
args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
|
||||
args["itemsize"] = int_(sizeof(T));
|
||||
// This is essentially the same as calling np.dtype() constructor in Python and passing
|
||||
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
|
||||
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
dtype_() = pybind11::dtype::from_args(args).release().ptr();
|
||||
|
||||
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
||||
// not encoded explicitly into the format string. This will supposedly
|
||||
@ -366,9 +396,9 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
format_() = oss.str();
|
||||
|
||||
// Sanity check: verify that NumPy properly parses our buffer format string
|
||||
auto& api = npy_api::get();
|
||||
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
|
||||
auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
|
||||
if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
|
||||
if (!api.PyArray_EquivTypes_(dtype_(), arr.dtype().ptr()))
|
||||
pybind11_fail("NumPy: invalid buffer descriptor!");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user