summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/vm.c23
-rw-r--r--tests/basics/try_finally_return.py49
-rw-r--r--tests/basics/with_return.py47
3 files changed, 110 insertions, 9 deletions
diff --git a/py/vm.c b/py/vm.c
index 2611be683a..393b8a1db7 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -637,10 +637,14 @@ unwind_jump:;
unum -= 1;
assert(exc_sp >= exc_stack);
if (MP_TAGPTR_TAG1(exc_sp->val_sp)) {
+ // Getting here the stack looks like:
+ // (..., X, dest_ip)
+ // where X is pointed to by exc_sp->val_sp and in the case
+ // of a "with" block contains the context manager info.
// We're going to run "finally" code as a coroutine
// (not calling it recursively). Set up a sentinel
// on a stack so it can return back to us when it is
- // done (when END_FINALLY reached).
+ // done (when WITH_CLEANUP or END_FINALLY reached).
PUSH((void*)unum); // push number of exception handlers left to unwind
PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_JUMP)); // push sentinel
ip = exc_sp->handler; // get exception handler byte code address
@@ -1016,15 +1020,24 @@ unwind_jump:;
unwind_return:
while (exc_sp >= exc_stack) {
if (MP_TAGPTR_TAG1(exc_sp->val_sp)) {
+ // Getting here the stack looks like:
+ // (..., X, [iter0, iter1, ...,] ret_val)
+ // where X is pointed to by exc_sp->val_sp and in the case
+ // of a "with" block contains the context manager info.
+ // There may be 0 or more for-iterators between X and the
+ // return value, and these must be removed before control can
+ // pass to the finally code. We simply copy the ret_value down
+ // over these iterators, if they exist. If they don't then the
+ // following is a null operation.
+ mp_obj_t *finally_sp = MP_TAGPTR_PTR(exc_sp->val_sp);
+ finally_sp[1] = sp[0];
+ sp = &finally_sp[1];
// We're going to run "finally" code as a coroutine
// (not calling it recursively). Set up a sentinel
// on a stack so it can return back to us when it is
- // done (when END_FINALLY reached).
+ // done (when WITH_CLEANUP or END_FINALLY reached).
PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_RETURN));
ip = exc_sp->handler;
- // We don't need to do anything with sp, finally is just
- // syntactic sugar for sequential execution??
- // sp =
exc_sp--;
goto dispatch_loop;
}
diff --git a/tests/basics/try_finally_return.py b/tests/basics/try_finally_return.py
index 4adf3f0977..31a507e8d0 100644
--- a/tests/basics/try_finally_return.py
+++ b/tests/basics/try_finally_return.py
@@ -21,3 +21,52 @@ def func3():
print("finally 3")
print(func3())
+
+# for loop within try-finally
+def f():
+ try:
+ for i in [1, 2]:
+ return i
+ finally:
+ print('finally')
+print(f())
+
+# multiple for loops within try-finally
+def f():
+ try:
+ for i in [1, 2]:
+ for j in [3, 4]:
+ return (i, j)
+ finally:
+ print('finally')
+print(f())
+
+# multiple for loops and nested try-finally's
+def f():
+ try:
+ for i in [1, 2]:
+ for j in [3, 4]:
+ try:
+ for k in [5, 6]:
+ for l in [7, 8]:
+ return (i, j, k, l)
+ finally:
+ print('finally 2')
+ finally:
+ print('finally 1')
+print(f())
+
+# multiple for loops that are optimised, and nested try-finally's
+def f():
+ try:
+ for i in range(1, 3):
+ for j in range(3, 5):
+ try:
+ for k in range(5, 7):
+ for l in range(7, 9):
+ return (i, j, k, l)
+ finally:
+ print('finally 2')
+ finally:
+ print('finally 1')
+print(f())
diff --git a/tests/basics/with_return.py b/tests/basics/with_return.py
index cb0135c8b3..fd848f1331 100644
--- a/tests/basics/with_return.py
+++ b/tests/basics/with_return.py
@@ -1,14 +1,53 @@
class CtxMgr:
+ def __init__(self, id):
+ self.id = id
def __enter__(self):
- print("__enter__")
+ print("__enter__", self.id)
return self
def __exit__(self, a, b, c):
- print("__exit__", repr(a), repr(b))
+ print("__exit__", self.id, repr(a), repr(b))
+# simple case
def foo():
- with CtxMgr():
+ with CtxMgr(1):
return 4
-
print(foo())
+
+# for loop within with (iterator needs removing upon return)
+def f():
+ with CtxMgr(1):
+ for i in [1, 2]:
+ return i
+print(f())
+
+# multiple for loops within with
+def f():
+ with CtxMgr(1):
+ for i in [1, 2]:
+ for j in [3, 4]:
+ return (i, j)
+print(f())
+
+# multiple for loops within nested withs
+def f():
+ with CtxMgr(1):
+ for i in [1, 2]:
+ for j in [3, 4]:
+ with CtxMgr(2):
+ for k in [5, 6]:
+ for l in [7, 8]:
+ return (i, j, k, l)
+print(f())
+
+# multiple for loops that are optimised, and nested withs
+def f():
+ with CtxMgr(1):
+ for i in range(1, 3):
+ for j in range(3, 5):
+ with CtxMgr(2):
+ for k in range(5, 7):
+ for l in range(7, 9):
+ return (i, j, k, l)
+print(f())