Adding smart_holder_type_caster_load::loaded_as_shared_ptr, currently bypassing smart_holder shared_ptr tracking completely, but the tests pass and are sanitizer clean.

This commit is contained in:
Ralf W. Grosse-Kunstleve 2021-01-23 10:09:20 -08:00
parent 11690b0b91
commit 0f82a0b014
3 changed files with 40 additions and 8 deletions

View File

@ -231,7 +231,8 @@ struct smart_holder_type_caster_load {
return true;
}
T *as_raw_ptr_unowned() {
T *loaded_as_raw_ptr_unowned() {
// BYPASSES smart_holder type checking completely.
if (load_impl.loaded_v_h_cpptype != nullptr) {
if (load_impl.reinterpret_cast_deemed_ok) {
return static_cast<T *>(loaded_smhldr_ptr->vptr.get());
@ -241,7 +242,13 @@ struct smart_holder_type_caster_load {
return static_cast<T *>(implicit_casted);
}
}
return loaded_smhldr_ptr->as_raw_ptr_unowned<T>();
return static_cast<T *>(loaded_smhldr_ptr->vptr.get());
}
std::shared_ptr<T> loaded_as_shared_ptr() {
T *raw_ptr = loaded_as_raw_ptr_unowned();
// BYPASSES smart_holder shared_ptr tracking completely.
return std::shared_ptr<T>(loaded_smhldr_ptr->vptr, raw_ptr);
}
std::unique_ptr<T> loaded_as_unique_ptr() {
@ -350,8 +357,8 @@ struct classh_type_caster : smart_holder_type_caster_load<T> {
operator T&&() && { return this->loaded_smhldr_ptr->template rvalue_ref<T>(); }
operator T const&() { return this->loaded_smhldr_ptr->template lvalue_ref<T>(); }
operator T&() { return this->loaded_smhldr_ptr->template lvalue_ref<T>(); }
operator T const*() { return this->as_raw_ptr_unowned(); }
operator T*() { return this->as_raw_ptr_unowned(); }
operator T const*() { return this->loaded_as_raw_ptr_unowned(); }
operator T*() { return this->loaded_as_raw_ptr_unowned(); }
// clang-format on
@ -487,7 +494,7 @@ struct classh_type_caster<std::shared_ptr<T>> : smart_holder_type_caster_load<T>
template <typename>
using cast_op_type = std::shared_ptr<T>;
operator std::shared_ptr<T>() { return this->loaded_smhldr_ptr->template as_shared_ptr<T>(); }
operator std::shared_ptr<T>() { return this->loaded_as_shared_ptr(); }
};
template <typename T>
@ -505,9 +512,7 @@ struct classh_type_caster<std::shared_ptr<T const>> : smart_holder_type_caster_l
template <typename>
using cast_op_type = std::shared_ptr<T const>;
operator std::shared_ptr<T const>() {
return this->loaded_smhldr_ptr->template as_shared_ptr<T>();
}
operator std::shared_ptr<T const>() { return this->loaded_as_shared_ptr(); } // Mutbl2Const
};
template <typename T>

View File

@ -2,6 +2,8 @@
#include <pybind11/classh.h>
#include <memory>
namespace pybind11_tests {
namespace classh_inheritance {
@ -25,6 +27,12 @@ inline base *rtrn_mptr_drvd_up_cast() { return new drvd; }
inline int pass_cptr_base(base const *b) { return b->id() + 11; }
inline int pass_cptr_drvd(drvd const *d) { return d->id() + 12; }
inline std::shared_ptr<drvd> rtrn_shmp_drvd() { return std::shared_ptr<drvd>(new drvd); }
inline std::shared_ptr<base> rtrn_shmp_drvd_up_cast() { return std::shared_ptr<drvd>(new drvd); }
inline int pass_shcp_base(std::shared_ptr<base const> b) { return b->id() + 21; }
inline int pass_shcp_drvd(std::shared_ptr<drvd const> d) { return d->id() + 22; }
// clang-format on
using base1 = base_template<110>;
@ -69,6 +77,11 @@ TEST_SUBMODULE(classh_inheritance, m) {
m.def("pass_cptr_base", pass_cptr_base);
m.def("pass_cptr_drvd", pass_cptr_drvd);
m.def("rtrn_shmp_drvd", rtrn_shmp_drvd);
m.def("rtrn_shmp_drvd_up_cast", rtrn_shmp_drvd_up_cast);
m.def("pass_shcp_base", pass_shcp_base);
m.def("pass_shcp_drvd", pass_shcp_drvd);
py::classh<base1>(m, "base1").def(py::init<>()); // __init__ needed for Python inheritance.
py::classh<base2>(m, "base2").def(py::init<>());
py::classh<drvd2, base1, base2>(m, "drvd2");

View File

@ -9,6 +9,12 @@ def test_rtrn_mptr_drvd_pass_cptr_base():
assert i == 2 * 100 + 11
def test_rtrn_shmp_drvd_pass_shcp_base():
d = m.rtrn_shmp_drvd()
i = m.pass_shcp_base(d) # load_impl Case 2a
assert i == 2 * 100 + 21
def test_rtrn_mptr_drvd_up_cast_pass_cptr_drvd():
b = m.rtrn_mptr_drvd_up_cast()
# the base return is down-cast immediately.
@ -17,6 +23,14 @@ def test_rtrn_mptr_drvd_up_cast_pass_cptr_drvd():
assert i == 2 * 100 + 12
def test_rtrn_shmp_drvd_up_cast_pass_shcp_drvd():
b = m.rtrn_shmp_drvd_up_cast()
# the base return is down-cast immediately.
assert b.__class__.__name__ == "drvd"
i = m.pass_shcp_drvd(b)
assert i == 2 * 100 + 22
def test_rtrn_mptr_drvd2_pass_cptr_bases():
d = m.rtrn_mptr_drvd2()
i1 = m.pass_cptr_base1(d) # load_impl Case 2c