Add better support for pure virtual methods

This commit is contained in:
Arnim Balzer 2023-03-04 09:53:10 +00:00
parent b3c9615d2d
commit cbb0590ac7
No known key found for this signature in database
GPG Key ID: 785EF903F0917AB7
4 changed files with 133 additions and 19 deletions

View File

@ -867,6 +867,8 @@ which should look as follows:
.. [#f5] https://docs.python.org/3/library/copy.html
.. _multiple_inheritance:
Multiple Inheritance
====================
@ -961,6 +963,63 @@ In effect, this mechanism enforces that the actual class the trampolines are usi
Since the trampolines only need to add their respective trampoline function registrations, the order of the
inheritance of the various trampoline classes does not matter.
If the base classes contain pure virtual methods, another pattern can be applied to reduce the amount of
trampoline code that needs writing. The cost is an additional ``std::same`` call for each pure-virtual
method using the macro ``PYBIND11_OVERRIDE_TEMPLATE``.
.. code-block:: cpp
class Animal {
public:
virtual ~Animal() { }
virtual std::string go(int n_times) = 0;
};
class Dog : public Animal {
public:
std::string go(int n_times) override;
};
template <class AnimalBase = Animal, class PureVirtualBase = Animal>
class PyAnimal : public AnimalBase {
public:
using AnimalBase::AnimalBase; // Inherit constructors
std::string go(int n_times) override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, std::string, AnimalBase, go, n_times); }
};
using PyDog = PyAnimal<Dog>
class Mutant {
public:
virtual ~Mutant() { }
virtual void transform() = 0;
};
class XMen : public Mutant{
public:
virtual ~Mutant() { }
void transform() override;
};
template <class MutantBase = Mutant, class PyMutantBase = MutantBase, class PureVirtualBase = Mutant>
class PyMutant : public PyMutantBase {
public:
using PyMutantBase::PyMutantBase; // Inherit constructors
void transform() override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, void, MutantBase, transform, ); }
};
using PyXMen = PyMutant<XMen>
class Chimera : public Dog, public Mutant {
public:
virtual ~Chimera() { }
};
template <class ChimeraBase = Chimera, class PyChimeraBase = ChimeraBase>
class PyChimera : public PyMutant<ChimeraBase, PyAnimal<ChimeraBase, PyChimeraBase>> {
public:
using PyMutant<ChimeraBase, PyAnimal<ChimeraBase, PyChimeraBase>>::PyMutant; // Inherit constructors
};
The first parameter of the :c:macro:`PYBIND11_OVERRIDE_TEMPLATE` is the base class containing
the pure virtual method. Together with the cname parameter, an ``std::same`` call is used to
invoke either :c:macro:`PYBIND11_OVERRIDE_PURE` or :c:macro:`PYBIND11_OVERRIDE`. A corresponding
:c:macro:`PYBIND11_OVERRIDE_TEMPLATE_NAME` implementation is also available. The template parameter
``PureVirtualBase`` can be used in case the pure virtual methods are not implemented in a child class.
Module-local class bindings
===========================

View File

@ -21,6 +21,7 @@
#include <memory>
#include <new>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
@ -2854,6 +2855,51 @@ function get_override(const T *this_ptr, const char *name) {
PYBIND11_OVERRIDE_PURE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
/** \rst
Macro to wrap :c:macro:`PYBIND11_OVERRIDE_PURE_NAME` and :c:macro:`PYBIND11_OVERRIDE_PURE`
depending on the base class and cname parameter provided.
See :ref:`_multiple_inheritance` for more information.
.. code-block:: cpp
template<class AnimalBase = Animal>
class PyAnimal : public AnimalBase {
public:
// Inherit the constructors
using AnimalBase::AnimalBase;
// Trampoline (need one for each virtual function)
std::string go(int n_times) override {
PYBIND11_OVERRIDE_TEMPLATE(
Animal, // The base class containing the purely virtual implementation
std::string, // Return type (ret_type)
Dog, // Parent class (cname)
"_go", // Name of method in Python (name)
go, // Name of function in C++ (must match Python name) (fn)
n_times // Argument(s) (...)
);
}
};
\endrst */
#define PYBIND11_OVERRIDE_TEMPLATE_NAME(base, ret_type, cname, name, fn, ...) \
if (std::is_same<base, cname>::value) { \
PYBIND11_OVERRIDE_PURE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, fn, __VA_ARGS__); \
} else { \
PYBIND11_OVERRIDE_NAME( \
PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, fn, __VA_ARGS__); \
}
/** \rst
Macro to wrap :c:macro:`PYBIND11_OVERRIDE_NAME` and :c:macro:`PYBIND11_OVERRIDE`
depending on the base class and cname parameter provided.
Uses :c:macro:`PYBIND11_OVERRIDE_TEMPLATE_NAME` under the hood.
See :ref:`_multiple_inheritance` for more information.
\endrst */
#define PYBIND11_OVERRIDE_TEMPLATE(base, ret_type, cname, fn, ...) \
PYBIND11_OVERRIDE_TEMPLATE_NAME( \
PYBIND11_TYPE(base), PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__)
// Deprecated versions
PYBIND11_DEPRECATED("get_type_overload has been deprecated")

