From fac7c0945801bdc37234a9559bfe86ac2f862059 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Thu, 13 Oct 2016 10:37:52 +0200 Subject: [PATCH] NumPy "base" feature: integrated feedback by @aldanor --- include/pybind11/numpy.h | 2 +- tests/test_numpy_array.cpp | 15 +++++++++++++++ tests/test_numpy_array.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 445a6368b..a99c72eee 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -157,7 +157,7 @@ NAMESPACE_END(detail) #define PyArrayDescr_GET_(ptr, attr) \ (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) #define PyArray_FLAGS_(ptr) \ - (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags) + PyArray_GET_(ptr, flags) #define PyArray_CHKFLAGS_(ptr, flag) \ (flag == (PyArray_FLAGS_(ptr) & flag)) diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index a6bf50de0..ec4ddacb9 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -109,4 +109,19 @@ test_initializer numpy_array([](py::module &m) { a ); }); + + struct ArrayClass { + int data[2] = { 1, 2 }; + ArrayClass() { py::print("ArrayClass()"); } + ~ArrayClass() { py::print("~ArrayClass()"); } + }; + + py::class_(sm, "ArrayClass") + .def(py::init<>()) + .def("numpy_view", [](py::object &obj) { + py::print("ArrayClass::numpy_view()"); + ArrayClass &a = obj.cast(); + return py::array_t({2}, {4}, a.data, obj); + } + ); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 52350f690..ae1954a65 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -1,4 +1,5 @@ import pytest +import gc with pytest.suppress(ImportError): import numpy as np @@ -209,3 +210,31 @@ def test_wrap(): A1 = A1.diagonal() A2 = wrap(A1) assert_references(A1, A2) + + +@pytest.requires_numpy +def test_numpy_view(capture): + from pybind11_tests.array import ArrayClass + with capture: + ac = ArrayClass() + ac_view_1 = ac.numpy_view() + ac_view_2 = ac.numpy_view() + assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) + del ac + gc.collect() + assert capture == """ + ArrayClass() + ArrayClass::numpy_view() + ArrayClass::numpy_view() + """ + ac_view_1[0] = 4 + ac_view_1[1] = 3 + assert ac_view_2[0] == 4 + assert ac_view_2[1] == 3 + with capture: + del ac_view_1 + del ac_view_2 + gc.collect() + assert capture == """ + ~ArrayClass() + """