From 333e889ef2fdc22fece30512961b2b7f2da02570 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Sat, 14 Nov 2015 19:04:49 +0100 Subject: [PATCH] Improved STL support, support for std::set --- example/example2.cpp | 32 +++++++++++++++++++++++++ example/example2.py | 8 +++++++ example/example2.ref | 28 +++++++++++++++++++++- include/pybind11/cast.h | 1 + include/pybind11/pytypes.h | 46 +++++++++++++++++++++++++----------- include/pybind11/stl.h | 48 ++++++++++++++++++++++++++++++++++---- 6 files changed, 143 insertions(+), 20 deletions(-) diff --git a/example/example2.cpp b/example/example2.cpp index f7972e03e..2e20e8b33 100644 --- a/example/example2.cpp +++ b/example/example2.cpp @@ -27,6 +27,14 @@ public: return dict; } + /* Create and return a Python set */ + py::set get_set() { + py::set set; + set.insert(py::str("key1")); + set.insert(py::str("key2")); + return set; + } + /* Create and return a C++ dictionary */ std::map get_dict_2() { std::map result; @@ -34,6 +42,14 @@ public: return result; } + /* Create and return a C++ set */ + std::set get_set_2() { + std::set result; + result.insert("key1"); + result.insert("key2"); + return result; + } + /* Create, manipulate, and return a Python list */ py::list get_list() { py::list list; @@ -62,6 +78,18 @@ public: std::cout << "key: " << item.first << ", value=" << item.second << std::endl; } + /* Easily iterate over a setionary using a C++11 range-based for loop */ + void print_set(py::set set) { + for (auto item : set) + std::cout << "key: " << item << std::endl; + } + + /* STL data types are automatically casted from Python */ + void print_set_2(const std::set &set) { + for (auto item : set) + std::cout << "key: " << item << std::endl; + } + /* Easily iterate over a list using a C++11 range-based for loop */ void print_list(py::list list) { int index = 0; @@ -105,8 +133,12 @@ void init_ex2(py::module &m) { .def("get_dict_2", &Example2::get_dict_2, "Return a C++ dictionary") .def("get_list", &Example2::get_list, "Return a Python list") .def("get_list_2", &Example2::get_list_2, "Return a C++ list") + .def("get_set", &Example2::get_set, "Return a Python set") + .def("get_set2", &Example2::get_set, "Return a C++ set") .def("print_dict", &Example2::print_dict, "Print entries of a Python dictionary") .def("print_dict_2", &Example2::print_dict_2, "Print entries of a C++ dictionary") + .def("print_set", &Example2::print_set, "Print entries of a Python set") + .def("print_set_2", &Example2::print_set_2, "Print entries of a C++ set") .def("print_list", &Example2::print_list, "Print entries of a Python list") .def("print_list_2", &Example2::print_list_2, "Print entries of a C++ list") .def("pair_passthrough", &Example2::pair_passthrough, "Return a pair in reversed order") diff --git a/example/example2.py b/example/example2.py index 2782da549..f42ee49bc 100755 --- a/example/example2.py +++ b/example/example2.py @@ -29,6 +29,14 @@ dict_result = instance.get_dict_2() dict_result['key2'] = 'value2' instance.print_dict_2(dict_result) +set_result = instance.get_set() +set_result.add(u'key3') +instance.print_set(set_result) + +set_result = instance.get_set2() +set_result.add(u'key3') +instance.print_set_2(set_result) + list_result = instance.get_list() list_result.append('value2') instance.print_list(list_result) diff --git a/example/example2.ref b/example/example2.ref index 341fcb2f3..64d2afa88 100644 --- a/example/example2.ref +++ b/example/example2.ref @@ -6,6 +6,12 @@ key: key2, value=value2 key: key, value=value key: key, value=value key: key2, value=value2 +key: key3 +key: key2 +key: key1 +key: key1 +key: key2 +key: key3 Entry at positon 0: value list item 0: overwritten list item 1: value2 @@ -44,6 +50,16 @@ class EExxaammppllee22(__builtin__.object) | | Return a C++ list | + | ggeett__sseett(...) + | Signature : (Example2) -> set + | + | Return a Python set + | + | ggeett__sseett22(...) + | Signature : (Example2) -> set + | + | Return a C++ set + | | ppaaiirr__ppaasssstthhrroouugghh(...) | Signature : (Example2, (bool, str)) -> (str, bool) | @@ -69,6 +85,16 @@ class EExxaammppllee22(__builtin__.object) | | Print entries of a C++ list | + | pprriinntt__sseett(...) + | Signature : (Example2, set) -> None + | + | Print entries of a Python set + | + | pprriinntt__sseett__22(...) + | Signature : (Example2, set) -> None + | + | Print entries of a C++ set + | | tthhrrooww__eexxcceeppttiioonn(...) | Signature : (Example2) -> None | @@ -85,7 +111,7 @@ class EExxaammppllee22(__builtin__.object) | ____nneeww____ = | T.__new__(S, ...) -> a new object with type S, a subtype of T | - | ____ppyybbiinndd____ = + | ____ppyybbiinndd1111____ = | | nneeww__iinnssttaannccee = | Signature : () -> Example2 diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 3832bc1ef..ac71c9cc8 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -571,6 +571,7 @@ PYBIND11_TYPE_CASTER_PYTYPE(capsule) PYBIND11_TYPE_CASTER_PYTYPE(dict) PYBIND11_TYPE_CASTER_PYTYPE(float_) PYBIND11_TYPE_CASTER_PYTYPE(int_) PYBIND11_TYPE_CASTER_PYTYPE(list) PYBIND11_TYPE_CASTER_PYTYPE(slice) PYBIND11_TYPE_CASTER_PYTYPE(tuple) PYBIND11_TYPE_CASTER_PYTYPE(function) +PYBIND11_TYPE_CASTER_PYTYPE(set) PYBIND11_TYPE_CASTER_PYTYPE(iterator) NAMESPACE_END(detail) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index d85de8738..3baecf8b3 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -19,6 +19,7 @@ class object; class str; class object; class dict; +class iterator; namespace detail { class accessor; } /// Holds a reference to a Python object (no reference counting) @@ -33,6 +34,8 @@ public: void dec_ref() const { Py_XDECREF(m_ptr); } int ref_count() const { return (int) Py_REFCNT(m_ptr); } handle get_type() { return (PyObject *) Py_TYPE(m_ptr); } + inline iterator begin(); + inline iterator end(); inline detail::accessor operator[](handle key); inline detail::accessor operator[](const char *key); inline detail::accessor attr(handle key); @@ -73,6 +76,23 @@ public: } }; +class iterator : public object { +public: + iterator(PyObject *obj, bool borrowed = false) : object(obj, borrowed) { ++*this; } + iterator& operator++() { + if (ptr()) + value = object(PyIter_Next(ptr()), false); + return *this; + } + bool operator==(const iterator &it) const { return *it == **this; } + bool operator!=(const iterator &it) const { return *it != **this; } + object operator*() { return value; } + const object &operator*() const { return value; } + bool check() const { return PyIter_Check(ptr()); } +private: + object value; +}; + NAMESPACE_BEGIN(detail) class accessor { public: @@ -159,18 +179,6 @@ private: size_t index; }; -class list_iterator { -public: - list_iterator(PyObject *list, ssize_t pos) : list(list), pos(pos) { } - list_iterator& operator++() { ++pos; return *this; } - object operator*() { return object(PyList_GetItem(list, pos), true); } - bool operator==(const list_iterator &it) const { return it.pos == pos; } - bool operator!=(const list_iterator &it) const { return it.pos != pos; } -private: - PyObject *list; - ssize_t pos; -}; - struct dict_iterator { public: dict_iterator(PyObject *dict = nullptr, ssize_t pos = -1) : dict(dict), pos(pos) { } @@ -194,6 +202,8 @@ inline detail::accessor handle::operator[](handle key) { return detail::accessor inline detail::accessor handle::operator[](const char *key) { return detail::accessor(ptr(), key, false); } inline detail::accessor handle::attr(handle key) { return detail::accessor(ptr(), key.ptr(), true); } inline detail::accessor handle::attr(const char *key) { return detail::accessor(ptr(), key, true); } +inline iterator handle::begin() { return iterator(PyObject_GetIter(ptr())); } +inline iterator handle::end() { return iterator(nullptr); } #define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, CvtStmt) \ Name(const handle &h, bool borrowed) : Parent(h, borrowed) { CvtStmt; } \ @@ -310,6 +320,7 @@ public: size_t size() const { return (size_t) PyDict_Size(m_ptr); } detail::dict_iterator begin() { return (++detail::dict_iterator(ptr(), 0)); } detail::dict_iterator end() { return detail::dict_iterator(); } + void clear() { PyDict_Clear(ptr()); } }; class list : public object { @@ -317,12 +328,19 @@ public: PYBIND11_OBJECT(list, object, PyList_Check) list(size_t size = 0) : object(PyList_New((ssize_t) size), false) { } size_t size() const { return (size_t) PyList_Size(m_ptr); } - detail::list_iterator begin() { return detail::list_iterator(ptr(), 0); } - detail::list_iterator end() { return detail::list_iterator(ptr(), (ssize_t) size()); } detail::list_accessor operator[](size_t index) { return detail::list_accessor(ptr(), index); } void append(const object &object) { PyList_Append(m_ptr, (PyObject *) object.ptr()); } }; +class set : public object { +public: + PYBIND11_OBJECT(set, object, PySet_Check) + set() : object(PySet_New(nullptr), false) { } + size_t size() const { return (size_t) PySet_Size(m_ptr); } + void insert(const object &object) { PySet_Add(m_ptr, (PyObject *) object.ptr()); } + void clear() { PySet_Clear(ptr()); } +}; + class function : public object { public: PYBIND11_OBJECT_DEFAULT(function, object, PyFunction_Check) diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index fc4637b46..364ad6537 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -11,6 +11,7 @@ #include "pybind11.h" #include +#include #include @@ -22,8 +23,8 @@ NAMESPACE_BEGIN(pybind11) NAMESPACE_BEGIN(detail) -template struct type_caster> { - typedef std::vector type; +template struct type_caster> { + typedef std::vector type; typedef type_caster value_conv; public: bool load(PyObject *src, bool convert) { @@ -32,8 +33,8 @@ public: size_t size = (size_t) PyList_GET_SIZE(src); value.reserve(size); value.clear(); + value_conv conv; for (size_t i=0; i")); }; -template struct type_caster> { +template struct type_caster> { + typedef std::set type; + typedef type_caster value_conv; public: - typedef std::map type; + bool load(PyObject *src, bool convert) { + pybind11::set s(src, true); + if (!s.check()) + return false; + value.clear(); + value_conv conv; + for (const object &o: s) { + if (!conv.load((PyObject *) o.ptr(), convert)) + return false; + value.insert((Value) conv); + } + return true; + } + + static PyObject *cast(const type &src, return_value_policy policy, PyObject *parent) { + PyObject *set = PySet_New(nullptr); + for (auto const &value: src) { + PyObject *value_ = value_conv::cast(value, policy, parent); + if (!value_) { + Py_DECREF(set); + return nullptr; + } + if (PySet_Add(set, value) != 0) { + Py_DECREF(value); + Py_DECREF(set); + return nullptr; + } + } + return set; + } + PYBIND11_TYPE_CASTER(type, detail::descr("set<") + value_conv::name() + detail::descr(">")); +}; + +template struct type_caster> { +public: + typedef std::map type; typedef type_caster key_conv; typedef type_caster value_conv;