pybind11/tests/test_thread.py
Sam Gross 15d9dae14b
Fix data race when using shared variables (free threading) (#5494)
* Fix data race when using shared variables (free threading)

In the free threading build, there's a race between wrapper re-use and
wrapper deallocation. This can happen with a static variable accessed by
multiple threads.

Fixing this requires using some private CPython APIs: _Py_TryIncref and
_PyObject_SetMaybeWeakref. The implementations of these functions are
included until they're made available as public ("unstable") APIs.

Fixes #5489

* style: pre-commit fixes

* Avoid unused parameter

* Add missing return for default build

* Changes from review

* Assign result to local variable

* s/clang-tidy/ruff

* clang-tidy: static is redundant

* Use 'noqa: B018'

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-16 11:13:21 -08:00

69 lines
1.5 KiB
Python

from __future__ import annotations
import sys
import threading
import pytest
from pybind11_tests import thread as m
class Thread(threading.Thread):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.e = None
def run(self):
try:
for i in range(10):
self.fn(i, i)
except Exception as e:
self.e = e
def join(self):
super().join()
if self.e:
raise self.e
@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
def test_implicit_conversion():
a = Thread(m.test)
b = Thread(m.test)
c = Thread(m.test)
for x in [a, b, c]:
x.start()
for x in [c, b, a]:
x.join()
@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
def test_implicit_conversion_no_gil():
a = Thread(m.test_no_gil)
b = Thread(m.test_no_gil)
c = Thread(m.test_no_gil)
for x in [a, b, c]:
x.start()
for x in [c, b, a]:
x.join()
@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
def test_bind_shared_instance():
nb_threads = 4
b = threading.Barrier(nb_threads)
def access_shared_instance():
b.wait()
for _ in range(1000):
m.EmptyStruct.SharedInstance # noqa: B018
threads = [
threading.Thread(target=access_shared_instance) for _ in range(nb_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()