mirror of
https://github.com/pybind/pybind11.git
synced 2025-03-12 07:49:28 +00:00
array_t overload resolution support
This makes array_t respect overload resolution and noconvert by failing to load when `convert = false` if the src isn't already an array of the correct type.
This commit is contained in:
parent
38fc542f97
commit
c44fe6fda5
@ -688,7 +688,9 @@ template <typename T, int ExtraFlags>
|
||||
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||
using type = array_t<T, ExtraFlags>;
|
||||
|
||||
bool load(handle src, bool /* convert */) {
|
||||
bool load(handle src, bool convert) {
|
||||
if (!convert && !type::check_(src))
|
||||
return false;
|
||||
value = type::ensure(src);
|
||||
return static_cast<bool>(value);
|
||||
}
|
||||
|
@ -151,6 +151,34 @@ test_initializer numpy_array([](py::module &m) {
|
||||
);
|
||||
});
|
||||
|
||||
// Overload resolution tests:
|
||||
sm.def("overloaded", [](py::array_t<double>) { return "double"; });
|
||||
sm.def("overloaded", [](py::array_t<float>) { return "float"; });
|
||||
sm.def("overloaded", [](py::array_t<int>) { return "int"; });
|
||||
sm.def("overloaded", [](py::array_t<unsigned short>) { return "unsigned short"; });
|
||||
sm.def("overloaded", [](py::array_t<long long>) { return "long long"; });
|
||||
sm.def("overloaded", [](py::array_t<std::complex<double>>) { return "double complex"; });
|
||||
sm.def("overloaded", [](py::array_t<std::complex<float>>) { return "float complex"; });
|
||||
|
||||
sm.def("overloaded2", [](py::array_t<std::complex<double>>) { return "double complex"; });
|
||||
sm.def("overloaded2", [](py::array_t<double>) { return "double"; });
|
||||
sm.def("overloaded2", [](py::array_t<std::complex<float>>) { return "float complex"; });
|
||||
sm.def("overloaded2", [](py::array_t<float>) { return "float"; });
|
||||
|
||||
// Only accept the exact types:
|
||||
sm.def("overloaded3", [](py::array_t<int>) { return "int"; }, py::arg().noconvert());
|
||||
sm.def("overloaded3", [](py::array_t<double>) { return "double"; }, py::arg().noconvert());
|
||||
|
||||
// Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but
|
||||
// rather that float gets converted via the safe (conversion to double) overload:
|
||||
sm.def("overloaded4", [](py::array_t<long long, 0>) { return "long long"; });
|
||||
sm.def("overloaded4", [](py::array_t<double, 0>) { return "double"; });
|
||||
|
||||
// But we do allow conversion to int if forcecast is enabled (but only if no overload matches
|
||||
// without conversion)
|
||||
sm.def("overloaded5", [](py::array_t<unsigned int>) { return "unsigned int"; });
|
||||
sm.def("overloaded5", [](py::array_t<double>) { return "double"; });
|
||||
|
||||
// Issue 685: ndarray shouldn't go to std::string overload
|
||||
sm.def("issue685", [](std::string) { return "string"; });
|
||||
sm.def("issue685", [](py::array) { return "array"; });
|
||||
|
@ -264,7 +264,60 @@ def test_constructors():
|
||||
assert results["array_t<double>"].dtype == np.float64
|
||||
|
||||
|
||||
@pytest.requires_numpy
|
||||
def test_overload_resolution(msg):
|
||||
from pybind11_tests.array import overloaded, overloaded2, overloaded3, overloaded4, overloaded5
|
||||
|
||||
# Exact overload matches:
|
||||
assert overloaded(np.array([1], dtype='float64')) == 'double'
|
||||
assert overloaded(np.array([1], dtype='float32')) == 'float'
|
||||
assert overloaded(np.array([1], dtype='ushort')) == 'unsigned short'
|
||||
assert overloaded(np.array([1], dtype='intc')) == 'int'
|
||||
assert overloaded(np.array([1], dtype='longlong')) == 'long long'
|
||||
assert overloaded(np.array([1], dtype='complex')) == 'double complex'
|
||||
assert overloaded(np.array([1], dtype='csingle')) == 'float complex'
|
||||
|
||||
# No exact match, should call first convertible version:
|
||||
assert overloaded(np.array([1], dtype='uint8')) == 'double'
|
||||
|
||||
assert overloaded2(np.array([1], dtype='float64')) == 'double'
|
||||
assert overloaded2(np.array([1], dtype='float32')) == 'float'
|
||||
assert overloaded2(np.array([1], dtype='complex64')) == 'float complex'
|
||||
assert overloaded2(np.array([1], dtype='complex128')) == 'double complex'
|
||||
assert overloaded2(np.array([1], dtype='float32')) == 'float'
|
||||
|
||||
assert overloaded3(np.array([1], dtype='float64')) == 'double'
|
||||
assert overloaded3(np.array([1], dtype='intc')) == 'int'
|
||||
expected_exc = """
|
||||
overloaded3(): incompatible function arguments. The following argument types are supported:
|
||||
1. (arg0: numpy.ndarray[int]) -> str
|
||||
2. (arg0: numpy.ndarray[float]) -> str
|
||||
|
||||
Invoked with:"""
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
overloaded3(np.array([1], dtype='uintc'))
|
||||
assert msg(excinfo.value) == expected_exc + " array([1], dtype=uint32)"
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
overloaded3(np.array([1], dtype='float32'))
|
||||
assert msg(excinfo.value) == expected_exc + " array([ 1.], dtype=float32)"
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
overloaded3(np.array([1], dtype='complex'))
|
||||
assert msg(excinfo.value) == expected_exc + " array([ 1.+0.j])"
|
||||
|
||||
# Exact matches:
|
||||
assert overloaded4(np.array([1], dtype='double')) == 'double'
|
||||
assert overloaded4(np.array([1], dtype='longlong')) == 'long long'
|
||||
# Non-exact matches requiring conversion. Since float to integer isn't a
|
||||
# save conversion, it should go to the double overload, but short can go to
|
||||
# either (and so should end up on the first-registered, the long long).
|
||||
assert overloaded4(np.array([1], dtype='float32')) == 'double'
|
||||
assert overloaded4(np.array([1], dtype='short')) == 'long long'
|
||||
|
||||
assert overloaded5(np.array([1], dtype='double')) == 'double'
|
||||
assert overloaded5(np.array([1], dtype='uintc')) == 'unsigned int'
|
||||
assert overloaded5(np.array([1], dtype='float32')) == 'unsigned int'
|
||||
|
||||
|
||||
def test_greedy_string_overload(): # issue 685
|
||||
from pybind11_tests.array import issue685
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user