Permit creation of NumPy arrays with a "base" object that owns the data

This patch adds an extra base handle parameter to most ``py::array`` and
``py::array_t<>`` constructors. If specified along with a pointer to
data, the base object will be registered within NumPy, which increases
the base's reference count. This feature is useful to create shallow
copies of C++ or Python arrays while ensuring that the owners of the
underlying can't be garbage collected while referenced by NumPy.

The commit also adds a simple test function involving a ``wrap()``
function that creates shallow copies of various N-D arrays.
This commit is contained in:
Wenzel Jakob 2016-10-13 00:57:42 +02:00
parent 43f6aa6846
commit 369e9b3937
3 changed files with 124 additions and 24 deletions

View File

@ -156,8 +156,10 @@ NAMESPACE_END(detail)
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \ #define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags)
#define PyArray_CHKFLAGS_(ptr, flag) \ #define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) (flag == (PyArray_FLAGS_(ptr) & flag))
class dtype : public object { class dtype : public object {
public: public:
@ -259,37 +261,61 @@ public:
}; };
array(const pybind11::dtype &dt, const std::vector<size_t> &shape, array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const std::vector<size_t>& strides, const void *ptr = nullptr) { const std::vector<size_t> &strides, const void *ptr = nullptr,
handle base = handle()) {
auto& api = detail::npy_api::get(); auto& api = detail::npy_api::get();
auto ndim = shape.size(); auto ndim = shape.size();
if (shape.size() != strides.size()) if (shape.size() != strides.size())
pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
auto descr = dt; auto descr = dt;
int flags = 0;
if (base && ptr) {
array base_array(base, true);
if (base_array.check())
/* Copy flags from base (except baseship bit) */
flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
else
/* Writable by default, easy to downgrade later on if needed */
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
}
object tmp(api.PyArray_NewFromDescr_( object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
(Py_intptr_t *) strides.data(), const_cast<void *>(ptr), 0, nullptr), false); (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr), false);
if (!tmp) if (!tmp)
pybind11_fail("NumPy: unable to create array!"); pybind11_fail("NumPy: unable to create array!");
if (ptr) if (ptr) {
if (base) {
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
} else {
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
}
}
m_ptr = tmp.release().ptr(); m_ptr = tmp.release().ptr();
} }
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, const void *ptr = nullptr) array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { } const void *ptr = nullptr, handle base = handle())
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr) array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
: array(dt, std::vector<size_t> { count }, ptr) { } handle base = handle())
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
template<typename T> array(const std::vector<size_t>& shape, template<typename T> array(const std::vector<size_t>& shape,
const std::vector<size_t>& strides, const T* ptr) const std::vector<size_t>& strides,
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { } const T* ptr, handle base = handle())
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
template<typename T> array(const std::vector<size_t>& shape, const T* ptr) template <typename T>
: array(shape, default_strides(shape, sizeof(T)), ptr) { } array(const std::vector<size_t> &shape, const T *ptr,
handle base = handle())
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
template<typename T> array(size_t count, const T* ptr) template <typename T>
: array(std::vector<size_t> { count }, ptr) { } array(size_t count, const T *ptr, handle base = handle())
: array(std::vector<size_t>{ count }, ptr, base) { }
array(const buffer_info &info) array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@ -319,6 +345,11 @@ public:
return (size_t) PyArray_GET_(m_ptr, nd); return (size_t) PyArray_GET_(m_ptr, nd);
} }
/// Base object
object base() const {
return object(PyArray_GET_(m_ptr, base), true);
}
/// Dimensions of the array /// Dimensions of the array
const size_t* shape() const { const size_t* shape() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions)); return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
@ -343,6 +374,11 @@ public:
return strides()[dim]; return strides()[dim];
} }
/// Return the NumPy array flags
int flags() const {
return PyArray_FLAGS_(m_ptr);
}
/// If set, the array is writeable (otherwise the buffer is read-only) /// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const { bool writeable() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
@ -436,14 +472,17 @@ public:
array_t(const buffer_info& info) : array(info) { } array_t(const buffer_info& info) : array(info) { }
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr = nullptr) array_t(const std::vector<size_t> &shape,
: array(shape, strides, ptr) { } const std::vector<size_t> &strides, const T *ptr = nullptr,
handle base = handle())
: array(shape, strides, ptr, base) { }
array_t(const std::vector<size_t>& shape, const T* ptr = nullptr) array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
: array(shape, ptr) { } handle base = handle())
: array(shape, ptr, base) { }
array_t(size_t count, const T* ptr = nullptr) array_t(size_t count, const T *ptr = nullptr, handle base = handle())
: array(count, ptr) { } : array(count, ptr, base) { }
constexpr size_t itemsize() const { constexpr size_t itemsize() const {
return sizeof(T); return sizeof(T);

View File

@ -99,4 +99,14 @@ test_initializer numpy_array([](py::module &m) {
sm.def("make_c_array", [] { sm.def("make_c_array", [] {
return py::array_t<float>({ 2, 2 }, { 8, 4 }); return py::array_t<float>({ 2, 2 }, { 8, 4 });
}); });
sm.def("wrap", [](py::array a) {
return py::array(
a.dtype(),
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
a.data(),
a
);
});
}); });

View File

@ -149,6 +149,7 @@ def test_bounds_check(arr):
index_at(arr, 0, 4) index_at(arr, 0, 4)
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
@pytest.requires_numpy @pytest.requires_numpy
def test_make_c_f_array(): def test_make_c_f_array():
from pybind11_tests.array import ( from pybind11_tests.array import (
@ -158,3 +159,53 @@ def test_make_c_f_array():
assert not make_c_array().flags.f_contiguous assert not make_c_array().flags.f_contiguous
assert make_f_array().flags.f_contiguous assert make_f_array().flags.f_contiguous
assert not make_f_array().flags.c_contiguous assert not make_f_array().flags.c_contiguous
@pytest.requires_numpy
def test_wrap():
from pybind11_tests.array import wrap
def assert_references(A, B):
assert A is not B
assert A.__array_interface__['data'][0] == \
B.__array_interface__['data'][0]
assert A.shape == B.shape
assert A.strides == B.strides
assert A.flags.c_contiguous == B.flags.c_contiguous
assert A.flags.f_contiguous == B.flags.f_contiguous
assert A.flags.writeable == B.flags.writeable
assert A.flags.aligned == B.flags.aligned
assert A.flags.updateifcopy == B.flags.updateifcopy
assert np.all(A == B)
assert not B.flags.owndata
assert B.base is A
if A.flags.writeable and A.ndim == 2:
A[0, 0] = 1234
assert B[0, 0] == 1234
A1 = np.array([1, 2], dtype=np.int16)
assert A1.flags.owndata and A1.base is None
A2 = wrap(A1)
assert_references(A1, A2)
A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
assert A1.flags.owndata and A1.base is None
A2 = wrap(A1)
assert_references(A1, A2)
A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
A1.flags.writeable = False
A2 = wrap(A1)
assert_references(A1, A2)
A1 = np.random.random((4, 4, 4))
A2 = wrap(A1)
assert_references(A1, A2)
A1 = A1.transpose()
A2 = wrap(A1)
assert_references(A1, A2)
A1 = A1.diagonal()
A2 = wrap(A1)
assert_references(A1, A2)