Use rvalue subcasting when casting an rvalue container

This updates the std::tuple, std::pair and `stl.h` type casters to
forward their contained value according to whether the container being
cast is an lvalue or rvalue reference.  This fixes an issue where
subcaster casts were always called with a const lvalue which meant
nested type casters didn't have the desired `cast()` overload invoked.
For example, this caused Eigen values in a tuple to end up with a
readonly flag (issue #935) and made it impossible to return a container
of move-only types (issue #853).

This fixes both issues by adding templated universal reference `cast()`
methods to the various container types that forward container elements
according to the container reference type.
This commit is contained in:
Jason Rhinelander 2017-07-03 19:12:09 -04:00
parent 897d71687e
commit b57281bb00
8 changed files with 144 additions and 35 deletions

View File

@ -1256,9 +1256,9 @@ public:
}; };
// Base implementation for std::tuple and std::pair // Base implementation for std::tuple and std::pair
template <template<typename...> class TupleType, typename... Tuple> class tuple_caster { template <template<typename...> class Tuple, typename... Ts> class tuple_caster {
using type = TupleType<Tuple...>; using type = Tuple<Ts...>;
static constexpr auto size = sizeof...(Tuple); static constexpr auto size = sizeof...(Ts);
using indices = make_index_sequence<size>; using indices = make_index_sequence<size>;
public: public:
@ -1271,12 +1271,13 @@ public:
return load_impl(seq, convert, indices{}); return load_impl(seq, convert, indices{});
} }
static handle cast(const type &src, return_value_policy policy, handle parent) { template <typename T>
return cast_impl(src, policy, parent, indices{}); static handle cast(T &&src, return_value_policy policy, handle parent) {
return cast_impl(std::forward<T>(src), policy, parent, indices{});
} }
static PYBIND11_DESCR name() { static PYBIND11_DESCR name() {
return type_descr(_("Tuple[") + detail::concat(make_caster<Tuple>::name()...) + _("]")); return type_descr(_("Tuple[") + detail::concat(make_caster<Ts>::name()...) + _("]"));
} }
template <typename T> using cast_op_type = type; template <typename T> using cast_op_type = type;
@ -1286,9 +1287,9 @@ public:
protected: protected:
template <size_t... Is> template <size_t... Is>
type implicit_cast(index_sequence<Is...>) & { return type(cast_op<Tuple>(std::get<Is>(subcasters))...); } type implicit_cast(index_sequence<Is...>) & { return type(cast_op<Ts>(std::get<Is>(subcasters))...); }
template <size_t... Is> template <size_t... Is>
type implicit_cast(index_sequence<Is...>) && { return type(cast_op<Tuple>(std::move(std::get<Is>(subcasters)))...); } type implicit_cast(index_sequence<Is...>) && { return type(cast_op<Ts>(std::move(std::get<Is>(subcasters)))...); }
static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; } static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; }
@ -1301,10 +1302,10 @@ protected:
} }
/* Implementation: Convert a C++ tuple into a Python tuple */ /* Implementation: Convert a C++ tuple into a Python tuple */
template <size_t... Is> template <typename T, size_t... Is>
static handle cast_impl(const type &src, return_value_policy policy, handle parent, index_sequence<Is...>) { static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence<Is...>) {
std::array<object, size> entries {{ std::array<object, size> entries{{
reinterpret_steal<object>(make_caster<Tuple>::cast(std::get<Is>(src), policy, parent))... reinterpret_steal<object>(make_caster<Ts>::cast(std::get<Is>(std::forward<T>(src)), policy, parent))...
}}; }};
for (const auto &entry: entries) for (const auto &entry: entries)
if (!entry) if (!entry)
@ -1316,14 +1317,14 @@ protected:
return result.release(); return result.release();
} }
TupleType<make_caster<Tuple>...> subcasters; Tuple<make_caster<Ts>...> subcasters;
}; };
template <typename T1, typename T2> class type_caster<std::pair<T1, T2>> template <typename T1, typename T2> class type_caster<std::pair<T1, T2>>
: public tuple_caster<std::pair, T1, T2> {}; : public tuple_caster<std::pair, T1, T2> {};
template <typename... Tuple> class type_caster<std::tuple<Tuple...>> template <typename... Ts> class type_caster<std::tuple<Ts...>>
: public tuple_caster<std::tuple, Tuple...> {}; : public tuple_caster<std::tuple, Ts...> {};
/// Helper class which abstracts away certain actions. Users can provide specializations for /// Helper class which abstracts away certain actions. Users can provide specializations for
/// custom holders, but it's only necessary if the type has a non-standard interface. /// custom holders, but it's only necessary if the type has a non-standard interface.

