diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 729d0f655..0a1208e16 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -54,7 +54,7 @@ struct type_caster::value && !is_eigen_re static constexpr bool isVector = Type::IsVectorAtCompileTime; bool load(handle src, bool) { - array_t buf(src, true); + auto buf = array_t::ensure(src); if (!buf) return false; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 3cbea0191..77006c887 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -305,7 +305,7 @@ private: class array : public buffer { public: - PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_) + PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) enum { c_style = detail::npy_api::NPY_C_CONTIGUOUS_, @@ -313,6 +313,8 @@ public: forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; + array() : array(0, static_cast(nullptr)) {} + array(const pybind11::dtype &dt, const std::vector &shape, const std::vector &strides, const void *ptr = nullptr, handle base = handle()) { @@ -478,10 +480,12 @@ public: } /// Ensure that the argument is a NumPy array - static array ensure(object input, int ExtraFlags = 0) { - auto& api = detail::npy_api::get(); - return reinterpret_steal(api.PyArray_FromAny_( - input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr)); + /// In case of an error, nullptr is returned and the Python error is cleared. + static array ensure(handle h, int ExtraFlags = 0) { + auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); + if (!result) + PyErr_Clear(); + return result; } protected: @@ -520,8 +524,6 @@ protected: return strides; } -protected: - template void check_dimensions(Ix... index) const { check_dimensions_impl(size_t(0), shape(), size_t(index)...); } @@ -536,15 +538,31 @@ protected: } check_dimensions_impl(axis + 1, shape + 1, index...); } + + /// Create array from any object -- always returns a new reference + static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { + if (ptr == nullptr) + return nullptr; + return detail::npy_api::get().PyArray_FromAny_( + ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + } }; template class array_t : public array { public: - array_t() : array() { } + array_t() : array(0, static_cast(nullptr)) {} + array_t(handle h, borrowed_t) : array(h, borrowed) { } + array_t(handle h, stolen_t) : array(h, stolen) { } - array_t(handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_(m_ptr); } + PYBIND11_DEPRECATED("Use array_t::ensure() instead") + array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) { + if (!m_ptr) PyErr_Clear(); + if (!is_borrowed) Py_XDECREF(h.ptr()); + } - array_t(const object &o) : array(o) { m_ptr = ensure_(m_ptr); } + array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) { + if (!m_ptr) throw error_already_set(); + } explicit array_t(const buffer_info& info) : array(info) { } @@ -590,17 +608,30 @@ public: return *(static_cast(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); } - static PyObject *ensure_(PyObject *ptr) { - if (ptr == nullptr) - return nullptr; - auto& api = detail::npy_api::get(); - PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of().release().ptr(), 0, 0, - detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + /// Ensure that the argument is a NumPy array of the correct dtype. + /// In case of an error, nullptr is returned and the Python error is cleared. + static array_t ensure(handle h) { + auto result = reinterpret_steal(raw_array_t(h.ptr())); if (!result) PyErr_Clear(); - Py_DECREF(ptr); return result; } + + static bool _check(handle h) { + const auto &api = detail::npy_api::get(); + return api.PyArray_Check_(h.ptr()) + && api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of().ptr()); + } + +protected: + /// Create array from any object -- always returns a new reference + static PyObject *raw_array_t(PyObject *ptr) { + if (ptr == nullptr) + return nullptr; + return detail::npy_api::get().PyArray_FromAny_( + ptr, dtype::of().release().ptr(), 0, 0, + detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); + } }; template @@ -631,7 +662,7 @@ struct pyobject_caster> { using type = array_t; bool load(handle src, bool /* convert */) { - value = type(src, true); + value = type::ensure(src); return static_cast(value); } diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index df6377eb7..14c4c2999 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -126,4 +126,28 @@ test_initializer numpy_array([](py::module &m) { ); sm.def("function_taking_uint64", [](uint64_t) { }); + + sm.def("isinstance_untyped", [](py::object yes, py::object no) { + return py::isinstance(yes) && !py::isinstance(no); + }); + + sm.def("isinstance_typed", [](py::object o) { + return py::isinstance>(o) && !py::isinstance>(o); + }); + + sm.def("default_constructors", []() { + return py::dict( + "array"_a=py::array(), + "array_t"_a=py::array_t(), + "array_t"_a=py::array_t() + ); + }); + + sm.def("converting_constructors", [](py::object o) { + return py::dict( + "array"_a=py::array(o), + "array_t"_a=py::array_t(o), + "array_t"_a=py::array_t(o) + ); + }); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 40682efc2..cec005419 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -245,3 +245,30 @@ def test_cast_numpy_int64_to_uint64(): from pybind11_tests.array import function_taking_uint64 function_taking_uint64(123) function_taking_uint64(np.uint64(123)) + + +@pytest.requires_numpy +def test_isinstance(): + from pybind11_tests.array import isinstance_untyped, isinstance_typed + + assert isinstance_untyped(np.array([1, 2, 3]), "not an array") + assert isinstance_typed(np.array([1.0, 2.0, 3.0])) + + +@pytest.requires_numpy +def test_constructors(): + from pybind11_tests.array import default_constructors, converting_constructors + + defaults = default_constructors() + for a in defaults.values(): + assert a.size == 0 + assert defaults["array"].dtype == np.array([]).dtype + assert defaults["array_t"].dtype == np.int32 + assert defaults["array_t"].dtype == np.float64 + + results = converting_constructors([1, 2, 3]) + for a in results.values(): + np.testing.assert_array_equal(a, [1, 2, 3]) + assert results["array"].dtype == np.int_ + assert results["array_t"].dtype == np.int32 + assert results["array_t"].dtype == np.float64