From 9df13835c851753b7d385717eb79d8a06906e0c4 Mon Sep 17 00:00:00 2001 From: Yannick Jadoul Date: Tue, 15 Sep 2020 14:50:51 +0200 Subject: [PATCH] Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484) * Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags * Add trivially-contiguous arrays to the tests --- include/pybind11/numpy.h | 3 ++- tests/test_numpy_array.cpp | 38 +++++++++++++++++++++++++++++++ tests/test_numpy_array.py | 46 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 0192a8b17..c0b38ce20 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -934,7 +934,8 @@ public: static bool check_(handle h) { const auto &api = detail::npy_api::get(); return api.PyArray_Check_(h.ptr()) - && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()); + && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()) + && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style)); } protected: diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index e37beb5a5..caa052549 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -385,4 +385,42 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("index_using_ellipsis", [](py::array a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; }); + + // test_argument_conversions + sm.def("accept_double", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_forcecast", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_c_style", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_c_style_forcecast", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_f_style", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_f_style_forcecast", + [](py::array_t) {}, + py::arg("a")); + sm.def("accept_double_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); + sm.def("accept_double_forcecast_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); + sm.def("accept_double_c_style_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); + sm.def("accept_double_c_style_forcecast_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); + sm.def("accept_double_f_style_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); + sm.def("accept_double_f_style_forcecast_noconvert", + [](py::array_t) {}, + py::arg("a").noconvert()); } diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index ad3ca58c1..a36e707c1 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -435,6 +435,52 @@ def test_index_using_ellipsis(): assert a.shape == (6,) +@pytest.mark.parametrize("forcecast", [False, True]) +@pytest.mark.parametrize("contiguity", [None, 'C', 'F']) +@pytest.mark.parametrize("noconvert", [False, True]) +@pytest.mark.filterwarnings( + "ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning" +) +def test_argument_conversions(forcecast, contiguity, noconvert): + function_name = "accept_double" + if contiguity == 'C': + function_name += "_c_style" + elif contiguity == 'F': + function_name += "_f_style" + if forcecast: + function_name += "_forcecast" + if noconvert: + function_name += "_noconvert" + function = getattr(m, function_name) + + for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]: + for order in ['C', 'F']: + for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]: + if not noconvert: + # If noconvert is not passed, only complex128 needs to be truncated and + # "cannot be safely obtained". So without `forcecast`, the argument shouldn't + # be accepted. + should_raise = dtype.name == 'complex128' and not forcecast + else: + # If noconvert is passed, only float64 and the matching order is accepted. + # If at most one dimension has a size greater than 1, the array is also + # trivially contiguous. + trivially_contiguous = sum(1 for d in shape if d > 1) <= 1 + should_raise = ( + dtype.name != 'float64' or + (contiguity is not None and + contiguity != order and + not trivially_contiguous) + ) + + array = np.zeros(shape, dtype=dtype, order=order) + if not should_raise: + function(array) + else: + with pytest.raises(TypeError, match="incompatible function arguments"): + function(array) + + @pytest.mark.xfail("env.PYPY") def test_dtype_refcount_leak(): from sys import getrefcount