diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index a99c72eee..d437c922f 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -425,6 +425,13 @@ public: return array(api.PyArray_Squeeze_(m_ptr), false); } + /// Ensure that the argument is a NumPy array + static array ensure(object input, int ExtraFlags = 0) { + auto& api = detail::npy_api::get(); + return array(api.PyArray_FromAny_( + input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr), false); + } + protected: template friend struct detail::npy_format_descriptor; @@ -466,7 +473,7 @@ protected: template class array_t : public array { public: - PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr)); + PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure_(m_ptr)); array_t() : array() { } @@ -518,7 +525,7 @@ public: static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } - static PyObject *ensure(PyObject *ptr) { + static PyObject *ensure_(PyObject *ptr) { if (ptr == nullptr) return nullptr; auto& api = detail::npy_api::get();