From 2097826346c3e894101663ef20b7e31a6a9f5bf8 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Mon, 29 Aug 2016 18:16:46 -0400 Subject: [PATCH] Fix template trampoline overload lookup failure Problem ======= The template trampoline pattern documented in PR #322 has a problem with virtual method overloads in intermediate classes in the inheritance chain between the trampoline class and the base class. For example, consider the following inheritance structure, where `B` is the actual class, `PyB` is the trampoline class, and `PyA` is an intermediate class adding A's methods into the trampoline: PyB -> PyA -> B -> A Suppose PyA has a method `some_method()` with a PYBIND11_OVERLOAD in it to overload the virtual `A::some_method()`. If a Python class `C` is defined that inherits from the pybind11-registered `B` and tries to provide an overriding `some_method()`, the PYBIND11_OVERLOADs declared in PyA fails to find this overloaded method, and thus never invoke it (or, if pure virtual and not overridden in PyB, raises an exception). This happens because the base (internal) `PYBIND11_OVERLOAD_INT` macro simply calls `get_overload(this, name)`; `get_overload()` then uses the inferred type of `this` to do a type lookup in `registered_types_cpp`. This is where it fails: `this` will be a `PyA *`, but `PyA` is neither the base type (`B`) nor the trampoline type (`PyB`). As a result, the overload fails and we get a failed overload lookup. The fix ======= The fix is relatively simple: we can cast `this` passed to `get_overload()` to a `const B *`, which lets get_overload look up the correct class. Since trampoline classes should be derived from `B` classes anyway, this cast should be perfectly safe. This does require adding the class name as an argument to the PYBIND11_OVERLOAD_INT macro, but leaves the public macro signatures unchanged. --- include/pybind11/pybind11.h | 8 ++++---- tests/test_virtual_functions.cpp | 9 +++++++-- tests/test_virtual_functions.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index f9e840f16..9f7b2be81 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1396,18 +1396,18 @@ template function get_overload(const T *this_ptr, const char *name) { return get_type_overload(this_ptr, (const detail::type_info *) it->second, name); } -#define PYBIND11_OVERLOAD_INT(ret_type, name, ...) { \ +#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ pybind11::gil_scoped_acquire gil; \ - pybind11::function overload = pybind11::get_overload(this, name); \ + pybind11::function overload = pybind11::get_overload(static_cast(this), name); \ if (overload) \ return overload(__VA_ARGS__).template cast(); } #define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ - PYBIND11_OVERLOAD_INT(ret_type, name, __VA_ARGS__) \ + PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ return cname::fn(__VA_ARGS__) #define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \ - PYBIND11_OVERLOAD_INT(ret_type, name, __VA_ARGS__) \ + PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\""); #define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \ diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index 8926292b8..f0f4702a9 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -145,6 +145,9 @@ public: \ for (unsigned i = 0; i < times; ++i) \ s += "hi"; \ return s; \ + } \ + std::string say_everything() { \ + return say_something(1) + " " + std::to_string(unlucky_number()); \ } A_METHODS }; @@ -253,7 +256,8 @@ void initialize_inherited_virtuals(py::module &m) { py::class_, PyA_Repeat>(m, "A_Repeat") .def(py::init<>()) .def("unlucky_number", &A_Repeat::unlucky_number) - .def("say_something", &A_Repeat::say_something); + .def("say_something", &A_Repeat::say_something) + .def("say_everything", &A_Repeat::say_everything); py::class_, PyB_Repeat>(m, "B_Repeat", py::base()) .def(py::init<>()) .def("lucky_number", &B_Repeat::lucky_number); @@ -266,7 +270,8 @@ void initialize_inherited_virtuals(py::module &m) { py::class_, PyA_Tpl<>>(m, "A_Tpl") .def(py::init<>()) .def("unlucky_number", &A_Tpl::unlucky_number) - .def("say_something", &A_Tpl::say_something); + .def("say_something", &A_Tpl::say_something) + .def("say_everything", &A_Tpl::say_everything); py::class_, PyB_Tpl<>>(m, "B_Tpl", py::base()) .def(py::init<>()) .def("lucky_number", &B_Tpl::lucky_number); diff --git a/tests/test_virtual_functions.py b/tests/test_virtual_functions.py index d65adc6ba..ef05de800 100644 --- a/tests/test_virtual_functions.py +++ b/tests/test_virtual_functions.py @@ -69,20 +69,24 @@ def test_inheriting_repeat(): obj = VI_AR() assert obj.say_something(3) == "hihihi" assert obj.unlucky_number() == 99 + assert obj.say_everything() == "hi 99" obj = VI_AT() assert obj.say_something(3) == "hihihi" assert obj.unlucky_number() == 999 + assert obj.say_everything() == "hi 999" for obj in [B_Repeat(), B_Tpl()]: assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 13 assert obj.lucky_number() == 7.0 + assert obj.say_everything() == "B says hi 1 times 13" for obj in [C_Repeat(), C_Tpl()]: assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 888.0 + assert obj.say_everything() == "B says hi 1 times 4444" class VI_CR(C_Repeat): def lucky_number(self): @@ -92,6 +96,7 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 889.25 + assert obj.say_everything() == "B says hi 1 times 4444" class VI_CT(C_Tpl): pass @@ -100,6 +105,7 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 888.0 + assert obj.say_everything() == "B says hi 1 times 4444" class VI_CCR(VI_CR): def lucky_number(self): @@ -109,6 +115,7 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 8892.5 + assert obj.say_everything() == "B says hi 1 times 4444" class VI_CCT(VI_CT): def lucky_number(self): @@ -118,6 +125,7 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 888000.0 + assert obj.say_everything() == "B says hi 1 times 4444" class VI_DR(D_Repeat): def unlucky_number(self): @@ -130,11 +138,13 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 4444 assert obj.lucky_number() == 888.0 + assert obj.say_everything() == "B says hi 1 times 4444" obj = VI_DR() assert obj.say_something(3) == "B says hi 3 times" assert obj.unlucky_number() == 123 assert obj.lucky_number() == 42.0 + assert obj.say_everything() == "B says hi 1 times 123" class VI_DT(D_Tpl): def say_something(self, times): @@ -150,6 +160,28 @@ def test_inheriting_repeat(): assert obj.say_something(3) == "VI_DT says: quack quack quack" assert obj.unlucky_number() == 1234 assert obj.lucky_number() == -4.25 + assert obj.say_everything() == "VI_DT says: quack 1234" + + class VI_DT2(VI_DT): + def say_something(self, times): + return "VI_DT2: " + ('QUACK' * times) + + def unlucky_number(self): + return -3 + + class VI_BT(B_Tpl): + def say_something(self, times): + return "VI_BT" * times + def unlucky_number(self): + return -7 + def lucky_number(self): + return -1.375 + + obj = VI_BT() + assert obj.say_something(3) == "VI_BTVI_BTVI_BT" + assert obj.unlucky_number() == -7 + assert obj.lucky_number() == -1.375 + assert obj.say_everything() == "VI_BT -7" @pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'), reason="NCVirt test broken on ICPC")