Fix template trampoline overload lookup failure

Problem
=======

The template trampoline pattern documented in PR #322 has a problem with
virtual method overloads in intermediate classes in the inheritance
chain between the trampoline class and the base class.

For example, consider the following inheritance structure, where `B` is
the actual class, `PyB<B>` is the trampoline class, and `PyA<B>` is an
intermediate class adding A's methods into the trampoline:

    PyB<B> -> PyA<B> -> B -> A

Suppose PyA<B> has a method `some_method()` with a PYBIND11_OVERLOAD in
it to overload the virtual `A::some_method()`.  If a Python class `C` is
defined that inherits from the pybind11-registered `B` and tries to
provide an overriding `some_method()`, the PYBIND11_OVERLOADs declared
in PyA<B> fails to find this overloaded method, and thus never invoke it
(or, if pure virtual and not overridden in PyB<B>, raises an exception).

This happens because the base (internal) `PYBIND11_OVERLOAD_INT` macro
simply calls `get_overload(this, name)`; `get_overload()` then uses the
inferred type of `this` to do a type lookup in `registered_types_cpp`.
This is where it fails: `this` will be a `PyA<B> *`, but `PyA<B>` is
neither the base type (`B`) nor the trampoline type (`PyB<B>`).  As a
result, the overload fails and we get a failed overload lookup.

The fix
=======

The fix is relatively simple: we can cast `this` passed to
`get_overload()` to a `const B *`, which lets get_overload look up the
correct class.  Since trampoline classes should be derived from `B`
classes anyway, this cast should be perfectly safe.

This does require adding the class name as an argument to the
PYBIND11_OVERLOAD_INT macro, but leaves the public macro signatures
unchanged.
This commit is contained in:
Jason Rhinelander 2016-08-29 18:16:46 -04:00
parent d9b3db3e64
commit 2097826346
3 changed files with 43 additions and 6 deletions

View File

