From 1e5a7da30dded4f302a13c42a30bfdca4b12cd54 Mon Sep 17 00:00:00 2001 From: Dean Moldovan Date: Thu, 24 Aug 2017 01:53:15 +0200 Subject: [PATCH] Add py::pickle() adaptor for safer __getstate__/__setstate__ bindings This is analogous to `py::init()` vs `__init__` + placement-new. `py::pickle()` reuses most of the implementation details of `py::init()`. --- docs/advanced/classes.rst | 44 ++++++++++++++++++------------ docs/changelog.rst | 5 ++++ docs/upgrade.rst | 33 +++++++++++++++++++++++ include/pybind11/detail/init.h | 49 ++++++++++++++++++++++++++++++++++ include/pybind11/pybind11.h | 13 +++++++++ tests/test_pickling.cpp | 47 ++++++++++++++++++++++++++++++++ tests/test_pickling.py | 12 ++++++--- 7 files changed, 182 insertions(+), 21 deletions(-) diff --git a/docs/advanced/classes.rst b/docs/advanced/classes.rst index 7bcd0385a..be4bc2e77 100644 --- a/docs/advanced/classes.rst +++ b/docs/advanced/classes.rst @@ -687,13 +687,15 @@ throwing a type error. complete example that demonstrates how to work with overloaded operators in more detail. +.. _pickling: + Pickling support ================ Python's ``pickle`` module provides a powerful facility to serialize and de-serialize a Python object graph into a binary data stream. To pickle and -unpickle C++ classes using pybind11, two additional functions must be provided. -Suppose the class in question has the following signature: +unpickle C++ classes using pybind11, a ``py::pickle()`` definition must be +provided. Suppose the class in question has the following signature: .. code-block:: cpp @@ -709,8 +711,9 @@ Suppose the class in question has the following signature: int m_extra = 0; }; -The binding code including the requisite ``__setstate__`` and ``__getstate__`` methods [#f3]_ -looks as follows: +Pickling support in Python is enable by defining the ``__setstate__`` and +``__getstate__`` methods [#f3]_. For pybind11 classes, use ``py::pickle()`` +to bind these two functions: .. code-block:: cpp @@ -719,21 +722,28 @@ looks as follows: .def("value", &Pickleable::value) .def("extra", &Pickleable::extra) .def("setExtra", &Pickleable::setExtra) - .def("__getstate__", [](const Pickleable &p) { - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(p.value(), p.extra()); - }) - .def("__setstate__", [](Pickleable &p, py::tuple t) { - if (t.size() != 2) - throw std::runtime_error("Invalid state!"); + .def(py::pickle( + [](const Pickleable &p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p.value(), p.extra()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) + throw std::runtime_error("Invalid state!"); - /* Invoke the in-place constructor. Note that this is needed even - when the object just has a trivial default constructor */ - new (&p) Pickleable(t[0].cast()); + /* Create a new C++ instance */ + Pickleable p(t[0].cast()); - /* Assign any additional state */ - p.setExtra(t[1].cast()); - }); + /* Assign any additional state */ + p.setExtra(t[1].cast()); + + return p; + } + )); + +The ``__setstate__`` part of the ``py::picke()`` definition follows the same +rules as the single-argument version of ``py::init()``. The return type can be +a value, pointer or holder type. See :ref:`custom_constructors` for details. An instance can now be pickled as follows: diff --git a/docs/changelog.rst b/docs/changelog.rst index 23d556370..478b7d7ab 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -91,6 +91,11 @@ v2.2.0 (Not yet released) return std::make_unique(std::to_string(n)); })); +* Similarly to custom constructors, pickling support functions are now bound + using the ``py::pickle()`` adaptor which improves type safety. See the + :doc:`upgrade` and :ref:`pickling` for details. + `#1038 `_. + * Builtin support for converting C++17 standard library types and general conversion improvements: diff --git a/docs/upgrade.rst b/docs/upgrade.rst index 2fe8470b7..bcbc6b135 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -162,6 +162,39 @@ See :ref:`custom_constructors` for details. })); +New syntax for pickling support +------------------------------- + +Mirroring the custom constructor changes, ``py::pickle()`` is now the preferred +way to get and set object state. See :ref:`pickling` for details. + +.. code-block:: cpp + + // old -- deprecated + py::class(m, "Foo") + ... + .def("__getstate__", [](const Foo &self) { + return py::make_tuple(self.value1(), self.value2(), ...); + }) + .def("__setstate__", [](Foo &self, py::tuple t) { + new (&self) Foo(t[0].cast(), ...); + }); + + // new + py::class(m, "Foo") + ... + .def(py::pickle( + [](const Foo &self) { // __getstate__ + return py::make_tuple(f.value1(), f.value2(), ...); // unchanged + }, + [](py::tuple t) { // __setstate__, note: no `self` argument + return new Foo(t[0].cast(), ...); + // or: return std::make_unique(...); // return by holder + // or: return Foo(...); // return by value (move constructor) + } + )); + + Deprecation of some ``py::object`` APIs --------------------------------------- diff --git a/include/pybind11/detail/init.h b/include/pybind11/detail/init.h index ee2de8c2a..deace19c0 100644 --- a/include/pybind11/detail/init.h +++ b/include/pybind11/detail/init.h @@ -271,6 +271,55 @@ struct factory { } }; +/// Set just the C++ state. Same as `__init__`. +template +void setstate(value_and_holder &v_h, T &&result, bool need_alias) { + construct(v_h, std::forward(result), need_alias); +} + +/// Set both the C++ and Python states +template ::value, int> = 0> +void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { + construct(v_h, std::move(result.first), need_alias); + setattr((PyObject *) v_h.inst, "__dict__", result.second); +} + +/// Implementation for py::pickle(GetState, SetState) +template , typename = function_signature_t> +struct pickle_factory; + +template +struct pickle_factory { + static_assert(std::is_same::value, + "The type returned by `__getstate__` must be the same " + "as the argument accepted by `__setstate__`"); + + remove_reference_t get; + remove_reference_t set; + + pickle_factory(Get get, Set set) + : get(std::forward(get)), set(std::forward(set)) { } + + template + void execute(Class &cl, const Extra &...extra) && { + cl.def("__getstate__", std::move(get)); + +#if defined(PYBIND11_CPP14) + cl.def("__setstate__", [func = std::move(set)] +#else + auto &func = set; + cl.def("__setstate__", [func] +#endif + (value_and_holder &v_h, ArgState state) { + setstate(v_h, func(std::forward(state)), + Py_TYPE(v_h.inst) != v_h.type->type); + }, is_new_style_constructor(), extra...); + } +}; + NAMESPACE_END(initimpl) NAMESPACE_END(detail) NAMESPACE_END(pybind11) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 16e3fdc3e..0e67c4060 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1095,6 +1095,12 @@ public: return *this; } + template + class_ &def(detail::initimpl::pickle_factory &&pf, const Extra &...extra) { + std::move(pf).execute(*this, extra...); + return *this; + } + template class_& def_buffer(Func &&func) { struct capture { Func func; }; capture *ptr = new capture { std::forward(func) }; @@ -1399,6 +1405,13 @@ Ret init(CFunc &&c, AFunc &&a) { return {std::forward(c), std::forward(a)}; } +/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type +/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`. +template +detail::initimpl::pickle_factory pickle(GetState &&g, SetState &&s) { + return {std::forward(g), std::forward(s)}; +}; + NAMESPACE_BEGIN(detail) diff --git a/tests/test_pickling.cpp b/tests/test_pickling.cpp index 1e5f4ce74..821462ac4 100644 --- a/tests/test_pickling.cpp +++ b/tests/test_pickling.cpp @@ -25,6 +25,12 @@ TEST_SUBMODULE(pickling, m) { int m_extra1 = 0; int m_extra2 = 0; }; + + class PickleableNew : public Pickleable { + public: + using Pickleable::Pickleable; + }; + py::class_(m, "Pickleable") .def(py::init()) .def("value", &Pickleable::value) @@ -49,6 +55,23 @@ TEST_SUBMODULE(pickling, m) { p.setExtra2(t[2].cast()); }); + py::class_(m, "PickleableNew") + .def(py::init()) + .def(py::pickle( + [](const PickleableNew &p) { + return py::make_tuple(p.value(), p.extra1(), p.extra2()); + }, + [](py::tuple t) { + if (t.size() != 3) + throw std::runtime_error("Invalid state!"); + auto p = PickleableNew(t[0].cast()); + + p.setExtra1(t[1].cast()); + p.setExtra2(t[2].cast()); + return p; + } + )); + #if !defined(PYPY_VERSION) // test_roundtrip_with_dict class PickleableWithDict { @@ -58,6 +81,12 @@ TEST_SUBMODULE(pickling, m) { std::string value; int extra; }; + + class PickleableWithDictNew : public PickleableWithDict { + public: + using PickleableWithDict::PickleableWithDict; + }; + py::class_(m, "PickleableWithDict", py::dynamic_attr()) .def(py::init()) .def_readwrite("value", &PickleableWithDict::value) @@ -79,5 +108,23 @@ TEST_SUBMODULE(pickling, m) { /* Assign Python state */ self.attr("__dict__") = t[2]; }); + + py::class_(m, "PickleableWithDictNew") + .def(py::init()) + .def(py::pickle( + [](py::object self) { + return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__")); + }, + [](py::tuple t) { + if (t.size() != 3) + throw std::runtime_error("Invalid state!"); + + auto cpp_state = PickleableWithDictNew(t[0].cast()); + cpp_state.extra = t[1].cast(); + + auto py_state = t[2].cast(); + return std::make_pair(cpp_state, py_state); + } + )); #endif } diff --git a/tests/test_pickling.py b/tests/test_pickling.py index 6cbcdf516..707d34786 100644 --- a/tests/test_pickling.py +++ b/tests/test_pickling.py @@ -7,8 +7,10 @@ except ImportError: import pickle -def test_roundtrip(): - p = m.Pickleable("test_value") +@pytest.mark.parametrize("cls_name", ["Pickleable", "PickleableNew"]) +def test_roundtrip(cls_name): + cls = getattr(m, cls_name) + p = cls("test_value") p.setExtra1(15) p.setExtra2(48) @@ -20,8 +22,10 @@ def test_roundtrip(): @pytest.unsupported_on_pypy -def test_roundtrip_with_dict(): - p = m.PickleableWithDict("test_value") +@pytest.mark.parametrize("cls_name", ["PickleableWithDict", "PickleableWithDictNew"]) +def test_roundtrip_with_dict(cls_name): + cls = getattr(m, cls_name) + p = cls("test_value") p.extra = 15 p.dynamic = "Attribute"