diff options
Diffstat (limited to 'Lib/test/test_free_threading')
-rw-r--r-- | Lib/test/test_free_threading/test_heapq.py | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py new file mode 100644 index 00000000000..f75fb264c8a --- /dev/null +++ b/Lib/test/test_free_threading/test_heapq.py @@ -0,0 +1,240 @@ +import unittest + +import heapq + +from enum import Enum +from threading import Thread, Barrier +from random import shuffle, randint + +from test.support import threading_helper +from test import test_heapq + + +NTHREADS = 10 +OBJECT_COUNT = 5_000 + + +class Heap(Enum): + MIN = 1 + MAX = 2 + + +@threading_helper.requires_working_threading() +class TestHeapq(unittest.TestCase): + def setUp(self): + self.test_heapq = test_heapq.TestHeapPython() + + def test_racing_heapify(self): + heap = list(range(OBJECT_COUNT)) + shuffle(heap) + + self.run_concurrently( + worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS + ) + self.test_heapq.check_invariant(heap) + + def test_racing_heappush(self): + heap = [] + + def heappush_func(heap): + for item in reversed(range(OBJECT_COUNT)): + heapq.heappush(heap, item) + + self.run_concurrently( + worker_func=heappush_func, args=(heap,), nthreads=NTHREADS + ) + self.test_heapq.check_invariant(heap) + + def test_racing_heappop(self): + heap = self.create_heap(OBJECT_COUNT, Heap.MIN) + + # Each thread pops (OBJECT_COUNT / NTHREADS) items + self.assertEqual(OBJECT_COUNT % NTHREADS, 0) + per_thread_pop_count = OBJECT_COUNT // NTHREADS + + def heappop_func(heap, pop_count): + local_list = [] + for _ in range(pop_count): + item = heapq.heappop(heap) + local_list.append(item) + + # Each local list should be sorted + self.assertTrue(self.is_sorted_ascending(local_list)) + + self.run_concurrently( + worker_func=heappop_func, + args=(heap, per_thread_pop_count), + nthreads=NTHREADS, + ) + self.assertEqual(len(heap), 0) + + def test_racing_heappushpop(self): + heap = self.create_heap(OBJECT_COUNT, Heap.MIN) + pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT) + + def heappushpop_func(heap, pushpop_items): + for item in pushpop_items: + popped_item = heapq.heappushpop(heap, item) + self.assertTrue(popped_item <= item) + + self.run_concurrently( + worker_func=heappushpop_func, + args=(heap, pushpop_items), + nthreads=NTHREADS, + ) + self.assertEqual(len(heap), OBJECT_COUNT) + self.test_heapq.check_invariant(heap) + + def test_racing_heapreplace(self): + heap = self.create_heap(OBJECT_COUNT, Heap.MIN) + replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT) + + def heapreplace_func(heap, replace_items): + for item in replace_items: + heapq.heapreplace(heap, item) + + self.run_concurrently( + worker_func=heapreplace_func, + args=(heap, replace_items), + nthreads=NTHREADS, + ) + self.assertEqual(len(heap), OBJECT_COUNT) + self.test_heapq.check_invariant(heap) + + def test_racing_heapify_max(self): + max_heap = list(range(OBJECT_COUNT)) + shuffle(max_heap) + + self.run_concurrently( + worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS + ) + self.test_heapq.check_max_invariant(max_heap) + + def test_racing_heappush_max(self): + max_heap = [] + + def heappush_max_func(max_heap): + for item in range(OBJECT_COUNT): + heapq.heappush_max(max_heap, item) + + self.run_concurrently( + worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS + ) + self.test_heapq.check_max_invariant(max_heap) + + def test_racing_heappop_max(self): + max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX) + + # Each thread pops (OBJECT_COUNT / NTHREADS) items + self.assertEqual(OBJECT_COUNT % NTHREADS, 0) + per_thread_pop_count = OBJECT_COUNT // NTHREADS + + def heappop_max_func(max_heap, pop_count): + local_list = [] + for _ in range(pop_count): + item = heapq.heappop_max(max_heap) + local_list.append(item) + + # Each local list should be sorted + self.assertTrue(self.is_sorted_descending(local_list)) + + self.run_concurrently( + worker_func=heappop_max_func, + args=(max_heap, per_thread_pop_count), + nthreads=NTHREADS, + ) + self.assertEqual(len(max_heap), 0) + + def test_racing_heappushpop_max(self): + max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX) + pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT) + + def heappushpop_max_func(max_heap, pushpop_items): + for item in pushpop_items: + popped_item = heapq.heappushpop_max(max_heap, item) + self.assertTrue(popped_item >= item) + + self.run_concurrently( + worker_func=heappushpop_max_func, + args=(max_heap, pushpop_items), + nthreads=NTHREADS, + ) + self.assertEqual(len(max_heap), OBJECT_COUNT) + self.test_heapq.check_max_invariant(max_heap) + + def test_racing_heapreplace_max(self): + max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX) + replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT) + + def heapreplace_max_func(max_heap, replace_items): + for item in replace_items: + heapq.heapreplace_max(max_heap, item) + + self.run_concurrently( + worker_func=heapreplace_max_func, + args=(max_heap, replace_items), + nthreads=NTHREADS, + ) + self.assertEqual(len(max_heap), OBJECT_COUNT) + self.test_heapq.check_max_invariant(max_heap) + + @staticmethod + def is_sorted_ascending(lst): + """ + Check if the list is sorted in ascending order (non-decreasing). + """ + return all(lst[i - 1] <= lst[i] for i in range(1, len(lst))) + + @staticmethod + def is_sorted_descending(lst): + """ + Check if the list is sorted in descending order (non-increasing). + """ + return all(lst[i - 1] >= lst[i] for i in range(1, len(lst))) + + @staticmethod + def create_heap(size, heap_kind): + """ + Create a min/max heap where elements are in the range (0, size - 1) and + shuffled before heapify. + """ + heap = list(range(OBJECT_COUNT)) + shuffle(heap) + if heap_kind == Heap.MIN: + heapq.heapify(heap) + else: + heapq.heapify_max(heap) + + return heap + + @staticmethod + def create_random_list(a, b, size): + """ + Create a list of random numbers between a and b (inclusive). + """ + return [randint(-a, b) for _ in range(size)] + + def run_concurrently(self, worker_func, args, nthreads): + """ + Run the worker function concurrently in multiple threads. + """ + barrier = Barrier(nthreads) + + def wrapper_func(*args): + # Wait for all threads to reach this point before proceeding. + barrier.wait() + worker_func(*args) + + with threading_helper.catch_threading_exception() as cm: + workers = ( + Thread(target=wrapper_func, args=args) for _ in range(nthreads) + ) + with threading_helper.start_threads(workers): + pass + + # Worker threads should not raise any exceptions + self.assertIsNone(cm.exc_value) + + +if __name__ == "__main__": + unittest.main() |