diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_taskgroups.py')
-rw-r--r-- | Lib/test/test_asyncio/test_taskgroups.py | 62 |
1 files changed, 58 insertions, 4 deletions
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index c47bf4ec9ed..870fa8dbbf2 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -1,6 +1,7 @@ # Adapted with permission from the EdgeDB project; # license: PSFL. +import weakref import sys import gc import asyncio @@ -38,7 +39,25 @@ def no_other_refs(): return [coro] -class TestTaskGroup(unittest.IsolatedAsyncioTestCase): +def set_gc_state(enabled): + was_enabled = gc.isenabled() + if enabled: + gc.enable() + else: + gc.disable() + return was_enabled + + +@contextlib.contextmanager +def disable_gc(): + was_enabled = set_gc_state(enabled=False) + try: + yield + finally: + set_gc_state(enabled=was_enabled) + + +class BaseTestTaskGroup: async def test_taskgroup_01(self): @@ -832,15 +851,15 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): with self.assertRaisesRegex(RuntimeError, "has not been entered"): tg.create_task(coro) - def test_coro_closed_when_tg_closed(self): + async def test_coro_closed_when_tg_closed(self): async def run_coro_after_tg_closes(): async with taskgroups.TaskGroup() as tg: pass coro = asyncio.sleep(0) with self.assertRaisesRegex(RuntimeError, "is finished"): tg.create_task(coro) - loop = asyncio.get_event_loop() - loop.run_until_complete(run_coro_after_tg_closes()) + + await run_coro_after_tg_closes() async def test_cancelling_level_preserved(self): async def raise_after(t, e): @@ -965,6 +984,30 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertIsInstance(exc, _Done) self.assertListEqual(gc.get_referrers(exc), no_other_refs()) + + async def test_exception_refcycles_parent_task_wr(self): + """Test that TaskGroup deletes self._parent_task and create_task() deletes task""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + async def coro_fn(): + async with tg: + raise _Done + + with disable_gc(): + try: + async with asyncio.TaskGroup() as tg2: + task_wr = weakref.ref(tg2.create_task(coro_fn())) + except* _Done as excs: + exc = excs.exceptions[0].exceptions[0] + + self.assertIsNone(task_wr()) + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), no_other_refs()) + async def test_exception_refcycles_propagate_cancellation_error(self): """Test that TaskGroup deletes propagate_cancellation_error""" tg = asyncio.TaskGroup() @@ -998,5 +1041,16 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertListEqual(gc.get_referrers(exc), no_other_refs()) +class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): + loop_factory = asyncio.EventLoop + +class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): + @staticmethod + def loop_factory(): + loop = asyncio.EventLoop() + loop.set_task_factory(asyncio.eager_task_factory) + return loop + + if __name__ == "__main__": unittest.main() |