Added function for reloading module (#1040)

This commit is contained in:
Gunnar Läthén 2017-09-12 08:05:05 +02:00 committed by Dean Moldovan
parent 2cf87a54d8
commit c64e6b1670
3 changed files with 65 additions and 1 deletions

View File

@ -133,6 +133,11 @@ embedding the interpreter. This makes it easy to import local Python files:
int n = result.cast<int>(); int n = result.cast<int>();
assert(n == 3); assert(n == 3);
Modules can be reloaded using `module::reload()` if the source is modified e.g.
by an external process. This can be useful in scenarios where the application
imports a user defined data processing script which needs to be updated after
changes by the user. Note that this function does not reload modules recursively.
.. _embedding_modules: .. _embedding_modules:
Adding embedded modules Adding embedded modules
@ -185,7 +190,7 @@ naturally:
namespace py = pybind11; namespace py = pybind11;
PYBIND11_EMBEDDED_MODULE(cpp_module, m) { PYBIND11_EMBEDDED_MODULE(cpp_module, m) {
m.attr("a") = 1 m.attr("a") = 1;
} }
int main() { int main() {

View File

@ -836,6 +836,14 @@ public:
return reinterpret_steal<module>(obj); return reinterpret_steal<module>(obj);
} }
/// Reload the module or throws `error_already_set`.
void reload() {
PyObject *obj = PyImport_ReloadModule(ptr());
if (!obj)
throw error_already_set();
*this = reinterpret_steal<module>(obj);
}
// Adds an object to the module using the given name. Throws if an object with the given name // Adds an object to the module using the given name. Throws if an object with the given name
// already exists. // already exists.
// //

View File

@ -2,6 +2,8 @@
#include <catch.hpp> #include <catch.hpp>
#include <thread> #include <thread>
#include <fstream>
#include <functional>
namespace py = pybind11; namespace py = pybind11;
using namespace py::literals; using namespace py::literals;
@ -216,3 +218,52 @@ TEST_CASE("Threads") {
REQUIRE(locals["count"].cast<int>() == num_threads); REQUIRE(locals["count"].cast<int>() == num_threads);
} }
// Scope exit utility https://stackoverflow.com/a/36644501/7255855
struct scope_exit {
std::function<void()> f_;
explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
~scope_exit() { if (f_) f_(); }
};
TEST_CASE("Reload module from file") {
// Disable generation of cached bytecode (.pyc files) for this test, otherwise
// Python might pick up an old version from the cache instead of the new versions
// of the .py files generated below
auto sys = py::module::import("sys");
bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
sys.attr("dont_write_bytecode") = true;
// Reset the value at scope exit
scope_exit reset_dont_write_bytecode([&]() {
sys.attr("dont_write_bytecode") = dont_write_bytecode;
});
std::string module_name = "test_module_reload";
std::string module_file = module_name + ".py";
// Create the module .py file
std::ofstream test_module(module_file);
test_module << "def test():\n";
test_module << " return 1\n";
test_module.close();
// Delete the file at scope exit
scope_exit delete_module_file([&]() {
std::remove(module_file.c_str());
});
// Import the module from file
auto module = py::module::import(module_name.c_str());
int result = module.attr("test")().cast<int>();
REQUIRE(result == 1);
// Update the module .py file with a small change
test_module.open(module_file);
test_module << "def test():\n";
test_module << " return 2\n";
test_module.close();
// Reload the module
module.reload();
result = module.attr("test")().cast<int>();
REQUIRE(result == 2);
}