array_t overload resolution support

This makes array_t respect overload resolution and noconvert by failing
to load when `convert = false` if the src isn't already an array of the
correct type.
This commit is contained in:
Jason Rhinelander 2017-02-26 18:03:00 -05:00
parent 38fc542f97
commit c44fe6fda5
3 changed files with 85 additions and 2 deletions

View File

@ -688,7 +688,9 @@ template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
using type = array_t<T, ExtraFlags>;
bool load(handle src, bool /* convert */) {
bool load(handle src, bool convert) {
if (!convert && !type::check_(src))
return false;
value = type::ensure(src);
return static_cast<bool>(value);
}

View File

@ -151,6 +151,34 @@ test_initializer numpy_array([](py::module &m) {
);
});
// Overload resolution tests:
sm.def("overloaded", [](py::array_t<double>) { return "double"; });
sm.def("overloaded", [](py::array_t<float>) { return "float"; });
sm.def("overloaded", [](py::array_t<int>) { return "int"; });
sm.def("overloaded", [](py::array_t<unsigned short>) { return "unsigned short"; });
sm.def("overloaded", [](py::array_t<long long>) { return "long long"; });
sm.def("overloaded", [](py::array_t<std::complex<double>>) { return "double complex"; });
sm.def("overloaded", [](py::array_t<std::complex<float>>) { return "float complex"; });
sm.def("overloaded2", [](py::array_t<std::complex<double>>) { return "double complex"; });
sm.def("overloaded2", [](py::array_t<double>) { return "double"; });
sm.def("overloaded2", [](py::array_t<std::complex<float>>) { return "float complex"; });
sm.def("overloaded2", [](py::array_t<float>) { return "float"; });
// Only accept the exact types:
sm.def("overloaded3", [](py::array_t<int>) { return "int"; }, py::arg().noconvert());
sm.def("overloaded3", [](py::array_t<double>) { return "double"; }, py::arg().noconvert());
// Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but
// rather that float gets converted via the safe (conversion to double) overload:
sm.def("overloaded4", [](py::array_t<long long, 0>) { return "long long"; });
sm.def("overloaded4", [](py::array_t<double, 0>) { return "double"; });
// But we do allow conversion to int if forcecast is enabled (but only if no overload matches
// without conversion)
sm.def("overloaded5", [](py::array_t<unsigned int>) { return "unsigned int"; });
sm.def("overloaded5", [](py::array_t<double>) { return "double"; });
// Issue 685: ndarray shouldn't go to std::string overload
sm.def("issue685", [](std::string) { return "string"; });
sm.def("issue685", [](py::array) { return "array"; });

View File

@ -264,7 +264,60 @@ def test_constructors():
assert results["array_t<double>"].dtype == np.float64
@pytest.requires_numpy
def test_overload_resolution(msg):
from pybind11_tests.array import overloaded, overloaded2, overloaded3, overloaded4, overloaded5
# Exact overload matches:
assert overloaded(np.array([1], dtype='float64')) == 'double'
assert overloaded(np.array([1], dtype='float32')) == 'float'
assert overloaded(np.array([1], dtype='ushort')) == 'unsigned short'
assert overloaded(np.array([1], dtype='intc')) == 'int'
assert overloaded(np.array([1], dtype='longlong')) == 'long long'
assert overloaded(np.array([1], dtype='complex')) == 'double complex'
assert overloaded(np.array([1], dtype='csingle')) == 'float complex'
# No exact match, should call first convertible version:
assert overloaded(np.array([1], dtype='uint8')) == 'double'
assert overloaded2(np.array([1], dtype='float64')) == 'double'
assert overloaded2(np.array([1], dtype='float32')) == 'float'
assert overloaded2(np.array([1], dtype='complex64')) == 'float complex'
assert overloaded2(np.array([1], dtype='complex128')) == 'double complex'
assert overloaded2(np.array([1], dtype='float32')) == 'float'
assert overloaded3(np.array([1], dtype='float64')) == 'double'
assert overloaded3(np.array([1], dtype='intc')) == 'int'
expected_exc = """
overloaded3(): incompatible function arguments. The following argument types are supported:
1. (arg0: numpy.ndarray[int]) -> str
2. (arg0: numpy.ndarray[float]) -> str
Invoked with:"""
with pytest.raises(TypeError) as excinfo:
overloaded3(np.array([1], dtype='uintc'))
assert msg(excinfo.value) == expected_exc + " array([1], dtype=uint32)"
with pytest.raises(TypeError) as excinfo:
overloaded3(np.array([1], dtype='float32'))
assert msg(excinfo.value) == expected_exc + " array([ 1.], dtype=float32)"
with pytest.raises(TypeError) as excinfo:
overloaded3(np.array([1], dtype='complex'))
assert msg(excinfo.value) == expected_exc + " array([ 1.+0.j])"
# Exact matches:
assert overloaded4(np.array([1], dtype='double')) == 'double'
assert overloaded4(np.array([1], dtype='longlong')) == 'long long'
# Non-exact matches requiring conversion. Since float to integer isn't a
# save conversion, it should go to the double overload, but short can go to
# either (and so should end up on the first-registered, the long long).
assert overloaded4(np.array([1], dtype='float32')) == 'double'
assert overloaded4(np.array([1], dtype='short')) == 'long long'
assert overloaded5(np.array([1], dtype='double')) == 'double'
assert overloaded5(np.array([1], dtype='uintc')) == 'unsigned int'
assert overloaded5(np.array([1], dtype='float32')) == 'unsigned int'
def test_greedy_string_overload(): # issue 685
from pybind11_tests.array import issue685