@ -1396,18 +1396,18 @@ template <class T> function get_overload(const T *this_ptr, const char *name) {
return get_type_overload(this_ptr, (const detail::type_info *) it->second, name); return get_type_overload(this_ptr, (const detail::type_info *) it->second, name);
} }
#define PYBIND11_OVERLOAD_INT(ret_type, name, ...) { \ #define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \
pybind11::gil_scoped_acquire gil; \ pybind11::gil_scoped_acquire gil; \
pybind11::function overload = pybind11::get_overload(this, name); \ pybind11::function overload = pybind11::get_overload(static_cast<const cname *>(this), name); \
if (overload) \ if (overload) \
return overload(__VA_ARGS__).template cast<ret_type>(); } return overload(__VA_ARGS__).template cast<ret_type>(); }
#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ #define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \
PYBIND11_OVERLOAD_INT(ret_type, name, __VA_ARGS__) \ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
return cname::fn(__VA_ARGS__) return cname::fn(__VA_ARGS__)
#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \ #define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \
PYBIND11_OVERLOAD_INT(ret_type, name, __VA_ARGS__) \ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \
pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\""); pybind11::pybind11_fail("Tried to call pure virtual function \"" #cname "::" name "\"");
#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \ #define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \

View File

@ -145,6 +145,9 @@ public: \
for (unsigned i = 0; i < times; ++i) \ for (unsigned i = 0; i < times; ++i) \
s += "hi"; \ s += "hi"; \
return s; \ return s; \
} \
std::string say_everything() { \
return say_something(1) + " " + std::to_string(unlucky_number()); \
} }
A_METHODS A_METHODS
}; };
@ -253,7 +256,8 @@ void initialize_inherited_virtuals(py::module &m) {
py::class_<A_Repeat, std::unique_ptr<A_Repeat>, PyA_Repeat>(m, "A_Repeat") py::class_<A_Repeat, std::unique_ptr<A_Repeat>, PyA_Repeat>(m, "A_Repeat")
.def(py::init<>()) .def(py::init<>())
.def("unlucky_number", &A_Repeat::unlucky_number) .def("unlucky_number", &A_Repeat::unlucky_number)
.def("say_something", &A_Repeat::say_something); .def("say_something", &A_Repeat::say_something)
.def("say_everything", &A_Repeat::say_everything);
py::class_<B_Repeat, std::unique_ptr<B_Repeat>, PyB_Repeat>(m, "B_Repeat", py::base<A_Repeat>()) py::class_<B_Repeat, std::unique_ptr<B_Repeat>, PyB_Repeat>(m, "B_Repeat", py::base<A_Repeat>())
.def(py::init<>()) .def(py::init<>())
.def("lucky_number", &B_Repeat::lucky_number); .def("lucky_number", &B_Repeat::lucky_number);
@ -266,7 +270,8 @@ void initialize_inherited_virtuals(py::module &m) {
py::class_<A_Tpl, std::unique_ptr<A_Tpl>, PyA_Tpl<>>(m, "A_Tpl") py::class_<A_Tpl, std::unique_ptr<A_Tpl>, PyA_Tpl<>>(m, "A_Tpl")
.def(py::init<>()) .def(py::init<>())
.def("unlucky_number", &A_Tpl::unlucky_number) .def("unlucky_number", &A_Tpl::unlucky_number)
.def("say_something", &A_Tpl::say_something); .def("say_something", &A_Tpl::say_something)
.def("say_everything", &A_Tpl::say_everything);
py::class_<B_Tpl, std::unique_ptr<B_Tpl>, PyB_Tpl<>>(m, "B_Tpl", py::base<A_Tpl>()) py::class_<B_Tpl, std::unique_ptr<B_Tpl>, PyB_Tpl<>>(m, "B_Tpl", py::base<A_Tpl>())
.def(py::init<>()) .def(py::init<>())
.def("lucky_number", &B_Tpl::lucky_number); .def("lucky_number", &B_Tpl::lucky_number);

View File

@ -69,20 +69,24 @@ def test_inheriting_repeat():
obj = VI_AR() obj = VI_AR()
assert obj.say_something(3) == "hihihi" assert obj.say_something(3) == "hihihi"
assert obj.unlucky_number() == 99 assert obj.unlucky_number() == 99
assert obj.say_everything() == "hi 99"
obj = VI_AT() obj = VI_AT()
assert obj.say_something(3) == "hihihi" assert obj.say_something(3) == "hihihi"
assert obj.unlucky_number() == 999 assert obj.unlucky_number() == 999
assert obj.say_everything() == "hi 999"
for obj in [B_Repeat(), B_Tpl()]: for obj in [B_Repeat(), B_Tpl()]:
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 13 assert obj.unlucky_number() == 13
assert obj.lucky_number() == 7.0 assert obj.lucky_number() == 7.0
assert obj.say_everything() == "B says hi 1 times 13"
for obj in [C_Repeat(), C_Tpl()]: for obj in [C_Repeat(), C_Tpl()]:
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 888.0 assert obj.lucky_number() == 888.0
assert obj.say_everything() == "B says hi 1 times 4444"
class VI_CR(C_Repeat): class VI_CR(C_Repeat):
def lucky_number(self): def lucky_number(self):
@ -92,6 +96,7 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 889.25 assert obj.lucky_number() == 889.25
assert obj.say_everything() == "B says hi 1 times 4444"
class VI_CT(C_Tpl): class VI_CT(C_Tpl):
pass pass
@ -100,6 +105,7 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 888.0 assert obj.lucky_number() == 888.0
assert obj.say_everything() == "B says hi 1 times 4444"
class VI_CCR(VI_CR): class VI_CCR(VI_CR):
def lucky_number(self): def lucky_number(self):
@ -109,6 +115,7 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 8892.5 assert obj.lucky_number() == 8892.5
assert obj.say_everything() == "B says hi 1 times 4444"
class VI_CCT(VI_CT): class VI_CCT(VI_CT):
def lucky_number(self): def lucky_number(self):
@ -118,6 +125,7 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 888000.0 assert obj.lucky_number() == 888000.0
assert obj.say_everything() == "B says hi 1 times 4444"
class VI_DR(D_Repeat): class VI_DR(D_Repeat):
def unlucky_number(self): def unlucky_number(self):
@ -130,11 +138,13 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 4444 assert obj.unlucky_number() == 4444
assert obj.lucky_number() == 888.0 assert obj.lucky_number() == 888.0
assert obj.say_everything() == "B says hi 1 times 4444"
obj = VI_DR() obj = VI_DR()
assert obj.say_something(3) == "B says hi 3 times" assert obj.say_something(3) == "B says hi 3 times"
assert obj.unlucky_number() == 123 assert obj.unlucky_number() == 123
assert obj.lucky_number() == 42.0 assert obj.lucky_number() == 42.0
assert obj.say_everything() == "B says hi 1 times 123"
class VI_DT(D_Tpl): class VI_DT(D_Tpl):
def say_something(self, times): def say_something(self, times):
@ -150,6 +160,28 @@ def test_inheriting_repeat():
assert obj.say_something(3) == "VI_DT says: quack quack quack" assert obj.say_something(3) == "VI_DT says: quack quack quack"
assert obj.unlucky_number() == 1234 assert obj.unlucky_number() == 1234
assert obj.lucky_number() == -4.25 assert obj.lucky_number() == -4.25
assert obj.say_everything() == "VI_DT says: quack 1234"
class VI_DT2(VI_DT):
def say_something(self, times):
return "VI_DT2: " + ('QUACK' * times)
def unlucky_number(self):
return -3
class VI_BT(B_Tpl):
def say_something(self, times):
return "VI_BT" * times
def unlucky_number(self):
return -7
def lucky_number(self):
return -1.375
obj = VI_BT()
assert obj.say_something(3) == "VI_BTVI_BTVI_BT"
assert obj.unlucky_number() == -7
assert obj.lucky_number() == -1.375
assert obj.say_everything() == "VI_BT -7"
@pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'), @pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'),
reason="NCVirt test broken on ICPC") reason="NCVirt test broken on ICPC")