View File

@ -49,6 +49,19 @@
NAMESPACE_BEGIN(pybind11) NAMESPACE_BEGIN(pybind11)
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
template <typename T, typename U>
using forwarded_type = conditional_t<
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
/// used for forwarding a container's elements.
template <typename T, typename U>
forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
}
template <typename Type, typename Key> struct set_caster { template <typename Type, typename Key> struct set_caster {
using type = Type; using type = Type;
using key_conv = make_caster<Key>; using key_conv = make_caster<Key>;
@ -67,10 +80,11 @@ template <typename Type, typename Key> struct set_caster {
return true; return true;
} }
static handle cast(const type &src, return_value_policy policy, handle parent) { template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
pybind11::set s; pybind11::set s;
for (auto const &value: src) { for (auto &value: src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(value, policy, parent)); auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(value_)) if (!value_ || !s.add(value_))
return handle(); return handle();
} }
@ -100,11 +114,12 @@ template <typename Type, typename Key, typename Value> struct map_caster {
return true; return true;
} }
static handle cast(const Type &src, return_value_policy policy, handle parent) { template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d; dict d;
for (auto const &kv: src) { for (auto &kv: src) {
auto key = reinterpret_steal<object>(key_conv::cast(kv.first, policy, parent)); auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy, parent));
auto value = reinterpret_steal<object>(value_conv::cast(kv.second, policy, parent)); auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy, parent));
if (!key || !value) if (!key || !value)
return handle(); return handle();
d[key] = value; d[key] = value;
@ -140,11 +155,12 @@ private:
void reserve_maybe(sequence, void *) { } void reserve_maybe(sequence, void *) { }
public: public:
static handle cast(const Type &src, return_value_policy policy, handle parent) { template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size()); list l(src.size());
size_t index = 0; size_t index = 0;
for (auto const &value: src) { for (auto &value: src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(value, policy, parent)); auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_) if (!value_)
return handle(); return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
@ -193,11 +209,12 @@ public:
return true; return true;
} }
static handle cast(const ArrayType &src, return_value_policy policy, handle parent) { template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size()); list l(src.size());
size_t index = 0; size_t index = 0;
for (auto const &value: src) { for (auto &value: src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(value, policy, parent)); auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_) if (!value_)
return handle(); return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
@ -230,10 +247,11 @@ template <typename Key, typename Value, typename Hash, typename Equal, typename
template<typename T> struct optional_caster { template<typename T> struct optional_caster {
using value_conv = make_caster<typename T::value_type>; using value_conv = make_caster<typename T::value_type>;
static handle cast(const T& src, return_value_policy policy, handle parent) { template <typename T_>
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
if (!src) if (!src)
return none().inc_ref(); return none().inc_ref();
return value_conv::cast(*src, policy, parent); return value_conv::cast(*std::forward<T_>(src), policy, parent);
} }
bool load(handle src, bool convert) { bool load(handle src, bool convert) {

View File

@ -1,6 +1,11 @@
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#if defined(_MSC_VER) && _MSC_VER < 1910
// We get some really long type names here which causes MSVC 2015 to emit warnings
# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated
#endif
namespace py = pybind11; namespace py = pybind11;
using namespace pybind11::literals; using namespace pybind11::literals;
@ -43,3 +48,17 @@ public:
IncType &operator=(const IncType &) = delete; IncType &operator=(const IncType &) = delete;
IncType &operator=(IncType &&) = delete; IncType &operator=(IncType &&) = delete;
}; };
/// Custom cast-only type that casts to a string "rvalue" or "lvalue" depending on the cast context.
/// Used to test recursive casters (e.g. std::tuple, stl containers).
struct RValueCaster {};
NAMESPACE_BEGIN(pybind11)
NAMESPACE_BEGIN(detail)
template<> class type_caster<RValueCaster> {
public:
PYBIND11_TYPE_CASTER(RValueCaster, _("RValueCaster"));
static handle cast(RValueCaster &&, return_value_policy, handle) { return py::str("rvalue").release(); }
static handle cast(const RValueCaster &, return_value_policy, handle) { return py::str("lvalue").release(); }
};
NAMESPACE_END(detail)
NAMESPACE_END(pybind11)

View File

@ -86,6 +86,16 @@ TEST_SUBMODULE(builtin_casters, m) {
return std::make_tuple(std::get<2>(input), std::get<1>(input), std::get<0>(input)); return std::make_tuple(std::get<2>(input), std::get<1>(input), std::get<0>(input));
}, "Return a triple in reversed order"); }, "Return a triple in reversed order");
m.def("empty_tuple", []() { return std::tuple<>(); }); m.def("empty_tuple", []() { return std::tuple<>(); });
static std::pair<RValueCaster, RValueCaster> lvpair;
static std::tuple<RValueCaster, RValueCaster, RValueCaster> lvtuple;
static std::pair<RValueCaster, std::tuple<RValueCaster, std::pair<RValueCaster, RValueCaster>>> lvnested;
m.def("rvalue_pair", []() { return std::make_pair(RValueCaster{}, RValueCaster{}); });
m.def("lvalue_pair", []() -> const decltype(lvpair) & { return lvpair; });
m.def("rvalue_tuple", []() { return std::make_tuple(RValueCaster{}, RValueCaster{}, RValueCaster{}); });
m.def("lvalue_tuple", []() -> const decltype(lvtuple) & { return lvtuple; });
m.def("rvalue_nested", []() {
return std::make_pair(RValueCaster{}, std::make_tuple(RValueCaster{}, std::make_pair(RValueCaster{}, RValueCaster{}))); });
m.def("lvalue_nested", []() -> const decltype(lvnested) & { return lvnested; });
// test_builtins_cast_return_none // test_builtins_cast_return_none
m.def("return_none_string", []() -> std::string * { return nullptr; }); m.def("return_none_string", []() -> std::string * { return nullptr; });

