Prefer non-converting argument overloads

This changes the function dispatching code for overloaded functions into
a two-pass procedure where we first try all overloads with
`convert=false` for all arguments.  If no function calls succeeds in the
first pass, we then try a second pass where we allow arguments to have
`convert=true` (unless, of course, the argument was explicitly specified
with `py::arg().noconvert()`).

For non-overloaded methods, the two-pass procedure is skipped (we just
make the overload-allowed call).  The second pass is also skipped if it
would result in the same thing (i.e. where all arguments are
`.noconvert()` arguments).
This commit is contained in:
Jason Rhinelander 2017-02-03 18:25:34 -05:00
parent abc29cad02
commit e550589b42
5 changed files with 138 additions and 23 deletions

View File

@ -372,3 +372,33 @@ name, i.e. by specifying ``py::arg().noconvert()``.
enable no-convert behaviour for just one of several arguments, you will
need to specify a ``py::arg()`` annotation for each argument with the
no-convert argument modified to ``py::arg().noconvert()``.
Overload resolution order
=========================
When a function or method with multiple overloads is called from Python,
pybind11 determines which overload to call in two passes. The first pass
attempts to call each overload without allowing argument conversion (as if
every argument had been specified as ``py::arg().noconvert()`` as decribed
above).
If no overload succeeds in the no-conversion first pass, a second pass is
attempted in which argument conversion is allowed (except where prohibited via
an explicit ``py::arg().noconvert()`` attribute in the function definition).
If the second pass also fails a ``TypeError`` is raised.
Within each pass, overloads are tried in the order they were registered with
pybind11.
What this means in practice is that pybind11 will prefer any overload that does
not require conversion of arguments to an overload that does, but otherwise prefers
earlier-defined overloads to later-defined ones.
.. note::
pybind11 does *not* further prioritize based on the number/pattern of
overloaded arguments. That is, pybind11 does not prioritize a function
requiring one conversion over one requiring three, but only prioritizes
overloads requiring no conversion at all to overloads that require
conversion of at least one argument.

View File

