mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-25 22:52:01 +00:00
Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484)
* Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags * Add trivially-contiguous arrays to the tests
This commit is contained in:
parent
f12ec00d70
commit
9df13835c8
@ -934,7 +934,8 @@ public:
|
||||
static bool check_(handle h) {
|
||||
const auto &api = detail::npy_api::get();
|
||||
return api.PyArray_Check_(h.ptr())
|
||||
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
|
||||
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
|
||||
&& detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -385,4 +385,42 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
sm.def("index_using_ellipsis", [](py::array a) {
|
||||
return a[py::make_tuple(0, py::ellipsis(), 0)];
|
||||
});
|
||||
|
||||
// test_argument_conversions
|
||||
sm.def("accept_double",
|
||||
[](py::array_t<double, 0>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_forcecast",
|
||||
[](py::array_t<double, py::array::forcecast>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_c_style",
|
||||
[](py::array_t<double, py::array::c_style>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_c_style_forcecast",
|
||||
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_f_style",
|
||||
[](py::array_t<double, py::array::f_style>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_f_style_forcecast",
|
||||
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
|
||||
py::arg("a"));
|
||||
sm.def("accept_double_noconvert",
|
||||
[](py::array_t<double, 0>) {},
|
||||
py::arg("a").noconvert());
|
||||
sm.def("accept_double_forcecast_noconvert",
|
||||
[](py::array_t<double, py::array::forcecast>) {},
|
||||
py::arg("a").noconvert());
|
||||
sm.def("accept_double_c_style_noconvert",
|
||||
[](py::array_t<double, py::array::c_style>) {},
|
||||
py::arg("a").noconvert());
|
||||
sm.def("accept_double_c_style_forcecast_noconvert",
|
||||
[](py::array_t<double, py::array::forcecast | py::array::c_style>) {},
|
||||
py::arg("a").noconvert());
|
||||
sm.def("accept_double_f_style_noconvert",
|
||||
[](py::array_t<double, py::array::f_style>) {},
|
||||
py::arg("a").noconvert());
|
||||
sm.def("accept_double_f_style_forcecast_noconvert",
|
||||
[](py::array_t<double, py::array::forcecast | py::array::f_style>) {},
|
||||
py::arg("a").noconvert());
|
||||
}
|
||||
|
@ -435,6 +435,52 @@ def test_index_using_ellipsis():
|
||||
assert a.shape == (6,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("forcecast", [False, True])
|
||||
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
|
||||
@pytest.mark.parametrize("noconvert", [False, True])
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
|
||||
)
|
||||
def test_argument_conversions(forcecast, contiguity, noconvert):
|
||||
function_name = "accept_double"
|
||||
if contiguity == 'C':
|
||||
function_name += "_c_style"
|
||||
elif contiguity == 'F':
|
||||
function_name += "_f_style"
|
||||
if forcecast:
|
||||
function_name += "_forcecast"
|
||||
if noconvert:
|
||||
function_name += "_noconvert"
|
||||
function = getattr(m, function_name)
|
||||
|
||||
for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
|
||||
for order in ['C', 'F']:
|
||||
for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
|
||||
if not noconvert:
|
||||
# If noconvert is not passed, only complex128 needs to be truncated and
|
||||
# "cannot be safely obtained". So without `forcecast`, the argument shouldn't
|
||||
# be accepted.
|
||||
should_raise = dtype.name == 'complex128' and not forcecast
|
||||
else:
|
||||
# If noconvert is passed, only float64 and the matching order is accepted.
|
||||
# If at most one dimension has a size greater than 1, the array is also
|
||||
# trivially contiguous.
|
||||
trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
|
||||
should_raise = (
|
||||
dtype.name != 'float64' or
|
||||
(contiguity is not None and
|
||||
contiguity != order and
|
||||
not trivially_contiguous)
|
||||
)
|
||||
|
||||
array = np.zeros(shape, dtype=dtype, order=order)
|
||||
if not should_raise:
|
||||
function(array)
|
||||
else:
|
||||
with pytest.raises(TypeError, match="incompatible function arguments"):
|
||||
function(array)
|
||||
|
||||
|
||||
@pytest.mark.xfail("env.PYPY")
|
||||
def test_dtype_refcount_leak():
|
||||
from sys import getrefcount
|
||||
|
Loading…
Reference in New Issue
Block a user