Add `all_type_info_check_for_divergence()` and some tests.

This commit is contained in:
Ralf W. Grosse-Kunstleve 2023-11-02 17:03:20 -07:00
parent f3bb31e89f
commit 0a9599f775
3 changed files with 90 additions and 0 deletions

View File

@ -115,6 +115,40 @@ inline void all_type_info_add_base_most_derived_first(std::vector<type_info *> &
bases.push_back(addl_base);
}
inline void all_type_info_check_for_divergence(const std::vector<type_info *> &bases) {
using sz_t = std::size_t;
sz_t n = bases.size();
if (n < 3) {
return;
}
std::vector<sz_t> cluster_ids;
cluster_ids.reserve(n);
for (sz_t ci = 0; ci < n; ci++) {
cluster_ids.push_back(ci);
}
for (sz_t i = 0; i < n - 1; i++) {
if (cluster_ids[i] != i) {
continue;
}
for (sz_t j = i + 1; j < n; j++) {
if (PyType_IsSubtype(bases[i]->type, bases[j]->type) != 0) {
sz_t k = cluster_ids[j];
if (k == j) {
cluster_ids[j] = i;
} else {
PyErr_Format(
PyExc_TypeError,
"bases include diverging derived types: base=%s, derived1=%s, derived2=%s",
bases[j]->type->tp_name,
bases[k]->type->tp_name,
bases[i]->type->tp_name);
throw error_already_set();
}
}
}
}
}
// Populates a just-created cache entry.
PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
assert(bases.empty());
@ -168,6 +202,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
}
}
}
all_type_info_check_for_divergence(bases);
}
/**

View File

@ -26,6 +26,18 @@ private:
int drvd_value;
};
struct CppDrvd2 : CppBase {
explicit CppDrvd2(int value) : CppBase(value), drvd2_value(value * 5) {}
int get_drvd2_value() const { return drvd2_value; }
void reset_drvd2_value(int new_value) { drvd2_value = new_value; }
int get_base_value_from_drvd2() const { return get_base_value(); }
void reset_base_value_from_drvd2(int new_value) { reset_base_value(new_value); }
private:
int drvd2_value;
};
} // namespace test_python_multiple_inheritance
TEST_SUBMODULE(python_multiple_inheritance, m) {
@ -42,4 +54,11 @@ TEST_SUBMODULE(python_multiple_inheritance, m) {
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);
py::class_<CppDrvd2, CppBase>(m, "CppDrvd2")
.def(py::init<int>())
.def("get_drvd2_value", &CppDrvd2::get_drvd2_value)
.def("reset_drvd2_value", &CppDrvd2::reset_drvd2_value)
.def("get_base_value_from_drvd2", &CppDrvd2::get_base_value_from_drvd2)
.def("reset_base_value_from_drvd2", &CppDrvd2::reset_base_value_from_drvd2);
}

View File

@ -1,6 +1,8 @@
# Adapted from:
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py
import pytest
from pybind11_tests import python_multiple_inheritance as m
@ -12,6 +14,22 @@ class PPCC(PC, m.CppDrvd):
pass
class PPPCCC(PPCC, m.CppDrvd2):
pass
class PC1(m.CppDrvd):
pass
class PC2(m.CppDrvd2):
pass
class PCD(PC1, PC2):
pass
def test_PC():
d = PC(11)
assert d.get_base_value() == 11
@ -33,3 +51,21 @@ def test_PPCC():
d.reset_base_value_from_drvd(30)
assert d.get_base_value() == 30
assert d.get_base_value_from_drvd() == 30
def NOtest_PPPCCC():
# terminate called after throwing an instance of 'pybind11::error_already_set'
# what(): TypeError: bases include diverging derived types:
# base=pybind11_tests.python_multiple_inheritance.CppBase,
# derived1=pybind11_tests.python_multiple_inheritance.CppDrvd,
# derived2=pybind11_tests.python_multiple_inheritance.CppDrvd2
PPPCCC(11)
def test_PCD():
# This escapes all_type_info_check_for_divergence() because CppBase does not appear in bases.
with pytest.raises(
TypeError,
match=r"CppDrvd2\.__init__\(\) must be called when overriding __init__$",
):
PCD(11)