aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_asyncio/test_tasks.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_tasks.py')
-rw-r--r--Lib/test/test_asyncio/test_tasks.py109
1 files changed, 99 insertions, 10 deletions
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 26e4f643d1a..96d2658cb4c 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -2,10 +2,11 @@
import collections
import contextlib
+import contextvars
import functools
import gc
import io
-import os
+import random
import re
import sys
import types
@@ -1377,9 +1378,9 @@ class BaseTaskTests:
self.cb_added = False
super().__init__(*args, **kwds)
- def add_done_callback(self, fn):
+ def add_done_callback(self, *args, **kwargs):
self.cb_added = True
- super().add_done_callback(fn)
+ super().add_done_callback(*args, **kwargs)
fut = Fut(loop=self.loop)
result = None
@@ -2091,7 +2092,7 @@ class BaseTaskTests:
@mock.patch('asyncio.base_events.logger')
def test_error_in_call_soon(self, m_log):
- def call_soon(callback, *args):
+ def call_soon(callback, *args, **kwargs):
raise ValueError
self.loop.call_soon = call_soon
@@ -2176,6 +2177,91 @@ class BaseTaskTests:
self.loop.run_until_complete(coro())
+ def test_context_1(self):
+ cvar = contextvars.ContextVar('cvar', default='nope')
+
+ async def sub():
+ await asyncio.sleep(0.01, loop=loop)
+ self.assertEqual(cvar.get(), 'nope')
+ cvar.set('something else')
+
+ async def main():
+ self.assertEqual(cvar.get(), 'nope')
+ subtask = self.new_task(loop, sub())
+ cvar.set('yes')
+ self.assertEqual(cvar.get(), 'yes')
+ await subtask
+ self.assertEqual(cvar.get(), 'yes')
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ def test_context_2(self):
+ cvar = contextvars.ContextVar('cvar', default='nope')
+
+ async def main():
+ def fut_on_done(fut):
+ # This change must not pollute the context
+ # of the "main()" task.
+ cvar.set('something else')
+
+ self.assertEqual(cvar.get(), 'nope')
+
+ for j in range(2):
+ fut = self.new_future(loop)
+ fut.add_done_callback(fut_on_done)
+ cvar.set(f'yes{j}')
+ loop.call_soon(fut.set_result, None)
+ await fut
+ self.assertEqual(cvar.get(), f'yes{j}')
+
+ for i in range(3):
+ # Test that task passed its context to add_done_callback:
+ cvar.set(f'yes{i}-{j}')
+ await asyncio.sleep(0.001, loop=loop)
+ self.assertEqual(cvar.get(), f'yes{i}-{j}')
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual(cvar.get(), 'nope')
+
+ def test_context_3(self):
+ # Run 100 Tasks in parallel, each modifying cvar.
+
+ cvar = contextvars.ContextVar('cvar', default=-1)
+
+ async def sub(num):
+ for i in range(10):
+ cvar.set(num + i)
+ await asyncio.sleep(
+ random.uniform(0.001, 0.05), loop=loop)
+ self.assertEqual(cvar.get(), num + i)
+
+ async def main():
+ tasks = []
+ for i in range(100):
+ task = loop.create_task(sub(random.randint(0, 10)))
+ tasks.append(task)
+
+ await asyncio.gather(*tasks, loop=loop)
+
+ loop = asyncio.new_event_loop()
+ try:
+ loop.run_until_complete(main())
+ finally:
+ loop.close()
+
+ self.assertEqual(cvar.get(), -1)
+
def add_subclass_tests(cls):
BaseTask = cls.Task
@@ -2193,9 +2279,9 @@ def add_subclass_tests(cls):
self.calls['_schedule_callbacks'] += 1
return super()._schedule_callbacks()
- def add_done_callback(self, *args):
+ def add_done_callback(self, *args, **kwargs):
self.calls['add_done_callback'] += 1
- return super().add_done_callback(*args)
+ return super().add_done_callback(*args, **kwargs)
class Task(CommonFuture, BaseTask):
def _step(self, *args):
@@ -2486,10 +2572,13 @@ class PyIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests):
@unittest.skipUnless(hasattr(tasks, '_c_register_task'),
'requires the C _asyncio module')
class CIntrospectionTests(unittest.TestCase, BaseTaskIntrospectionTests):
- _register_task = staticmethod(tasks._c_register_task)
- _unregister_task = staticmethod(tasks._c_unregister_task)
- _enter_task = staticmethod(tasks._c_enter_task)
- _leave_task = staticmethod(tasks._c_leave_task)
+ if hasattr(tasks, '_c_register_task'):
+ _register_task = staticmethod(tasks._c_register_task)
+ _unregister_task = staticmethod(tasks._c_unregister_task)
+ _enter_task = staticmethod(tasks._c_enter_task)
+ _leave_task = staticmethod(tasks._c_leave_task)
+ else:
+ _register_task = _unregister_task = _enter_task = _leave_task = None
class BaseCurrentLoopTests: