summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--extmod/uasyncio/funcs.py76
-rw-r--r--tests/extmod/uasyncio_gather.py63
-rw-r--r--tests/extmod/uasyncio_gather.py.exp21
-rw-r--r--tests/extmod/uasyncio_gather_notimpl.py53
-rw-r--r--tests/extmod/uasyncio_gather_notimpl.py.exp14
5 files changed, 202 insertions, 25 deletions
diff --git a/extmod/uasyncio/funcs.py b/extmod/uasyncio/funcs.py
index 0ce48b015c..b19edeb6ef 100644
--- a/extmod/uasyncio/funcs.py
+++ b/extmod/uasyncio/funcs.py
@@ -53,22 +53,68 @@ def wait_for_ms(aw, timeout):
return wait_for(aw, timeout, core.sleep_ms)
+class _Remove:
+ @staticmethod
+ def remove(t):
+ pass
+
+
async def gather(*aws, return_exceptions=False):
+ def done(t, er):
+ nonlocal state
+ if type(state) is not int:
+ # A sub-task already raised an exception, so do nothing.
+ return
+ elif not return_exceptions and not isinstance(er, StopIteration):
+ # A sub-task raised an exception, indicate that to the gather task.
+ state = er
+ else:
+ state -= 1
+ if state:
+ # Still some sub-tasks running.
+ return
+ # Gather waiting is done, schedule the main gather task.
+ core._task_queue.push_head(gather_task)
+
ts = [core._promote_to_task(aw) for aw in aws]
for i in range(len(ts)):
- try:
- # TODO handle cancel of gather itself
- # if ts[i].coro:
- # iter(ts[i]).waiting.push_head(cur_task)
- # try:
- # yield
- # except CancelledError as er:
- # # cancel all waiting tasks
- # raise er
- ts[i] = await ts[i]
- except (core.CancelledError, Exception) as er:
- if return_exceptions:
- ts[i] = er
- else:
- raise er
+ if ts[i].state is not True:
+ # Task is not running, gather not currently supported for this case.
+ raise RuntimeError("can't gather")
+ # Register the callback to call when the task is done.
+ ts[i].state = done
+
+ # Set the state for execution of the gather.
+ gather_task = core.cur_task
+ state = len(ts)
+ cancel_all = False
+
+ # Wait for the a sub-task to need attention.
+ gather_task.data = _Remove
+ try:
+ yield
+ except core.CancelledError as er:
+ cancel_all = True
+ state = er
+
+ # Clean up tasks.
+ for i in range(len(ts)):
+ if ts[i].state is done:
+ # Sub-task is still running, deregister the callback and cancel if needed.
+ ts[i].state = True
+ if cancel_all:
+ ts[i].cancel()
+ elif isinstance(ts[i].data, StopIteration):
+ # Sub-task ran to completion, get its return value.
+ ts[i] = ts[i].data.value
+ else:
+ # Sub-task had an exception with return_exceptions==True, so get its exception.
+ ts[i] = ts[i].data
+
+ # Either this gather was cancelled, or one of the sub-tasks raised an exception with
+ # return_exceptions==False, so reraise the exception here.
+ if state is not 0:
+ raise state
+
+ # Return the list of return values of each sub-task.
return ts
diff --git a/tests/extmod/uasyncio_gather.py b/tests/extmod/uasyncio_gather.py
index 6053873dbc..718e702be6 100644
--- a/tests/extmod/uasyncio_gather.py
+++ b/tests/extmod/uasyncio_gather.py
@@ -27,9 +27,22 @@ async def task(id):
return id
-async def gather_task():
+async def task_loop(id):
+ print("task_loop start", id)
+ while True:
+ await asyncio.sleep(0.02)
+ print("task_loop loop", id)
+
+
+async def task_raise(id):
+ print("task_raise start", id)
+ await asyncio.sleep(0.02)
+ raise ValueError(id)
+
+
+async def gather_task(t0, t1):
print("gather_task")
- await asyncio.gather(task(1), task(2))
+ await asyncio.gather(t0, t1)
print("gather_task2")
@@ -37,19 +50,49 @@ async def main():
# Simple gather with return values
print(await asyncio.gather(factorial("A", 2), factorial("B", 3), factorial("C", 4)))
+ print("====")
+
# Test return_exceptions, where one task is cancelled and the other finishes normally
tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
tasks[0].cancel()
print(await asyncio.gather(*tasks, return_exceptions=True))
- # Cancel a multi gather
- # TODO doesn't work, Task should not forward cancellation from gather to sub-task
- # but rather CancelledError should cancel the gather directly, which will then cancel
- # all sub-tasks explicitly
- # t = asyncio.create_task(gather_task())
- # await asyncio.sleep(0.01)
- # t.cancel()
- # await asyncio.sleep(0.01)
+ print("====")
+
+ # Test return_exceptions, where one task raises an exception and the other finishes normally.
+ tasks = [asyncio.create_task(task(1)), asyncio.create_task(task_raise(2))]
+ print(await asyncio.gather(*tasks, return_exceptions=True))
+
+ print("====")
+
+ # Test case where one task raises an exception and other task keeps running.
+ tasks = [asyncio.create_task(task_loop(1)), asyncio.create_task(task_raise(2))]
+ try:
+ await asyncio.gather(*tasks)
+ except ValueError as er:
+ print(repr(er))
+ print(tasks[0].done(), tasks[1].done())
+ for t in tasks:
+ t.cancel()
+ await asyncio.sleep(0.04)
+
+ print("====")
+
+ # Test case where both tasks raise an exception.
+ tasks = [asyncio.create_task(task_raise(1)), asyncio.create_task(task_raise(2))]
+ try:
+ await asyncio.gather(*tasks)
+ except ValueError as er:
+ print(repr(er))
+ print(tasks[0].done(), tasks[1].done())
+
+ print("====")
+
+ # Cancel a multi gather.
+ t = asyncio.create_task(gather_task(task(1), task(2)))
+ await asyncio.sleep(0.01)
+ t.cancel()
+ await asyncio.sleep(0.04)
asyncio.run(main())
diff --git a/tests/extmod/uasyncio_gather.py.exp b/tests/extmod/uasyncio_gather.py.exp
index 95310bbe1c..9b30c36b67 100644
--- a/tests/extmod/uasyncio_gather.py.exp
+++ b/tests/extmod/uasyncio_gather.py.exp
@@ -8,6 +8,27 @@ Task B: factorial(3) = 6
Task C: Compute factorial(4)...
Task C: factorial(4) = 24
[2, 6, 24]
+====
start 2
end 2
[CancelledError(), 2]
+====
+start 1
+task_raise start 2
+end 1
+[1, ValueError(2,)]
+====
+task_loop start 1
+task_raise start 2
+task_loop loop 1
+ValueError(2,)
+False True
+====
+task_raise start 1
+task_raise start 2
+ValueError(1,)
+True True
+====
+gather_task
+start 1
+start 2
diff --git a/tests/extmod/uasyncio_gather_notimpl.py b/tests/extmod/uasyncio_gather_notimpl.py
new file mode 100644
index 0000000000..3ebab9bad6
--- /dev/null
+++ b/tests/extmod/uasyncio_gather_notimpl.py
@@ -0,0 +1,53 @@
+# Test uasyncio.gather() function, features that are not implemented.
+
+try:
+ import uasyncio as asyncio
+except ImportError:
+ try:
+ import asyncio
+ except ImportError:
+ print("SKIP")
+ raise SystemExit
+
+
+def custom_handler(loop, context):
+ print(repr(context["exception"]))
+
+
+async def task(id):
+ print("task start", id)
+ await asyncio.sleep(0.01)
+ print("task end", id)
+ return id
+
+
+async def gather_task(t0, t1):
+ print("gather_task start")
+ await asyncio.gather(t0, t1)
+ print("gather_task end")
+
+
+async def main():
+ loop = asyncio.get_event_loop()
+ loop.set_exception_handler(custom_handler)
+
+ # Test case where can't wait on a task being gathered.
+ tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
+ gt = asyncio.create_task(gather_task(tasks[0], tasks[1]))
+ await asyncio.sleep(0) # let the gather start
+ try:
+ await tasks[0] # can't await because this task is part of the gather
+ except RuntimeError as er:
+ print(repr(er))
+ await gt
+
+ print("====")
+
+ # Test case where can't gather on a task being waited.
+ tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
+ asyncio.create_task(gather_task(tasks[0], tasks[1]))
+ await tasks[0] # wait on this task before the gather starts
+ await tasks[1]
+
+
+asyncio.run(main())
diff --git a/tests/extmod/uasyncio_gather_notimpl.py.exp b/tests/extmod/uasyncio_gather_notimpl.py.exp
new file mode 100644
index 0000000000..f21614ffbe
--- /dev/null
+++ b/tests/extmod/uasyncio_gather_notimpl.py.exp
@@ -0,0 +1,14 @@
+task start 1
+task start 2
+gather_task start
+RuntimeError("can't wait",)
+task end 1
+task end 2
+gather_task end
+====
+task start 1
+task start 2
+gather_task start
+RuntimeError("can't gather",)
+task end 1
+task end 2