diff --git a/example/example-numpy-dtypes.cpp b/example/example-numpy-dtypes.cpp index 2e25670f3..2c7cdc07c 100644 --- a/example/example-numpy-dtypes.cpp +++ b/example/example-numpy-dtypes.cpp @@ -158,15 +158,12 @@ void print_format_descriptors() { } void print_dtypes() { - auto to_str = [](py::object obj) { - return (std::string) (py::str) ((py::object) obj.attr("__str__"))(); - }; - std::cout << to_str(py::dtype_of()) << std::endl; - std::cout << to_str(py::dtype_of()) << std::endl; - std::cout << to_str(py::dtype_of()) << std::endl; - std::cout << to_str(py::dtype_of()) << std::endl; - std::cout << to_str(py::dtype_of()) << std::endl; - std::cout << to_str(py::dtype_of()) << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; + std::cout << (std::string) py::dtype::of().str() << std::endl; } void init_ex_numpy_dtypes(py::module &m) { diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 77f7e6f72..3b52fa39b 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -52,7 +52,12 @@ struct npy_api { return api; } - bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); } + bool PyArray_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArray_Type_); + } + bool PyArrayDescr_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); + } PyObject *(*PyArray_DescrFromType_)(int); PyObject *(*PyArray_NewFromDescr_) @@ -61,6 +66,7 @@ struct npy_api { PyObject *(*PyArray_DescrNewFromType_)(int); PyObject *(*PyArray_NewCopy_)(PyObject *, int); PyTypeObject *PyArray_Type_; + PyTypeObject *PyArrayDescr_Type_; PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); int (*PyArray_DescrConverter_) (PyObject *, PyObject **); bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); @@ -69,6 +75,7 @@ struct npy_api { private: enum functions { API_PyArray_Type = 2, + API_PyArrayDescr_Type = 3, API_PyArray_DescrFromType = 45, API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, @@ -90,6 +97,7 @@ private: npy_api api; #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; DECL_NPY_API(PyArray_Type); + DECL_NPY_API(PyArrayDescr_Type); DECL_NPY_API(PyArray_DescrFromType); DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_NewCopy); @@ -104,6 +112,86 @@ private: }; } +class dtype : public object { +public: + PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); + + dtype(const buffer_info &info) { + dtype descr(_dtype_from_pep3118()(pybind11::str(info.format))); + m_ptr = descr.strip_padding().release().ptr(); + } + + dtype(std::string format) { + m_ptr = from_args(pybind11::str(format)).release().ptr(); + } + + static dtype from_args(object args) { + // This is essentially the same as calling np.dtype() constructor in Python + PyObject *ptr = nullptr; + if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr) + pybind11_fail("NumPy: failed to create structured dtype"); + return object(ptr, false); + } + + template static dtype of() { + return detail::npy_format_descriptor::dtype(); + } + + size_t itemsize() const { + return (size_t) attr("itemsize").cast(); + } + + bool has_fields() const { + return attr("fields").cast().ptr() != Py_None; + } + + std::string kind() const { + return (std::string) attr("kind").cast(); + } + +private: + static object& _dtype_from_pep3118() { + static object obj = module::import("numpy.core._internal").attr("_dtype_from_pep3118"); + return obj; + } + + dtype strip_padding() { + // Recursively strip all void fields with empty names that are generated for + // padding fields (as of NumPy v1.11). + auto fields = attr("fields").cast(); + if (fields.ptr() == Py_None) + return *this; + + struct field_descr { pybind11::str name; object format; int_ offset; }; + std::vector field_descriptors; + + auto items = fields.attr("items").cast(); + for (auto field : items()) { + auto spec = object(field, true).cast(); + auto name = spec[0].cast(); + auto format = spec[1].cast()[0].cast(); + auto offset = spec[1].cast()[1].cast(); + if (!len(name) && format.kind() == "V") + continue; + field_descriptors.push_back({name, format.strip_padding(), offset}); + } + + std::sort(field_descriptors.begin(), field_descriptors.end(), + [](const field_descr& a, const field_descr& b) { + return (int) a.offset < (int) b.offset; + }); + + list names, formats, offsets; + for (auto& descr : field_descriptors) { + names.append(descr.name); formats.append(descr.format); offsets.append(descr.offset); + } + auto args = dict(); + args["names"] = names; args["formats"] = formats; args["offsets"] = offsets; + args["itemsize"] = (int_) itemsize(); + return dtype::from_args(args); + } +}; + class array : public buffer { public: PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_) @@ -116,7 +204,7 @@ public: template array(size_t size, const Type *ptr) { auto& api = detail::npy_api::get(); - PyObject *descr = detail::npy_format_descriptor::dtype().release().ptr(); + auto descr = pybind11::dtype::of().release().ptr(); Py_intptr_t shape = (Py_intptr_t) size; object tmp = object(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false); @@ -129,14 +217,9 @@ public: array(const buffer_info &info) { auto& api = detail::npy_api::get(); - - // _dtype_from_pep3118 returns dtypes with padding fields in, so we need to strip them - auto numpy_internal = module::import("numpy.core._internal"); - auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118"); - auto dtype = strip_padding_fields(dtype_from_fmt(pybind11::str(info.format))); - + auto descr = pybind11::dtype(info).release().ptr(); object tmp(api.PyArray_NewFromDescr_( - api.PyArray_Type_, dtype.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0], + api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0], (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); @@ -145,50 +228,12 @@ public: m_ptr = tmp.release().ptr(); } + pybind11::dtype dtype() { + return attr("dtype").cast(); + } + protected: template friend struct detail::npy_format_descriptor; - - static object strip_padding_fields(object dtype) { - // Recursively strip all void fields with empty names that are generated for - // padding fields (as of NumPy v1.11). - auto fields = dtype.attr("fields").cast(); - if (fields.ptr() == Py_None) - return dtype; - - struct field_descr { pybind11::str name; object format; int_ offset; }; - std::vector field_descriptors; - - auto items = fields.attr("items").cast(); - for (auto field : items()) { - auto spec = object(field, true).cast(); - auto name = spec[0].cast(); - auto format = spec[1].cast()[0].cast(); - auto offset = spec[1].cast()[1].cast(); - if (!len(name) && (std::string) dtype.attr("kind").cast() == "V") - continue; - field_descriptors.push_back({name, strip_padding_fields(format), offset}); - } - - std::sort(field_descriptors.begin(), field_descriptors.end(), - [](const field_descr& a, const field_descr& b) { - return (int) a.offset < (int) b.offset; - }); - - list names, formats, offsets; - for (auto& descr : field_descriptors) { - names.append(descr.name); - formats.append(descr.format); - offsets.append(descr.offset); - } - auto args = dict(); - args["names"] = names; args["formats"] = formats; args["offsets"] = offsets; - args["itemsize"] = dtype.attr("itemsize").cast(); - - PyObject *descr = nullptr; - if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr) - pybind11_fail("NumPy: failed to create structured dtype"); - return object(descr, false); - } }; template class array_t : public array { @@ -201,8 +246,7 @@ public: if (ptr == nullptr) return nullptr; auto& api = detail::npy_api::get(); - PyObject *descr = detail::npy_format_descriptor::dtype().release().ptr(); - PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, + PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of().release().ptr(), 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); if (!result) PyErr_Clear(); @@ -223,11 +267,6 @@ template struct format_descriptor> { static const char *format() { PYBIND11_DESCR s = detail::_() + detail::_("s"); return s.text(); } }; -template -object dtype_of() { - return detail::npy_format_descriptor::dtype(); -} - NAMESPACE_BEGIN(detail) template struct is_std_array : std::false_type { }; template struct is_std_array> : std::true_type { }; @@ -252,7 +291,7 @@ private: npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ }; public: enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned::value ? 1 : 0)] }; - static object dtype() { + static pybind11::dtype dtype() { if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) return object(ptr, true); pybind11_fail("Unsupported buffer format!"); @@ -267,7 +306,7 @@ template constexpr const int npy_format_descriptor< #define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor { \ enum { value = npy_api::NumPyName }; \ - static object dtype() { \ + static pybind11::dtype dtype() { \ if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \ return object(ptr, true); \ pybind11_fail("Unsupported buffer format!"); \ @@ -282,14 +321,9 @@ DECL_FMT(std::complex, NPY_CDOUBLE_, "complex128"); #define DECL_CHAR_FMT \ static PYBIND11_DESCR name() { return _("S") + _(); } \ - static object dtype() { \ - auto& api = npy_api::get(); \ - PyObject *descr = nullptr; \ + static pybind11::dtype dtype() { \ PYBIND11_DESCR fmt = _("S") + _(); \ - pybind11::str py_fmt(fmt.text()); \ - if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \ - pybind11_fail("NumPy: failed to create string dtype"); \ - return object(descr, false); \ + return pybind11::dtype::from_args(pybind11::str(fmt.text())); \ } \ static const char *format() { PYBIND11_DESCR s = _() + _("s"); return s.text(); } template struct npy_format_descriptor { DECL_CHAR_FMT }; @@ -301,14 +335,14 @@ struct field_descriptor { size_t offset; size_t size; const char *format; - object descr; + dtype descr; }; template struct npy_format_descriptor::value>::type> { static PYBIND11_DESCR name() { return _("struct"); } - static object dtype() { + static pybind11::dtype dtype() { if (!dtype_()) pybind11_fail("NumPy: unsupported buffer format!"); return object(dtype_(), true); @@ -321,7 +355,6 @@ struct npy_format_descriptor::value> } static void register_dtype(std::initializer_list fields) { - auto& api = npy_api::get(); auto args = dict(); list names { }, offsets { }, formats { }; for (auto field : fields) { @@ -333,10 +366,7 @@ struct npy_format_descriptor::value> } args["names"] = names; args["offsets"] = offsets; args["formats"] = formats; args["itemsize"] = int_(sizeof(T)); - // This is essentially the same as calling np.dtype() constructor in Python and passing - // it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}. - if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_()) - pybind11_fail("NumPy: failed to create structured dtype"); + dtype_() = pybind11::dtype::from_args(args).release().ptr(); // There is an existing bug in NumPy (as of v1.11): trailing bytes are // not encoded explicitly into the format string. This will supposedly @@ -366,9 +396,9 @@ struct npy_format_descriptor::value> format_() = oss.str(); // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) })); - auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true)); - if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr())) + if (!api.PyArray_EquivTypes_(dtype_(), arr.dtype().ptr())) pybind11_fail("NumPy: invalid buffer descriptor!"); }