mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-22 13:15:12 +00:00
Added function for reloading module (#1040)
This commit is contained in:
parent
2cf87a54d8
commit
c64e6b1670
@ -133,6 +133,11 @@ embedding the interpreter. This makes it easy to import local Python files:
|
|||||||
int n = result.cast<int>();
|
int n = result.cast<int>();
|
||||||
assert(n == 3);
|
assert(n == 3);
|
||||||
|
|
||||||
|
Modules can be reloaded using `module::reload()` if the source is modified e.g.
|
||||||
|
by an external process. This can be useful in scenarios where the application
|
||||||
|
imports a user defined data processing script which needs to be updated after
|
||||||
|
changes by the user. Note that this function does not reload modules recursively.
|
||||||
|
|
||||||
.. _embedding_modules:
|
.. _embedding_modules:
|
||||||
|
|
||||||
Adding embedded modules
|
Adding embedded modules
|
||||||
@ -185,7 +190,7 @@ naturally:
|
|||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PYBIND11_EMBEDDED_MODULE(cpp_module, m) {
|
PYBIND11_EMBEDDED_MODULE(cpp_module, m) {
|
||||||
m.attr("a") = 1
|
m.attr("a") = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
@ -836,6 +836,14 @@ public:
|
|||||||
return reinterpret_steal<module>(obj);
|
return reinterpret_steal<module>(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reload the module or throws `error_already_set`.
|
||||||
|
void reload() {
|
||||||
|
PyObject *obj = PyImport_ReloadModule(ptr());
|
||||||
|
if (!obj)
|
||||||
|
throw error_already_set();
|
||||||
|
*this = reinterpret_steal<module>(obj);
|
||||||
|
}
|
||||||
|
|
||||||
// Adds an object to the module using the given name. Throws if an object with the given name
|
// Adds an object to the module using the given name. Throws if an object with the given name
|
||||||
// already exists.
|
// already exists.
|
||||||
//
|
//
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
#include <catch.hpp>
|
#include <catch.hpp>
|
||||||
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <fstream>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace py::literals;
|
using namespace py::literals;
|
||||||
@ -216,3 +218,52 @@ TEST_CASE("Threads") {
|
|||||||
|
|
||||||
REQUIRE(locals["count"].cast<int>() == num_threads);
|
REQUIRE(locals["count"].cast<int>() == num_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scope exit utility https://stackoverflow.com/a/36644501/7255855
|
||||||
|
struct scope_exit {
|
||||||
|
std::function<void()> f_;
|
||||||
|
explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
|
||||||
|
~scope_exit() { if (f_) f_(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("Reload module from file") {
|
||||||
|
// Disable generation of cached bytecode (.pyc files) for this test, otherwise
|
||||||
|
// Python might pick up an old version from the cache instead of the new versions
|
||||||
|
// of the .py files generated below
|
||||||
|
auto sys = py::module::import("sys");
|
||||||
|
bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
|
||||||
|
sys.attr("dont_write_bytecode") = true;
|
||||||
|
// Reset the value at scope exit
|
||||||
|
scope_exit reset_dont_write_bytecode([&]() {
|
||||||
|
sys.attr("dont_write_bytecode") = dont_write_bytecode;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::string module_name = "test_module_reload";
|
||||||
|
std::string module_file = module_name + ".py";
|
||||||
|
|
||||||
|
// Create the module .py file
|
||||||
|
std::ofstream test_module(module_file);
|
||||||
|
test_module << "def test():\n";
|
||||||
|
test_module << " return 1\n";
|
||||||
|
test_module.close();
|
||||||
|
// Delete the file at scope exit
|
||||||
|
scope_exit delete_module_file([&]() {
|
||||||
|
std::remove(module_file.c_str());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Import the module from file
|
||||||
|
auto module = py::module::import(module_name.c_str());
|
||||||
|
int result = module.attr("test")().cast<int>();
|
||||||
|
REQUIRE(result == 1);
|
||||||
|
|
||||||
|
// Update the module .py file with a small change
|
||||||
|
test_module.open(module_file);
|
||||||
|
test_module << "def test():\n";
|
||||||
|
test_module << " return 2\n";
|
||||||
|
test_module.close();
|
||||||
|
|
||||||
|
// Reload the module
|
||||||
|
module.reload();
|
||||||
|
result = module.attr("test")().cast<int>();
|
||||||
|
REQUIRE(result == 2);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user