aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_free_threading
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_free_threading')
-rw-r--r--Lib/test/test_free_threading/test_io.py42
-rw-r--r--Lib/test/test_free_threading/test_itertools.py32
-rw-r--r--Lib/test/test_free_threading/test_itertools_combinatoric.py51
3 files changed, 106 insertions, 19 deletions
diff --git a/Lib/test/test_free_threading/test_io.py b/Lib/test/test_free_threading/test_io.py
index f9bec740ddf..41d89e04da8 100644
--- a/Lib/test/test_free_threading/test_io.py
+++ b/Lib/test/test_free_threading/test_io.py
@@ -1,12 +1,13 @@
+import io
+import _pyio as pyio
import threading
from unittest import TestCase
from test.support import threading_helper
from random import randint
-from io import BytesIO
from sys import getsizeof
-class TestBytesIO(TestCase):
+class ThreadSafetyMixin:
# Test pretty much everything that can break under free-threading.
# Non-deterministic, but at least one of these things will fail if
# BytesIO object is not free-thread safe.
@@ -90,20 +91,27 @@ class TestBytesIO(TestCase):
barrier.wait()
getsizeof(b)
- self.check([write] * 10, BytesIO())
- self.check([writelines] * 10, BytesIO())
- self.check([write] * 10 + [truncate] * 10, BytesIO())
- self.check([truncate] + [read] * 10, BytesIO(b'0\n'*204800))
- self.check([truncate] + [read1] * 10, BytesIO(b'0\n'*204800))
- self.check([truncate] + [readline] * 10, BytesIO(b'0\n'*20480))
- self.check([truncate] + [readlines] * 10, BytesIO(b'0\n'*20480))
- self.check([truncate] + [readinto] * 10, BytesIO(b'0\n'*204800), bytearray(b'0\n'*204800))
- self.check([close] + [write] * 10, BytesIO())
- self.check([truncate] + [getvalue] * 10, BytesIO(b'0\n'*204800))
- self.check([truncate] + [getbuffer] * 10, BytesIO(b'0\n'*204800))
- self.check([truncate] + [iter] * 10, BytesIO(b'0\n'*20480))
- self.check([truncate] + [getstate] * 10, BytesIO(b'0\n'*204800))
- self.check([truncate] + [setstate] * 10, BytesIO(b'0\n'*204800), (b'123', 0, None))
- self.check([truncate] + [sizeof] * 10, BytesIO(b'0\n'*204800))
+ self.check([write] * 10, self.ioclass())
+ self.check([writelines] * 10, self.ioclass())
+ self.check([write] * 10 + [truncate] * 10, self.ioclass())
+ self.check([truncate] + [read] * 10, self.ioclass(b'0\n'*204800))
+ self.check([truncate] + [read1] * 10, self.ioclass(b'0\n'*204800))
+ self.check([truncate] + [readline] * 10, self.ioclass(b'0\n'*20480))
+ self.check([truncate] + [readlines] * 10, self.ioclass(b'0\n'*20480))
+ self.check([truncate] + [readinto] * 10, self.ioclass(b'0\n'*204800), bytearray(b'0\n'*204800))
+ self.check([close] + [write] * 10, self.ioclass())
+ self.check([truncate] + [getvalue] * 10, self.ioclass(b'0\n'*204800))
+ self.check([truncate] + [getbuffer] * 10, self.ioclass(b'0\n'*204800))
+ self.check([truncate] + [iter] * 10, self.ioclass(b'0\n'*20480))
+ self.check([truncate] + [getstate] * 10, self.ioclass(b'0\n'*204800))
+ state = self.ioclass(b'123').__getstate__()
+ self.check([truncate] + [setstate] * 10, self.ioclass(b'0\n'*204800), state)
+ self.check([truncate] + [sizeof] * 10, self.ioclass(b'0\n'*204800))
# no tests for seek or tell because they don't break anything
+
+class CBytesIOTest(ThreadSafetyMixin, TestCase):
+ ioclass = io.BytesIO
+
+class PyBytesIOTest(ThreadSafetyMixin, TestCase):
+ ioclass = pyio.BytesIO
diff --git a/Lib/test/test_free_threading/test_itertools.py b/Lib/test/test_free_threading/test_itertools.py
index b8663ade1d4..9d366041917 100644
--- a/Lib/test/test_free_threading/test_itertools.py
+++ b/Lib/test/test_free_threading/test_itertools.py
@@ -1,6 +1,6 @@
import unittest
from threading import Thread, Barrier
-from itertools import batched, cycle
+from itertools import batched, chain, cycle
from test.support import threading_helper
@@ -17,7 +17,7 @@ class ItertoolsThreading(unittest.TestCase):
barrier.wait()
while True:
try:
- _ = next(it)
+ next(it)
except StopIteration:
break
@@ -62,6 +62,34 @@ class ItertoolsThreading(unittest.TestCase):
barrier.reset()
+ @threading_helper.reap_threads
+ def test_chain(self):
+ number_of_threads = 6
+ number_of_iterations = 20
+
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ while True:
+ try:
+ next(it)
+ except StopIteration:
+ break
+
+ data = [(1, )] * 200
+ for it in range(number_of_iterations):
+ chain_iterator = chain(*data)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[chain_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_free_threading/test_itertools_combinatoric.py b/Lib/test/test_free_threading/test_itertools_combinatoric.py
new file mode 100644
index 00000000000..5b3b88deedd
--- /dev/null
+++ b/Lib/test/test_free_threading/test_itertools_combinatoric.py
@@ -0,0 +1,51 @@
+import unittest
+from threading import Thread, Barrier
+from itertools import combinations, product
+from test.support import threading_helper
+
+
+threading_helper.requires_working_threading(module=True)
+
+def test_concurrent_iteration(iterator, number_of_threads):
+ barrier = Barrier(number_of_threads)
+ def iterator_worker(it):
+ barrier.wait()
+ while True:
+ try:
+ _ = next(it)
+ except StopIteration:
+ return
+
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=iterator_worker, args=[iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+class ItertoolsThreading(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ def test_combinations(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = combinations((1, 2, 3, 4, 5), 2)
+ test_concurrent_iteration(iterator, number_of_threads)
+
+ @threading_helper.reap_threads
+ def test_product(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = product((1, 2, 3, 4, 5), (10, 20, 30))
+ test_concurrent_iteration(iterator, number_of_threads)
+
+
+if __name__ == "__main__":
+ unittest.main()