From 6b52c838d78cb697f68761c75b82d2a1806aedde Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Tue, 6 Sep 2016 12:27:00 -0400 Subject: [PATCH] Allow passing base types as a template parameter This allows a slightly cleaner base type specification of: py::class_("Type") as an alternative to py::class_("Type", py::base()) As with the other template parameters, the order relative to the holder or trampoline types doesn't matter. This also includes a compile-time assertion failure if attempting to specify more than one base class (but is easily extendible to support multiple inheritance, someday, by updating the class_selector::set_bases function to set multiple bases). --- docs/advanced.rst | 13 +++++++------ docs/classes.rst | 16 ++++++++++++---- include/pybind11/cast.h | 10 +++++----- include/pybind11/common.h | 16 ++++++++-------- include/pybind11/pybind11.h | 33 +++++++++++++++++++++++--------- tests/test_inheritance.cpp | 9 +++++++++ tests/test_inheritance.py | 5 ++++- tests/test_issues.cpp | 2 +- tests/test_virtual_functions.cpp | 12 ++++++------ 9 files changed, 76 insertions(+), 40 deletions(-) diff --git a/docs/advanced.rst b/docs/advanced.rst index 6a4a02077..748f91e2e 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -1227,7 +1227,7 @@ section. the other existing exception translators. The ``py::exception`` wrapper for creating custom exceptions cannot (yet) - be used as a ``py::base``. + be used as a base type. .. _eigen: @@ -1811,16 +1811,17 @@ However, it can be acquired as follows: .def(py::init()) .def("bark", &Dog::bark); -Alternatively, we can rely on the ``base`` tag, which performs an automated -lookup of the corresponding Python type. However, this also requires invoking -the ``import`` function once to ensure that the pybind11 binding code of the -module ``basic`` has been executed. +Alternatively, you can specify the base class as a template parameter option to +``class_``, which performs an automated lookup of the corresponding Python +type. Like the above code, however, this also requires invoking the ``import`` +function once to ensure that the pybind11 binding code of the module ``basic`` +has been executed: .. code-block:: cpp py::module::import("basic"); - py::class_(m, "Dog", py::base()) + py::class_(m, "Dog") .def(py::init()) .def("bark", &Dog::bark); diff --git a/docs/classes.rst b/docs/classes.rst index 5afb21edc..80f378f68 100644 --- a/docs/classes.rst +++ b/docs/classes.rst @@ -185,9 +185,10 @@ inheritance relationship: std::string bark() const { return "woof!"; } }; -There are two different ways of indicating a hierarchical relationship to -pybind11: the first is by specifying the C++ base class explicitly during -construction using the ``base`` attribute: +There are three different ways of indicating a hierarchical relationship to +pybind11: the first specifies the C++ base class as an extra template +parameter of the :class:`class_`; the second uses a special ``base`` attribute +passed into the constructor: .. code-block:: cpp @@ -195,6 +196,12 @@ construction using the ``base`` attribute: .def(py::init()) .def_readwrite("name", &Pet::name); + // Method 1: template parameter: + py::class_(m, "Dog") + .def(py::init()) + .def("bark", &Dog::bark); + + // Method 2: py::base attribute: py::class_(m, "Dog", py::base() /* <- specify C++ parent type */) .def(py::init()) .def("bark", &Dog::bark); @@ -208,11 +215,12 @@ Alternatively, we can also assign a name to the previously bound ``Pet`` pet.def(py::init()) .def_readwrite("name", &Pet::name); + // Method 3: pass parent class_ object: py::class_(m, "Dog", pet /* <- specify Python parent type */) .def(py::init()) .def("bark", &Dog::bark); -Functionality-wise, both approaches are completely equivalent. Afterwards, +Functionality-wise, all three approaches are completely equivalent. Afterwards, instances will expose fields and methods of both types: .. code-block:: pycon diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index e5579e1b4..f6d86e3cf 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -798,12 +798,12 @@ protected: holder_type holder; }; +// PYBIND11_DECLARE_HOLDER_TYPE holder types: template struct is_holder_type : - // PYBIND11_DECLARE_HOLDER_TYPE holder types: - std::conditional, detail::type_caster>::value, - std::true_type, - std::false_type>::type {}; -template struct is_holder_type> : std::true_type {}; + std::is_base_of, detail::type_caster> {}; +// Specialization for always-supported unique_ptr holders: +template struct is_holder_type> : + std::true_type {}; template struct handle_type_name { static PYBIND11_DESCR name() { return _(); } }; template <> struct handle_type_name { static PYBIND11_DESCR name() { return _(PYBIND11_BYTES_NAME); } }; diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 6aeeb3d7a..c086e8796 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -358,20 +358,20 @@ template class P, typename T, typename... Ts> struct any_of_t : conditional_t::value, std::true_type, any_of_t> { }; #endif -// Extracts the first type from the template parameter pack matching the predicate, or void if none match. -template class Predicate, class... Ts> struct first_of; -template class Predicate> struct first_of { - using type = void; +// Extracts the first type from the template parameter pack matching the predicate, or Default if none match. +template class Predicate, class Default, class... Ts> struct first_of; +template class Predicate, class Default> struct first_of { + using type = Default; }; -template class Predicate, class T, class... Ts> -struct first_of { +template class Predicate, class Default, class T, class... Ts> +struct first_of { using type = typename std::conditional< Predicate::value, T, - typename first_of::type + typename first_of::type >::type; }; -template class Predicate, class... T> using first_of_t = typename first_of::type; +template class Predicate, class Default, class... T> using first_of_t = typename first_of::type; // Counts the number of types in the template parameter pack matching the predicate template class Predicate, typename... Ts> struct count_t; diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 3f04ba54a..0d16c7a1f 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -802,34 +802,47 @@ protected: static void releasebuffer(PyObject *, Py_buffer *view) { delete (buffer_info *) view->internal; } }; + +template class Predicate, typename... BaseTypes> struct class_selector; +template class Predicate, typename Base, typename... Bases> +struct class_selector { + static inline void set_bases(detail::type_record &record) { + if (Predicate::value) record.base_type = &typeid(Base); + else class_selector::set_bases(record); + } +}; +template class Predicate> +struct class_selector { + static inline void set_bases(detail::type_record &) {} +}; + NAMESPACE_END(detail) template class class_ : public detail::generic_type { template using is_holder = detail::is_holder_type; template using is_subtype = detail::bool_constant::value && !std::is_same::value>; + template using is_base_class = detail::bool_constant::value && !std::is_same::value>; template using is_valid_class_option = detail::bool_constant< is_holder::value || - is_subtype::value + is_subtype::value || + is_base_class::value >; - using extracted_holder_t = typename detail::first_of_t; - public: using type = type_; - using type_alias = detail::first_of_t; + using type_alias = detail::first_of_t; constexpr static bool has_alias = !std::is_void::value; - using holder_type = typename std::conditional< - std::is_void::value, - std::unique_ptr, - extracted_holder_t - >::type; + using holder_type = detail::first_of_t, options...>; using instance_type = detail::instance; static_assert(detail::all_of_t::value, "Unknown/invalid class_ template parameters provided"); + static_assert(detail::count_t::value <= 1, + "Invalid class_ base types: multiple inheritance is not supported"); + PYBIND11_OBJECT(class_, detail::generic_type, PyType_Check) template @@ -843,6 +856,8 @@ public: record.init_holder = init_holder; record.dealloc = dealloc; + detail::class_selector::set_bases(record); + /* Process optional arguments, if any */ detail::process_attributes::init(extra..., &record); diff --git a/tests/test_inheritance.cpp b/tests/test_inheritance.cpp index e1aad9920..798befffc 100644 --- a/tests/test_inheritance.cpp +++ b/tests/test_inheritance.cpp @@ -31,6 +31,11 @@ public: Rabbit(const std::string &name) : Pet(name, "parrot") {} }; +class Hamster : public Pet { +public: + Hamster(const std::string &name) : Pet(name, "rodent") {} +}; + std::string pet_name_species(const Pet &pet) { return pet.name() + " is a " + pet.species(); } @@ -59,6 +64,10 @@ test_initializer inheritance([](py::module &m) { py::class_(m, "Rabbit", py::base()) .def(py::init()); + /* And another: list parent in class template arguments */ + py::class_(m, "Hamster") + .def(py::init()); + m.def("pet_name_species", pet_name_species); m.def("dog_bark", dog_bark); diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index b55490cf6..d4cea8253 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -2,7 +2,7 @@ import pytest def test_inheritance(msg): - from pybind11_tests import Pet, Dog, Rabbit, dog_bark, pet_name_species + from pybind11_tests import Pet, Dog, Rabbit, Hamster, dog_bark, pet_name_species roger = Rabbit('Rabbit') assert roger.name() + " is a " + roger.species() == "Rabbit is a parrot" @@ -16,6 +16,9 @@ def test_inheritance(msg): assert molly.name() + " is a " + molly.species() == "Molly is a dog" assert pet_name_species(molly) == "Molly is a dog" + fred = Hamster('Fred') + assert fred.name() + " is a " + fred.species() == "Fred is a rodent" + assert dog_bark(molly) == "Woof!" with pytest.raises(TypeError) as excinfo: diff --git a/tests/test_issues.cpp b/tests/test_issues.cpp index 0a5a0b09a..55502fe18 100644 --- a/tests/test_issues.cpp +++ b/tests/test_issues.cpp @@ -96,7 +96,7 @@ void init_issues(py::module &m) { py::class_> (m2, "ElementBase"); - py::class_>(m2, "ElementA", py::base()) + py::class_>(m2, "ElementA") .def(py::init()) .def("value", &ElementA::value); diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index ac5b3fbbd..381b87e0e 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -258,12 +258,12 @@ void initialize_inherited_virtuals(py::module &m) { .def("unlucky_number", &A_Repeat::unlucky_number) .def("say_something", &A_Repeat::say_something) .def("say_everything", &A_Repeat::say_everything); - py::class_(m, "B_Repeat", py::base()) + py::class_(m, "B_Repeat") .def(py::init<>()) .def("lucky_number", &B_Repeat::lucky_number); - py::class_(m, "C_Repeat", py::base()) + py::class_(m, "C_Repeat") .def(py::init<>()); - py::class_(m, "D_Repeat", py::base()) + py::class_(m, "D_Repeat") .def(py::init<>()); // Method 2: Templated trampolines @@ -272,12 +272,12 @@ void initialize_inherited_virtuals(py::module &m) { .def("unlucky_number", &A_Tpl::unlucky_number) .def("say_something", &A_Tpl::say_something) .def("say_everything", &A_Tpl::say_everything); - py::class_>(m, "B_Tpl", py::base()) + py::class_>(m, "B_Tpl") .def(py::init<>()) .def("lucky_number", &B_Tpl::lucky_number); - py::class_>(m, "C_Tpl", py::base()) + py::class_>(m, "C_Tpl") .def(py::init<>()); - py::class_>(m, "D_Tpl", py::base()) + py::class_>(m, "D_Tpl") .def(py::init<>()); };