diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index b8b7c7a02..c2b88688a 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1979,6 +1979,16 @@ inline std::pair all_t // gets destroyed: weakref((PyObject *) type, cpp_function([type](handle wr) { get_internals().registered_types_py.erase(type); + + // TODO consolidate the erasure code in pybind11_meta_dealloc() in class.h + auto &cache = get_internals().inactive_override_cache; + for (auto it = cache.begin(), last = cache.end(); it != last; ) { + if (it->first == reinterpret_cast(type)) + it = cache.erase(it); + else + ++it; + } + wr.dec_ref(); })).release(); } diff --git a/tests/test_embed/CMakeLists.txt b/tests/test_embed/CMakeLists.txt index 3b89d6e58..edb8961a7 100644 --- a/tests/test_embed/CMakeLists.txt +++ b/tests/test_embed/CMakeLists.txt @@ -25,7 +25,7 @@ pybind11_enable_warnings(test_embed) target_link_libraries(test_embed PRIVATE pybind11::embed Catch2::Catch2 Threads::Threads) if(NOT CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR) - file(COPY test_interpreter.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") + file(COPY test_interpreter.py test_trampoline.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") endif() add_custom_target( diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index 20bcade0a..508975eb3 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -37,6 +37,22 @@ class PyWidget final : public Widget { std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); } }; +class test_override_cache_helper { + +public: + virtual int func() { return 0; } + + test_override_cache_helper() = default; + virtual ~test_override_cache_helper() = default; + // Non-copyable + test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete; + test_override_cache_helper(test_override_cache_helper const &Copy) = delete; +}; + +class test_override_cache_helper_trampoline : public test_override_cache_helper { + int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); } +}; + PYBIND11_EMBEDDED_MODULE(widget_module, m) { py::class_(m, "Widget") .def(py::init()) @@ -45,6 +61,12 @@ PYBIND11_EMBEDDED_MODULE(widget_module, m) { m.def("add", [](int i, int j) { return i + j; }); } +PYBIND11_EMBEDDED_MODULE(trampoline_module, m) { + py::class_>(m, "test_override_cache_helper") + .def(py::init_alias<>()) + .def("func", &test_override_cache_helper::func); +} + PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); } @@ -73,6 +95,33 @@ TEST_CASE("Pass classes and data between modules defined in C++ and Python") { REQUIRE(cpp_widget.the_answer() == 42); } +TEST_CASE("Override cache") { + auto module_ = py::module_::import("test_trampoline"); + REQUIRE(py::hasattr(module_, "func")); + REQUIRE(py::hasattr(module_, "func2")); + + auto locals = py::dict(**module_.attr("__dict__")); + + int i = 0; + for (; i < 1500; ++i) { + std::shared_ptr p_obj; + std::shared_ptr p_obj2; + + py::object loc_inst = locals["func"](); + p_obj = py::cast>(loc_inst); + + int ret = p_obj->func(); + + REQUIRE(ret == 42); + + loc_inst = locals["func2"](); + + p_obj2 = py::cast>(loc_inst); + + p_obj2->func(); + } +} + TEST_CASE("Import error handling") { REQUIRE_NOTHROW(py::module_::import("widget_module")); REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), diff --git a/tests/test_embed/test_trampoline.py b/tests/test_embed/test_trampoline.py new file mode 100644 index 000000000..87c8fa44c --- /dev/null +++ b/tests/test_embed/test_trampoline.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + +import trampoline_module + + +def func(): + class Test(trampoline_module.test_override_cache_helper): + def func(self): + return 42 + + return Test() + + +def func2(): + class Test(trampoline_module.test_override_cache_helper): + pass + + return Test() diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index 6e06db9fc..f1a513180 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -214,6 +214,25 @@ static void test_gil_from_thread() { t.join(); } +class test_override_cache_helper { + +public: + virtual int func() { return 0; } + + test_override_cache_helper() = default; + virtual ~test_override_cache_helper() = default; + // Non-copyable + test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete; + test_override_cache_helper(test_override_cache_helper const &Copy) = delete; +}; + +class test_override_cache_helper_trampoline : public test_override_cache_helper { + int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); } +}; + +inline int test_override_cache(std::shared_ptr const &instance) { return instance->func(); } + + // Forward declaration (so that we can put the main tests here; the inherited virtual approaches are // rather long). @@ -378,6 +397,12 @@ TEST_SUBMODULE(virtual_functions, m) { // .def("str_ref", &OverrideTest::str_ref) .def("A_value", &OverrideTest::A_value) .def("A_ref", &OverrideTest::A_ref); + + py::class_>(m, "test_override_cache_helper") + .def(py::init_alias<>()) + .def("func", &test_override_cache_helper::func); + + m.def("test_override_cache", test_override_cache); } diff --git a/tests/test_virtual_functions.py b/tests/test_virtual_functions.py index 0b550992f..4f25cac4a 100644 --- a/tests/test_virtual_functions.py +++ b/tests/test_virtual_functions.py @@ -439,3 +439,22 @@ def test_issue_1454(): # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7) m.test_gil() m.test_gil_from_thread() + + +def test_python_override(): + def func(): + class Test(m.test_override_cache_helper): + def func(self): + return 42 + + return Test() + + def func2(): + class Test(m.test_override_cache_helper): + pass + + return Test() + + for _ in range(1500): + assert m.test_override_cache(func()) == 42 + assert m.test_override_cache(func2()) == 0