Add better support for pure virtual methods

This commit is contained in:
Arnim Balzer 2023-03-04 09:53:10 +00:00
parent b3c9615d2d
commit cbb0590ac7
No known key found for this signature in database
GPG Key ID: 785EF903F0917AB7
4 changed files with 133 additions and 19 deletions

View File

@ -867,6 +867,8 @@ which should look as follows:
.. [#f5] https://docs.python.org/3/library/copy.html .. [#f5] https://docs.python.org/3/library/copy.html
.. _multiple_inheritance:
Multiple Inheritance Multiple Inheritance
==================== ====================
@ -961,6 +963,63 @@ In effect, this mechanism enforces that the actual class the trampolines are usi
Since the trampolines only need to add their respective trampoline function registrations, the order of the Since the trampolines only need to add their respective trampoline function registrations, the order of the
inheritance of the various trampoline classes does not matter. inheritance of the various trampoline classes does not matter.
If the base classes contain pure virtual methods, another pattern can be applied to reduce the amount of
trampoline code that needs writing. The cost is an additional ``std::same`` call for each pure-virtual
method using the macro ``PYBIND11_OVERRIDE_TEMPLATE``.
.. code-block:: cpp
class Animal {
public:
virtual ~Animal() { }
virtual std::string go(int n_times) = 0;
};
class Dog : public Animal {
public:
std::string go(int n_times) override;
};
template <class AnimalBase = Animal, class PureVirtualBase = Animal>
class PyAnimal : public AnimalBase {
public:
using AnimalBase::AnimalBase; // Inherit constructors
std::string go(int n_times) override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, std::string, AnimalBase, go, n_times); }
};
using PyDog = PyAnimal<Dog>
class Mutant {
public:
virtual ~Mutant() { }
virtual void transform() = 0;
};
class XMen : public Mutant{
public:
virtual ~Mutant() { }
void transform() override;
};
template <class MutantBase = Mutant, class PyMutantBase = MutantBase, class PureVirtualBase = Mutant>
class PyMutant : public PyMutantBase {
public:
using PyMutantBase::PyMutantBase; // Inherit constructors
void transform() override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, void, MutantBase, transform, ); }
};
using PyXMen = PyMutant<XMen>
class Chimera : public Dog, public Mutant {
public:
virtual ~Chimera() { }
};
template <class ChimeraBase = Chimera, class PyChimeraBase = ChimeraBase>
class PyChimera : public PyMutant<ChimeraBase, PyAnimal<ChimeraBase, PyChimeraBase>> {
public:
using PyMutant<ChimeraBase, PyAnimal<ChimeraBase, PyChimeraBase>>::PyMutant; // Inherit constructors
};
The first parameter of the :c:macro:`PYBIND11_OVERRIDE_TEMPLATE` is the base class containing
the pure virtual method. Together with the cname parameter, an ``std::same`` call is used to
invoke either :c:macro:`PYBIND11_OVERRIDE_PURE` or :c:macro:`PYBIND11_OVERRIDE`. A corresponding
:c:macro:`PYBIND11_OVERRIDE_TEMPLATE_NAME` implementation is also available. The template parameter
``PureVirtualBase`` can be used in case the pure virtual methods are not implemented in a child class.
Module-local class bindings Module-local class bindings
=========================== ===========================

View File

