Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484)

* Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags

* Add trivially-contiguous arrays to the tests
This commit is contained in:
Yannick Jadoul 2020-09-15 14:50:51 +02:00 committed by GitHub
parent f12ec00d70
commit 9df13835c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 1 deletions

View File

@ -934,7 +934,8 @@ public:
static bool check_(handle h) {
const auto &api = detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
&& detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
}
protected:

View File

@ -385,4 +385,42 @@ TEST_SUBMODULE(numpy_array, sm) {
sm.def("index_using_ellipsis", [](py::array a) {
return a[py::make_tuple(0, py::ellipsis(), 0)];
});
// test_argument_conversions
sm.def("accept_double",
[](py::array_t<double, 0>) {},
py::arg("a"));
sm.def("accept_double_forcecast",
[](py::array_t<double, py::array::forcecast>) {},
py::arg("a"));
sm.def("accept_double_c_style",
[](py::array_t<double, py::array::c_style>) {},
py::arg("a"));
sm.def("accept_double_c_style_forcecast",
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
py::arg("a"));
sm.def("accept_double_f_style",
[](py::array_t<double, py::array::f_style>) {},
py::arg("a"));
sm.def("accept_double_f_style_forcecast",
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
py::arg("a"));
sm.def("accept_double_noconvert",
[](py::array_t<double, 0>) {},
py::arg("a").noconvert());
sm.def("accept_double_forcecast_noconvert",
[](py::array_t<double, py::array::forcecast>) {},
py::arg("a").noconvert());
sm.def("accept_double_c_style_noconvert",
[](py::array_t<double, py::array::c_style>) {},
py::arg("a").noconvert());
sm.def("accept_double_c_style_forcecast_noconvert",
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
py::arg("a").noconvert());
sm.def("accept_double_f_style_noconvert",
[](py::array_t<double, py::array::f_style>) {},
py::arg("a").noconvert());
sm.def("accept_double_f_style_forcecast_noconvert",
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
py::arg("a").noconvert());
}

View File

@ -435,6 +435,52 @@ def test_index_using_ellipsis():
assert a.shape == (6,)
@pytest.mark.parametrize("forcecast", [False, True])
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
@pytest.mark.parametrize("noconvert", [False, True])
@pytest.mark.filterwarnings(
"ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
)
def test_argument_conversions(forcecast, contiguity, noconvert):
function_name = "accept_double"
if contiguity == 'C':
function_name += "_c_style"
elif contiguity == 'F':
function_name += "_f_style"
if forcecast:
function_name += "_forcecast"
if noconvert:
function_name += "_noconvert"
function = getattr(m, function_name)
for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
for order in ['C', 'F']:
for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
if not noconvert:
# If noconvert is not passed, only complex128 needs to be truncated and
# "cannot be safely obtained". So without `forcecast`, the argument shouldn't
# be accepted.
should_raise = dtype.name == 'complex128' and not forcecast
else:
# If noconvert is passed, only float64 and the matching order is accepted.
# If at most one dimension has a size greater than 1, the array is also
# trivially contiguous.
trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
should_raise = (
dtype.name != 'float64' or
(contiguity is not None and
contiguity != order and
not trivially_contiguous)
)
array = np.zeros(shape, dtype=dtype, order=order)
if not should_raise:
function(array)
else:
with pytest.raises(TypeError, match="incompatible function arguments"):
function(array)
@pytest.mark.xfail("env.PYPY")
def test_dtype_refcount_leak():
from sys import getrefcount