From c64e6b1670e57d103e52c4ca3b42e5e55858ca41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gunnar=20L=C3=A4th=C3=A9n?= Date: Tue, 12 Sep 2017 08:05:05 +0200 Subject: [PATCH] Added function for reloading module (#1040) --- docs/advanced/embedding.rst | 7 +++- include/pybind11/pybind11.h | 8 +++++ tests/test_embed/test_interpreter.cpp | 51 +++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/docs/advanced/embedding.rst b/docs/advanced/embedding.rst index bdfc75e0d..393031603 100644 --- a/docs/advanced/embedding.rst +++ b/docs/advanced/embedding.rst @@ -133,6 +133,11 @@ embedding the interpreter. This makes it easy to import local Python files: int n = result.cast(); assert(n == 3); +Modules can be reloaded using `module::reload()` if the source is modified e.g. +by an external process. This can be useful in scenarios where the application +imports a user defined data processing script which needs to be updated after +changes by the user. Note that this function does not reload modules recursively. + .. _embedding_modules: Adding embedded modules @@ -185,7 +190,7 @@ naturally: namespace py = pybind11; PYBIND11_EMBEDDED_MODULE(cpp_module, m) { - m.attr("a") = 1 + m.attr("a") = 1; } int main() { diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 3fcb99ff7..80102c5e8 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -836,6 +836,14 @@ public: return reinterpret_steal(obj); } + /// Reload the module or throws `error_already_set`. + void reload() { + PyObject *obj = PyImport_ReloadModule(ptr()); + if (!obj) + throw error_already_set(); + *this = reinterpret_steal(obj); + } + // Adds an object to the module using the given name. Throws if an object with the given name // already exists. // diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index acbad6bec..6b5f051f2 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include namespace py = pybind11; using namespace py::literals; @@ -216,3 +218,52 @@ TEST_CASE("Threads") { REQUIRE(locals["count"].cast() == num_threads); } + +// Scope exit utility https://stackoverflow.com/a/36644501/7255855 +struct scope_exit { + std::function f_; + explicit scope_exit(std::function f) noexcept : f_(std::move(f)) {} + ~scope_exit() { if (f_) f_(); } +}; + +TEST_CASE("Reload module from file") { + // Disable generation of cached bytecode (.pyc files) for this test, otherwise + // Python might pick up an old version from the cache instead of the new versions + // of the .py files generated below + auto sys = py::module::import("sys"); + bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast(); + sys.attr("dont_write_bytecode") = true; + // Reset the value at scope exit + scope_exit reset_dont_write_bytecode([&]() { + sys.attr("dont_write_bytecode") = dont_write_bytecode; + }); + + std::string module_name = "test_module_reload"; + std::string module_file = module_name + ".py"; + + // Create the module .py file + std::ofstream test_module(module_file); + test_module << "def test():\n"; + test_module << " return 1\n"; + test_module.close(); + // Delete the file at scope exit + scope_exit delete_module_file([&]() { + std::remove(module_file.c_str()); + }); + + // Import the module from file + auto module = py::module::import(module_name.c_str()); + int result = module.attr("test")().cast(); + REQUIRE(result == 1); + + // Update the module .py file with a small change + test_module.open(module_file); + test_module << "def test():\n"; + test_module << " return 2\n"; + test_module.close(); + + // Reload the module + module.reload(); + result = module.attr("test")().cast(); + REQUIRE(result == 2); +}