@ -21,6 +21,7 @@
#include <memory> #include <memory>
#include <new> #include <new>
#include <string> #include <string>
#include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -2854,6 +2855,51 @@ function get_override(const T *this_ptr, const char *name) {
PYBIND11_OVERRIDE_PURE_NAME( \ PYBIND11_OVERRIDE_PURE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
/** \rst
Macro to wrap :c:macro:`PYBIND11_OVERRIDE_PURE_NAME` and :c:macro:`PYBIND11_OVERRIDE_PURE`
depending on the base class and cname parameter provided.
See :ref:`_multiple_inheritance` for more information.
.. code-block:: cpp
template<class AnimalBase = Animal>
class PyAnimal : public AnimalBase {
public:
// Inherit the constructors
using AnimalBase::AnimalBase;
// Trampoline (need one for each virtual function)
std::string go(int n_times) override {
PYBIND11_OVERRIDE_TEMPLATE(
Animal, // The base class containing the purely virtual implementation
std::string, // Return type (ret_type)
Dog, // Parent class (cname)
"_go", // Name of method in Python (name)
go, // Name of function in C++ (must match Python name) (fn)
n_times // Argument(s) (...)
);
}
};
\endrst */
#define PYBIND11_OVERRIDE_TEMPLATE_NAME(base, ret_type, cname, name, fn, ...) \
if (std::is_same<base, cname>::value) { \
PYBIND11_OVERRIDE_PURE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, fn, __VA_ARGS__); \
} else { \
PYBIND11_OVERRIDE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, fn, __VA_ARGS__); \
}
/** \rst
Macro to wrap :c:macro:`PYBIND11_OVERRIDE_NAME` and :c:macro:`PYBIND11_OVERRIDE`
depending on the base class and cname parameter provided.
Uses :c:macro:`PYBIND11_OVERRIDE_TEMPLATE_NAME` under the hood.
See :ref:`_multiple_inheritance` for more information.
\endrst */
#define PYBIND11_OVERRIDE_TEMPLATE(base, ret_type, cname, fn, ...) \
PYBIND11_OVERRIDE_TEMPLATE_NAME( \
PYBIND11_TYPE(base), PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
// Deprecated versions // Deprecated versions
PYBIND11_DEPRECATED("get_type_overload has been deprecated") PYBIND11_DEPRECATED("get_type_overload has been deprecated")

View File

@ -87,12 +87,12 @@ public:
ChainBaseA(const ChainBaseA &) = default; ChainBaseA(const ChainBaseA &) = default;
ChainBaseA(ChainBaseA &&) = default; ChainBaseA(ChainBaseA &&) = default;
virtual ~ChainBaseA() = default; virtual ~ChainBaseA() = default;
virtual int resultA() { return 1; } virtual int resultA() = 0;
}; };
class ChainChildA : public ChainBaseA { class ChainChildA : public ChainBaseA {
public: public:
using ChainBaseA::ChainBaseA; using ChainBaseA::ChainBaseA;
int resultA() override { return 2; } int resultA() override { return 1; }
}; };
class ChainBaseB { class ChainBaseB {
public: public:
@ -100,12 +100,12 @@ public:
ChainBaseB(const ChainBaseB &) = default; ChainBaseB(const ChainBaseB &) = default;
ChainBaseB(ChainBaseB &&) = default; ChainBaseB(ChainBaseB &&) = default;
virtual ~ChainBaseB() = default; virtual ~ChainBaseB() = default;
virtual std::string resultB() { return "A"; } virtual std::string resultB() = 0;
}; };
class ChainChildB : public ChainBaseB { class ChainChildB : public ChainBaseB {
public: public:
using ChainBaseB::ChainBaseB; using ChainBaseB::ChainBaseB;
std::string resultB() override { return "B"; } std::string resultB() override { return "A"; }
}; };
class Joined : public ChainChildA, public ChainChildB { class Joined : public ChainChildA, public ChainChildB {
public: public:
@ -114,22 +114,24 @@ public:
Joined(Joined &&) = default; Joined(Joined &&) = default;
}; };
template <class Base = ChainBaseA> template <class Base = ChainBaseA, typename PureVirtualBase = ChainBaseA>
class TrampolineA : public Base { class TrampolineA : public Base {
public: public:
using Base::Base; using Base::Base;
int resultA() override { PYBIND11_OVERLOAD(int, Base, resultA, ); } int resultA() override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, int, Base, resultA, ) }
}; };
template <class Base = ChainBaseB, class PyBase = Base> template <class Base = ChainBaseB, class PyBase = Base, typename PureVirtualBase = ChainBaseB>
class TrampolineB : public PyBase { class TrampolineB : public PyBase {
public: public:
using PyBase::PyBase; using PyBase::PyBase;
std::string resultB() override { PYBIND11_OVERLOAD(std::string, Base, resultB, ); } std::string resultB() override {
PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, std::string, Base, resultB, )
}
}; };
template <class Base = Joined> template <class Base = Joined, class PyBase = Base>
class TrampolineJoined : public TrampolineB<Base, TrampolineA<Base>> { class TrampolineJoined : public TrampolineB<Base, TrampolineA<Base, PyBase>> {
public: public:
using TrampolineB<Base, TrampolineA<Base>>::TrampolineB; using TrampolineB<Base, TrampolineA<Base, PyBase>>::TrampolineB;
}; };
} // namespace TrampolineNesting } // namespace TrampolineNesting
@ -393,7 +395,6 @@ TEST_SUBMODULE(multiple_inheritance, m) {
.def_readwrite("f", &MVF::f); .def_readwrite("f", &MVF::f);
namespace TN = TrampolineNesting; namespace TN = TrampolineNesting;
py::class_<TN::ChainBaseA, TN::TrampolineA<>>(m, "ChainBaseA") py::class_<TN::ChainBaseA, TN::TrampolineA<>>(m, "ChainBaseA")
.def(py::init<>()) .def(py::init<>())
.def("resultA", &TN::ChainBaseA::resultA); .def("resultA", &TN::ChainBaseA::resultA);
@ -404,6 +405,12 @@ TEST_SUBMODULE(multiple_inheritance, m) {
.def("resultB", &TN::ChainBaseB::resultB); .def("resultB", &TN::ChainBaseB::resultB);
py::class_<TN::ChainChildB, TN::ChainBaseB, TN::TrampolineB<TN::ChainChildB>>(m, "ChainChildB") py::class_<TN::ChainChildB, TN::ChainBaseB, TN::TrampolineB<TN::ChainChildB>>(m, "ChainChildB")
.def(py::init<>()); .def(py::init<>());
py::class_<TN::Joined, TN::ChainBaseA, TN::ChainBaseB, TN::TrampolineJoined<>>(m, "Joined") py::class_<TN::Joined, TN::ChainChildA, TN::ChainChildB, TN::TrampolineJoined<>>(m, "Joined")
.def(py::init<>()); .def(py::init<>());
} }
// Needed for MSVC linker
namespace TrampolineNesting {
int ChainBaseA::resultA() { return 0; }
std::string ChainBaseB::resultB() { return ""; }
} // namespace TrampolineNesting

View File

@ -494,9 +494,11 @@ def test_python_inherit_from_mi():
def test_trampoline_nesting(): def test_trampoline_nesting():
assert m.ChainBaseA().resultA() == 1 with pytest.raises(RuntimeError):
assert m.ChainChildA().resultA() == 2 m.ChainBaseA().resultA()
assert m.ChainBaseB().resultB() == "A" assert m.ChainChildA().resultA() == 1
assert m.ChainChildB().resultB() == "B" with pytest.raises(RuntimeError):
assert m.Joined().resultA() == 2 m.ChainBaseB().resultB()
assert m.Joined().resultB() == "B" assert m.ChainChildB().resultB() == "A"
assert m.Joined().resultA() == 1
assert m.Joined().resultB() == "A"