From e550589b42cb130da2b7fec22d3aae368b851c26 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Fri, 3 Feb 2017 18:25:34 -0500 Subject: [PATCH] Prefer non-converting argument overloads This changes the function dispatching code for overloaded functions into a two-pass procedure where we first try all overloads with `convert=false` for all arguments. If no function calls succeeds in the first pass, we then try a second pass where we allow arguments to have `convert=true` (unless, of course, the argument was explicitly specified with `py::arg().noconvert()`). For non-overloaded methods, the two-pass procedure is skipped (we just make the overload-allowed call). The second pass is also skipped if it would result in the same thing (i.e. where all arguments are `.noconvert()` arguments). --- docs/advanced/functions.rst | 30 ++++++++++++ include/pybind11/pybind11.h | 46 ++++++++++++++++++ tests/test_methods_and_attributes.cpp | 70 +++++++++++++++++++-------- tests/test_methods_and_attributes.py | 12 ++++- tests/test_opaque_types.py | 3 +- 5 files changed, 138 insertions(+), 23 deletions(-) diff --git a/docs/advanced/functions.rst b/docs/advanced/functions.rst index e67d7bc37..4a562bd5c 100644 --- a/docs/advanced/functions.rst +++ b/docs/advanced/functions.rst @@ -372,3 +372,33 @@ name, i.e. by specifying ``py::arg().noconvert()``. enable no-convert behaviour for just one of several arguments, you will need to specify a ``py::arg()`` annotation for each argument with the no-convert argument modified to ``py::arg().noconvert()``. + +Overload resolution order +========================= + +When a function or method with multiple overloads is called from Python, +pybind11 determines which overload to call in two passes. The first pass +attempts to call each overload without allowing argument conversion (as if +every argument had been specified as ``py::arg().noconvert()`` as decribed +above). + +If no overload succeeds in the no-conversion first pass, a second pass is +attempted in which argument conversion is allowed (except where prohibited via +an explicit ``py::arg().noconvert()`` attribute in the function definition). + +If the second pass also fails a ``TypeError`` is raised. + +Within each pass, overloads are tried in the order they were registered with +pybind11. + +What this means in practice is that pybind11 will prefer any overload that does +not require conversion of arguments to an overload that does, but otherwise prefers +earlier-defined overloads to later-defined ones. + +.. note:: + + pybind11 does *not* further prioritize based on the number/pattern of + overloaded arguments. That is, pybind11 does not prioritize a function + requiring one conversion over one requiring three, but only prioritizes + overloads requiring no conversion at all to overloads that require + conversion of at least one argument. diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index fc283704e..05d6b068b 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -394,6 +394,15 @@ protected: result = PYBIND11_TRY_NEXT_OVERLOAD; try { + // We do this in two passes: in the first pass, we load arguments with `convert=false`; + // in the second, we allow conversion (except for arguments with an explicit + // py::arg().noconvert()). This lets us prefer calls without conversion, with + // conversion as a fallback. + std::vector second_pass; + + // However, if there are no overloads, we can just skip the no-convert pass entirely + const bool overloaded = it != nullptr && it->next != nullptr; + for (; it != nullptr; it = it->next) { /* For each overload: @@ -525,6 +534,15 @@ protected: pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!"); #endif + std::vector second_pass_convert; + if (overloaded) { + // We're in the first no-convert pass, so swap out the conversion flags for a + // set of all-false flags. If the call fails, we'll swap the flags back in for + // the conversion-allowed call below. + second_pass_convert.resize(func.nargs, false); + call.args_convert.swap(second_pass_convert); + } + // 6. Call the function. try { result = func.impl(call); @@ -535,6 +553,34 @@ protected: if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) break; + if (overloaded) { + // The (overloaded) call failed; if the call has at least one argument that + // permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`) + // then add this call to the list of second pass overloads to try. + for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) { + if (second_pass_convert[i]) { + // Found one: swap the converting flags back in and store the call for + // the second pass. + call.args_convert.swap(second_pass_convert); + second_pass.push_back(std::move(call)); + break; + } + } + } + } + + if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + // The no-conversion pass finished without success, try again with conversion allowed + for (auto &call : second_pass) { + try { + result = call.func.impl(call); + } catch (reference_cast_error &) { + result = PYBIND11_TRY_NEXT_OVERLOAD; + } + + if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) + break; + } } } catch (error_already_set &e) { e.restore(); diff --git a/tests/test_methods_and_attributes.cpp b/tests/test_methods_and_attributes.cpp index 6f5e5ef2f..e33e4ac39 100644 --- a/tests/test_methods_and_attributes.cpp +++ b/tests/test_methods_and_attributes.cpp @@ -50,10 +50,14 @@ public: int *internal4() { return &value; } // return by pointer const int *internal5() { return &value; } // return by const pointer - py::str overloaded(int, float) { return "(int, float)"; } - py::str overloaded(float, int) { return "(float, int)"; } - py::str overloaded(int, float) const { return "(int, float) const"; } - py::str overloaded(float, int) const { return "(float, int) const"; } + py::str overloaded(int, float) { return "(int, float)"; } + py::str overloaded(float, int) { return "(float, int)"; } + py::str overloaded(int, int) { return "(int, int)"; } + py::str overloaded(float, float) { return "(float, float)"; } + py::str overloaded(int, float) const { return "(int, float) const"; } + py::str overloaded(float, int) const { return "(float, int) const"; } + py::str overloaded(int, int) const { return "(int, int) const"; } + py::str overloaded(float, float) const { return "(float, float) const"; } int value = 0; }; @@ -100,6 +104,7 @@ class CppDerivedDynamicClass : public DynamicClass { }; // py::arg/py::arg_v testing: these arguments just record their argument when invoked class ArgInspector1 { public: std::string arg = "(default arg inspector 1)"; }; class ArgInspector2 { public: std::string arg = "(default arg inspector 2)"; }; +class ArgAlwaysConverts { }; namespace pybind11 { namespace detail { template <> struct type_caster { public: @@ -131,6 +136,18 @@ public: return str(src.arg).release(); } }; +template <> struct type_caster { +public: + PYBIND11_TYPE_CASTER(ArgAlwaysConverts, _("ArgAlwaysConverts")); + + bool load(handle, bool convert) { + return convert; + } + + static handle cast(const ArgAlwaysConverts &, return_value_policy, handle) { + return py::none(); + } +}; }} test_initializer methods_and_attributes([](py::module &m) { @@ -159,15 +176,25 @@ test_initializer methods_and_attributes([](py::module &m) { .def("internal4", &ExampleMandA::internal4) .def("internal5", &ExampleMandA::internal5) #if defined(PYBIND11_OVERLOAD_CAST) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) + .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) + .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) + .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) + .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) + .def("overloaded_float", py::overload_cast(&ExampleMandA::overloaded)) + .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) + .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) + .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) + .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) #else - .def("overloaded", static_cast(&ExampleMandA::overloaded)) - .def("overloaded", static_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) + .def("overloaded", static_cast(&ExampleMandA::overloaded)) + .def("overloaded", static_cast(&ExampleMandA::overloaded)) + .def("overloaded", static_cast(&ExampleMandA::overloaded)) + .def("overloaded", static_cast(&ExampleMandA::overloaded)) + .def("overloaded_float", static_cast(&ExampleMandA::overloaded)) + .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) + .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) + .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) + .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) #endif .def("__str__", &ExampleMandA::toString) .def_readwrite("value", &ExampleMandA::value); @@ -220,22 +247,25 @@ test_initializer methods_and_attributes([](py::module &m) { .def(py::init()); #endif + // Test converting. The ArgAlwaysConverts is just there to make the first no-conversion pass + // fail so that our call always ends up happening via the second dispatch (the one that allows + // some conversion). class ArgInspector { public: - ArgInspector1 f(ArgInspector1 a) { return a; } - std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d) { + ArgInspector1 f(ArgInspector1 a, ArgAlwaysConverts) { return a; } + std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d, ArgAlwaysConverts) { return a.arg + "\n" + b.arg + "\n" + std::to_string(c) + "\n" + d->arg; } - static ArgInspector2 h(ArgInspector2 a) { return a; } + static ArgInspector2 h(ArgInspector2 a, ArgAlwaysConverts) { return a; } }; py::class_(m, "ArgInspector") .def(py::init<>()) - .def("f", &ArgInspector::f) - .def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2()) - .def_static("h", &ArgInspector::h, py::arg().noconvert()) + .def("f", &ArgInspector::f, py::arg(), py::arg() = ArgAlwaysConverts()) + .def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2(), py::arg() = ArgAlwaysConverts()) + .def_static("h", &ArgInspector::h, py::arg().noconvert(), py::arg() = ArgAlwaysConverts()) ; - m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b) { return a.arg + "\n" + b.arg; }, - py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true)); + m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b, ArgAlwaysConverts) { return a.arg + "\n" + b.arg; }, + py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true), py::arg() = ArgAlwaysConverts()); m.def("floats_preferred", [](double f) { return 0.5 * f; }, py::arg("f")); m.def("floats_only", [](double f) { return 0.5 * f; }, py::arg("f").noconvert()); diff --git a/tests/test_methods_and_attributes.py b/tests/test_methods_and_attributes.py index f890c6aaa..3a5d2e198 100644 --- a/tests/test_methods_and_attributes.py +++ b/tests/test_methods_and_attributes.py @@ -33,8 +33,16 @@ def test_methods_and_attributes(): assert instance1.overloaded(1, 1.0) == "(int, float)" assert instance1.overloaded(2.0, 2) == "(float, int)" - assert instance1.overloaded_const(3, 3.0) == "(int, float) const" - assert instance1.overloaded_const(4.0, 4) == "(float, int) const" + assert instance1.overloaded(3, 3) == "(int, int)" + assert instance1.overloaded(4., 4.) == "(float, float)" + assert instance1.overloaded_const(5, 5.0) == "(int, float) const" + assert instance1.overloaded_const(6.0, 6) == "(float, int) const" + assert instance1.overloaded_const(7, 7) == "(int, int) const" + assert instance1.overloaded_const(8., 8.) == "(float, float) const" + assert instance1.overloaded_float(1, 1) == "(float, float)" + assert instance1.overloaded_float(1, 1.) == "(float, float)" + assert instance1.overloaded_float(1., 1) == "(float, float)" + assert instance1.overloaded_float(1., 1.) == "(float, float)" assert instance1.value == 320 instance1.value = 100 diff --git a/tests/test_opaque_types.py b/tests/test_opaque_types.py index 7781943b4..1cd410208 100644 --- a/tests/test_opaque_types.py +++ b/tests/test_opaque_types.py @@ -28,9 +28,10 @@ def test_pointers(msg): print_opaque_list, return_null_str, get_null_str_value, return_unique_ptr, ConstructorStats) + living_before = ConstructorStats.get(ExampleMandA).alive() assert get_void_ptr_value(return_void_ptr()) == 0x1234 assert get_void_ptr_value(ExampleMandA()) # Should also work for other C++ types - assert ConstructorStats.get(ExampleMandA).alive() == 0 + assert ConstructorStats.get(ExampleMandA).alive() == living_before with pytest.raises(TypeError) as excinfo: get_void_ptr_value([1, 2, 3]) # This should not work