mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 14:45:12 +00:00
NumPy "base" feature: integrated feedback by @aldanor
This commit is contained in:
parent
c49d6e508a
commit
fac7c09458
@ -157,7 +157,7 @@ NAMESPACE_END(detail)
|
||||
#define PyArrayDescr_GET_(ptr, attr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||
#define PyArray_FLAGS_(ptr) \
|
||||
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags)
|
||||
PyArray_GET_(ptr, flags)
|
||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||
(flag == (PyArray_FLAGS_(ptr) & flag))
|
||||
|
||||
|
@ -109,4 +109,19 @@ test_initializer numpy_array([](py::module &m) {
|
||||
a
|
||||
);
|
||||
});
|
||||
|
||||
struct ArrayClass {
|
||||
int data[2] = { 1, 2 };
|
||||
ArrayClass() { py::print("ArrayClass()"); }
|
||||
~ArrayClass() { py::print("~ArrayClass()"); }
|
||||
};
|
||||
|
||||
py::class_<ArrayClass>(sm, "ArrayClass")
|
||||
.def(py::init<>())
|
||||
.def("numpy_view", [](py::object &obj) {
|
||||
py::print("ArrayClass::numpy_view()");
|
||||
ArrayClass &a = obj.cast<ArrayClass&>();
|
||||
return py::array_t<int>({2}, {4}, a.data, obj);
|
||||
}
|
||||
);
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import gc
|
||||
|
||||
with pytest.suppress(ImportError):
|
||||
import numpy as np
|
||||
@ -209,3 +210,31 @@ def test_wrap():
|
||||
A1 = A1.diagonal()
|
||||
A2 = wrap(A1)
|
||||
assert_references(A1, A2)
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_numpy_view(capture):
|
||||
from pybind11_tests.array import ArrayClass
|
||||
with capture:
|
||||
ac = ArrayClass()
|
||||
ac_view_1 = ac.numpy_view()
|
||||
ac_view_2 = ac.numpy_view()
|
||||
assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
|
||||
del ac
|
||||
gc.collect()
|
||||
assert capture == """
|
||||
ArrayClass()
|
||||
ArrayClass::numpy_view()
|
||||
ArrayClass::numpy_view()
|
||||
"""
|
||||
ac_view_1[0] = 4
|
||||
ac_view_1[1] = 3
|
||||
assert ac_view_2[0] == 4
|
||||
assert ac_view_2[1] == 3
|
||||
with capture:
|
||||
del ac_view_1
|
||||
del ac_view_2
|
||||
gc.collect()
|
||||
assert capture == """
|
||||
~ArrayClass()
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user