diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index c428e3f13..d2869264c 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -960,9 +960,14 @@ template class type_caster> { private: using caster_t = make_caster; caster_t subcaster; - using subcaster_cast_op_type = typename caster_t::template cast_op_type; - static_assert(std::is_same::type &, subcaster_cast_op_type>::value, - "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); + using reference_t = type&; + using subcaster_cast_op_type = + typename caster_t::template cast_op_type; + + static_assert(std::is_same::type &, subcaster_cast_op_type>::value || + std::is_same::value, + "std::reference_wrapper caster requires T to have a caster with an " + "`operator T &()` or `operator const T &()`"); public: bool load(handle src, bool convert) { return subcaster.load(src, convert); } static constexpr auto name = caster_t::name; @@ -973,7 +978,7 @@ public: return caster_t::cast(&src.get(), policy, parent); } template using cast_op_type = std::reference_wrapper; - operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } + operator std::reference_wrapper() { return cast_op(subcaster); } }; #define PYBIND11_TYPE_CASTER(type, py_name) \ diff --git a/tests/test_builtin_casters.cpp b/tests/test_builtin_casters.cpp index acc9f8fb3..e16c2d62b 100644 --- a/tests/test_builtin_casters.cpp +++ b/tests/test_builtin_casters.cpp @@ -15,6 +15,49 @@ # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif +struct ConstRefCasted { + int tag; +}; + +PYBIND11_NAMESPACE_BEGIN(pybind11) +PYBIND11_NAMESPACE_BEGIN(detail) +template <> +class type_caster { + public: + static constexpr auto name = _(); + + // Input is unimportant, a new value will always be constructed based on the + // cast operator. + bool load(handle, bool) { return true; } + + operator ConstRefCasted&&() { value = {1}; return std::move(value); } + operator ConstRefCasted&() { value = {2}; return value; } + operator ConstRefCasted*() { value = {3}; return &value; } + + operator const ConstRefCasted&() { value = {4}; return value; } + operator const ConstRefCasted*() { value = {5}; return &value; } + + // custom cast_op to explicitly propagate types to the conversion operators. + template + using cast_op_type = + /// const + conditional_t< + std::is_same, const ConstRefCasted*>::value, const ConstRefCasted*, + conditional_t< + std::is_same::value, const ConstRefCasted&, + /// non-const + conditional_t< + std::is_same, ConstRefCasted*>::value, ConstRefCasted*, + conditional_t< + std::is_same::value, ConstRefCasted&, + /* else */ConstRefCasted&&>>>>; + + private: + ConstRefCasted value = {0}; +}; +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(pybind11) + TEST_SUBMODULE(builtin_casters, m) { // test_simple_string m.def("string_roundtrip", [](const char *s) { return s; }); @@ -147,6 +190,17 @@ TEST_SUBMODULE(builtin_casters, m) { // test_reference_wrapper m.def("refwrap_builtin", [](std::reference_wrapper p) { return 10 * p.get(); }); m.def("refwrap_usertype", [](std::reference_wrapper p) { return p.get().value(); }); + m.def("refwrap_usertype_const", [](std::reference_wrapper p) { return p.get().value(); }); + + m.def("refwrap_lvalue", []() -> std::reference_wrapper { + static UserType x(1); + return std::ref(x); + }); + m.def("refwrap_lvalue_const", []() -> std::reference_wrapper { + static UserType x(1); + return std::cref(x); + }); + // Not currently supported (std::pair caster has return-by-value cast operator); // triggers static_assert failure. //m.def("refwrap_pair", [](std::reference_wrapper>) { }); @@ -189,4 +243,14 @@ TEST_SUBMODULE(builtin_casters, m) { py::object o = py::cast(v); return py::cast(o) == v; }); + + // Tests const/non-const propagation in cast_op. + m.def("takes", [](ConstRefCasted x) { return x.tag; }); + m.def("takes_move", [](ConstRefCasted&& x) { return x.tag; }); + m.def("takes_ptr", [](ConstRefCasted* x) { return x->tag; }); + m.def("takes_ref", [](ConstRefCasted& x) { return x.tag; }); + m.def("takes_ref_wrap", [](std::reference_wrapper x) { return x.get().tag; }); + m.def("takes_const_ptr", [](const ConstRefCasted* x) { return x->tag; }); + m.def("takes_const_ref", [](const ConstRefCasted& x) { return x.tag; }); + m.def("takes_const_ref_wrap", [](std::reference_wrapper x) { return x.get().tag; }); } diff --git a/tests/test_builtin_casters.py b/tests/test_builtin_casters.py index bd7996b62..39e8711df 100644 --- a/tests/test_builtin_casters.py +++ b/tests/test_builtin_casters.py @@ -315,6 +315,7 @@ def test_reference_wrapper(): """std::reference_wrapper for builtin and user types""" assert m.refwrap_builtin(42) == 420 assert m.refwrap_usertype(UserType(42)) == 42 + assert m.refwrap_usertype_const(UserType(42)) == 42 with pytest.raises(TypeError) as excinfo: m.refwrap_builtin(None) @@ -324,6 +325,9 @@ def test_reference_wrapper(): m.refwrap_usertype(None) assert "incompatible function arguments" in str(excinfo.value) + assert m.refwrap_lvalue().value == 1 + assert m.refwrap_lvalue_const().value == 1 + a1 = m.refwrap_list(copy=True) a2 = m.refwrap_list(copy=True) assert [x.value for x in a1] == [2, 3] @@ -421,3 +425,21 @@ def test_int_long(): def test_void_caster_2(): assert m.test_void_caster() + + +def test_const_ref_caster(): + """Verifies that const-ref is propagated through type_caster cast_op. + The returned ConstRefCasted type is a mimimal type that is constructed to + reference the casting mode used. + """ + x = False + assert m.takes(x) == 1 + assert m.takes_move(x) == 1 + + assert m.takes_ptr(x) == 3 + assert m.takes_ref(x) == 2 + assert m.takes_ref_wrap(x) == 2 + + assert m.takes_const_ptr(x) == 5 + assert m.takes_const_ref(x) == 4 + assert m.takes_const_ref_wrap(x) == 4