diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index 0898be014..d5d86dc6c 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector(t->tp_bases)) { check.push_back((PyTypeObject *) parent.ptr()); } - auto const &type_dict = get_internals().registered_types_py; for (size_t i = 0; i < check.size(); i++) { auto *type = check[i]; @@ -176,13 +175,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) { - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - } - - return ins.first->second; + return all_type_info_get_cache(type).first->second; } /** diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 2527d25fa..b4f93f1a6 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { inline std::pair all_type_info_get_cache(PyTypeObject *type) { auto res = with_internals([type](internals &internals) { - return internals - .registered_types_py + auto ins = internals + .registered_types_py #ifdef __cpp_lib_unordered_map_try_emplace - .try_emplace(type); + .try_emplace(type); #else - .emplace(type, std::vector()); + .emplace(type, std::vector()); #endif + if (ins.second) { + // For free-threading mode, this call must be under + // the with_internals() mutex lock, to avoid that other threads + // continue running with the empty ins.first->second. + all_type_info_populate(type, ins.first->second); + } + return ins; }); if (res.second) { // New cache entry created; set up a weak reference to automatically remove it if the type diff --git a/tests/pybind11_tests.cpp b/tests/pybind11_tests.cpp index 3d2d84e77..818d53a54 100644 --- a/tests/pybind11_tests.cpp +++ b/tests/pybind11_tests.cpp @@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) { for (const auto &initializer : initializers()) { initializer(m); } + + py::class_(m, "TestContext") + .def(py::init<>(&TestContext::createNewContextForInit)) + .def("__enter__", &TestContext::contextEnter) + .def("__exit__", &TestContext::contextExit); } diff --git a/tests/pybind11_tests.h b/tests/pybind11_tests.h index 7be58feb6..0eb0398df 100644 --- a/tests/pybind11_tests.h +++ b/tests/pybind11_tests.h @@ -96,3 +96,24 @@ void ignoreOldStyleInitWarnings(F &&body) { )", py::dict(py::arg("body") = py::cpp_function(body))); } + +// See PR #5419 for background. +class TestContext { +public: + TestContext() = delete; + TestContext(const TestContext &) = delete; + TestContext(TestContext &&) = delete; + static TestContext *createNewContextForInit() { return new TestContext("new-context"); } + + pybind11::object contextEnter() { + py::object contextObj = py::cast(*this); + return contextObj; + } + void contextExit(const pybind11::object & /*excType*/, + const pybind11::object & /*excVal*/, + const pybind11::object & /*excTb*/) {} + +private: + explicit TestContext(const std::string &context) : context(context) {} + std::string context; +}; diff --git a/tests/test_class.py b/tests/test_class.py index f424db5c3..01963d012 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -508,3 +509,31 @@ def test_pr4220_tripped_over_this(): m.Empty0().get_msg() == "This is really only meant to exercise successful compilation." ) + + +@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads") +def test_all_type_info_multithreaded(): + # See PR #5419 for background. + import threading + + from pybind11_tests import TestContext + + class Context(TestContext): + pass + + num_runs = 10 + num_threads = 4 + barrier = threading.Barrier(num_threads) + + def func(): + barrier.wait() + with Context(): + pass + + for _ in range(num_runs): + threads = [threading.Thread(target=func) for _ in range(num_threads)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join()