aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/asyncio/tasks.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/tasks.py')
-rw-r--r--Lib/asyncio/tasks.py30
1 files changed, 26 insertions, 4 deletions
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 2112dd4b99d..a25854cc4bd 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -322,6 +322,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._loop.call_soon(
self.__step, new_exc, context=self._context)
else:
+ futures.future_add_to_awaited_by(result, self)
result._asyncio_future_blocking = False
result.add_done_callback(
self.__wakeup, context=self._context)
@@ -356,6 +357,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self = None # Needed to break cycles when an exception occurs.
def __wakeup(self, future):
+ futures.future_discard_from_awaited_by(future, self)
try:
future.result()
except BaseException as exc:
@@ -502,6 +504,7 @@ async def _wait(fs, timeout, return_when, loop):
if timeout is not None:
timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
counter = len(fs)
+ cur_task = current_task()
def _on_completion(f):
nonlocal counter
@@ -514,9 +517,11 @@ async def _wait(fs, timeout, return_when, loop):
timeout_handle.cancel()
if not waiter.done():
waiter.set_result(None)
+ futures.future_discard_from_awaited_by(f, cur_task)
for f in fs:
f.add_done_callback(_on_completion)
+ futures.future_add_to_awaited_by(f, cur_task)
try:
await waiter
@@ -802,10 +807,19 @@ def gather(*coros_or_futures, return_exceptions=False):
outer.set_result([])
return outer
- def _done_callback(fut):
+ loop = events._get_running_loop()
+ if loop is not None:
+ cur_task = current_task(loop)
+ else:
+ cur_task = None
+
+ def _done_callback(fut, cur_task=cur_task):
nonlocal nfinished
nfinished += 1
+ if cur_task is not None:
+ futures.future_discard_from_awaited_by(fut, cur_task)
+
if outer is None or outer.done():
if not fut.cancelled():
# Mark exception retrieved.
@@ -862,7 +876,6 @@ def gather(*coros_or_futures, return_exceptions=False):
nfuts = 0
nfinished = 0
done_futs = []
- loop = None
outer = None # bpo-46672
for arg in coros_or_futures:
if arg not in arg_to_fut:
@@ -875,12 +888,13 @@ def gather(*coros_or_futures, return_exceptions=False):
# can't control it, disable the "destroy pending task"
# warning.
fut._log_destroy_pending = False
-
nfuts += 1
arg_to_fut[arg] = fut
if fut.done():
done_futs.append(fut)
else:
+ if cur_task is not None:
+ futures.future_add_to_awaited_by(fut, cur_task)
fut.add_done_callback(_done_callback)
else:
@@ -940,7 +954,15 @@ def shield(arg):
loop = futures._get_loop(inner)
outer = loop.create_future()
- def _inner_done_callback(inner):
+ if loop is not None and (cur_task := current_task(loop)) is not None:
+ futures.future_add_to_awaited_by(inner, cur_task)
+ else:
+ cur_task = None
+
+ def _inner_done_callback(inner, cur_task=cur_task):
+ if cur_task is not None:
+ futures.future_discard_from_awaited_by(inner, cur_task)
+
if outer.cancelled():
if not inner.cancelled():
# Mark inner's result as retrieved.