Merge pull request #445 from lsst-dm/master

Accept any sequence type as std::vector (or std::list)
This commit is contained in:
Wenzel Jakob 2016-10-15 23:50:06 +02:00 committed by GitHub
commit 946f897da0
3 changed files with 40 additions and 6 deletions

View File

@ -29,12 +29,14 @@ namespace accessor_policies {
struct obj_attr; struct obj_attr;
struct str_attr; struct str_attr;
struct generic_item; struct generic_item;
struct sequence_item;
struct list_item; struct list_item;
struct tuple_item; struct tuple_item;
} }
using obj_attr_accessor = accessor<accessor_policies::obj_attr>; using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
using str_attr_accessor = accessor<accessor_policies::str_attr>; using str_attr_accessor = accessor<accessor_policies::str_attr>;
using item_accessor = accessor<accessor_policies::generic_item>; using item_accessor = accessor<accessor_policies::generic_item>;
using sequence_accessor = accessor<accessor_policies::sequence_item>;
using list_accessor = accessor<accessor_policies::list_item>; using list_accessor = accessor<accessor_policies::list_item>;
using tuple_accessor = accessor<accessor_policies::tuple_item>; using tuple_accessor = accessor<accessor_policies::tuple_item>;
@ -261,6 +263,23 @@ struct generic_item {
} }
}; };
struct sequence_item {
using key_type = size_t;
static object get(handle obj, size_t index) {
PyObject *result = PySequence_GetItem(obj.ptr(), static_cast<ssize_t>(index));
if (!result) { throw error_already_set(); }
return {result, true};
}
static void set(handle obj, size_t index, handle val) {
// PySequence_SetItem does not steal a reference to 'val'
if (PySequence_SetItem(obj.ptr(), static_cast<ssize_t>(index), val.ptr()) != 0) {
throw error_already_set();
}
}
};
struct list_item { struct list_item {
using key_type = size_t; using key_type = size_t;
@ -673,6 +692,13 @@ public:
bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
}; };
class sequence : public object {
public:
PYBIND11_OBJECT(sequence, object, PySequence_Check)
size_t size() const { return (size_t) PySequence_Size(m_ptr); }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
};
class list : public object { class list : public object {
public: public:
PYBIND11_OBJECT(list, object, PyList_Check) PYBIND11_OBJECT(list, object, PyList_Check)

View File

@ -97,13 +97,13 @@ template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>; using value_conv = make_caster<Value>;
bool load(handle src, bool convert) { bool load(handle src, bool convert) {
list l(src, true); sequence s(src, true);
if (!l.check()) if (!s.check())
return false; return false;
value_conv conv; value_conv conv;
value.clear(); value.clear();
reserve_maybe(l, &value); reserve_maybe(s, &value);
for (auto it : l) { for (auto it : s) {
if (!conv.load(it, convert)) if (!conv.load(it, convert))
return false; return false;
value.push_back((Value) conv); value.push_back((Value) conv);
@ -113,8 +113,8 @@ template <typename Type, typename Value> struct list_caster {
template <typename T = Type, template <typename T = Type,
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0> enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
void reserve_maybe(list l, Type *) { value.reserve(l.size()); } void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
void reserve_maybe(list, void *) { } void reserve_maybe(sequence, void *) { }
static handle cast(const Type &src, return_value_policy policy, handle parent) { static handle cast(const Type &src, return_value_policy policy, handle parent) {
list l(src.size()); list l(src.size());

View File

@ -71,6 +71,14 @@ def test_instance(capture):
list item 0: value list item 0: value
list item 1: value2 list item 1: value2
""" """
with capture:
list_result = instance.get_list_2()
list_result.append('value2')
instance.print_list_2(tuple(list_result))
assert capture.unordered == """
list item 0: value
list item 1: value2
"""
array_result = instance.get_array() array_result = instance.get_array()
assert array_result == ['array entry 1', 'array entry 2'] assert array_result == ['array entry 1', 'array entry 2']
with capture: with capture: