diff options
-rw-r--r-- | Lib/test/test_asyncio/test_eager_task_factory.py | 77 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_free_threading.py | 31 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_graph.py | 9 |
3 files changed, 110 insertions, 7 deletions
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 10450c11b68..bb0760a6967 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -267,12 +267,33 @@ class EagerTaskFactoryLoopTests: class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() + + @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = getattr(tasks, '_CTask', None) + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() + + + @unittest.skip("skip") def test_issue105987(self): code = """if 1: from _asyncio import _swap_current_task @@ -400,31 +421,83 @@ class BaseEagerTaskFactoryTests(BaseTaskCountingTests): class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): - Task = asyncio.Task + Task = asyncio.tasks._CTask + + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase): - Task = asyncio.Task + Task = asyncio.tasks._CTask + + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): Task = tasks._PyTask + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() + class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): Task = tasks._PyTask + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): Task = getattr(tasks, '_CTask', None) + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() + @unittest.skipUnless(hasattr(tasks, '_CTask'), 'requires the C _asyncio module') class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): Task = getattr(tasks, '_CTask', None) + def setUp(self): + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._current_task + return super().tearDown() + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index d0221d87062..199dbbdda5e 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -40,7 +40,7 @@ class TestFreeThreading: self.assertEqual(task.get_loop(), loop) self.assertFalse(task.done()) - current = self.current_task() + current = asyncio.current_task() self.assertEqual(current.get_loop(), loop) self.assertSetEqual(all_tasks, tasks | {current}) future.set_result(None) @@ -101,8 +101,12 @@ class TestFreeThreading: async def func(): nonlocal task task = asyncio.current_task() - - thread = Thread(target=lambda: asyncio.run(func())) + def runner(): + with asyncio.Runner() as runner: + loop = runner.get_loop() + loop.set_task_factory(self.factory) + runner.run(func()) + thread = Thread(target=runner) thread.start() thread.join() wr = weakref.ref(task) @@ -164,7 +168,15 @@ class TestFreeThreading: class TestPyFreeThreading(TestFreeThreading, TestCase): all_tasks = staticmethod(asyncio.tasks._py_all_tasks) - current_task = staticmethod(asyncio.tasks._py_current_task) + + def setUp(self): + self._old_current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._old_current_task + return super().tearDown() def factory(self, loop, coro, **kwargs): return asyncio.tasks._PyTask(coro, loop=loop, **kwargs) @@ -173,7 +185,16 @@ class TestPyFreeThreading(TestFreeThreading, TestCase): @unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio") class TestCFreeThreading(TestFreeThreading, TestCase): all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None)) - current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None)) + + def setUp(self): + self._old_current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task + return super().setUp() + + def tearDown(self): + asyncio.current_task = asyncio.tasks.current_task = self._old_current_task + return super().tearDown() + def factory(self, loop, coro, **kwargs): return asyncio.tasks._CTask(coro, loop=loop, **kwargs) diff --git a/Lib/test/test_asyncio/test_graph.py b/Lib/test/test_asyncio/test_graph.py index fd2160d4ca3..62f6593c31d 100644 --- a/Lib/test/test_asyncio/test_graph.py +++ b/Lib/test/test_asyncio/test_graph.py @@ -369,6 +369,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase): futures.future_discard_from_awaited_by = futures._c_future_discard_from_awaited_by asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = tasks._c_current_task def tearDown(self): futures = asyncio.futures @@ -390,6 +392,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase): futures.Future = self._Future del self._Future + asyncio.current_task = asyncio.tasks.current_task = self._current_task + @unittest.skipIf( not hasattr(asyncio.futures, "_py_future_add_to_awaited_by"), @@ -414,6 +418,9 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase): futures.future_discard_from_awaited_by = futures._py_future_discard_from_awaited_by asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by + self._current_task = asyncio.current_task + asyncio.current_task = asyncio.tasks.current_task = tasks._py_current_task + def tearDown(self): futures = asyncio.futures @@ -434,3 +441,5 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase): asyncio.Future = self._Future futures.Future = self._Future del self._Future + + asyncio.current_task = asyncio.tasks.current_task = self._current_task |