From 8fb307cd650511ba019c4493275cb6684ad308bc Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 22 Jul 2015 13:33:45 +0300 Subject: Issue #24619: New approach for tokenizing async/await. This commit fixes how one-line async-defs and defs are tracked by tokenizer. It allows to correctly parse invalid code such as: >>> async def f(): ... def g(): pass ... async = 10 and valid code such as: >>> async def f(): ... async def g(): pass ... await z As a consequence, is is now possible to have one-line 'async def foo(): await ..' functions: >>> async def foo(): return await bar() --- Lib/test/test_coroutines.py | 226 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 218 insertions(+), 8 deletions(-) (limited to 'Lib/test/test_coroutines.py') diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py index 9d97123b82e..3ba2f2383c8 100644 --- a/Lib/test/test_coroutines.py +++ b/Lib/test/test_coroutines.py @@ -67,11 +67,11 @@ def silence_coro_gc(): class AsyncBadSyntaxTest(unittest.TestCase): def test_badsyntax_1(self): - with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + with self.assertRaisesRegex(SyntaxError, "'await' outside"): import test.badsyntax_async1 def test_badsyntax_2(self): - with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + with self.assertRaisesRegex(SyntaxError, "'await' outside"): import test.badsyntax_async2 def test_badsyntax_3(self): @@ -103,10 +103,6 @@ class AsyncBadSyntaxTest(unittest.TestCase): import test.badsyntax_async8 def test_badsyntax_9(self): - with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): - import test.badsyntax_async9 - - def test_badsyntax_10(self): ns = {} for comp in {'(await a for a in b)', '[await a for a in b]', @@ -116,6 +112,221 @@ class AsyncBadSyntaxTest(unittest.TestCase): with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'): exec('async def f():\n\t{}'.format(comp), ns, ns) + def test_badsyntax_10(self): + # Tests for issue 24619 + + samples = [ + """async def foo(): + def bar(): pass + await = 1 + """, + + """async def foo(): + + def bar(): pass + await = 1 + """, + + """async def foo(): + def bar(): pass + if 1: + await = 1 + """, + + """def foo(): + async def bar(): pass + if 1: + await a + """, + + """def foo(): + async def bar(): pass + await a + """, + + """def foo(): + def baz(): pass + async def bar(): pass + await a + """, + + """def foo(): + def baz(): pass + # 456 + async def bar(): pass + # 123 + await a + """, + + """async def foo(): + def baz(): pass + # 456 + async def bar(): pass + # 123 + await = 2 + """, + + """def foo(): + + def baz(): pass + + async def bar(): pass + + await a + """, + + """async def foo(): + + def baz(): pass + + async def bar(): pass + + await = 2 + """, + + """async def foo(): + def async(): pass + """, + + """async def foo(): + def await(): pass + """, + + """async def foo(): + def bar(): + await + """, + + """async def foo(): + return lambda async: await + """, + + """async def foo(): + return lambda a: await + """, + + """async def foo(a: await b): + pass + """, + + """def baz(): + async def foo(a: await b): + pass + """, + + """async def foo(async): + pass + """, + + """async def foo(): + def bar(): + def baz(): + async = 1 + """, + + """async def foo(): + def bar(): + def baz(): + pass + async = 1 + """, + + """def foo(): + async def bar(): + + async def baz(): + pass + + def baz(): + 42 + + async = 1 + """, + + """async def foo(): + def bar(): + def baz(): + pass\nawait foo() + """, + + """def foo(): + def bar(): + async def baz(): + pass\nawait foo() + """, + + """async def foo(await): + pass + """, + + """def foo(): + + async def bar(): pass + + await a + """, + + """def foo(): + async def bar(): + pass\nawait a + """] + + ns = {} + for code in samples: + with self.subTest(code=code), self.assertRaises(SyntaxError): + exec(code, ns, ns) + + def test_goodsyntax_1(self): + # Tests for issue 24619 + + def foo(await): + async def foo(): pass + async def foo(): + pass + return await + 1 + self.assertEqual(foo(10), 11) + + def foo(await): + async def foo(): pass + async def foo(): pass + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + + async def foo(): pass + + async def foo(): pass + + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + """spam""" + async def foo(): \ + pass + # 123 + async def foo(): pass + # 456 + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + def foo(): pass + def foo(): pass + async def bar(): return await_ + await_ = await + try: + bar().send(None) + except StopIteration as ex: + return ex.args[0] + self.assertEqual(foo(42), 42) + + async def f(): + async def g(): pass + await z + self.assertTrue(inspect.iscoroutinefunction(f)) + class TokenizerRegrTest(unittest.TestCase): @@ -461,8 +672,7 @@ class CoroutineTest(unittest.TestCase): class Awaitable: pass - async def foo(): - return (await Awaitable()) + async def foo(): return await Awaitable() with self.assertRaisesRegex( TypeError, "object Awaitable can't be used in 'await' expression"): -- cgit v1.2.3