NumPy "base" feature: integrated feedback by @aldanor

This commit is contained in:
Wenzel Jakob 2016-10-13 10:37:52 +02:00
parent c49d6e508a
commit fac7c09458
3 changed files with 45 additions and 1 deletions

View File

@ -157,7 +157,7 @@ NAMESPACE_END(detail)
#define PyArrayDescr_GET_(ptr, attr) \ #define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
#define PyArray_FLAGS_(ptr) \ #define PyArray_FLAGS_(ptr) \
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags) PyArray_GET_(ptr, flags)
#define PyArray_CHKFLAGS_(ptr, flag) \ #define PyArray_CHKFLAGS_(ptr, flag) \
(flag == (PyArray_FLAGS_(ptr) & flag)) (flag == (PyArray_FLAGS_(ptr) & flag))

View File

@ -109,4 +109,19 @@ test_initializer numpy_array([](py::module &m) {
a a
); );
}); });
struct ArrayClass {
int data[2] = { 1, 2 };
ArrayClass() { py::print("ArrayClass()"); }
~ArrayClass() { py::print("~ArrayClass()"); }
};
py::class_<ArrayClass>(sm, "ArrayClass")
.def(py::init<>())
.def("numpy_view", [](py::object &obj) {
py::print("ArrayClass::numpy_view()");
ArrayClass &a = obj.cast<ArrayClass&>();
return py::array_t<int>({2}, {4}, a.data, obj);
}
);
}); });

View File

@ -1,4 +1,5 @@
import pytest import pytest
import gc
with pytest.suppress(ImportError): with pytest.suppress(ImportError):
import numpy as np import numpy as np
@ -209,3 +210,31 @@ def test_wrap():
A1 = A1.diagonal() A1 = A1.diagonal()
A2 = wrap(A1) A2 = wrap(A1)
assert_references(A1, A2) assert_references(A1, A2)
@pytest.requires_numpy
def test_numpy_view(capture):
from pybind11_tests.array import ArrayClass
with capture:
ac = ArrayClass()
ac_view_1 = ac.numpy_view()
ac_view_2 = ac.numpy_view()
assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
del ac
gc.collect()
assert capture == """
ArrayClass()
ArrayClass::numpy_view()
ArrayClass::numpy_view()
"""
ac_view_1[0] = 4
ac_view_1[1] = 3
assert ac_view_2[0] == 4
assert ac_view_2[1] == 3
with capture:
del ac_view_1
del ac_view_2
gc.collect()
assert capture == """
~ArrayClass()
"""