perf: Add object rvalue overload for accessors. Enables reference stealing (#3970)

* Add object rvalue overload for accessors. Enables reference stealing

* Fix comments

* Fix more comment typos

* Fix bug

* reorder declarations for clarity

* fix another perf bug

* should be static

* future proof operator overloads

* Fix perfect forwarding

* Add a couple of tests

* Remove errant include

* Improve test documentation

* Add dict test

* add object attr tests

* Optimize STL map caster and cleanup enum

* Reorder to match declarations

* adjust increfs

* Remove comment

* revert value change

* add missing move
This commit is contained in:
Aaron Gokaslan 2022-06-01 15:19:13 -04:00 committed by GitHub
parent 9f7b3f735a
commit 58802de41b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 12 deletions

View File

@ -2069,12 +2069,12 @@ struct enum_base {
str name(name_); str name(name_);
if (entries.contains(name)) { if (entries.contains(name)) {
std::string type_name = (std::string) str(m_base.attr("__name__")); std::string type_name = (std::string) str(m_base.attr("__name__"));
throw value_error(type_name + ": element \"" + std::string(name_) throw value_error(std::move(type_name) + ": element \"" + std::string(name_)
+ "\" already exists!"); + "\" already exists!");
} }
entries[name] = std::make_pair(value, doc); entries[name] = std::make_pair(value, doc);
m_base.attr(name) = value; m_base.attr(std::move(name)) = std::move(value);
} }
PYBIND11_NOINLINE void export_values() { PYBIND11_NOINLINE void export_values() {
@ -2610,7 +2610,7 @@ PYBIND11_NOINLINE void print(const tuple &args, const dict &kwargs) {
} }
auto write = file.attr("write"); auto write = file.attr("write");
write(line); write(std::move(line));
write(kwargs.contains("end") ? kwargs["end"] : str("\n")); write(kwargs.contains("end") ? kwargs["end"] : str("\n"));
if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) { if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {

View File

@ -85,7 +85,9 @@ public:
or `object` subclass causes a call to ``__setitem__``. or `object` subclass causes a call to ``__setitem__``.
\endrst */ \endrst */
item_accessor operator[](handle key) const; item_accessor operator[](handle key) const;
/// See above (the only difference is that they key is provided as a string literal) /// See above (the only difference is that the key's reference is stolen)
item_accessor operator[](object &&key) const;
/// See above (the only difference is that the key is provided as a string literal)
item_accessor operator[](const char *key) const; item_accessor operator[](const char *key) const;
/** \rst /** \rst
@ -95,7 +97,9 @@ public:
or `object` subclass causes a call to ``setattr``. or `object` subclass causes a call to ``setattr``.
\endrst */ \endrst */
obj_attr_accessor attr(handle key) const; obj_attr_accessor attr(handle key) const;
/// See above (the only difference is that they key is provided as a string literal) /// See above (the only difference is that the key's reference is stolen)
obj_attr_accessor attr(object &&key) const;
/// See above (the only difference is that the key is provided as a string literal)
str_attr_accessor attr(const char *key) const; str_attr_accessor attr(const char *key) const;
/** \rst /** \rst
@ -684,7 +688,7 @@ public:
} }
template <typename T> template <typename T>
void operator=(T &&value) & { void operator=(T &&value) & {
get_cache() = reinterpret_borrow<object>(object_or_cast(std::forward<T>(value))); get_cache() = ensure_object(object_or_cast(std::forward<T>(value)));
} }
template <typename T = Policy> template <typename T = Policy>
@ -712,6 +716,9 @@ public:
} }
private: private:
static object ensure_object(object &&o) { return std::move(o); }
static object ensure_object(handle h) { return reinterpret_borrow<object>(h); }
object &get_cache() const { object &get_cache() const {
if (!cache) { if (!cache) {
cache = Policy::get(obj, key); cache = Policy::get(obj, key);
@ -1711,7 +1718,10 @@ public:
size_t size() const { return (size_t) PyTuple_Size(m_ptr); } size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
bool empty() const { return size() == 0; } bool empty() const { return size() == 0; }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); } template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::tuple_iterator begin() const { return {*this, 0}; } detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
}; };
@ -1771,7 +1781,10 @@ public:
} }
bool empty() const { return size() == 0; } bool empty() const { return size() == 0; }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); } template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::sequence_iterator begin() const { return {*this, 0}; } detail::sequence_iterator begin() const { return {*this, 0}; }
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; } detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
}; };
@ -1790,7 +1803,10 @@ public:
size_t size() const { return (size_t) PyList_Size(m_ptr); } size_t size() const { return (size_t) PyList_Size(m_ptr); }
bool empty() const { return size() == 0; } bool empty() const { return size() == 0; }
detail::list_accessor operator[](size_t index) const { return {*this, index}; } detail::list_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); } template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
detail::item_accessor operator[](T &&o) const {
return object::operator[](std::forward<T>(o));
}
detail::list_iterator begin() const { return {*this, 0}; } detail::list_iterator begin() const { return {*this, 0}; }
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
template <typename T> template <typename T>
@ -2090,6 +2106,10 @@ item_accessor object_api<D>::operator[](handle key) const {
return {derived(), reinterpret_borrow<object>(key)}; return {derived(), reinterpret_borrow<object>(key)};
} }
template <typename D> template <typename D>
item_accessor object_api<D>::operator[](object &&key) const {
return {derived(), std::move(key)};
}
template <typename D>
item_accessor object_api<D>::operator[](const char *key) const { item_accessor object_api<D>::operator[](const char *key) const {
return {derived(), pybind11::str(key)}; return {derived(), pybind11::str(key)};
} }
@ -2098,6 +2118,10 @@ obj_attr_accessor object_api<D>::attr(handle key) const {
return {derived(), reinterpret_borrow<object>(key)}; return {derived(), reinterpret_borrow<object>(key)};
} }
template <typename D> template <typename D>
obj_attr_accessor object_api<D>::attr(object &&key) const {
return {derived(), std::move(key)};
}
template <typename D>
str_attr_accessor object_api<D>::attr(const char *key) const { str_attr_accessor object_api<D>::attr(const char *key) const {
return {derived(), key}; return {derived(), key};
} }

View File

@ -128,7 +128,7 @@ struct map_caster {
if (!key || !value) { if (!key || !value) {
return handle(); return handle();
} }
d[key] = value; d[std::move(key)] = std::move(value);
} }
return d.release(); return d.release();
} }

View File

@ -661,4 +661,38 @@ TEST_SUBMODULE(pytypes, m) {
double v = x.get_value(); double v = x.get_value();
return v * v; return v * v;
}); });
m.def("tuple_rvalue_getter", [](const py::tuple &tup) {
// tests accessing tuple object with rvalue int
for (size_t i = 0; i < tup.size(); i++) {
auto o = py::handle(tup[py::int_(i)]);
if (!o) {
throw py::value_error("tuple is malformed");
}
}
return tup;
});
m.def("list_rvalue_getter", [](const py::list &l) {
// tests accessing list with rvalue int
for (size_t i = 0; i < l.size(); i++) {
auto o = py::handle(l[py::int_(i)]);
if (!o) {
throw py::value_error("list is malformed");
}
}
return l;
});
m.def("populate_dict_rvalue", [](int population) {
auto d = py::dict();
for (int i = 0; i < population; i++) {
d[py::int_(i)] = py::int_(i);
}
return d;
});
m.def("populate_obj_str_attrs", [](py::object &o, int population) {
for (int i = 0; i < population; i++) {
o.attr(py::str(py::int_(i))) = py::str(py::int_(i));
}
return o;
});
} }

View File

@ -1,5 +1,6 @@
import contextlib import contextlib
import sys import sys
import types
import pytest import pytest
@ -320,8 +321,7 @@ def test_accessors():
def test_accessor_moves(): def test_accessor_moves():
inc_refs = m.accessor_moves() inc_refs = m.accessor_moves()
if inc_refs: if inc_refs:
# To be changed in PR #3970: [1, 0, 1, 0, ...] assert inc_refs == [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
assert inc_refs == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
else: else:
pytest.skip("Not defined: PYBIND11_HANDLE_REF_DEBUG") pytest.skip("Not defined: PYBIND11_HANDLE_REF_DEBUG")
@ -707,3 +707,30 @@ def test_implementation_details():
def test_external_float_(): def test_external_float_():
r1 = m.square_float_(2.0) r1 = m.square_float_(2.0)
assert r1 == 4.0 assert r1 == 4.0
def test_tuple_rvalue_getter():
pop = 1000
tup = tuple(range(pop))
m.tuple_rvalue_getter(tup)
def test_list_rvalue_getter():
pop = 1000
my_list = list(range(pop))
m.list_rvalue_getter(my_list)
def test_populate_dict_rvalue():
pop = 1000
my_dict = {i: i for i in range(pop)}
assert m.populate_dict_rvalue(pop) == my_dict
def test_populate_obj_str_attrs():
pop = 1000
o = types.SimpleNamespace(**{str(i): i for i in range(pop)})
new_o = m.populate_obj_str_attrs(o, pop)
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
assert all(isinstance(v, str) for v in new_attrs.values())
assert len(new_attrs) == pop