View File

@ -87,12 +87,12 @@ public:
ChainBaseA(const ChainBaseA &) = default;
ChainBaseA(ChainBaseA &&) = default;
virtual ~ChainBaseA() = default;
virtual int resultA() { return 1; }
virtual int resultA() = 0;
};
class ChainChildA : public ChainBaseA {
public:
using ChainBaseA::ChainBaseA;
int resultA() override { return 2; }
int resultA() override { return 1; }
};
class ChainBaseB {
public:
@ -100,12 +100,12 @@ public:
ChainBaseB(const ChainBaseB &) = default;
ChainBaseB(ChainBaseB &&) = default;
virtual ~ChainBaseB() = default;
virtual std::string resultB() { return "A"; }
virtual std::string resultB() = 0;
};
class ChainChildB : public ChainBaseB {
public:
using ChainBaseB::ChainBaseB;
std::string resultB() override { return "B"; }
std::string resultB() override { return "A"; }
};
class Joined : public ChainChildA, public ChainChildB {
public:
@ -114,22 +114,24 @@ public:
Joined(Joined &&) = default;
};
template <class Base = ChainBaseA>
template <class Base = ChainBaseA, typename PureVirtualBase = ChainBaseA>
class TrampolineA : public Base {
public:
using Base::Base;
int resultA() override { PYBIND11_OVERLOAD(int, Base, resultA, ); }
int resultA() override { PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, int, Base, resultA, ) }
};
template <class Base = ChainBaseB, class PyBase = Base>
template <class Base = ChainBaseB, class PyBase = Base, typename PureVirtualBase = ChainBaseB>
class TrampolineB : public PyBase {
public:
using PyBase::PyBase;
std::string resultB() override { PYBIND11_OVERLOAD(std::string, Base, resultB, ); }
std::string resultB() override {
PYBIND11_OVERRIDE_TEMPLATE(PureVirtualBase, std::string, Base, resultB, )
}
};
template <class Base = Joined>
class TrampolineJoined : public TrampolineB<Base, TrampolineA<Base>> {
template <class Base = Joined, class PyBase = Base>
class TrampolineJoined : public TrampolineB<Base, TrampolineA<Base, PyBase>> {
public:
using TrampolineB<Base, TrampolineA<Base>>::TrampolineB;
using TrampolineB<Base, TrampolineA<Base, PyBase>>::TrampolineB;
};
} // namespace TrampolineNesting
@ -393,7 +395,6 @@ TEST_SUBMODULE(multiple_inheritance, m) {
.def_readwrite("f", &MVF::f);
namespace TN = TrampolineNesting;
py::class_<TN::ChainBaseA, TN::TrampolineA<>>(m, "ChainBaseA")
.def(py::init<>())
.def("resultA", &TN::ChainBaseA::resultA);
@ -404,6 +405,12 @@ TEST_SUBMODULE(multiple_inheritance, m) {
.def("resultB", &TN::ChainBaseB::resultB);
py::class_<TN::ChainChildB, TN::ChainBaseB, TN::TrampolineB<TN::ChainChildB>>(m, "ChainChildB")
.def(py::init<>());
py::class_<TN::Joined, TN::ChainBaseA, TN::ChainBaseB, TN::TrampolineJoined<>>(m, "Joined")
py::class_<TN::Joined, TN::ChainChildA, TN::ChainChildB, TN::TrampolineJoined<>>(m, "Joined")
.def(py::init<>());
}
// Needed for MSVC linker
namespace TrampolineNesting {
int ChainBaseA::resultA() { return 0; }
std::string ChainBaseB::resultB() { return ""; }
} // namespace TrampolineNesting

View File

@ -494,9 +494,11 @@ def test_python_inherit_from_mi():
def test_trampoline_nesting():
assert m.ChainBaseA().resultA() == 1
assert m.ChainChildA().resultA() == 2
assert m.ChainBaseB().resultB() == "A"
assert m.ChainChildB().resultB() == "B"
assert m.Joined().resultA() == 2
assert m.Joined().resultB() == "B"
with pytest.raises(RuntimeError):
m.ChainBaseA().resultA()
assert m.ChainChildA().resultA() == 1
with pytest.raises(RuntimeError):
m.ChainBaseB().resultB()
assert m.ChainChildB().resultB() == "A"
assert m.Joined().resultA() == 1
assert m.Joined().resultB() == "A"