diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 1e76d7bc1..9f4196d21 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -95,11 +95,8 @@ public: subclass causes a corresponding call to ``__getitem__``. Assigning a `handle` or `object` subclass causes a call to ``__setitem__``. \endrst */ - item_accessor operator[](handle key) const; - /// See above (the only difference is that the key's reference is stolen) - item_accessor operator[](object &&key) const; - /// See above (the only difference is that the key is provided as a string literal) - item_accessor operator[](const char *key) const; + template + item_accessor operator[](T &&key) const; /** \rst Return an internal functor to access the object's attributes. Casting the @@ -2493,16 +2490,9 @@ iterator object_api::end() const { return iterator::sentinel(); } template -item_accessor object_api::operator[](handle key) const { - return {derived(), reinterpret_borrow(key)}; -} -template -item_accessor object_api::operator[](object &&key) const { - return {derived(), std::move(key)}; -} -template -item_accessor object_api::operator[](const char *key) const { - return {derived(), pybind11::str(key)}; +template +item_accessor object_api::operator[](T &&key) const { + return {derived(), detail::object_or_cast(std::forward(key)).template cast()}; } template obj_attr_accessor object_api::attr(handle key) const { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 19f65ce7e..9e83af2d0 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -179,6 +179,20 @@ TEST_SUBMODULE(pytypes, m) { py::print("list item {}: {}"_s.format(index++, item)); } }); + m.def("access_list", []() { + py::list l1 = py::list(); + l1.append(1); + l1.append(2); + return l1[1]; + }); + m.def("access_list_as_object", []() { + py::list l1 = py::list(); + l1.append(1); + l1.append(2); + py::object l2 = std::move(l1); + return l2[1]; + }); + // test_none m.def("get_none", [] { return py::none(); }); m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); }); @@ -228,6 +242,28 @@ TEST_SUBMODULE(pytypes, m) { [](const py::dict &dict, const py::object &val) { return dict.contains(val); }); m.def("dict_contains", [](const py::dict &dict, const char *val) { return dict.contains(val); }); + m.def("access_dict_with_str", []() { + py::dict d1 = py::dict(); + d1["x"] = 1; + return d1["x"]; + }); + m.def("access_dict_with_int", []() { + py::dict d1 = py::dict(); + d1[1] = 1; + return d1[1]; + }); + m.def("access_dict_as_object_with_str", []() { + py::dict d1 = py::dict(); + d1["x"] = 1; + py::object d2 = std::move(d1); + return d2["x"]; + }); + m.def("access_dict_as_object_with_int", []() { + py::dict d1 = py::dict(); + d1[1] = 1; + py::object d2 = std::move(d1); + return d2[1]; + }); // test_tuple m.def("tuple_no_args", []() { return py::tuple{}; }); @@ -235,6 +271,16 @@ TEST_SUBMODULE(pytypes, m) { m.def("tuple_size_t", []() { return py::tuple{(py::size_t) 0}; }); m.def("get_tuple", []() { return py::make_tuple(42, py::none(), "spam"); }); + m.def("access_tuple", [](py::tuple &tpl) { + return tpl[1]; + }); + m.def("access_tuple_as_object_with_int_index", [](py::object &tpl) { + return tpl[1]; + }); + m.def("access_tuple_as_object_with_int_index_multidimension", [](py::object &tpl) { + return tpl[1][2]; + }); + // test_simple_namespace m.def("get_simple_namespace", []() { auto ns = py::module_::import("types").attr("SimpleNamespace")( diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 6f015eec8..397a8dc72 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -90,6 +90,8 @@ def test_list(capture, doc): assert doc(m.get_list) == "get_list() -> list" assert doc(m.print_list) == "print_list(arg0: list) -> None" + assert m.access_list() == 2 + assert m.access_list_as_object() == 2 def test_none(doc): assert doc(m.get_none) == "get_none() -> None" @@ -177,6 +179,11 @@ def test_dict(capture, doc): assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} + assert m.access_dict_with_str() == 1 + assert m.access_dict_with_int() == 1 + assert m.access_dict_as_object_with_str() == 1 + assert m.access_dict_as_object_with_int() == 1 + class CustomContains: d = {"key": None} @@ -208,6 +215,9 @@ def test_tuple(): assert m.tuple_ssize_t() == () assert m.tuple_size_t() == () assert m.get_tuple() == (42, None, "spam") + assert m.access_tuple((1,2)) == 2 + assert m.access_tuple_as_object_with_int_index((1,2)) == 2 + assert m.access_tuple_as_object_with_int_index_multidimension(((1,2,3),(4,5,6))) == 6 def test_simple_namespace():