diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_tasks.py')
-rw-r--r-- | Lib/test/test_asyncio/test_tasks.py | 109 |
1 files changed, 99 insertions, 10 deletions
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 26e4f643d1a..96d2658cb4c 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -2,10 +2,11 @@ import collections import contextlib +import contextvars import functools import gc import io -import os +import random import re import sys import types @@ -1377,9 +1378,9 @@ class BaseTaskTests: self.cb_added = False super().__init__(*args, **kwds) - def add_done_callback(self, fn): + def add_done_callback(self, *args, **kwargs): self.cb_added = True - super().add_done_callback(fn) + super().add_done_callback(*args, **kwargs) fut = Fut(loop=self.loop) result = None @@ -2091,7 +2092,7 @@ class BaseTaskTests: @mock.patch('asyncio.base_events.logger') def test_error_in_call_soon(self, m_log): - def call_soon(callback, *args): + def call_soon(callback, *args, **kwargs): raise ValueError self.loop.call_soon = call_soon @@ -2176,6 +2177,91 @@ class BaseTaskTests: self.loop.run_until_complete(coro()) + def test_context_1(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def sub(): + await asyncio.sleep(0.01, loop=loop) + self.assertEqual(cvar.get(), 'nope') + cvar.set('something else') + + async def main(): + self.assertEqual(cvar.get(), 'nope') + subtask = self.new_task(loop, sub()) + cvar.set('yes') + self.assertEqual(cvar.get(), 'yes') + await subtask + self.assertEqual(cvar.get(), 'yes') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + def test_context_2(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def main(): + def fut_on_done(fut): + # This change must not pollute the context + # of the "main()" task. + cvar.set('something else') + + self.assertEqual(cvar.get(), 'nope') + + for j in range(2): + fut = self.new_future(loop) + fut.add_done_callback(fut_on_done) + cvar.set(f'yes{j}') + loop.call_soon(fut.set_result, None) + await fut + self.assertEqual(cvar.get(), f'yes{j}') + + for i in range(3): + # Test that task passed its context to add_done_callback: + cvar.set(f'yes{i}-{j}') + await asyncio.sleep(0.001, loop=loop) + self.assertEqual(cvar.get(), f'yes{i}-{j}') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + + def test_context_3(self): + # Run 100 Tasks in parallel, each modifying cvar. + + cvar = contextvars.ContextVar('cvar', default=-1) + + async def sub(num): + for i in range(10): + cvar.set(num + i) + await asyncio.sleep( + random.uniform(0.001, 0.05), loop=loop) + self.assertEqual(cvar.get(), num + i) + + async def main(): + tasks = [] + for i in range(100): + task = loop.create_task(sub(random.randint(0, 10))) + tasks.append(task) + + await asyncio.gather(*tasks, loop=loop) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + finally: + loop.close() + + self.assertEqual(cvar.get(), -1) + def add_subclass_tests(cls): BaseTask = cls.Task @@ -2193,9 +2279,9 @@ def add_subclass_tests(cls): self.calls['_schedule_callbacks'] += 1 return super()._schedule_callbacks() - def add_done_callback(self, *args): + def add_done_callback(self, *args, **kwargs): self.calls['add_done_callback'] += 1 - return super().add_done_callback(*args) + return super().add_done_callback(*args, **kwargs) class Task(CommonFuture, BaseTask): def _step(self, *args): @@ -2486,10 +2572,13 @@ class PyIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests): @unittest.skipUnless(hasattr(tasks, '_c_register_task'), 'requires the C _asyncio module') class CIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests): - _register_task = staticmethod(tasks._c_register_task) - _unregister_task = staticmethod(tasks._c_unregister_task) - _enter_task = staticmethod(tasks._c_enter_task) - _leave_task = staticmethod(tasks._c_leave_task) + if hasattr(tasks, '_c_register_task'): + _register_task = staticmethod(tasks._c_register_task) + _unregister_task = staticmethod(tasks._c_unregister_task) + _enter_task = staticmethod(tasks._c_enter_task) + _leave_task = staticmethod(tasks._c_leave_task) + else: + _register_task = _unregister_task = _enter_task = _leave_task = None class BaseCurrentLoopTests: |