mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-29 16:37:13 +00:00
Merge pull request #440 from wjakob/master
Permit creation of NumPy arrays with a "base" object that owns the data
This commit is contained in:
commit
00488a3e2c
@ -455,6 +455,7 @@ PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
|
|||||||
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
|
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
|
||||||
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
|
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
|
||||||
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
|
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
|
||||||
|
PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError)
|
||||||
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
|
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
|
||||||
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
|
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
|
||||||
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
|
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
|
||||||
|
@ -156,8 +156,10 @@ NAMESPACE_END(detail)
|
|||||||
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||||
#define PyArrayDescr_GET_(ptr, attr) \
|
#define PyArrayDescr_GET_(ptr, attr) \
|
||||||
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||||
|
#define PyArray_FLAGS_(ptr) \
|
||||||
|
PyArray_GET_(ptr, flags)
|
||||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||||
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
(flag == (PyArray_FLAGS_(ptr) & flag))
|
||||||
|
|
||||||
class dtype : public object {
|
class dtype : public object {
|
||||||
public:
|
public:
|
||||||
@ -259,37 +261,61 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
||||||
const std::vector<size_t>& strides, const void *ptr = nullptr) {
|
const std::vector<size_t> &strides, const void *ptr = nullptr,
|
||||||
|
handle base = handle()) {
|
||||||
auto& api = detail::npy_api::get();
|
auto& api = detail::npy_api::get();
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
if (shape.size() != strides.size())
|
if (shape.size() != strides.size())
|
||||||
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
||||||
auto descr = dt;
|
auto descr = dt;
|
||||||
|
|
||||||
|
int flags = 0;
|
||||||
|
if (base && ptr) {
|
||||||
|
array base_array(base, true);
|
||||||
|
if (base_array.check())
|
||||||
|
/* Copy flags from base (except baseship bit) */
|
||||||
|
flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
|
||||||
|
else
|
||||||
|
/* Writable by default, easy to downgrade later on if needed */
|
||||||
|
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||||
|
}
|
||||||
|
|
||||||
object tmp(api.PyArray_NewFromDescr_(
|
object tmp(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
||||||
(Py_intptr_t *) strides.data(), const_cast<void *>(ptr), 0, nullptr), false);
|
(Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr), false);
|
||||||
if (!tmp)
|
if (!tmp)
|
||||||
pybind11_fail("NumPy: unable to create array!");
|
pybind11_fail("NumPy: unable to create array!");
|
||||||
if (ptr)
|
if (ptr) {
|
||||||
|
if (base) {
|
||||||
|
PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
|
||||||
|
} else {
|
||||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
m_ptr = tmp.release().ptr();
|
m_ptr = tmp.release().ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, const void *ptr = nullptr)
|
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
||||||
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
|
const void *ptr = nullptr, handle base = handle())
|
||||||
|
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
|
||||||
|
|
||||||
array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr)
|
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
|
||||||
: array(dt, std::vector<size_t> { count }, ptr) { }
|
handle base = handle())
|
||||||
|
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
|
||||||
|
|
||||||
template<typename T> array(const std::vector<size_t>& shape,
|
template<typename T> array(const std::vector<size_t>& shape,
|
||||||
const std::vector<size_t>& strides, const T* ptr)
|
const std::vector<size_t>& strides,
|
||||||
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
|
const T* ptr, handle base = handle())
|
||||||
|
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
|
||||||
|
|
||||||
template<typename T> array(const std::vector<size_t>& shape, const T* ptr)
|
template <typename T>
|
||||||
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
array(const std::vector<size_t> &shape, const T *ptr,
|
||||||
|
handle base = handle())
|
||||||
|
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
|
||||||
|
|
||||||
template<typename T> array(size_t count, const T* ptr)
|
template <typename T>
|
||||||
: array(std::vector<size_t> { count }, ptr) { }
|
array(size_t count, const T *ptr, handle base = handle())
|
||||||
|
: array(std::vector<size_t>{ count }, ptr, base) { }
|
||||||
|
|
||||||
array(const buffer_info &info)
|
array(const buffer_info &info)
|
||||||
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
||||||
@ -319,6 +345,11 @@ public:
|
|||||||
return (size_t) PyArray_GET_(m_ptr, nd);
|
return (size_t) PyArray_GET_(m_ptr, nd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Base object
|
||||||
|
object base() const {
|
||||||
|
return object(PyArray_GET_(m_ptr, base), true);
|
||||||
|
}
|
||||||
|
|
||||||
/// Dimensions of the array
|
/// Dimensions of the array
|
||||||
const size_t* shape() const {
|
const size_t* shape() const {
|
||||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||||
@ -343,6 +374,11 @@ public:
|
|||||||
return strides()[dim];
|
return strides()[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the NumPy array flags
|
||||||
|
int flags() const {
|
||||||
|
return PyArray_FLAGS_(m_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
/// If set, the array is writeable (otherwise the buffer is read-only)
|
/// If set, the array is writeable (otherwise the buffer is read-only)
|
||||||
bool writeable() const {
|
bool writeable() const {
|
||||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
||||||
@ -436,14 +472,17 @@ public:
|
|||||||
|
|
||||||
array_t(const buffer_info& info) : array(info) { }
|
array_t(const buffer_info& info) : array(info) { }
|
||||||
|
|
||||||
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr = nullptr)
|
array_t(const std::vector<size_t> &shape,
|
||||||
: array(shape, strides, ptr) { }
|
const std::vector<size_t> &strides, const T *ptr = nullptr,
|
||||||
|
handle base = handle())
|
||||||
|
: array(shape, strides, ptr, base) { }
|
||||||
|
|
||||||
array_t(const std::vector<size_t>& shape, const T* ptr = nullptr)
|
array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
|
||||||
: array(shape, ptr) { }
|
handle base = handle())
|
||||||
|
: array(shape, ptr, base) { }
|
||||||
|
|
||||||
array_t(size_t count, const T* ptr = nullptr)
|
array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
||||||
: array(count, ptr) { }
|
: array(count, ptr, base) { }
|
||||||
|
|
||||||
constexpr size_t itemsize() const {
|
constexpr size_t itemsize() const {
|
||||||
return sizeof(T);
|
return sizeof(T);
|
||||||
|
@ -567,7 +567,7 @@ public:
|
|||||||
static module import(const char *name) {
|
static module import(const char *name) {
|
||||||
PyObject *obj = PyImport_ImportModule(name);
|
PyObject *obj = PyImport_ImportModule(name);
|
||||||
if (!obj)
|
if (!obj)
|
||||||
pybind11_fail("Module \"" + std::string(name) + "\" not found!");
|
throw import_error("Module \"" + std::string(name) + "\" not found!");
|
||||||
return module(obj, false);
|
return module(obj, false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1344,16 +1344,28 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
|
|||||||
auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" ");
|
auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" ");
|
||||||
auto line = sep.attr("join")(strings);
|
auto line = sep.attr("join")(strings);
|
||||||
|
|
||||||
auto file = kwargs.contains("file") ? kwargs["file"].cast<object>()
|
object file;
|
||||||
: module::import("sys").attr("stdout");
|
if (kwargs.contains("file")) {
|
||||||
|
file = kwargs["file"].cast<object>();
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
file = module::import("sys").attr("stdout");
|
||||||
|
} catch (const import_error &) {
|
||||||
|
/* If print() is called from code that is executed as
|
||||||
|
part of garbage collection during interpreter shutdown,
|
||||||
|
importing 'sys' can fail. Give up rather than crashing the
|
||||||
|
interpreter in this case. */
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto write = file.attr("write");
|
auto write = file.attr("write");
|
||||||
write(line);
|
write(line);
|
||||||
write(kwargs.contains("end") ? kwargs["end"] : cast("\n"));
|
write(kwargs.contains("end") ? kwargs["end"] : cast("\n"));
|
||||||
|
|
||||||
if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {
|
if (kwargs.contains("flush") && kwargs["flush"].cast<bool>())
|
||||||
file.attr("flush")();
|
file.attr("flush")();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
NAMESPACE_END(detail)
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
|
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
|
||||||
|
@ -99,4 +99,29 @@ test_initializer numpy_array([](py::module &m) {
|
|||||||
sm.def("make_c_array", [] {
|
sm.def("make_c_array", [] {
|
||||||
return py::array_t<float>({ 2, 2 }, { 8, 4 });
|
return py::array_t<float>({ 2, 2 }, { 8, 4 });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
sm.def("wrap", [](py::array a) {
|
||||||
|
return py::array(
|
||||||
|
a.dtype(),
|
||||||
|
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
|
||||||
|
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
|
||||||
|
a.data(),
|
||||||
|
a
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
struct ArrayClass {
|
||||||
|
int data[2] = { 1, 2 };
|
||||||
|
ArrayClass() { py::print("ArrayClass()"); }
|
||||||
|
~ArrayClass() { py::print("~ArrayClass()"); }
|
||||||
|
};
|
||||||
|
|
||||||
|
py::class_<ArrayClass>(sm, "ArrayClass")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def("numpy_view", [](py::object &obj) {
|
||||||
|
py::print("ArrayClass::numpy_view()");
|
||||||
|
ArrayClass &a = obj.cast<ArrayClass&>();
|
||||||
|
return py::array_t<int>({2}, {4}, a.data, obj);
|
||||||
|
}
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import gc
|
||||||
|
|
||||||
with pytest.suppress(ImportError):
|
with pytest.suppress(ImportError):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -149,6 +150,7 @@ def test_bounds_check(arr):
|
|||||||
index_at(arr, 0, 4)
|
index_at(arr, 0, 4)
|
||||||
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
|
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
|
||||||
|
|
||||||
|
|
||||||
@pytest.requires_numpy
|
@pytest.requires_numpy
|
||||||
def test_make_c_f_array():
|
def test_make_c_f_array():
|
||||||
from pybind11_tests.array import (
|
from pybind11_tests.array import (
|
||||||
@ -158,3 +160,81 @@ def test_make_c_f_array():
|
|||||||
assert not make_c_array().flags.f_contiguous
|
assert not make_c_array().flags.f_contiguous
|
||||||
assert make_f_array().flags.f_contiguous
|
assert make_f_array().flags.f_contiguous
|
||||||
assert not make_f_array().flags.c_contiguous
|
assert not make_f_array().flags.c_contiguous
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_wrap():
|
||||||
|
from pybind11_tests.array import wrap
|
||||||
|
|
||||||
|
def assert_references(A, B):
|
||||||
|
assert A is not B
|
||||||
|
assert A.__array_interface__['data'][0] == \
|
||||||
|
B.__array_interface__['data'][0]
|
||||||
|
assert A.shape == B.shape
|
||||||
|
assert A.strides == B.strides
|
||||||
|
assert A.flags.c_contiguous == B.flags.c_contiguous
|
||||||
|
assert A.flags.f_contiguous == B.flags.f_contiguous
|
||||||
|
assert A.flags.writeable == B.flags.writeable
|
||||||
|
assert A.flags.aligned == B.flags.aligned
|
||||||
|
assert A.flags.updateifcopy == B.flags.updateifcopy
|
||||||
|
assert np.all(A == B)
|
||||||
|
assert not B.flags.owndata
|
||||||
|
assert B.base is A
|
||||||
|
if A.flags.writeable and A.ndim == 2:
|
||||||
|
A[0, 0] = 1234
|
||||||
|
assert B[0, 0] == 1234
|
||||||
|
|
||||||
|
A1 = np.array([1, 2], dtype=np.int16)
|
||||||
|
assert A1.flags.owndata and A1.base is None
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
|
||||||
|
assert A1.flags.owndata and A1.base is None
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
|
||||||
|
A1.flags.writeable = False
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
A1 = np.random.random((4, 4, 4))
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
A1 = A1.transpose()
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
A1 = A1.diagonal()
|
||||||
|
A2 = wrap(A1)
|
||||||
|
assert_references(A1, A2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_numpy_view(capture):
|
||||||
|
from pybind11_tests.array import ArrayClass
|
||||||
|
with capture:
|
||||||
|
ac = ArrayClass()
|
||||||
|
ac_view_1 = ac.numpy_view()
|
||||||
|
ac_view_2 = ac.numpy_view()
|
||||||
|
assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
|
||||||
|
del ac
|
||||||
|
gc.collect()
|
||||||
|
assert capture == """
|
||||||
|
ArrayClass()
|
||||||
|
ArrayClass::numpy_view()
|
||||||
|
ArrayClass::numpy_view()
|
||||||
|
"""
|
||||||
|
ac_view_1[0] = 4
|
||||||
|
ac_view_1[1] = 3
|
||||||
|
assert ac_view_2[0] == 4
|
||||||
|
assert ac_view_2[1] == 3
|
||||||
|
with capture:
|
||||||
|
del ac_view_1
|
||||||
|
del ac_view_2
|
||||||
|
gc.collect()
|
||||||
|
assert capture == """
|
||||||
|
~ArrayClass()
|
||||||
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user