From d6df11602b8eac21d885f8e0bd587e90bfe73e71 Mon Sep 17 00:00:00 2001 From: Dustin Spicuzza Date: Mon, 17 Oct 2022 14:22:49 -0400 Subject: [PATCH] Allow specifying custom base classes - Useful for object hierarchies that don't use C++ inheritance (such as GObject) --- include/pybind11/attr.h | 36 +++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/test_custom_base.cpp | 39 ++++++++++++++++++++++++++++++++++++++ tests/test_custom_base.py | 23 ++++++++++++++++++++++ 4 files changed, 99 insertions(+) create mode 100644 tests/test_custom_base.cpp create mode 100644 tests/test_custom_base.py diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index db7cd8eff..aed6ba615 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -65,6 +65,31 @@ struct base { base() = default; }; +/** \rst + Annotation indicating that a class should appear to python to derive from + another given type. This is useful for wrapping type systems that don't + utilize standard C++ inheritance. + + You must provide a caster function that casts from the derived type to + the base type. As an example, standard C++ inheritance would do this: + + .. code-block:: c++ + + py::class_ cls(m, "Derived", py::custom_base([](void *o) { + return static_cast(reinterpret_cast(o)); + })); + + .. note:: This is an advanced feature. If you use this, you likely need + to implement polymorphic_type_hook for your type hierarchy. + \endrst */ +template struct custom_base { + using caster = void *(*)(void *); + + explicit custom_base(const caster &f) : fn(f) {} + caster fn; +}; + + /// Keep patient alive while nurse lives template struct keep_alive {}; @@ -551,6 +576,17 @@ struct process_attribute> : process_attribute_default> { static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } }; +/// Process a custom base attribute +template +struct process_attribute> : process_attribute_default> { + static void init(const custom_base &b, type_record *r) + { + r->add_base(typeid(T), b.fn); + // TODO: rename this to 'nonsimple'? + r->multiple_inheritance = true; + } +}; + /// Process a multiple inheritance attribute template <> struct process_attribute : process_attribute_default { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7296cd1b8..702aa830f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -125,6 +125,7 @@ set(PYBIND11_TEST_FILES test_const_name test_constants_and_functions test_copy_move + test_custom_base test_custom_type_casters test_custom_type_setup test_docstring_options diff --git a/tests/test_custom_base.cpp b/tests/test_custom_base.cpp new file mode 100644 index 000000000..dc0ec1f63 --- /dev/null +++ b/tests/test_custom_base.cpp @@ -0,0 +1,39 @@ +/* + tests/test_custom_base.cpp -- test custom type hierarchy support + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include "pybind11_tests.h" + +namespace { + +struct Base { + int i = 5; +}; + +struct Derived { + int j = 6; + + // just to prove the base can be anywhere + Base base; +}; + +} // namespace + +TEST_SUBMODULE(custom_base, m) { + + py::class_(m, "Base").def_readwrite("i", &Base::i); + + py::class_(m, "Derived", py::custom_base([](void *o) -> void * { + return &reinterpret_cast(o)->base; + })).def_readwrite("j", &Derived::j); + + m.def("create_derived", []() { return new Derived; }); + m.def("create_base", []() { return new Base; }); + + m.def("base_i", [](Base *b) { return b->i; }); + + m.def("derived_j", [](Derived *d) { return d->j; }); +}; \ No newline at end of file diff --git a/tests/test_custom_base.py b/tests/test_custom_base.py new file mode 100644 index 000000000..e59ff7a97 --- /dev/null +++ b/tests/test_custom_base.py @@ -0,0 +1,23 @@ +from pybind11_tests import custom_base as m + + +def test_cb_base(): + b = m.create_base() + + assert isinstance(b, m.Base) + assert b.i == 5 + + assert m.base_i(b) == 5 + + +def test_cb_derived(): + d = m.create_derived() + + assert isinstance(d, m.Derived) + assert isinstance(d, m.Base) + + assert d.i == 5 + assert d.j == 6 + + assert m.base_i(d) == 5 + assert m.derived_j(d) == 6