diff options
Diffstat (limited to 'extmod/asyncio/funcs.py')
-rw-r--r-- | extmod/asyncio/funcs.py | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/extmod/asyncio/funcs.py b/extmod/asyncio/funcs.py index 599091dfbd..3ef8a76b1d 100644 --- a/extmod/asyncio/funcs.py +++ b/extmod/asyncio/funcs.py @@ -63,9 +63,6 @@ class _Remove: # async def gather(*aws, return_exceptions=False): - if not aws: - return [] - def done(t, er): # Sub-task "t" has finished, with exception "er". nonlocal state @@ -86,26 +83,39 @@ def gather(*aws, return_exceptions=False): # Gather waiting is done, schedule the main gather task. core._task_queue.push(gather_task) + # Prepare the sub-tasks for the gather. + # The `state` variable counts the number of tasks to wait for, and can be negative + # if the gather should not run at all (because a task already had an exception). ts = [core._promote_to_task(aw) for aw in aws] + state = 0 for i in range(len(ts)): - if ts[i].state is not True: - # Task is not running, gather not currently supported for this case. + if ts[i].state is True: + # Task is running, register the callback to call when the task is done. + ts[i].state = done + state += 1 + elif not ts[i].state: + # Task finished already. + if not isinstance(ts[i].data, StopIteration): + # Task finished by raising an exception. + if not return_exceptions: + # Do not run this gather at all. + state = -len(ts) + else: + # Task being waited on, gather not currently supported for this case. raise RuntimeError("can't gather") - # Register the callback to call when the task is done. - ts[i].state = done # Set the state for execution of the gather. gather_task = core.cur_task - state = len(ts) cancel_all = False - # Wait for the a sub-task to need attention. - gather_task.data = _Remove - try: - yield - except core.CancelledError as er: - cancel_all = True - state = er + # Wait for a sub-task to need attention (if there are any to wait for). + if state > 0: + gather_task.data = _Remove + try: + yield + except core.CancelledError as er: + cancel_all = True + state = er # Clean up tasks. for i in range(len(ts)): @@ -118,8 +128,13 @@ def gather(*aws, return_exceptions=False): # Sub-task ran to completion, get its return value. ts[i] = ts[i].data.value else: - # Sub-task had an exception with return_exceptions==True, so get its exception. - ts[i] = ts[i].data + # Sub-task had an exception. + if return_exceptions: + # Get the sub-task exception to return in the list of return values. + ts[i] = ts[i].data + elif isinstance(state, int): + # Raise the sub-task exception, if there is not already an exception to raise. + state = ts[i].data # Either this gather was cancelled, or one of the sub-tasks raised an exception with # return_exceptions==False, so reraise the exception here. |