mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-21 20:55:11 +00:00
Fixed data race in all_type_info in free-threading mode (#5419)
* Fix data race all_type_info_populate in free-threading mode Description: - fixed data race all_type_info_populate in free-threading mode - added test For example, we have 2 threads entering `all_type_info`. Both enter `all_type_info_get_cache`` function and there is a first one which inserts a tuple (type, empty_vector) to the map and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread after waiting gets (iter_to_key, False). Inserting thread than will add a weakref and will then call into `all_type_info_populate`. However, non-inserting thread is not entering `if (ins.second) {` clause and returns `ins.first->second;`` which is just empty_vector. Finally, non-inserting thread is failing the check in `allocate_layout`: ```c++ if (n_types == 0) { pybind11_fail( "instance allocation failed: new instance has no pybind11-registered base types"); } ``` * style: pre-commit fixes * Addressed PR comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f46f5be4fa
commit
ce2f005594
@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
|
|||||||
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
|
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
|
||||||
check.push_back((PyTypeObject *) parent.ptr());
|
check.push_back((PyTypeObject *) parent.ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto const &type_dict = get_internals().registered_types_py;
|
auto const &type_dict = get_internals().registered_types_py;
|
||||||
for (size_t i = 0; i < check.size(); i++) {
|
for (size_t i = 0; i < check.size(); i++) {
|
||||||
auto *type = check[i];
|
auto *type = check[i];
|
||||||
@ -176,13 +175,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
|
|||||||
* The value is cached for the lifetime of the Python type.
|
* The value is cached for the lifetime of the Python type.
|
||||||
*/
|
*/
|
||||||
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
|
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
|
||||||
auto ins = all_type_info_get_cache(type);
|
return all_type_info_get_cache(type).first->second;
|
||||||
if (ins.second) {
|
|
||||||
// New cache entry: populate it
|
|
||||||
all_type_info_populate(type, ins.first->second);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ins.first->second;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) {
|
|||||||
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
|
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
|
||||||
all_type_info_get_cache(PyTypeObject *type) {
|
all_type_info_get_cache(PyTypeObject *type) {
|
||||||
auto res = with_internals([type](internals &internals) {
|
auto res = with_internals([type](internals &internals) {
|
||||||
return internals
|
auto ins = internals
|
||||||
.registered_types_py
|
.registered_types_py
|
||||||
#ifdef __cpp_lib_unordered_map_try_emplace
|
#ifdef __cpp_lib_unordered_map_try_emplace
|
||||||
.try_emplace(type);
|
.try_emplace(type);
|
||||||
#else
|
#else
|
||||||
.emplace(type, std::vector<detail::type_info *>());
|
.emplace(type, std::vector<detail::type_info *>());
|
||||||
#endif
|
#endif
|
||||||
|
if (ins.second) {
|
||||||
|
// For free-threading mode, this call must be under
|
||||||
|
// the with_internals() mutex lock, to avoid that other threads
|
||||||
|
// continue running with the empty ins.first->second.
|
||||||
|
all_type_info_populate(type, ins.first->second);
|
||||||
|
}
|
||||||
|
return ins;
|
||||||
});
|
});
|
||||||
if (res.second) {
|
if (res.second) {
|
||||||
// New cache entry created; set up a weak reference to automatically remove it if the type
|
// New cache entry created; set up a weak reference to automatically remove it if the type
|
||||||
|
@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) {
|
|||||||
for (const auto &initializer : initializers()) {
|
for (const auto &initializer : initializers()) {
|
||||||
initializer(m);
|
initializer(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::class_<TestContext>(m, "TestContext")
|
||||||
|
.def(py::init<>(&TestContext::createNewContextForInit))
|
||||||
|
.def("__enter__", &TestContext::contextEnter)
|
||||||
|
.def("__exit__", &TestContext::contextExit);
|
||||||
}
|
}
|
||||||
|
@ -96,3 +96,24 @@ void ignoreOldStyleInitWarnings(F &&body) {
|
|||||||
)",
|
)",
|
||||||
py::dict(py::arg("body") = py::cpp_function(body)));
|
py::dict(py::arg("body") = py::cpp_function(body)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See PR #5419 for background.
|
||||||
|
class TestContext {
|
||||||
|
public:
|
||||||
|
TestContext() = delete;
|
||||||
|
TestContext(const TestContext &) = delete;
|
||||||
|
TestContext(TestContext &&) = delete;
|
||||||
|
static TestContext *createNewContextForInit() { return new TestContext("new-context"); }
|
||||||
|
|
||||||
|
pybind11::object contextEnter() {
|
||||||
|
py::object contextObj = py::cast(*this);
|
||||||
|
return contextObj;
|
||||||
|
}
|
||||||
|
void contextExit(const pybind11::object & /*excType*/,
|
||||||
|
const pybind11::object & /*excVal*/,
|
||||||
|
const pybind11::object & /*excTb*/) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit TestContext(const std::string &context) : context(context) {}
|
||||||
|
std::string context;
|
||||||
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -508,3 +509,31 @@ def test_pr4220_tripped_over_this():
|
|||||||
m.Empty0().get_msg()
|
m.Empty0().get_msg()
|
||||||
== "This is really only meant to exercise successful compilation."
|
== "This is really only meant to exercise successful compilation."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
|
||||||
|
def test_all_type_info_multithreaded():
|
||||||
|
# See PR #5419 for background.
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from pybind11_tests import TestContext
|
||||||
|
|
||||||
|
class Context(TestContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
num_runs = 10
|
||||||
|
num_threads = 4
|
||||||
|
barrier = threading.Barrier(num_threads)
|
||||||
|
|
||||||
|
def func():
|
||||||
|
barrier.wait()
|
||||||
|
with Context():
|
||||||
|
pass
|
||||||
|
|
||||||
|
for _ in range(num_runs):
|
||||||
|
threads = [threading.Thread(target=func) for _ in range(num_threads)]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
Loading…
Reference in New Issue
Block a user