From 0d44d720cb1b911de2d96849860d0c5beb368f95 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 14 Aug 2024 01:42:51 +0700 Subject: [PATCH] Make stl.h `list|set|map_caster` more user friendly. (#4686) * Add `test_pass_std_vector_int()`, `test_pass_std_set_int()` in test_stl * Change `list_caster` to also accept generator objects (`PyGen_Check(src.ptr()`). Note for completeness: This is a more conservative change than https://github.com/google/pywrapcc/pull/30042 * Drop in (currently unpublished) PyCLIF code, use in `list_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdSet()` in `set_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdMap()` in `map_caster`, add tests. * Simplify `list_caster` `load()` implementation, push str/bytes check into `PyObjectTypeIsConvertibleToStdVector()`. * clang-tidy cleanup with a few extra `(... != 0)` to be more consistent. * Also use `PyObjectTypeIsConvertibleToStdVector()` in `array_caster`. * Update comment pointing to clif/python/runtime.cc (code is unchanged). * Comprehensive test coverage, enhanced set_caster load implementation. * Resolve clang-tidy eror. * Add a long C++ comment explaining what led to the `PyObjectTypeIsConvertibleTo*()` implementations. * Minor function name change in test. * strcmp -> std::strcmp (thanks @Skylion007 for catching this) * Add `PyCallable_Check(items)` in `PyObjectTypeIsConvertibleToStdMap()` * Resolve clang-tidy error * Use `PyMapping_Items()` instead of `src.attr("items")()`, to be internally consistent with `PyMapping_Check()` * Update link to PyCLIF sources. * Fix typo (thanks @wangxf123456 for catching this) * Add `test_pass_std_vector_int()`, `test_pass_std_set_int()` in test_stl * Change `list_caster` to also accept generator objects (`PyGen_Check(src.ptr()`). Note for completeness: This is a more conservative change than https://github.com/google/pywrapcc/pull/30042 * Drop in (currently unpublished) PyCLIF code, use in `list_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdSet()` in `set_caster`, adjust tests. * Use `PyObjectTypeIsConvertibleToStdMap()` in `map_caster`, add tests. * Simplify `list_caster` `load()` implementation, push str/bytes check into `PyObjectTypeIsConvertibleToStdVector()`. * clang-tidy cleanup with a few extra `(... != 0)` to be more consistent. * Also use `PyObjectTypeIsConvertibleToStdVector()` in `array_caster`. * Update comment pointing to clif/python/runtime.cc (code is unchanged). * Comprehensive test coverage, enhanced set_caster load implementation. * Resolve clang-tidy eror. * Add a long C++ comment explaining what led to the `PyObjectTypeIsConvertibleTo*()` implementations. * Minor function name change in test. * strcmp -> std::strcmp (thanks @Skylion007 for catching this) * Add `PyCallable_Check(items)` in `PyObjectTypeIsConvertibleToStdMap()` * Resolve clang-tidy error * Use `PyMapping_Items()` instead of `src.attr("items")()`, to be internally consistent with `PyMapping_Check()` * Update link to PyCLIF sources. * Fix typo (thanks @wangxf123456 for catching this) * Fix typo discovered by new version of codespell. --- include/pybind11/stl.h | 212 ++++++++++++++++++++++++++++++++++------- tests/test_stl.cpp | 34 +++++++ tests/test_stl.py | 126 ++++++++++++++++++++++++ 3 files changed, 338 insertions(+), 34 deletions(-) diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 71bc5902e..096301bc7 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -13,6 +13,7 @@ #include "detail/common.h" #include +#include #include #include #include @@ -35,6 +36,89 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) +// +// Begin: Equivalent of +// https://github.com/google/clif/blob/ae4eee1de07cdf115c0c9bf9fec9ff28efce6f6c/clif/python/runtime.cc#L388-L438 +/* +The three `PyObjectTypeIsConvertibleTo*()` functions below are +the result of converging the behaviors of pybind11 and PyCLIF +(http://github.com/google/clif). + +Originally PyCLIF was extremely far on the permissive side of the spectrum, +while pybind11 was very far on the strict side. Originally PyCLIF accepted any +Python iterable as input for a C++ `vector`/`set`/`map` argument, as long as +the elements were convertible. The obvious (in hindsight) problem was that +any empty Python iterable could be passed to any of these C++ types, e.g. `{}` +was accepted for C++ `vector`/`set` arguments, or `[]` for C++ `map` arguments. + +The functions below strike a practical permissive-vs-strict compromise, +informed by tens of thousands of use cases in the wild. A main objective is +to prevent accidents and improve readability: + +- Python literals must match the C++ types. + +- For C++ `set`: The potentially reducing conversion from a Python sequence + (e.g. Python `list` or `tuple`) to a C++ `set` must be explicit, by going + through a Python `set`. + +- However, a Python `set` can still be passed to a C++ `vector`. The rationale + is that this conversion is not reducing. Implicit conversions of this kind + are also fairly commonly used, therefore enforcing explicit conversions + would have an unfavorable cost : benefit ratio; more sloppily speaking, + such an enforcement would be more annoying than helpful. +*/ + +inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj, + std::initializer_list tp_names) { + if (PyType_Check(obj)) { + return false; + } + const char *obj_tp_name = Py_TYPE(obj)->tp_name; + for (const auto *tp_name : tp_names) { + if (std::strcmp(obj_tp_name, tp_name) == 0) { + return true; + } + } + return false; +} + +inline bool PyObjectTypeIsConvertibleToStdVector(PyObject *obj) { + if (PySequence_Check(obj) != 0) { + return !PyUnicode_Check(obj) && !PyBytes_Check(obj); + } + return (PyGen_Check(obj) != 0) || (PyAnySet_Check(obj) != 0) + || PyObjectIsInstanceWithOneOfTpNames( + obj, {"dict_keys", "dict_values", "dict_items", "map", "zip"}); +} + +inline bool PyObjectTypeIsConvertibleToStdSet(PyObject *obj) { + return (PyAnySet_Check(obj) != 0) || PyObjectIsInstanceWithOneOfTpNames(obj, {"dict_keys"}); +} + +inline bool PyObjectTypeIsConvertibleToStdMap(PyObject *obj) { + if (PyDict_Check(obj)) { + return true; + } + // Implicit requirement in the conditions below: + // A type with `.__getitem__()` & `.items()` methods must implement these + // to be compatible with https://docs.python.org/3/c-api/mapping.html + if (PyMapping_Check(obj) == 0) { + return false; + } + PyObject *items = PyObject_GetAttrString(obj, "items"); + if (items == nullptr) { + PyErr_Clear(); + return false; + } + bool is_convertible = (PyCallable_Check(items) != 0); + Py_DECREF(items); + return is_convertible; +} + +// +// End: Equivalent of clif/python/runtime.cc +// + /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for /// forwarding a container element). Typically used indirect via forwarded_type(), below. template @@ -66,17 +150,10 @@ private: } void reserve_maybe(const anyset &, void *) {} -public: - bool load(handle src, bool convert) { - if (!isinstance(src)) { - return false; - } - auto s = reinterpret_borrow(src); - value.clear(); - reserve_maybe(s, &value); - for (auto entry : s) { + bool convert_iterable(const iterable &itbl, bool convert) { + for (const auto &it : itbl) { key_conv conv; - if (!conv.load(entry, convert)) { + if (!conv.load(it, convert)) { return false; } value.insert(cast_op(std::move(conv))); @@ -84,6 +161,29 @@ public: return true; } + bool convert_anyset(anyset s, bool convert) { + value.clear(); + reserve_maybe(s, &value); + return convert_iterable(s, convert); + } + +public: + bool load(handle src, bool convert) { + if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) { + return false; + } + if (isinstance(src)) { + value.clear(); + return convert_anyset(reinterpret_borrow(src), convert); + } + if (!convert) { + return false; + } + assert(isinstance(src)); + value.clear(); + return convert_iterable(reinterpret_borrow(src), convert); + } + template static handle cast(T &&src, return_value_policy policy, handle parent) { if (!std::is_lvalue_reference::value) { @@ -115,15 +215,10 @@ private: } void reserve_maybe(const dict &, void *) {} -public: - bool load(handle src, bool convert) { - if (!isinstance(src)) { - return false; - } - auto d = reinterpret_borrow(src); + bool convert_elements(const dict &d, bool convert) { value.clear(); reserve_maybe(d, &value); - for (auto it : d) { + for (const auto &it : d) { key_conv kconv; value_conv vconv; if (!kconv.load(it.first.ptr(), convert) || !vconv.load(it.second.ptr(), convert)) { @@ -134,6 +229,25 @@ public: return true; } +public: + bool load(handle src, bool convert) { + if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) { + return false; + } + if (isinstance(src)) { + return convert_elements(reinterpret_borrow(src), convert); + } + if (!convert) { + return false; + } + auto items = reinterpret_steal(PyMapping_Items(src.ptr())); + if (!items) { + throw error_already_set(); + } + assert(isinstance(items)); + return convert_elements(dict(reinterpret_borrow(items)), convert); + } + template static handle cast(T &&src, return_value_policy policy, handle parent) { dict d; @@ -166,20 +280,21 @@ struct list_caster { using value_conv = make_caster; bool load(handle src, bool convert) { - if (!isinstance(src) || isinstance(src) || isinstance(src)) { + if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) { return false; } - auto s = reinterpret_borrow(src); - value.clear(); - reserve_maybe(s, &value); - for (const auto &it : s) { - value_conv conv; - if (!conv.load(it, convert)) { - return false; - } - value.push_back(cast_op(std::move(conv))); + if (isinstance(src)) { + return convert_elements(src, convert); } - return true; + if (!convert) { + return false; + } + // Designed to be behavior-equivalent to passing tuple(src) from Python: + // The conversion to a tuple will first exhaust the generator object, to ensure that + // the generator is not left in an unpredictable (to the caller) partially-consumed + // state. + assert(isinstance(src)); + return convert_elements(tuple(reinterpret_borrow(src)), convert); } private: @@ -189,6 +304,20 @@ private: } void reserve_maybe(const sequence &, void *) {} + bool convert_elements(handle seq, bool convert) { + auto s = reinterpret_borrow(seq); + value.clear(); + reserve_maybe(s, &value); + for (const auto &it : seq) { + value_conv conv; + if (!conv.load(it, convert)) { + return false; + } + value.push_back(cast_op(std::move(conv))); + } + return true; + } + public: template static handle cast(T &&src, return_value_policy policy, handle parent) { @@ -237,12 +366,8 @@ private: return size == Size; } -public: - bool load(handle src, bool convert) { - if (!isinstance(src)) { - return false; - } - auto l = reinterpret_borrow(src); + bool convert_elements(handle seq, bool convert) { + auto l = reinterpret_borrow(seq); if (!require_size(l.size())) { return false; } @@ -257,6 +382,25 @@ public: return true; } +public: + bool load(handle src, bool convert) { + if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) { + return false; + } + if (isinstance(src)) { + return convert_elements(src, convert); + } + if (!convert) { + return false; + } + // Designed to be behavior-equivalent to passing tuple(src) from Python: + // The conversion to a tuple will first exhaust the generator object, to ensure that + // the generator is not left in an unpredictable (to the caller) partially-consumed + // state. + assert(isinstance(src)); + return convert_elements(tuple(reinterpret_borrow(src)), convert); + } + template static handle cast(T &&src, return_value_policy policy, handle parent) { list l(src.size()); diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 48c907ff3..c35f0d4ca 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -167,6 +167,14 @@ struct type_caster> } // namespace detail } // namespace PYBIND11_NAMESPACE +int pass_std_vector_int(const std::vector &v) { + int zum = 100; + for (const int i : v) { + zum += 2 * i; + } + return zum; +} + TEST_SUBMODULE(stl, m) { // test_vector m.def("cast_vector", []() { return std::vector{1}; }); @@ -546,4 +554,30 @@ TEST_SUBMODULE(stl, m) { []() { return new std::vector(4513); }, // Without explicitly specifying `take_ownership`, this function leaks. py::return_value_policy::take_ownership); + + m.def("pass_std_vector_int", pass_std_vector_int); + m.def("pass_std_vector_pair_int", [](const std::vector> &v) { + int zum = 0; + for (const auto &ij : v) { + zum += ij.first * 100 + ij.second; + } + return zum; + }); + m.def("pass_std_array_int_2", [](const std::array &a) { + return pass_std_vector_int(std::vector(a.begin(), a.end())) + 1; + }); + m.def("pass_std_set_int", [](const std::set &s) { + int zum = 200; + for (const int i : s) { + zum += 3 * i; + } + return zum; + }); + m.def("pass_std_map_int", [](const std::map &m) { + int zum = 500; + for (const auto &p : m) { + zum += p.first * 1000 + p.second; + } + return zum; + }); } diff --git a/tests/test_stl.py b/tests/test_stl.py index 65fda54cc..1896c3221 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -381,3 +381,129 @@ def test_return_vector_bool_raw_ptr(): v = m.return_vector_bool_raw_ptr() assert isinstance(v, list) assert len(v) == 4513 + + +@pytest.mark.parametrize( + ("fn", "offset"), [(m.pass_std_vector_int, 0), (m.pass_std_array_int_2, 1)] +) +def test_pass_std_vector_int(fn, offset): + assert fn([7, 13]) == 140 + offset + assert fn({6, 2}) == 116 + offset + assert fn({"x": 8, "y": 11}.values()) == 138 + offset + assert fn({3: None, 9: None}.keys()) == 124 + offset + assert fn(i for i in [4, 17]) == 142 + offset + assert fn(map(lambda i: i * 3, [8, 7])) == 190 + offset # noqa: C417 + with pytest.raises(TypeError): + fn({"x": 0, "y": 1}) + with pytest.raises(TypeError): + fn({}) + + +def test_pass_std_vector_pair_int(): + fn = m.pass_std_vector_pair_int + assert fn({1: 2, 3: 4}.items()) == 406 + assert fn(zip([5, 17], [13, 9])) == 2222 + + +def test_list_caster_fully_consumes_generator_object(): + def gen_invalid(): + yield from [1, 2.0, 3] + + gen_obj = gen_invalid() + with pytest.raises(TypeError): + m.pass_std_vector_int(gen_obj) + assert not tuple(gen_obj) + + +def test_pass_std_set_int(): + fn = m.pass_std_set_int + assert fn({3, 15}) == 254 + assert fn({5: None, 12: None}.keys()) == 251 + with pytest.raises(TypeError): + fn([]) + with pytest.raises(TypeError): + fn({}) + with pytest.raises(TypeError): + fn({}.values()) + with pytest.raises(TypeError): + fn(i for i in []) + + +def test_set_caster_dict_keys_failure(): + dict_keys = {1: None, 2.0: None, 3: None}.keys() + # The asserts does not really exercise anything in pybind11, but if one of + # them fails in some future version of Python, the set_caster load + # implementation may need to be revisited. + assert tuple(dict_keys) == (1, 2.0, 3) + assert tuple(dict_keys) == (1, 2.0, 3) + with pytest.raises(TypeError): + m.pass_std_set_int(dict_keys) + assert tuple(dict_keys) == (1, 2.0, 3) + + +class FakePyMappingMissingItems: + def __getitem__(self, _): + raise RuntimeError("Not expected to be called.") + + +class FakePyMappingWithItems(FakePyMappingMissingItems): + def items(self): + return ((1, 3), (2, 4)) + + +class FakePyMappingBadItems(FakePyMappingMissingItems): + def items(self): + return ((1, 2), (3, "x")) + + +class FakePyMappingItemsNotCallable(FakePyMappingMissingItems): + @property + def items(self): + return ((1, 2), (3, 4)) + + +class FakePyMappingItemsWithArg(FakePyMappingMissingItems): + def items(self, _): + return ((1, 2), (3, 4)) + + +class FakePyMappingGenObj(FakePyMappingMissingItems): + def __init__(self, gen_obj): + super().__init__() + self.gen_obj = gen_obj + + def items(self): + yield from self.gen_obj + + +def test_pass_std_map_int(): + fn = m.pass_std_map_int + assert fn({1: 2, 3: 4}) == 4506 + with pytest.raises(TypeError): + fn([]) + assert fn(FakePyMappingWithItems()) == 3507 + with pytest.raises(TypeError): + fn(FakePyMappingMissingItems()) + with pytest.raises(TypeError): + fn(FakePyMappingBadItems()) + with pytest.raises(TypeError): + fn(FakePyMappingItemsNotCallable()) + with pytest.raises(TypeError): + fn(FakePyMappingItemsWithArg()) + + +@pytest.mark.parametrize( + ("items", "expected_exception"), + [ + (((1, 2), (3, "x"), (4, 5)), TypeError), + (((1, 2), (3, 4, 5), (6, 7)), ValueError), + ], +) +def test_map_caster_fully_consumes_generator_object(items, expected_exception): + def gen_invalid(): + yield from items + + gen_obj = gen_invalid() + with pytest.raises(expected_exception): + m.pass_std_map_int(FakePyMappingGenObj(gen_obj)) + assert not tuple(gen_obj)