diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 84035eb3f..2ee469121 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration) PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError) PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError) PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError) +PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError) PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError) PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 1125fd712..a99c72eee 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -22,8 +22,8 @@ #include #if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif /* This will be true on all flat address space platforms and allows us to reduce the @@ -156,8 +156,10 @@ NAMESPACE_END(detail) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) #define PyArrayDescr_GET_(ptr, attr) \ (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_FLAGS_(ptr) \ + PyArray_GET_(ptr, flags) #define PyArray_CHKFLAGS_(ptr, flag) \ - (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) + (flag == (PyArray_FLAGS_(ptr) & flag)) class dtype : public object { public: @@ -258,38 +260,62 @@ public: forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; - array(const pybind11::dtype& dt, const std::vector& shape, - const std::vector& strides, const void *ptr = nullptr) { + array(const pybind11::dtype &dt, const std::vector &shape, + const std::vector &strides, const void *ptr = nullptr, + handle base = handle()) { auto& api = detail::npy_api::get(); auto ndim = shape.size(); if (shape.size() != strides.size()) pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); auto descr = dt; + + int flags = 0; + if (base && ptr) { + array base_array(base, true); + if (base_array.check()) + /* Copy flags from base (except baseship bit) */ + flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; + else + /* Writable by default, easy to downgrade later on if needed */ + flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; + } + object tmp(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), - (Py_intptr_t *) strides.data(), const_cast(ptr), 0, nullptr), false); + (Py_intptr_t *) strides.data(), const_cast(ptr), flags, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); - if (ptr) - tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); + if (ptr) { + if (base) { + PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr(); + } else { + tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); + } + } m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype& dt, const std::vector& shape, const void *ptr = nullptr) - : array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { } + array(const pybind11::dtype &dt, const std::vector &shape, + const void *ptr = nullptr, handle base = handle()) + : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } - array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr) - : array(dt, std::vector { count }, ptr) { } + array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, + handle base = handle()) + : array(dt, std::vector{ count }, ptr, base) { } template array(const std::vector& shape, - const std::vector& strides, const T* ptr) - : array(pybind11::dtype::of(), shape, strides, (void *) ptr) { } + const std::vector& strides, + const T* ptr, handle base = handle()) + : array(pybind11::dtype::of(), shape, strides, (void *) ptr, base) { } - template array(const std::vector& shape, const T* ptr) - : array(shape, default_strides(shape, sizeof(T)), ptr) { } + template + array(const std::vector &shape, const T *ptr, + handle base = handle()) + : array(shape, default_strides(shape, sizeof(T)), ptr, base) { } - template array(size_t count, const T* ptr) - : array(std::vector { count }, ptr) { } + template + array(size_t count, const T *ptr, handle base = handle()) + : array(std::vector{ count }, ptr, base) { } array(const buffer_info &info) : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } @@ -319,6 +345,11 @@ public: return (size_t) PyArray_GET_(m_ptr, nd); } + /// Base object + object base() const { + return object(PyArray_GET_(m_ptr, base), true); + } + /// Dimensions of the array const size_t* shape() const { return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); @@ -343,6 +374,11 @@ public: return strides()[dim]; } + /// Return the NumPy array flags + int flags() const { + return PyArray_FLAGS_(m_ptr); + } + /// If set, the array is writeable (otherwise the buffer is read-only) bool writeable() const { return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); @@ -436,14 +472,17 @@ public: array_t(const buffer_info& info) : array(info) { } - array_t(const std::vector& shape, const std::vector& strides, const T* ptr = nullptr) - : array(shape, strides, ptr) { } + array_t(const std::vector &shape, + const std::vector &strides, const T *ptr = nullptr, + handle base = handle()) + : array(shape, strides, ptr, base) { } - array_t(const std::vector& shape, const T* ptr = nullptr) - : array(shape, ptr) { } + array_t(const std::vector &shape, const T *ptr = nullptr, + handle base = handle()) + : array(shape, ptr, base) { } - array_t(size_t count, const T* ptr = nullptr) - : array(count, ptr) { } + array_t(size_t count, const T *ptr = nullptr, handle base = handle()) + : array(count, ptr, base) { } constexpr size_t itemsize() const { return sizeof(T); diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 686fe5361..f7d46cf36 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -567,7 +567,7 @@ public: static module import(const char *name) { PyObject *obj = PyImport_ImportModule(name); if (!obj) - pybind11_fail("Module \"" + std::string(name) + "\" not found!"); + throw import_error("Module \"" + std::string(name) + "\" not found!"); return module(obj, false); } }; @@ -1344,15 +1344,27 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); auto line = sep.attr("join")(strings); - auto file = kwargs.contains("file") ? kwargs["file"].cast() - : module::import("sys").attr("stdout"); + object file; + if (kwargs.contains("file")) { + file = kwargs["file"].cast(); + } else { + try { + file = module::import("sys").attr("stdout"); + } catch (const import_error &) { + /* If print() is called from code that is executed as + part of garbage collection during interpreter shutdown, + importing 'sys' can fail. Give up rather than crashing the + interpreter in this case. */ + return; + } + } + auto write = file.attr("write"); write(line); write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); - if (kwargs.contains("flush") && kwargs["flush"].cast()) { + if (kwargs.contains("flush") && kwargs["flush"].cast()) file.attr("flush")(); - } } NAMESPACE_END(detail) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 37a898367..ec4ddacb9 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) { sm.def("make_c_array", [] { return py::array_t({ 2, 2 }, { 8, 4 }); }); + + sm.def("wrap", [](py::array a) { + return py::array( + a.dtype(), + std::vector(a.shape(), a.shape() + a.ndim()), + std::vector(a.strides(), a.strides() + a.ndim()), + a.data(), + a + ); + }); + + struct ArrayClass { + int data[2] = { 1, 2 }; + ArrayClass() { py::print("ArrayClass()"); } + ~ArrayClass() { py::print("~ArrayClass()"); } + }; + + py::class_(sm, "ArrayClass") + .def(py::init<>()) + .def("numpy_view", [](py::object &obj) { + py::print("ArrayClass::numpy_view()"); + ArrayClass &a = obj.cast(); + return py::array_t({2}, {4}, a.data, obj); + } + ); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 85775e4f3..ae1954a65 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -1,4 +1,5 @@ import pytest +import gc with pytest.suppress(ImportError): import numpy as np @@ -149,6 +150,7 @@ def test_bounds_check(arr): index_at(arr, 0, 4) assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' + @pytest.requires_numpy def test_make_c_f_array(): from pybind11_tests.array import ( @@ -158,3 +160,81 @@ def test_make_c_f_array(): assert not make_c_array().flags.f_contiguous assert make_f_array().flags.f_contiguous assert not make_f_array().flags.c_contiguous + + +@pytest.requires_numpy +def test_wrap(): + from pybind11_tests.array import wrap + + def assert_references(A, B): + assert A is not B + assert A.__array_interface__['data'][0] == \ + B.__array_interface__['data'][0] + assert A.shape == B.shape + assert A.strides == B.strides + assert A.flags.c_contiguous == B.flags.c_contiguous + assert A.flags.f_contiguous == B.flags.f_contiguous + assert A.flags.writeable == B.flags.writeable + assert A.flags.aligned == B.flags.aligned + assert A.flags.updateifcopy == B.flags.updateifcopy + assert np.all(A == B) + assert not B.flags.owndata + assert B.base is A + if A.flags.writeable and A.ndim == 2: + A[0, 0] = 1234 + assert B[0, 0] == 1234 + + A1 = np.array([1, 2], dtype=np.int16) + assert A1.flags.owndata and A1.base is None + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') + assert A1.flags.owndata and A1.base is None + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') + A1.flags.writeable = False + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.random.random((4, 4, 4)) + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = A1.transpose() + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = A1.diagonal() + A2 = wrap(A1) + assert_references(A1, A2) + + +@pytest.requires_numpy +def test_numpy_view(capture): + from pybind11_tests.array import ArrayClass + with capture: + ac = ArrayClass() + ac_view_1 = ac.numpy_view() + ac_view_2 = ac.numpy_view() + assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) + del ac + gc.collect() + assert capture == """ + ArrayClass() + ArrayClass::numpy_view() + ArrayClass::numpy_view() + """ + ac_view_1[0] = 4 + ac_view_1[1] = 3 + assert ac_view_2[0] == 4 + assert ac_view_2[1] == 3 + with capture: + del ac_view_1 + del ac_view_2 + gc.collect() + assert capture == """ + ~ArrayClass() + """