diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 22d2dba4c..376a67954 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -908,6 +908,14 @@ struct handle_type_name { template struct pyobject_caster { + template ::value, int> = 0> + pyobject_caster() : value() {} + + // `type` may not be default constructible (e.g. frozenset, anyset). Initializing `value` + // to a nil handle is safe since it will only be accessed if `load` succeeds. + template ::value, int> = 0> + pyobject_caster() : value(reinterpret_steal(handle())) {} + template ::value, int> = 0> bool load(handle src, bool /* convert */) { value = src; diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 324fa932f..256a2441b 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1784,25 +1784,35 @@ class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) }; -class set : public object { +class anyset : public object { public: - PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New) - set() : object(PySet_New(nullptr), stolen_t{}) { + PYBIND11_OBJECT(anyset, object, PyAnySet_Check) + size_t size() const { return static_cast(PySet_Size(m_ptr)); } + bool empty() const { return size() == 0; } + template + bool contains(T &&val) const { + return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; + } +}; + +class set : public anyset { +public: + PYBIND11_OBJECT_CVT(set, anyset, PySet_Check, PySet_New) + set() : anyset(PySet_New(nullptr), stolen_t{}) { if (!m_ptr) { pybind11_fail("Could not allocate set object!"); } } - size_t size() const { return (size_t) PySet_Size(m_ptr); } - bool empty() const { return size() == 0; } template bool add(T &&val) /* py-non-const */ { return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; } void clear() /* py-non-const */ { PySet_Clear(m_ptr); } - template - bool contains(T &&val) const { - return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; - } +}; + +class frozenset : public anyset { +public: + PYBIND11_OBJECT_CVT(frozenset, anyset, PyFrozenSet_Check, PyFrozenSet_New) }; class function : public object { diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 51b57a92b..625fb210f 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -55,10 +55,10 @@ struct set_caster { using key_conv = make_caster; bool load(handle src, bool convert) { - if (!isinstance(src)) { + if (!isinstance(src)) { return false; } - auto s = reinterpret_borrow(src); + auto s = reinterpret_borrow(src); value.clear(); for (auto entry : s) { key_conv conv; diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index d1e9b81a7..8d296f655 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -75,7 +75,7 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_none", [] { return py::none(); }); m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); }); - // test_set + // test_set, test_frozenset m.def("get_set", []() { py::set set; set.add(py::str("key1")); @@ -83,14 +83,26 @@ TEST_SUBMODULE(pytypes, m) { set.add(std::string("key3")); return set; }); - m.def("print_set", [](const py::set &set) { + m.def("get_frozenset", []() { + py::set set; + set.add(py::str("key1")); + set.add("key2"); + set.add(std::string("key3")); + return py::frozenset(set); + }); + m.def("print_anyset", [](const py::anyset &set) { for (auto item : set) { py::print("key:", item); } }); - m.def("set_contains", - [](const py::set &set, const py::object &key) { return set.contains(key); }); - m.def("set_contains", [](const py::set &set, const char *key) { return set.contains(key); }); + m.def("anyset_size", [](const py::anyset &set) { return set.size(); }); + m.def("anyset_empty", [](const py::anyset &set) { return set.empty(); }); + m.def("anyset_contains", + [](const py::anyset &set, const py::object &key) { return set.contains(key); }); + m.def("anyset_contains", + [](const py::anyset &set, const char *key) { return set.contains(key); }); + m.def("set_add", [](py::set &set, const py::object &key) { set.add(key); }); + m.def("set_clear", [](py::set &set) { set.clear(); }); // test_dict m.def("get_dict", []() { return py::dict("key"_a = "value"); }); @@ -310,6 +322,7 @@ TEST_SUBMODULE(pytypes, m) { "list"_a = py::list(d["list"]), "dict"_a = py::dict(d["dict"]), "set"_a = py::set(d["set"]), + "frozenset"_a = py::frozenset(d["frozenset"]), "memoryview"_a = py::memoryview(d["memoryview"])); }); @@ -325,6 +338,7 @@ TEST_SUBMODULE(pytypes, m) { "list"_a = d["list"].cast(), "dict"_a = d["dict"].cast(), "set"_a = d["set"].cast(), + "frozenset"_a = d["frozenset"].cast(), "memoryview"_a = d["memoryview"].cast()); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index c740414ae..9afe62f42 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -66,11 +66,12 @@ def test_none(capture, doc): def test_set(capture, doc): s = m.get_set() + assert isinstance(s, set) assert s == {"key1", "key2", "key3"} + s.add("key4") with capture: - s.add("key4") - m.print_set(s) + m.print_anyset(s) assert ( capture.unordered == """ @@ -81,12 +82,43 @@ def test_set(capture, doc): """ ) - assert not m.set_contains(set(), 42) - assert m.set_contains({42}, 42) - assert m.set_contains({"foo"}, "foo") + m.set_add(s, "key5") + assert m.anyset_size(s) == 5 - assert doc(m.get_list) == "get_list() -> list" - assert doc(m.print_list) == "print_list(arg0: list) -> None" + m.set_clear(s) + assert m.anyset_empty(s) + + assert not m.anyset_contains(set(), 42) + assert m.anyset_contains({42}, 42) + assert m.anyset_contains({"foo"}, "foo") + + assert doc(m.get_set) == "get_set() -> set" + assert doc(m.print_anyset) == "print_anyset(arg0: anyset) -> None" + + +def test_frozenset(capture, doc): + s = m.get_frozenset() + assert isinstance(s, frozenset) + assert s == frozenset({"key1", "key2", "key3"}) + + with capture: + m.print_anyset(s) + assert ( + capture.unordered + == """ + key: key1 + key: key2 + key: key3 + """ + ) + assert m.anyset_size(s) == 3 + assert not m.anyset_empty(s) + + assert not m.anyset_contains(frozenset(), 42) + assert m.anyset_contains(frozenset({42}), 42) + assert m.anyset_contains(frozenset({"foo"}), "foo") + + assert doc(m.get_frozenset) == "get_frozenset() -> frozenset" def test_dict(capture, doc): @@ -302,6 +334,7 @@ def test_constructors(): list: range(3), dict: [("two", 2), ("one", 1), ("three", 3)], set: [4, 4, 5, 6, 6, 6], + frozenset: [4, 4, 5, 6, 6, 6], memoryview: b"abc", } inputs = {k.__name__: v for k, v in data.items()} diff --git a/tests/test_stl.py b/tests/test_stl.py index 975860b85..d30c38211 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -73,6 +73,7 @@ def test_set(doc): assert s == {"key1", "key2"} s.add("key3") assert m.load_set(s) + assert m.load_set(frozenset(s)) assert doc(m.cast_set) == "cast_set() -> Set[str]" assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool"