aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_free_threading
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_free_threading')
-rw-r--r--Lib/test/test_free_threading/test_heapq.py240
1 files changed, 240 insertions, 0 deletions
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py
new file mode 100644
index 00000000000..f75fb264c8a
--- /dev/null
+++ b/Lib/test/test_free_threading/test_heapq.py
@@ -0,0 +1,240 @@
+import unittest
+
+import heapq
+
+from enum import Enum
+from threading import Thread, Barrier
+from random import shuffle, randint
+
+from test.support import threading_helper
+from test import test_heapq
+
+
+NTHREADS = 10
+OBJECT_COUNT = 5_000
+
+
+class Heap(Enum):
+ MIN = 1
+ MAX = 2
+
+
+@threading_helper.requires_working_threading()
+class TestHeapq(unittest.TestCase):
+ def setUp(self):
+ self.test_heapq = test_heapq.TestHeapPython()
+
+ def test_racing_heapify(self):
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappush(self):
+ heap = []
+
+ def heappush_func(heap):
+ for item in reversed(range(OBJECT_COUNT)):
+ heapq.heappush(heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_func(heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop(heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_ascending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_func,
+ args=(heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), 0)
+
+ def test_racing_heappushpop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_func(heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop(heap, item)
+ self.assertTrue(popped_item <= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_func,
+ args=(heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapreplace(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_func(heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace(heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_func,
+ args=(heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapify_max(self):
+ max_heap = list(range(OBJECT_COUNT))
+ shuffle(max_heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappush_max(self):
+ max_heap = []
+
+ def heappush_max_func(max_heap):
+ for item in range(OBJECT_COUNT):
+ heapq.heappush_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_max_func(max_heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop_max(max_heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_descending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_max_func,
+ args=(max_heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), 0)
+
+ def test_racing_heappushpop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_max_func(max_heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop_max(max_heap, item)
+ self.assertTrue(popped_item >= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_max_func,
+ args=(max_heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heapreplace_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_max_func(max_heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_max_func,
+ args=(max_heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ @staticmethod
+ def is_sorted_ascending(lst):
+ """
+ Check if the list is sorted in ascending order (non-decreasing).
+ """
+ return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def is_sorted_descending(lst):
+ """
+ Check if the list is sorted in descending order (non-increasing).
+ """
+ return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def create_heap(size, heap_kind):
+ """
+ Create a min/max heap where elements are in the range (0, size - 1) and
+ shuffled before heapify.
+ """
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+ if heap_kind == Heap.MIN:
+ heapq.heapify(heap)
+ else:
+ heapq.heapify_max(heap)
+
+ return heap
+
+ @staticmethod
+ def create_random_list(a, b, size):
+ """
+ Create a list of random numbers between a and b (inclusive).
+ """
+ return [randint(-a, b) for _ in range(size)]
+
+ def run_concurrently(self, worker_func, args, nthreads):
+ """
+ Run the worker function concurrently in multiple threads.
+ """
+ barrier = Barrier(nthreads)
+
+ def wrapper_func(*args):
+ # Wait for all threads to reach this point before proceeding.
+ barrier.wait()
+ worker_func(*args)
+
+ with threading_helper.catch_threading_exception() as cm:
+ workers = (
+ Thread(target=wrapper_func, args=args) for _ in range(nthreads)
+ )
+ with threading_helper.start_threads(workers):
+ pass
+
+ # Worker threads should not raise any exceptions
+ self.assertIsNone(cm.exc_value)
+
+
+if __name__ == "__main__":
+ unittest.main()