From 58802de41bc9c78425b66c3b6f22392241aac8de Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Wed, 1 Jun 2022 15:19:13 -0400 Subject: [PATCH] perf: Add object rvalue overload for accessors. Enables reference stealing (#3970) * Add object rvalue overload for accessors. Enables reference stealing * Fix comments * Fix more comment typos * Fix bug * reorder declarations for clarity * fix another perf bug * should be static * future proof operator overloads * Fix perfect forwarding * Add a couple of tests * Remove errant include * Improve test documentation * Add dict test * add object attr tests * Optimize STL map caster and cleanup enum * Reorder to match declarations * adjust increfs * Remove comment * revert value change * add missing move --- include/pybind11/pybind11.h | 6 +++--- include/pybind11/pytypes.h | 36 ++++++++++++++++++++++++++++++------ include/pybind11/stl.h | 2 +- tests/test_pytypes.cpp | 34 ++++++++++++++++++++++++++++++++++ tests/test_pytypes.py | 31 +++++++++++++++++++++++++++++-- 5 files changed, 97 insertions(+), 12 deletions(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index be206e1a6..cfa442067 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -2069,12 +2069,12 @@ struct enum_base { str name(name_); if (entries.contains(name)) { std::string type_name = (std::string) str(m_base.attr("__name__")); - throw value_error(type_name + ": element \"" + std::string(name_) + throw value_error(std::move(type_name) + ": element \"" + std::string(name_) + "\" already exists!"); } entries[name] = std::make_pair(value, doc); - m_base.attr(name) = value; + m_base.attr(std::move(name)) = std::move(value); } PYBIND11_NOINLINE void export_values() { @@ -2610,7 +2610,7 @@ PYBIND11_NOINLINE void print(const tuple &args, const dict &kwargs) { } auto write = file.attr("write"); - write(line); + write(std::move(line)); write(kwargs.contains("end") ? kwargs["end"] : str("\n")); if (kwargs.contains("flush") && kwargs["flush"].cast()) { diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 0c5690064..27807953b 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -85,7 +85,9 @@ public: or `object` subclass causes a call to ``__setitem__``. \endrst */ item_accessor operator[](handle key) const; - /// See above (the only difference is that they key is provided as a string literal) + /// See above (the only difference is that the key's reference is stolen) + item_accessor operator[](object &&key) const; + /// See above (the only difference is that the key is provided as a string literal) item_accessor operator[](const char *key) const; /** \rst @@ -95,7 +97,9 @@ public: or `object` subclass causes a call to ``setattr``. \endrst */ obj_attr_accessor attr(handle key) const; - /// See above (the only difference is that they key is provided as a string literal) + /// See above (the only difference is that the key's reference is stolen) + obj_attr_accessor attr(object &&key) const; + /// See above (the only difference is that the key is provided as a string literal) str_attr_accessor attr(const char *key) const; /** \rst @@ -684,7 +688,7 @@ public: } template void operator=(T &&value) & { - get_cache() = reinterpret_borrow(object_or_cast(std::forward(value))); + get_cache() = ensure_object(object_or_cast(std::forward(value))); } template @@ -712,6 +716,9 @@ public: } private: + static object ensure_object(object &&o) { return std::move(o); } + static object ensure_object(handle h) { return reinterpret_borrow(h); } + object &get_cache() const { if (!cache) { cache = Policy::get(obj, key); @@ -1711,7 +1718,10 @@ public: size_t size() const { return (size_t) PyTuple_Size(m_ptr); } bool empty() const { return size() == 0; } detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } + template ::value, int> = 0> + detail::item_accessor operator[](T &&o) const { + return object::operator[](std::forward(o)); + } detail::tuple_iterator begin() const { return {*this, 0}; } detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } }; @@ -1771,7 +1781,10 @@ public: } bool empty() const { return size() == 0; } detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } + template ::value, int> = 0> + detail::item_accessor operator[](T &&o) const { + return object::operator[](std::forward(o)); + } detail::sequence_iterator begin() const { return {*this, 0}; } detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; } }; @@ -1790,7 +1803,10 @@ public: size_t size() const { return (size_t) PyList_Size(m_ptr); } bool empty() const { return size() == 0; } detail::list_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } + template ::value, int> = 0> + detail::item_accessor operator[](T &&o) const { + return object::operator[](std::forward(o)); + } detail::list_iterator begin() const { return {*this, 0}; } detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } template @@ -2090,6 +2106,10 @@ item_accessor object_api::operator[](handle key) const { return {derived(), reinterpret_borrow(key)}; } template +item_accessor object_api::operator[](object &&key) const { + return {derived(), std::move(key)}; +} +template item_accessor object_api::operator[](const char *key) const { return {derived(), pybind11::str(key)}; } @@ -2098,6 +2118,10 @@ obj_attr_accessor object_api::attr(handle key) const { return {derived(), reinterpret_borrow(key)}; } template +obj_attr_accessor object_api::attr(object &&key) const { + return {derived(), std::move(key)}; +} +template str_attr_accessor object_api::attr(const char *key) const { return {derived(), key}; } diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 597bce61d..ab30ecac0 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -128,7 +128,7 @@ struct map_caster { if (!key || !value) { return handle(); } - d[key] = value; + d[std::move(key)] = std::move(value); } return d.release(); } diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index ef214acc3..cb81007c3 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -661,4 +661,38 @@ TEST_SUBMODULE(pytypes, m) { double v = x.get_value(); return v * v; }); + + m.def("tuple_rvalue_getter", [](const py::tuple &tup) { + // tests accessing tuple object with rvalue int + for (size_t i = 0; i < tup.size(); i++) { + auto o = py::handle(tup[py::int_(i)]); + if (!o) { + throw py::value_error("tuple is malformed"); + } + } + return tup; + }); + m.def("list_rvalue_getter", [](const py::list &l) { + // tests accessing list with rvalue int + for (size_t i = 0; i < l.size(); i++) { + auto o = py::handle(l[py::int_(i)]); + if (!o) { + throw py::value_error("list is malformed"); + } + } + return l; + }); + m.def("populate_dict_rvalue", [](int population) { + auto d = py::dict(); + for (int i = 0; i < population; i++) { + d[py::int_(i)] = py::int_(i); + } + return d; + }); + m.def("populate_obj_str_attrs", [](py::object &o, int population) { + for (int i = 0; i < population; i++) { + o.attr(py::str(py::int_(i))) = py::str(py::int_(i)); + } + return o; + }); } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index b91a7e156..3e9d51a27 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1,5 +1,6 @@ import contextlib import sys +import types import pytest @@ -320,8 +321,7 @@ def test_accessors(): def test_accessor_moves(): inc_refs = m.accessor_moves() if inc_refs: - # To be changed in PR #3970: [1, 0, 1, 0, ...] - assert inc_refs == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + assert inc_refs == [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0] else: pytest.skip("Not defined: PYBIND11_HANDLE_REF_DEBUG") @@ -707,3 +707,30 @@ def test_implementation_details(): def test_external_float_(): r1 = m.square_float_(2.0) assert r1 == 4.0 + + +def test_tuple_rvalue_getter(): + pop = 1000 + tup = tuple(range(pop)) + m.tuple_rvalue_getter(tup) + + +def test_list_rvalue_getter(): + pop = 1000 + my_list = list(range(pop)) + m.list_rvalue_getter(my_list) + + +def test_populate_dict_rvalue(): + pop = 1000 + my_dict = {i: i for i in range(pop)} + assert m.populate_dict_rvalue(pop) == my_dict + + +def test_populate_obj_str_attrs(): + pop = 1000 + o = types.SimpleNamespace(**{str(i): i for i in range(pop)}) + new_o = m.populate_obj_str_attrs(o, pop) + new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")} + assert all(isinstance(v, str) for v in new_attrs.values()) + assert len(new_attrs) == pop