@ -394,6 +394,15 @@ protected:
result = PYBIND11_TRY_NEXT_OVERLOAD;
try {
// We do this in two passes: in the first pass, we load arguments with `convert=false`;
// in the second, we allow conversion (except for arguments with an explicit
// py::arg().noconvert()). This lets us prefer calls without conversion, with
// conversion as a fallback.
std::vector<function_call> second_pass;
// However, if there are no overloads, we can just skip the no-convert pass entirely
const bool overloaded = it != nullptr && it->next != nullptr;
for (; it != nullptr; it = it->next) {
/* For each overload:
@ -525,6 +534,15 @@ protected:
pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!");
#endif
std::vector<bool> second_pass_convert;
if (overloaded) {
// We're in the first no-convert pass, so swap out the conversion flags for a
// set of all-false flags. If the call fails, we'll swap the flags back in for
// the conversion-allowed call below.
second_pass_convert.resize(func.nargs, false);
call.args_convert.swap(second_pass_convert);
}
// 6. Call the function.
try {
result = func.impl(call);
@ -535,6 +553,34 @@ protected:
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
break;
if (overloaded) {
// The (overloaded) call failed; if the call has at least one argument that
// permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`)
// then add this call to the list of second pass overloads to try.
for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) {
if (second_pass_convert[i]) {
// Found one: swap the converting flags back in and store the call for
// the second pass.
call.args_convert.swap(second_pass_convert);
second_pass.push_back(std::move(call));
break;
}
}
}
}
if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) {
// The no-conversion pass finished without success, try again with conversion allowed
for (auto &call : second_pass) {
try {
result = call.func.impl(call);
} catch (reference_cast_error &) {
result = PYBIND11_TRY_NEXT_OVERLOAD;
}
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
break;
}
}
} catch (error_already_set &e) {
e.restore();

View File

@ -50,10 +50,14 @@ public:
int *internal4() { return &value; } // return by pointer
const int *internal5() { return &value; } // return by const pointer
py::str overloaded(int, float) { return "(int, float)"; }
py::str overloaded(float, int) { return "(float, int)"; }
py::str overloaded(int, float) const { return "(int, float) const"; }
py::str overloaded(float, int) const { return "(float, int) const"; }
py::str overloaded(int, float) { return "(int, float)"; }
py::str overloaded(float, int) { return "(float, int)"; }
py::str overloaded(int, int) { return "(int, int)"; }
py::str overloaded(float, float) { return "(float, float)"; }
py::str overloaded(int, float) const { return "(int, float) const"; }
py::str overloaded(float, int) const { return "(float, int) const"; }
py::str overloaded(int, int) const { return "(int, int) const"; }
py::str overloaded(float, float) const { return "(float, float) const"; }
int value = 0;
};
@ -100,6 +104,7 @@ class CppDerivedDynamicClass : public DynamicClass { };
// py::arg/py::arg_v testing: these arguments just record their argument when invoked
class ArgInspector1 { public: std::string arg = "(default arg inspector 1)"; };
class ArgInspector2 { public: std::string arg = "(default arg inspector 2)"; };
class ArgAlwaysConverts { };
namespace pybind11 { namespace detail {
template <> struct type_caster<ArgInspector1> {
public:
@ -131,6 +136,18 @@ public:
return str(src.arg).release();
}
};
template <> struct type_caster<ArgAlwaysConverts> {
public:
PYBIND11_TYPE_CASTER(ArgAlwaysConverts, _("ArgAlwaysConverts"));
bool load(handle, bool convert) {
return convert;
}
static handle cast(const ArgAlwaysConverts &, return_value_policy, handle) {
return py::none();
}
};
}}
test_initializer methods_and_attributes([](py::module &m) {
@ -159,15 +176,25 @@ test_initializer methods_and_attributes([](py::module &m) {
.def("internal4", &ExampleMandA::internal4)
.def("internal5", &ExampleMandA::internal5)
#if defined(PYBIND11_OVERLOAD_CAST)
.def("overloaded", py::overload_cast<int, float>(&ExampleMandA::overloaded))
.def("overloaded", py::overload_cast<float, int>(&ExampleMandA::overloaded))
.def("overloaded_const", py::overload_cast<int, float>(&ExampleMandA::overloaded, py::const_))
.def("overloaded_const", py::overload_cast<float, int>(&ExampleMandA::overloaded, py::const_))
.def("overloaded", py::overload_cast<int, float>(&ExampleMandA::overloaded))
.def("overloaded", py::overload_cast<float, int>(&ExampleMandA::overloaded))
.def("overloaded", py::overload_cast<int, int>(&ExampleMandA::overloaded))
.def("overloaded", py::overload_cast<float, float>(&ExampleMandA::overloaded))
.def("overloaded_float", py::overload_cast<float, float>(&ExampleMandA::overloaded))
.def("overloaded_const", py::overload_cast<int, float>(&ExampleMandA::overloaded, py::const_))
.def("overloaded_const", py::overload_cast<float, int>(&ExampleMandA::overloaded, py::const_))
.def("overloaded_const", py::overload_cast<int, int>(&ExampleMandA::overloaded, py::const_))
.def("overloaded_const", py::overload_cast<float, float>(&ExampleMandA::overloaded, py::const_))
#else
.def("overloaded", static_cast<py::str (ExampleMandA::*)(int, float)>(&ExampleMandA::overloaded))
.def("overloaded", static_cast<py::str (ExampleMandA::*)(float, int)>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(int, float) const>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(float, int) const>(&ExampleMandA::overloaded))
.def("overloaded", static_cast<py::str (ExampleMandA::*)(int, float)>(&ExampleMandA::overloaded))
.def("overloaded", static_cast<py::str (ExampleMandA::*)(float, int)>(&ExampleMandA::overloaded))
.def("overloaded", static_cast<py::str (ExampleMandA::*)(int, int)>(&ExampleMandA::overloaded))
.def("overloaded", static_cast<py::str (ExampleMandA::*)(float, float)>(&ExampleMandA::overloaded))
.def("overloaded_float", static_cast<py::str (ExampleMandA::*)(float, float)>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(int, float) const>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(float, int) const>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(int, int) const>(&ExampleMandA::overloaded))
.def("overloaded_const", static_cast<py::str (ExampleMandA::*)(float, float) const>(&ExampleMandA::overloaded))
#endif
.def("__str__", &ExampleMandA::toString)
.def_readwrite("value", &ExampleMandA::value);
@ -220,22 +247,25 @@ test_initializer methods_and_attributes([](py::module &m) {
.def(py::init());
#endif
// Test converting. The ArgAlwaysConverts is just there to make the first no-conversion pass
// fail so that our call always ends up happening via the second dispatch (the one that allows
// some conversion).
class ArgInspector {
public:
ArgInspector1 f(ArgInspector1 a) { return a; }
std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d) {
ArgInspector1 f(ArgInspector1 a, ArgAlwaysConverts) { return a; }
std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d, ArgAlwaysConverts) {
return a.arg + "\n" + b.arg + "\n" + std::to_string(c) + "\n" + d->arg;
}
static ArgInspector2 h(ArgInspector2 a) { return a; }
static ArgInspector2 h(ArgInspector2 a, ArgAlwaysConverts) { return a; }
};
py::class_<ArgInspector>(m, "ArgInspector")
.def(py::init<>())
.def("f", &ArgInspector::f)
.def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2())
.def_static("h", &ArgInspector::h, py::arg().noconvert())
.def("f", &ArgInspector::f, py::arg(), py::arg() = ArgAlwaysConverts())
.def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2(), py::arg() = ArgAlwaysConverts())
.def_static("h", &ArgInspector::h, py::arg().noconvert(), py::arg() = ArgAlwaysConverts())
;
m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b) { return a.arg + "\n" + b.arg; },
py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true));
m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b, ArgAlwaysConverts) { return a.arg + "\n" + b.arg; },
py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true), py::arg() = ArgAlwaysConverts());
m.def("floats_preferred", [](double f) { return 0.5 * f; }, py::arg("f"));
m.def("floats_only", [](double f) { return 0.5 * f; }, py::arg("f").noconvert());

View File

@ -33,8 +33,16 @@ def test_methods_and_attributes():
assert instance1.overloaded(1, 1.0) == "(int, float)"
assert instance1.overloaded(2.0, 2) == "(float, int)"
assert instance1.overloaded_const(3, 3.0) == "(int, float) const"
assert instance1.overloaded_const(4.0, 4) == "(float, int) const"
assert instance1.overloaded(3, 3) == "(int, int)"
assert instance1.overloaded(4., 4.) == "(float, float)"
assert instance1.overloaded_const(5, 5.0) == "(int, float) const"
assert instance1.overloaded_const(6.0, 6) == "(float, int) const"
assert instance1.overloaded_const(7, 7) == "(int, int) const"
assert instance1.overloaded_const(8., 8.) == "(float, float) const"
assert instance1.overloaded_float(1, 1) == "(float, float)"
assert instance1.overloaded_float(1, 1.) == "(float, float)"
assert instance1.overloaded_float(1., 1) == "(float, float)"
assert instance1.overloaded_float(1., 1.) == "(float, float)"
assert instance1.value == 320
instance1.value = 100

View File

@ -28,9 +28,10 @@ def test_pointers(msg):
print_opaque_list, return_null_str, get_null_str_value,
return_unique_ptr, ConstructorStats)
living_before = ConstructorStats.get(ExampleMandA).alive()
assert get_void_ptr_value(return_void_ptr()) == 0x1234
assert get_void_ptr_value(ExampleMandA()) # Should also work for other C++ types
assert ConstructorStats.get(ExampleMandA).alive() == 0
assert ConstructorStats.get(ExampleMandA).alive() == living_before
with pytest.raises(TypeError) as excinfo:
get_void_ptr_value([1, 2, 3]) # This should not work