From c01a1c1ade0ebeda166d9a070318b0abc6391086 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Fri, 14 Oct 2016 01:08:03 +0200 Subject: [PATCH] added array::ensure() function wrapping PyArray_FromAny This convenience function ensures that a py::object is either a py::array, or the implementation will try to convert it into one. Layout requirements (such as c_style or f_style) can be also be provided. --- include/pybind11/numpy.h | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index a99c72eee..d437c922f 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -425,6 +425,13 @@ public: return array(api.PyArray_Squeeze_(m_ptr), false); } + /// Ensure that the argument is a NumPy array + static array ensure(object input, int ExtraFlags = 0) { + auto& api = detail::npy_api::get(); + return array(api.PyArray_FromAny_( + input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr), false); + } + protected: template friend struct detail::npy_format_descriptor; @@ -466,7 +473,7 @@ protected: template class array_t : public array { public: - PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr)); + PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure_(m_ptr)); array_t() : array() { } @@ -518,7 +525,7 @@ public: static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } - static PyObject *ensure(PyObject *ptr) { + static PyObject *ensure_(PyObject *ptr) { if (ptr == nullptr) return nullptr; auto& api = detail::npy_api::get();