View File

@ -201,6 +201,13 @@ def test_tuple(doc):
Return a triple in reversed order Return a triple in reversed order
""" """
assert m.rvalue_pair() == ("rvalue", "rvalue")
assert m.lvalue_pair() == ("lvalue", "lvalue")
assert m.rvalue_tuple() == ("rvalue", "rvalue", "rvalue")
assert m.lvalue_tuple() == ("lvalue", "lvalue", "lvalue")
assert m.rvalue_nested() == ("rvalue", ("rvalue", ("rvalue", "rvalue")))
assert m.lvalue_nested() == ("lvalue", ("lvalue", ("lvalue", "lvalue")))
def test_builtins_cast_return_none(): def test_builtins_cast_return_none():
"""Casters produced with PYBIND11_TYPE_CASTER() should convert nullptr to None""" """Casters produced with PYBIND11_TYPE_CASTER() should convert nullptr to None"""

View File

@ -61,6 +61,46 @@ TEST_SUBMODULE(stl, m) {
return set.count("key1") && set.count("key2") && set.count("key3"); return set.count("key1") && set.count("key2") && set.count("key3");
}); });
// test_recursive_casting
m.def("cast_rv_vector", []() { return std::vector<RValueCaster>{2}; });
m.def("cast_rv_array", []() { return std::array<RValueCaster, 3>(); });
// NB: map and set keys are `const`, so while we technically do move them (as `const Type &&`),
// casters don't typically do anything with that, which means they fall to the `const Type &`
// caster.
m.def("cast_rv_map", []() { return std::unordered_map<std::string, RValueCaster>{{"a", RValueCaster{}}}; });
m.def("cast_rv_nested", []() {
std::vector<std::array<std::list<std::unordered_map<std::string, RValueCaster>>, 2>> v;
v.emplace_back(); // add an array
v.back()[0].emplace_back(); // add a map to the array
v.back()[0].back().emplace("b", RValueCaster{});
v.back()[0].back().emplace("c", RValueCaster{});
v.back()[1].emplace_back(); // add a map to the array
v.back()[1].back().emplace("a", RValueCaster{});
return v;
});
static std::vector<RValueCaster> lvv{2};
static std::array<RValueCaster, 2> lva;
static std::unordered_map<std::string, RValueCaster> lvm{{"a", RValueCaster{}}, {"b", RValueCaster{}}};
static std::unordered_map<std::string, std::vector<std::list<std::array<RValueCaster, 2>>>> lvn;
lvn["a"].emplace_back(); // add a list
lvn["a"].back().emplace_back(); // add an array
lvn["a"].emplace_back(); // another list
lvn["a"].back().emplace_back(); // add an array
lvn["b"].emplace_back(); // add a list
lvn["b"].back().emplace_back(); // add an array
lvn["b"].back().emplace_back(); // add another array
m.def("cast_lv_vector", []() -> const decltype(lvv) & { return lvv; });
m.def("cast_lv_array", []() -> const decltype(lva) & { return lva; });
m.def("cast_lv_map", []() -> const decltype(lvm) & { return lvm; });
m.def("cast_lv_nested", []() -> const decltype(lvn) & { return lvn; });
// #853:
m.def("cast_unique_ptr_vector", []() {
std::vector<std::unique_ptr<UserType>> v;
v.emplace_back(new UserType{7});
v.emplace_back(new UserType{42});
return v;
});
struct MoveOutContainer { struct MoveOutContainer {
struct Value { int value; }; struct Value { int value; };

View File

@ -58,6 +58,25 @@ def test_set(doc):
assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool" assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool"
def test_recursive_casting():
"""Tests that stl casters preserve lvalue/rvalue context for container values"""
assert m.cast_rv_vector() == ["rvalue", "rvalue"]
assert m.cast_lv_vector() == ["lvalue", "lvalue"]
assert m.cast_rv_array() == ["rvalue", "rvalue", "rvalue"]
assert m.cast_lv_array() == ["lvalue", "lvalue"]
assert m.cast_rv_map() == {"a": "rvalue"}
assert m.cast_lv_map() == {"a": "lvalue", "b": "lvalue"}
assert m.cast_rv_nested() == [[[{"b": "rvalue", "c": "rvalue"}], [{"a": "rvalue"}]]]
assert m.cast_lv_nested() == {
"a": [[["lvalue", "lvalue"]], [["lvalue", "lvalue"]]],
"b": [[["lvalue", "lvalue"], ["lvalue", "lvalue"]]]
}
# Issue #853 test case:
z = m.cast_unique_ptr_vector()
assert z[0].value == 7 and z[1].value == 42
def test_move_out_container(): def test_move_out_container():
"""Properties use the `reference_internal` policy by default. If the underlying function """Properties use the `reference_internal` policy by default. If the underlying function
returns an rvalue, the policy is automatically changed to `move` to avoid referencing returns an rvalue, the policy is automatically changed to `move` to avoid referencing

View File

@ -15,11 +15,6 @@
#include <deque> #include <deque>
#include <unordered_map> #include <unordered_map>
#ifdef _MSC_VER
// We get some really long type names here which causes MSVC to emit warnings
# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated
#endif
class El { class El {
public: public:
El() = delete; El() = delete;