From 242b146a5165015045d6547d4840fa6afc6161f4 Mon Sep 17 00:00:00 2001 From: Dean Moldovan Date: Thu, 8 Sep 2016 17:02:04 +0200 Subject: [PATCH] Extend attribute and item accessor interface using object_api --- docs/changelog.rst | 2 + include/pybind11/cast.h | 2 +- include/pybind11/numpy.h | 6 +-- include/pybind11/pybind11.h | 11 ++-- include/pybind11/pytypes.h | 101 +++++++++++++++++++++++------------- tests/constructor_stats.h | 2 +- tests/test_python_types.cpp | 37 ++++++++++++- tests/test_python_types.py | 30 +++++++++++ 8 files changed, 143 insertions(+), 48 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index a9886e039..647622448 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -53,6 +53,8 @@ Breaking changes queued for v2.0.0 (Not yet released) * Added ``py::dict`` keyword constructor:``auto d = dict("number"_a=42, "name"_a="World");`` * Added ``py::str::format()`` method and ``_s`` literal: ``py::str s = "1 + 2 = {}"_s.format(3);`` +* Attribute and item accessors now have a more complete interface which makes it possible + to chain attributes ``obj.attr("a")[key].attr("b").attr("method")(1, 2, 3)```. * Various minor improvements of library internals (no user-visible changes) 1.8.1 (July 12, 2016) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 7fa0348eb..683eb2ccd 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1219,7 +1219,7 @@ private: void process(list &args_list, detail::args_proxy ap) { for (const auto &a : ap) { - args_list.append(a.cast()); + args_list.append(a); } } diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index fb8d3c6d0..996bb7c6d 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -125,7 +125,7 @@ private: static npy_api lookup() { module m = module::import("numpy.core.multiarray"); - auto c = m.attr("_ARRAY_API").cast(); + auto c = m.attr("_ARRAY_API"); #if PY_MAJOR_VERSION >= 3 void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); #else @@ -220,9 +220,7 @@ private: struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; std::vector field_descriptors; - auto fields = attr("fields").cast(); - auto items = fields.attr("items").cast(); - for (auto field : items()) { + for (auto field : attr("fields").attr("items")()) { auto spec = object(field, true).cast(); auto name = spec[0].cast(); auto format = spec[1].cast()[0].cast(); diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index a0e2e725c..9c6bc32b8 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -176,7 +176,7 @@ protected: if (a.descr) a.descr = strdup(a.descr); else if (a.value) - a.descr = strdup(((std::string) ((object) handle(a.value).attr("__repr__"))().str()).c_str()); + a.descr = strdup(a.value.attr("__repr__")().cast().c_str()); } auto const ®istered_types = detail::get_internals().registered_types_cpp; @@ -723,8 +723,7 @@ protected: if (ob_type == &PyType_Type) { std::string name_ = std::string(ht_type.tp_name) + "__Meta"; #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 - object ht_qualname(PyUnicode_FromFormat( - "%U__Meta", ((object) attr("__qualname__")).ptr()), false); + object ht_qualname(PyUnicode_FromFormat("%U__Meta", attr("__qualname__").ptr()), false); #endif object name(PYBIND11_FROM_STRING(name_.c_str()), false); object type_holder(PyType_Type.tp_alloc(&PyType_Type, 0), false); @@ -1342,16 +1341,16 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { strings[i] = args[i].cast().str(); } auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); - auto line = sep.attr("join").cast()(strings); + auto line = sep.attr("join")(strings); auto file = kwargs.contains("file") ? kwargs["file"].cast() : module::import("sys").attr("stdout"); - auto write = file.attr("write").cast(); + auto write = file.attr("write"); write(line); write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); if (kwargs.contains("flush") && kwargs["flush"].cast()) { - file.attr("flush").cast()(); + file.attr("flush")(); } } NAMESPACE_END(detail) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 180bcdd7c..a5e26345c 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -21,7 +21,18 @@ class str; class iterator; struct arg; struct arg_v; NAMESPACE_BEGIN(detail) -class accessor; class args_proxy; +class args_proxy; + +// Accessor forward declarations +template class accessor; +namespace accessor_policies { + struct obj_attr; + struct str_attr; + struct generic_item; +} +using obj_attr_accessor = accessor; +using str_attr_accessor = accessor; +using item_accessor = accessor; /// Tag and check to identify a class which implements the Python object API class pyobject_tag { }; @@ -36,10 +47,10 @@ class object_api : public pyobject_tag { public: iterator begin() const; iterator end() const; - accessor operator[](handle key) const; - accessor operator[](const char *key) const; - accessor attr(handle key) const; - accessor attr(const char *key) const; + item_accessor operator[](handle key) const; + item_accessor operator[](const char *key) const; + obj_attr_accessor attr(handle key) const; + str_attr_accessor attr(const char *key) const; args_proxy operator*() const; template bool contains(T &&key) const; @@ -177,40 +188,60 @@ inline handle get_function(handle value) { return value; } -class accessor { +template +class accessor : public object_api> { + using key_type = typename Policy::key_type; + public: - accessor(handle obj, handle key, bool attr) - : obj(obj), key(key, true), attr(attr) { } - accessor(handle obj, const char *key, bool attr) - : obj(obj), key(PyUnicode_FromString(key), false), attr(attr) { } - accessor(const accessor &a) : obj(a.obj), key(a.key), attr(a.attr) { } + accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { } - void operator=(accessor o) { operator=(object(o)); } + void operator=(const accessor &a) { operator=(handle(a)); } + void operator=(const object &o) { operator=(handle(o)); } + void operator=(handle value) { Policy::set(obj, key, value); } - void operator=(const handle &value) { - if (attr) { - if (PyObject_SetAttr(obj.ptr(), key.ptr(), value.ptr()) == -1) - throw error_already_set(); - } else { - if (PyObject_SetItem(obj.ptr(), key.ptr(), value.ptr()) == -1) - throw error_already_set(); - } + operator object() const { return get_cache(); } + PyObject *ptr() const { return get_cache().ptr(); } + template T cast() const { return get_cache().template cast(); } + +private: + const object &get_cache() const { + if (!cache) { cache = Policy::get(obj, key); } + return cache; } - operator object() const { - PyObject *result = attr ? PyObject_GetAttr(obj.ptr(), key.ptr()) - : PyObject_GetItem(obj.ptr(), key.ptr()); +private: + handle obj; + key_type key; + mutable object cache; +}; + +NAMESPACE_BEGIN(accessor_policies) +struct obj_attr { + using key_type = object; + static object get(handle obj, handle key) { return getattr(obj, key); } + static void set(handle obj, handle key, handle val) { setattr(obj, key, val); } +}; + +struct str_attr { + using key_type = const char *; + static object get(handle obj, const char *key) { return getattr(obj, key); } + static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); } +}; + +struct generic_item { + using key_type = object; + + static object get(handle obj, handle key) { + PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr()); if (!result) { throw error_already_set(); } return {result, false}; } - template T cast() const { return operator object().cast(); } - -private: - handle obj; - object key; - bool attr; + static void set(handle obj, handle key, handle val) { + if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); } + } }; +NAMESPACE_END(accessor_policies) struct list_accessor { public: @@ -442,7 +473,7 @@ public: template str format(Args &&...args) const { - return attr("format").cast()(std::forward(args)...); + return attr("format")(std::forward(args)...); } }; @@ -729,13 +760,13 @@ inline size_t len(handle h) { NAMESPACE_BEGIN(detail) template iterator object_api::begin() const { return {PyObject_GetIter(derived().ptr()), false}; } template iterator object_api::end() const { return {nullptr, false}; } -template accessor object_api::operator[](handle key) const { return {derived(), key, false}; } -template accessor object_api::operator[](const char *key) const { return {derived(), key, false}; } -template accessor object_api::attr(handle key) const { return {derived(), key, true}; } -template accessor object_api::attr(const char *key) const { return {derived(), key, true}; } +template item_accessor object_api::operator[](handle key) const { return {derived(), object(key, true)}; } +template item_accessor object_api::operator[](const char *key) const { return {derived(), pybind11::str(key)}; } +template obj_attr_accessor object_api::attr(handle key) const { return {derived(), object(key, true)}; } +template str_attr_accessor object_api::attr(const char *key) const { return {derived(), key}; } template args_proxy object_api::operator*() const { return {derived().ptr()}; } template template bool object_api::contains(T &&key) const { - return attr("__contains__").template cast()(std::forward(key)).template cast(); + return attr("__contains__")(std::forward(key)).template cast(); } template diff --git a/tests/constructor_stats.h b/tests/constructor_stats.h index 69e385ec6..5dd215f19 100644 --- a/tests/constructor_stats.h +++ b/tests/constructor_stats.h @@ -103,7 +103,7 @@ public: int alive() { // Force garbage collection to ensure any pending destructors are invoked: - py::module::import("gc").attr("collect").operator py::object()(); + py::module::import("gc").attr("collect")(); int total = 0; for (const auto &p : _instances) if (p.second > 0) total += p.second; return total; diff --git a/tests/test_python_types.cpp b/tests/test_python_types.cpp index 9dafe777c..4ab90e63a 100644 --- a/tests/test_python_types.cpp +++ b/tests/test_python_types.cpp @@ -203,7 +203,7 @@ test_initializer python_types([](py::module &m) { py::print("no new line here", "end"_a=" -- "); py::print("next print"); - auto py_stderr = py::module::import("sys").attr("stderr").cast(); + auto py_stderr = py::module::import("sys").attr("stderr"); py::print("this goes to stderr", "file"_a=py_stderr); py::print("flush", "flush"_a=true); @@ -222,4 +222,39 @@ test_initializer python_types([](py::module &m) { auto d2 = py::dict("z"_a=3, **d1); return d2; }); + + m.def("test_accessor_api", [](py::object o) { + auto d = py::dict(); + + d["basic_attr"] = o.attr("basic_attr"); + + auto l = py::list(); + for (const auto &item : o.attr("begin_end")) { + l.append(item); + } + d["begin_end"] = l; + + d["operator[object]"] = o.attr("d")["operator[object]"_s]; + d["operator[char *]"] = o.attr("d")["operator[char *]"]; + + d["attr(object)"] = o.attr("sub").attr("attr_obj"); + d["attr(char *)"] = o.attr("sub").attr("attr_char"); + try { + o.attr("sub").attr("missing").ptr(); + } catch (const py::error_already_set &) { + d["missing_attr_ptr"] = "raised"_s; + } + try { + o.attr("missing").attr("doesn't matter"); + } catch (const py::error_already_set &) { + d["missing_attr_chain"] = "raised"_s; + } + + d["is_none"] = py::cast(o.attr("basic_attr").is_none()); + + d["operator()"] = o.attr("func")(1); + d["operator*"] = o.attr("func")(*o.attr("begin_end")); + + return d; + }); }); diff --git a/tests/test_python_types.py b/tests/test_python_types.py index fe58f9321..4f2cdb2c3 100644 --- a/tests/test_python_types.py +++ b/tests/test_python_types.py @@ -248,3 +248,33 @@ def test_dict_api(): from pybind11_tests import test_dict_keyword_constructor assert test_dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} + + +def test_accessors(): + from pybind11_tests import test_accessor_api + + class SubTestObject: + attr_obj = 1 + attr_char = 2 + + class TestObject: + basic_attr = 1 + begin_end = [1, 2, 3] + d = {"operator[object]": 1, "operator[char *]": 2} + sub = SubTestObject() + + def func(self, x, *args): + return self.basic_attr + x + sum(args) + + d = test_accessor_api(TestObject()) + assert d["basic_attr"] == 1 + assert d["begin_end"] == [1, 2, 3] + assert d["operator[object]"] == 1 + assert d["operator[char *]"] == 2 + assert d["attr(object)"] == 1 + assert d["attr(char *)"] == 2 + assert d["missing_attr_ptr"] == "raised" + assert d["missing_attr_chain"] == "raised" + assert d["is_none"] is False + assert d["operator()"] == 2 + assert d["operator*"] == 7