summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--extmod/asyncio/core.py5
-rw-r--r--extmod/asyncio/funcs.py49
-rw-r--r--tests/extmod/asyncio_gather_finished_early.py65
3 files changed, 102 insertions, 17 deletions
diff --git a/extmod/asyncio/core.py b/extmod/asyncio/core.py
index 214cc52f45..e5af3038f7 100644
--- a/extmod/asyncio/core.py
+++ b/extmod/asyncio/core.py
@@ -219,6 +219,11 @@ def run_until_complete(main_task=None):
elif t.state is None:
# Task is already finished and nothing await'ed on the task,
# so call the exception handler.
+
+ # Save exception raised by the coro for later use.
+ t.data = exc
+
+ # Create exception context and call the exception handler.
_exc_context["exception"] = exc
_exc_context["future"] = t
Loop.call_exception_handler(_exc_context)
diff --git a/extmod/asyncio/funcs.py b/extmod/asyncio/funcs.py
index 599091dfbd..3ef8a76b1d 100644
--- a/extmod/asyncio/funcs.py
+++ b/extmod/asyncio/funcs.py
@@ -63,9 +63,6 @@ class _Remove:
# async
def gather(*aws, return_exceptions=False):
- if not aws:
- return []
-
def done(t, er):
# Sub-task "t" has finished, with exception "er".
nonlocal state
@@ -86,26 +83,39 @@ def gather(*aws, return_exceptions=False):
# Gather waiting is done, schedule the main gather task.
core._task_queue.push(gather_task)
+ # Prepare the sub-tasks for the gather.
+ # The `state` variable counts the number of tasks to wait for, and can be negative
+ # if the gather should not run at all (because a task already had an exception).
ts = [core._promote_to_task(aw) for aw in aws]
+ state = 0
for i in range(len(ts)):
- if ts[i].state is not True:
- # Task is not running, gather not currently supported for this case.
+ if ts[i].state is True:
+ # Task is running, register the callback to call when the task is done.
+ ts[i].state = done
+ state += 1
+ elif not ts[i].state:
+ # Task finished already.
+ if not isinstance(ts[i].data, StopIteration):
+ # Task finished by raising an exception.
+ if not return_exceptions:
+ # Do not run this gather at all.
+ state = -len(ts)
+ else:
+ # Task being waited on, gather not currently supported for this case.
raise RuntimeError("can't gather")
- # Register the callback to call when the task is done.
- ts[i].state = done
# Set the state for execution of the gather.
gather_task = core.cur_task
- state = len(ts)
cancel_all = False
- # Wait for the a sub-task to need attention.
- gather_task.data = _Remove
- try:
- yield
- except core.CancelledError as er:
- cancel_all = True
- state = er
+ # Wait for a sub-task to need attention (if there are any to wait for).
+ if state > 0:
+ gather_task.data = _Remove
+ try:
+ yield
+ except core.CancelledError as er:
+ cancel_all = True
+ state = er
# Clean up tasks.
for i in range(len(ts)):
@@ -118,8 +128,13 @@ def gather(*aws, return_exceptions=False):
# Sub-task ran to completion, get its return value.
ts[i] = ts[i].data.value
else:
- # Sub-task had an exception with return_exceptions==True, so get its exception.
- ts[i] = ts[i].data
+ # Sub-task had an exception.
+ if return_exceptions:
+ # Get the sub-task exception to return in the list of return values.
+ ts[i] = ts[i].data
+ elif isinstance(state, int):
+ # Raise the sub-task exception, if there is not already an exception to raise.
+ state = ts[i].data
# Either this gather was cancelled, or one of the sub-tasks raised an exception with
# return_exceptions==False, so reraise the exception here.
diff --git a/tests/extmod/asyncio_gather_finished_early.py b/tests/extmod/asyncio_gather_finished_early.py
new file mode 100644
index 0000000000..030e79e357
--- /dev/null
+++ b/tests/extmod/asyncio_gather_finished_early.py
@@ -0,0 +1,65 @@
+# Test asyncio.gather() when a task is already finished before the gather starts.
+
+try:
+ import asyncio
+except ImportError:
+ print("SKIP")
+ raise SystemExit
+
+
+# CPython and MicroPython differ in when they signal (and print) that a task raised an
+# uncaught exception. So define an empty custom_handler() to suppress this output.
+def custom_handler(loop, context):
+ pass
+
+
+async def task_that_finishes_early(id, event, fail):
+ print("task_that_finishes_early", id)
+ event.set()
+ if fail:
+ raise ValueError("intentional exception", id)
+
+
+async def task_that_runs():
+ for i in range(5):
+ print("task_that_runs", i)
+ await asyncio.sleep(0)
+
+
+async def main(start_task_that_runs, task_fail, return_exceptions):
+ print("== start", start_task_that_runs, task_fail, return_exceptions)
+
+ # Set exception handler to suppress exception output.
+ loop = asyncio.get_event_loop()
+ loop.set_exception_handler(custom_handler)
+
+ # Create tasks.
+ event_a = asyncio.Event()
+ event_b = asyncio.Event()
+ tasks = []
+ if start_task_that_runs:
+ tasks.append(asyncio.create_task(task_that_runs()))
+ tasks.append(asyncio.create_task(task_that_finishes_early("a", event_a, task_fail)))
+ tasks.append(asyncio.create_task(task_that_finishes_early("b", event_b, task_fail)))
+
+ # Make sure task_that_finishes_early() are both done, before calling gather().
+ await event_a.wait()
+ await event_b.wait()
+
+ # Gather the tasks.
+ try:
+ result = "complete", await asyncio.gather(*tasks, return_exceptions=return_exceptions)
+ except Exception as er:
+ result = "exception", er, start_task_that_runs and tasks[0].done()
+
+ # Wait for the final task to finish (if it was started).
+ if start_task_that_runs:
+ await tasks[0]
+
+ # Print results.
+ print(result)
+
+
+# Run the test in the 8 different combinations of its arguments.
+for i in range(8):
+ asyncio.run(main(bool(i & 4), bool(i & 2), bool(i & 1)))