Extend attribute and item accessor interface using object_api

This commit is contained in:
Dean Moldovan 2016-09-08 17:02:04 +02:00
parent 865e43034b
commit 242b146a51
8 changed files with 143 additions and 48 deletions

View File

@ -53,6 +53,8 @@ Breaking changes queued for v2.0.0 (Not yet released)
* Added ``py::dict`` keyword constructor:``auto d = dict("number"_a=42, "name"_a="World");`` * Added ``py::dict`` keyword constructor:``auto d = dict("number"_a=42, "name"_a="World");``
* Added ``py::str::format()`` method and ``_s`` literal: * Added ``py::str::format()`` method and ``_s`` literal:
``py::str s = "1 + 2 = {}"_s.format(3);`` ``py::str s = "1 + 2 = {}"_s.format(3);``
* Attribute and item accessors now have a more complete interface which makes it possible
to chain attributes ``obj.attr("a")[key].attr("b").attr("method")(1, 2, 3)```.
* Various minor improvements of library internals (no user-visible changes) * Various minor improvements of library internals (no user-visible changes)
1.8.1 (July 12, 2016) 1.8.1 (July 12, 2016)

View File

@ -1219,7 +1219,7 @@ private:
void process(list &args_list, detail::args_proxy ap) { void process(list &args_list, detail::args_proxy ap) {
for (const auto &a : ap) { for (const auto &a : ap) {
args_list.append(a.cast<object>()); args_list.append(a);
} }
} }

View File

@ -125,7 +125,7 @@ private:
static npy_api lookup() { static npy_api lookup() {
module m = module::import("numpy.core.multiarray"); module m = module::import("numpy.core.multiarray");
auto c = m.attr("_ARRAY_API").cast<object>(); auto c = m.attr("_ARRAY_API");
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
#else #else
@ -220,9 +220,7 @@ private:
struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
std::vector<field_descr> field_descriptors; std::vector<field_descr> field_descriptors;
auto fields = attr("fields").cast<object>(); for (auto field : attr("fields").attr("items")()) {
auto items = fields.attr("items").cast<object>();
for (auto field : items()) {
auto spec = object(field, true).cast<tuple>(); auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>(); auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<dtype>(); auto format = spec[1].cast<tuple>()[0].cast<dtype>();

View File

@ -176,7 +176,7 @@ protected:
if (a.descr) if (a.descr)
a.descr = strdup(a.descr); a.descr = strdup(a.descr);
else if (a.value) else if (a.value)
a.descr = strdup(((std::string) ((object) handle(a.value).attr("__repr__"))().str()).c_str()); a.descr = strdup(a.value.attr("__repr__")().cast<std::string>().c_str());
} }
auto const &registered_types = detail::get_internals().registered_types_cpp; auto const &registered_types = detail::get_internals().registered_types_cpp;
@ -723,8 +723,7 @@ protected:
if (ob_type == &PyType_Type) { if (ob_type == &PyType_Type) {
std::string name_ = std::string(ht_type.tp_name) + "__Meta"; std::string name_ = std::string(ht_type.tp_name) + "__Meta";
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
object ht_qualname(PyUnicode_FromFormat( object ht_qualname(PyUnicode_FromFormat("%U__Meta", attr("__qualname__").ptr()), false);
"%U__Meta", ((object) attr("__qualname__")).ptr()), false);
#endif #endif
object name(PYBIND11_FROM_STRING(name_.c_str()), false); object name(PYBIND11_FROM_STRING(name_.c_str()), false);
object type_holder(PyType_Type.tp_alloc(&PyType_Type, 0), false); object type_holder(PyType_Type.tp_alloc(&PyType_Type, 0), false);
@ -1342,16 +1341,16 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) {
strings[i] = args[i].cast<object>().str(); strings[i] = args[i].cast<object>().str();
} }
auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" ");
auto line = sep.attr("join").cast<object>()(strings); auto line = sep.attr("join")(strings);
auto file = kwargs.contains("file") ? kwargs["file"].cast<object>() auto file = kwargs.contains("file") ? kwargs["file"].cast<object>()
: module::import("sys").attr("stdout"); : module::import("sys").attr("stdout");
auto write = file.attr("write").cast<object>(); auto write = file.attr("write");
write(line); write(line);
write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); write(kwargs.contains("end") ? kwargs["end"] : cast("\n"));
if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) { if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {
file.attr("flush").cast<object>()(); file.attr("flush")();
} }
} }
NAMESPACE_END(detail) NAMESPACE_END(detail)

View File

@ -21,7 +21,18 @@ class str; class iterator;
struct arg; struct arg_v; struct arg; struct arg_v;
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
class accessor; class args_proxy; class args_proxy;
// Accessor forward declarations
template <typename Policy> class accessor;
namespace accessor_policies {
struct obj_attr;
struct str_attr;
struct generic_item;
}
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
using str_attr_accessor = accessor<accessor_policies::str_attr>;
using item_accessor = accessor<accessor_policies::generic_item>;
/// Tag and check to identify a class which implements the Python object API /// Tag and check to identify a class which implements the Python object API
class pyobject_tag { }; class pyobject_tag { };
@ -36,10 +47,10 @@ class object_api : public pyobject_tag {
public: public:
iterator begin() const; iterator begin() const;
iterator end() const; iterator end() const;
accessor operator[](handle key) const; item_accessor operator[](handle key) const;
accessor operator[](const char *key) const; item_accessor operator[](const char *key) const;
accessor attr(handle key) const; obj_attr_accessor attr(handle key) const;
accessor attr(const char *key) const; str_attr_accessor attr(const char *key) const;
args_proxy operator*() const; args_proxy operator*() const;
template <typename T> bool contains(T &&key) const; template <typename T> bool contains(T &&key) const;
@ -177,40 +188,60 @@ inline handle get_function(handle value) {
return value; return value;
} }
class accessor { template <typename Policy>
class accessor : public object_api<accessor<Policy>> {
using key_type = typename Policy::key_type;
public: public:
accessor(handle obj, handle key, bool attr) accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { }
: obj(obj), key(key, true), attr(attr) { }
accessor(handle obj, const char *key, bool attr)
: obj(obj), key(PyUnicode_FromString(key), false), attr(attr) { }
accessor(const accessor &a) : obj(a.obj), key(a.key), attr(a.attr) { }
void operator=(accessor o) { operator=(object(o)); } void operator=(const accessor &a) { operator=(handle(a)); }
void operator=(const object &o) { operator=(handle(o)); }
void operator=(handle value) { Policy::set(obj, key, value); }
void operator=(const handle &value) { operator object() const { return get_cache(); }
if (attr) { PyObject *ptr() const { return get_cache().ptr(); }
if (PyObject_SetAttr(obj.ptr(), key.ptr(), value.ptr()) == -1) template <typename T> T cast() const { return get_cache().template cast<T>(); }
throw error_already_set();
} else { private:
if (PyObject_SetItem(obj.ptr(), key.ptr(), value.ptr()) == -1) const object &get_cache() const {
throw error_already_set(); if (!cache) { cache = Policy::get(obj, key); }
} return cache;
} }
operator object() const { private:
PyObject *result = attr ? PyObject_GetAttr(obj.ptr(), key.ptr()) handle obj;
: PyObject_GetItem(obj.ptr(), key.ptr()); key_type key;
mutable object cache;
};
NAMESPACE_BEGIN(accessor_policies)
struct obj_attr {
using key_type = object;
static object get(handle obj, handle key) { return getattr(obj, key); }
static void set(handle obj, handle key, handle val) { setattr(obj, key, val); }
};
struct str_attr {
using key_type = const char *;
static object get(handle obj, const char *key) { return getattr(obj, key); }
static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); }
};
struct generic_item {
using key_type = object;
static object get(handle obj, handle key) {
PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr());
if (!result) { throw error_already_set(); } if (!result) { throw error_already_set(); }
return {result, false}; return {result, false};
} }
template <typename T> T cast() const { return operator object().cast<T>(); } static void set(handle obj, handle key, handle val) {
if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); }
private: }
handle obj;
object key;
bool attr;
}; };
NAMESPACE_END(accessor_policies)
struct list_accessor { struct list_accessor {
public: public:
@ -442,7 +473,7 @@ public:
template <typename... Args> template <typename... Args>
str format(Args &&...args) const { str format(Args &&...args) const {
return attr("format").cast<object>()(std::forward<Args>(args)...); return attr("format")(std::forward<Args>(args)...);
} }
}; };
@ -729,13 +760,13 @@ inline size_t len(handle h) {
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
template <typename D> iterator object_api<D>::begin() const { return {PyObject_GetIter(derived().ptr()), false}; } template <typename D> iterator object_api<D>::begin() const { return {PyObject_GetIter(derived().ptr()), false}; }
template <typename D> iterator object_api<D>::end() const { return {nullptr, false}; } template <typename D> iterator object_api<D>::end() const { return {nullptr, false}; }
template <typename D> accessor object_api<D>::operator[](handle key) const { return {derived(), key, false}; } template <typename D> item_accessor object_api<D>::operator[](handle key) const { return {derived(), object(key, true)}; }
template <typename D> accessor object_api<D>::operator[](const char *key) const { return {derived(), key, false}; } template <typename D> item_accessor object_api<D>::operator[](const char *key) const { return {derived(), pybind11::str(key)}; }
template <typename D> accessor object_api<D>::attr(handle key) const { return {derived(), key, true}; } template <typename D> obj_attr_accessor object_api<D>::attr(handle key) const { return {derived(), object(key, true)}; }
template <typename D> accessor object_api<D>::attr(const char *key) const { return {derived(), key, true}; } template <typename D> str_attr_accessor object_api<D>::attr(const char *key) const { return {derived(), key}; }
template <typename D> args_proxy object_api<D>::operator*() const { return {derived().ptr()}; } template <typename D> args_proxy object_api<D>::operator*() const { return {derived().ptr()}; }
template <typename D> template <typename T> bool object_api<D>::contains(T &&key) const { template <typename D> template <typename T> bool object_api<D>::contains(T &&key) const {
return attr("__contains__").template cast<object>()(std::forward<T>(key)).template cast<bool>(); return attr("__contains__")(std::forward<T>(key)).template cast<bool>();
} }
template <typename D> template <typename D>

View File

@ -103,7 +103,7 @@ public:
int alive() { int alive() {
// Force garbage collection to ensure any pending destructors are invoked: // Force garbage collection to ensure any pending destructors are invoked:
py::module::import("gc").attr("collect").operator py::object()(); py::module::import("gc").attr("collect")();
int total = 0; int total = 0;
for (const auto &p : _instances) if (p.second > 0) total += p.second; for (const auto &p : _instances) if (p.second > 0) total += p.second;
return total; return total;

View File

@ -203,7 +203,7 @@ test_initializer python_types([](py::module &m) {
py::print("no new line here", "end"_a=" -- "); py::print("no new line here", "end"_a=" -- ");
py::print("next print"); py::print("next print");
auto py_stderr = py::module::import("sys").attr("stderr").cast<py::object>(); auto py_stderr = py::module::import("sys").attr("stderr");
py::print("this goes to stderr", "file"_a=py_stderr); py::print("this goes to stderr", "file"_a=py_stderr);
py::print("flush", "flush"_a=true); py::print("flush", "flush"_a=true);
@ -222,4 +222,39 @@ test_initializer python_types([](py::module &m) {
auto d2 = py::dict("z"_a=3, **d1); auto d2 = py::dict("z"_a=3, **d1);
return d2; return d2;
}); });
m.def("test_accessor_api", [](py::object o) {
auto d = py::dict();
d["basic_attr"] = o.attr("basic_attr");
auto l = py::list();
for (const auto &item : o.attr("begin_end")) {
l.append(item);
}
d["begin_end"] = l;
d["operator[object]"] = o.attr("d")["operator[object]"_s];
d["operator[char *]"] = o.attr("d")["operator[char *]"];
d["attr(object)"] = o.attr("sub").attr("attr_obj");
d["attr(char *)"] = o.attr("sub").attr("attr_char");
try {
o.attr("sub").attr("missing").ptr();
} catch (const py::error_already_set &) {
d["missing_attr_ptr"] = "raised"_s;
}
try {
o.attr("missing").attr("doesn't matter");
} catch (const py::error_already_set &) {
d["missing_attr_chain"] = "raised"_s;
}
d["is_none"] = py::cast(o.attr("basic_attr").is_none());
d["operator()"] = o.attr("func")(1);
d["operator*"] = o.attr("func")(*o.attr("begin_end"));
return d;
});
}); });

View File

@ -248,3 +248,33 @@ def test_dict_api():
from pybind11_tests import test_dict_keyword_constructor from pybind11_tests import test_dict_keyword_constructor
assert test_dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} assert test_dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3}
def test_accessors():
from pybind11_tests import test_accessor_api
class SubTestObject:
attr_obj = 1
attr_char = 2
class TestObject:
basic_attr = 1
begin_end = [1, 2, 3]
d = {"operator[object]": 1, "operator[char *]": 2}
sub = SubTestObject()
def func(self, x, *args):
return self.basic_attr + x + sum(args)
d = test_accessor_api(TestObject())
assert d["basic_attr"] == 1
assert d["begin_end"] == [1, 2, 3]
assert d["operator[object]"] == 1
assert d["operator[char *]"] == 2
assert d["attr(object)"] == 1
assert d["attr(char *)"] == 2
assert d["missing_attr_ptr"] == "raised"
assert d["missing_attr_chain"] == "raised"
assert d["is_none"] is False
assert d["operator()"] == 2
assert d["operator*"] == 7