diff options
Diffstat (limited to 'extmod/uasyncio/funcs.py')
-rw-r--r-- | extmod/uasyncio/funcs.py | 76 |
1 files changed, 61 insertions, 15 deletions
diff --git a/extmod/uasyncio/funcs.py b/extmod/uasyncio/funcs.py index 0ce48b015c..b19edeb6ef 100644 --- a/extmod/uasyncio/funcs.py +++ b/extmod/uasyncio/funcs.py @@ -53,22 +53,68 @@ def wait_for_ms(aw, timeout): return wait_for(aw, timeout, core.sleep_ms) +class _Remove: + @staticmethod + def remove(t): + pass + + async def gather(*aws, return_exceptions=False): + def done(t, er): + nonlocal state + if type(state) is not int: + # A sub-task already raised an exception, so do nothing. + return + elif not return_exceptions and not isinstance(er, StopIteration): + # A sub-task raised an exception, indicate that to the gather task. + state = er + else: + state -= 1 + if state: + # Still some sub-tasks running. + return + # Gather waiting is done, schedule the main gather task. + core._task_queue.push_head(gather_task) + ts = [core._promote_to_task(aw) for aw in aws] for i in range(len(ts)): - try: - # TODO handle cancel of gather itself - # if ts[i].coro: - # iter(ts[i]).waiting.push_head(cur_task) - # try: - # yield - # except CancelledError as er: - # # cancel all waiting tasks - # raise er - ts[i] = await ts[i] - except (core.CancelledError, Exception) as er: - if return_exceptions: - ts[i] = er - else: - raise er + if ts[i].state is not True: + # Task is not running, gather not currently supported for this case. + raise RuntimeError("can't gather") + # Register the callback to call when the task is done. + ts[i].state = done + + # Set the state for execution of the gather. + gather_task = core.cur_task + state = len(ts) + cancel_all = False + + # Wait for the a sub-task to need attention. + gather_task.data = _Remove + try: + yield + except core.CancelledError as er: + cancel_all = True + state = er + + # Clean up tasks. + for i in range(len(ts)): + if ts[i].state is done: + # Sub-task is still running, deregister the callback and cancel if needed. + ts[i].state = True + if cancel_all: + ts[i].cancel() + elif isinstance(ts[i].data, StopIteration): + # Sub-task ran to completion, get its return value. + ts[i] = ts[i].data.value + else: + # Sub-task had an exception with return_exceptions==True, so get its exception. + ts[i] = ts[i].data + + # Either this gather was cancelled, or one of the sub-tasks raised an exception with + # return_exceptions==False, so reraise the exception here. + if state is not 0: + raise state + + # Return the list of return values of each sub-task. return ts |