mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-26 07:02:11 +00:00
Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484)
* Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags * Add trivially-contiguous arrays to the tests
This commit is contained in:
parent
f12ec00d70
commit
9df13835c8
@ -934,7 +934,8 @@ public:
|
|||||||
static bool check_(handle h) {
|
static bool check_(handle h) {
|
||||||
const auto &api = detail::npy_api::get();
|
const auto &api = detail::npy_api::get();
|
||||||
return api.PyArray_Check_(h.ptr())
|
return api.PyArray_Check_(h.ptr())
|
||||||
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
|
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
|
||||||
|
&& detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -385,4 +385,42 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
sm.def("index_using_ellipsis", [](py::array a) {
|
sm.def("index_using_ellipsis", [](py::array a) {
|
||||||
return a[py::make_tuple(0, py::ellipsis(), 0)];
|
return a[py::make_tuple(0, py::ellipsis(), 0)];
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// test_argument_conversions
|
||||||
|
sm.def("accept_double",
|
||||||
|
[](py::array_t<double, 0>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_forcecast",
|
||||||
|
[](py::array_t<double, py::array::forcecast>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_c_style",
|
||||||
|
[](py::array_t<double, py::array::c_style>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_c_style_forcecast",
|
||||||
|
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_f_style",
|
||||||
|
[](py::array_t<double, py::array::f_style>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_f_style_forcecast",
|
||||||
|
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
|
||||||
|
py::arg("a"));
|
||||||
|
sm.def("accept_double_noconvert",
|
||||||
|
[](py::array_t<double, 0>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
|
sm.def("accept_double_forcecast_noconvert",
|
||||||
|
[](py::array_t<double, py::array::forcecast>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
|
sm.def("accept_double_c_style_noconvert",
|
||||||
|
[](py::array_t<double, py::array::c_style>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
|
sm.def("accept_double_c_style_forcecast_noconvert",
|
||||||
|
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
|
sm.def("accept_double_f_style_noconvert",
|
||||||
|
[](py::array_t<double, py::array::f_style>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
|
sm.def("accept_double_f_style_forcecast_noconvert",
|
||||||
|
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
|
||||||
|
py::arg("a").noconvert());
|
||||||
}
|
}
|
||||||
|
@ -435,6 +435,52 @@ def test_index_using_ellipsis():
|
|||||||
assert a.shape == (6,)
|
assert a.shape == (6,)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("forcecast", [False, True])
|
||||||
|
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
|
||||||
|
@pytest.mark.parametrize("noconvert", [False, True])
|
||||||
|
@pytest.mark.filterwarnings(
|
||||||
|
"ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
|
||||||
|
)
|
||||||
|
def test_argument_conversions(forcecast, contiguity, noconvert):
|
||||||
|
function_name = "accept_double"
|
||||||
|
if contiguity == 'C':
|
||||||
|
function_name += "_c_style"
|
||||||
|
elif contiguity == 'F':
|
||||||
|
function_name += "_f_style"
|
||||||
|
if forcecast:
|
||||||
|
function_name += "_forcecast"
|
||||||
|
if noconvert:
|
||||||
|
function_name += "_noconvert"
|
||||||
|
function = getattr(m, function_name)
|
||||||
|
|
||||||
|
for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
|
||||||
|
for order in ['C', 'F']:
|
||||||
|
for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
|
||||||
|
if not noconvert:
|
||||||
|
# If noconvert is not passed, only complex128 needs to be truncated and
|
||||||
|
# "cannot be safely obtained". So without `forcecast`, the argument shouldn't
|
||||||
|
# be accepted.
|
||||||
|
should_raise = dtype.name == 'complex128' and not forcecast
|
||||||
|
else:
|
||||||
|
# If noconvert is passed, only float64 and the matching order is accepted.
|
||||||
|
# If at most one dimension has a size greater than 1, the array is also
|
||||||
|
# trivially contiguous.
|
||||||
|
trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
|
||||||
|
should_raise = (
|
||||||
|
dtype.name != 'float64' or
|
||||||
|
(contiguity is not None and
|
||||||
|
contiguity != order and
|
||||||
|
not trivially_contiguous)
|
||||||
|
)
|
||||||
|
|
||||||
|
array = np.zeros(shape, dtype=dtype, order=order)
|
||||||
|
if not should_raise:
|
||||||
|
function(array)
|
||||||
|
else:
|
||||||
|
with pytest.raises(TypeError, match="incompatible function arguments"):
|
||||||
|
function(array)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail("env.PYPY")
|
@pytest.mark.xfail("env.PYPY")
|
||||||
def test_dtype_refcount_leak():
|
def test_dtype_refcount_leak():
|
||||||
from sys import getrefcount
|
from sys import getrefcount
|
||||||
|
Loading…
Reference in New Issue
Block a user