diff options
Diffstat (limited to 'Lib/test/test_contextlib.py')
-rw-r--r-- | Lib/test/test_contextlib.py | 349 |
1 files changed, 191 insertions, 158 deletions
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 015a0c51951..d6bb5b818e1 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1,15 +1,15 @@ """Unit tests for contextlib.py, and other context managers.""" - import sys -import os -import decimal import tempfile import unittest -import threading from contextlib import * # Tests __all__ from test import support -import warnings +try: + import threading +except ImportError: + threading = None + class ContextManagerTestCase(unittest.TestCase): @@ -35,16 +35,12 @@ class ContextManagerTestCase(unittest.TestCase): yield 42 finally: state.append(999) - try: + with self.assertRaises(ZeroDivisionError): with woohoo() as x: self.assertEqual(state, [1]) self.assertEqual(x, 42) state.append(x) raise ZeroDivisionError() - except ZeroDivisionError: - pass - else: - self.fail("Expected ZeroDivisionError") self.assertEqual(state, [1, 42, 999]) def test_contextmanager_no_reraise(self): @@ -86,7 +82,7 @@ class ContextManagerTestCase(unittest.TestCase): raise ZeroDivisionError(999) self.assertEqual(state, [1, 42, 999]) - def test_contextmanager_attribs(self): + def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): for k,v in kw.items(): @@ -97,131 +93,18 @@ class ContextManagerTestCase(unittest.TestCase): @attribs(foo='bar') def baz(spam): """Whee!""" + return baz + + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() self.assertEqual(baz.__name__,'baz') self.assertEqual(baz.foo, 'bar') - self.assertEqual(baz.__doc__, "Whee!") -class NestedTestCase(unittest.TestCase): - - # XXX This needs more work - - def test_nested(self): - @contextmanager - def a(): - yield 1 - @contextmanager - def b(): - yield 2 - @contextmanager - def c(): - yield 3 - with nested(a(), b(), c()) as (x, y, z): - self.assertEqual(x, 1) - self.assertEqual(y, 2) - self.assertEqual(z, 3) - - def test_nested_cleanup(self): - state = [] - @contextmanager - def a(): - state.append(1) - try: - yield 2 - finally: - state.append(3) - @contextmanager - def b(): - state.append(4) - try: - yield 5 - finally: - state.append(6) - try: - with nested(a(), b()) as (x, y): - state.append(x) - state.append(y) - 1/0 - except ZeroDivisionError: - self.assertEqual(state, [1, 4, 2, 5, 6, 3]) - else: - self.fail("Didn't raise ZeroDivisionError") - - def test_nested_right_exception(self): - state = [] - @contextmanager - def a(): - yield 1 - class b(object): - def __enter__(self): - return 2 - def __exit__(self, *exc_info): - try: - raise Exception() - except: - pass - try: - with nested(a(), b()) as (x, y): - 1/0 - except ZeroDivisionError: - self.assertEqual((x, y), (1, 2)) - except Exception: - self.fail("Reraised wrong exception") - else: - self.fail("Didn't raise ZeroDivisionError") - - def test_nested_b_swallows(self): - @contextmanager - def a(): - yield - @contextmanager - def b(): - try: - yield - except: - # Swallow the exception - pass - try: - with nested(a(), b()): - 1/0 - except ZeroDivisionError: - self.fail("Didn't swallow ZeroDivisionError") - - def test_nested_break(self): - @contextmanager - def a(): - yield - state = 0 - while True: - state += 1 - with nested(a(), a()): - break - state += 10 - self.assertEqual(state, 1) - - def test_nested_continue(self): - @contextmanager - def a(): - yield - state = 0 - while state < 3: - state += 1 - with nested(a(), a()): - continue - state += 10 - self.assertEqual(state, 3) - - def test_nested_return(self): - @contextmanager - def a(): - try: - yield - except: - pass - def foo(): - with nested(a(), a()): - return 1 - return 10 - self.assertEqual(foo(), 1) + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_contextmanager_doc_attrib(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__doc__, "Whee!") class ClosingTestCase(unittest.TestCase): @@ -245,14 +128,11 @@ class ClosingTestCase(unittest.TestCase): state.append(1) x = C() self.assertEqual(state, []) - try: + with self.assertRaises(ZeroDivisionError): with closing(x) as y: self.assertEqual(x, y) - 1/0 - except ZeroDivisionError: - self.assertEqual(state, [1]) - else: - self.fail("Didn't raise ZeroDivisionError") + 1 / 0 + self.assertEqual(state, [1]) class FileContextTestCase(unittest.TestCase): @@ -265,21 +145,16 @@ class FileContextTestCase(unittest.TestCase): f.write("Booh\n") self.assertTrue(f.closed) f = None - try: + with self.assertRaises(ZeroDivisionError): with open(tfn, "r") as f: self.assertFalse(f.closed) self.assertEqual(f.read(), "Booh\n") - 1/0 - except ZeroDivisionError: - self.assertTrue(f.closed) - else: - self.fail("Didn't raise ZeroDivisionError") + 1 / 0 + self.assertTrue(f.closed) finally: - try: - os.remove(tfn) - except os.error: - pass + support.unlink(tfn) +@unittest.skipUnless(threading, 'Threading required for this test.') class LockContextTestCase(unittest.TestCase): def boilerPlate(self, lock, locked): @@ -287,14 +162,11 @@ class LockContextTestCase(unittest.TestCase): with lock: self.assertTrue(locked()) self.assertFalse(locked()) - try: + with self.assertRaises(ZeroDivisionError): with lock: self.assertTrue(locked()) - 1/0 - except ZeroDivisionError: - self.assertFalse(locked()) - else: - self.fail("Didn't raise ZeroDivisionError") + 1 / 0 + self.assertFalse(locked()) def testWithLock(self): lock = threading.Lock() @@ -330,11 +202,172 @@ class LockContextTestCase(unittest.TestCase): return True self.boilerPlate(lock, locked) + +class mycontext(ContextDecorator): + started = False + exc = None + catch = False + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + return self.catch + + +class TestContextDecorator(unittest.TestCase): + + def test_contextdecorator(self): + context = mycontext() + with context as result: + self.assertIs(result, context) + self.assertTrue(context.started) + + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextdecorator_with_exception(self): + context = mycontext() + + with self.assertRaisesRegex(NameError, 'foo'): + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + context = mycontext() + context.catch = True + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorator(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_decorator_with_exception(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + raise NameError('foo') + + with self.assertRaisesRegex(NameError, 'foo'): + test() + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorating_method(self): + context = mycontext() + + class Test(object): + + @context + def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests are for argument passing when used as a decorator + test = Test() + test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + + + def test_typo_enter(self): + class mycontext(ContextDecorator): + def __unter__(self): + pass + def __exit__(self, *exc): + pass + + with self.assertRaises(AttributeError): + with mycontext(): + pass + + + def test_typo_exit(self): + class mycontext(ContextDecorator): + def __enter__(self): + pass + def __uxit__(self, *exc): + pass + + with self.assertRaises(AttributeError): + with mycontext(): + pass + + + def test_contextdecorator_as_mixin(self): + class somecontext(object): + started = False + exc = None + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + + class mycontext(somecontext, ContextDecorator): + pass + + context = mycontext() + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextmanager_as_decorator(self): + state = [] + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + test('something') + self.assertEqual(state, [1, 'something', 999]) + + # This is needed to make the test actually run under regrtest.py! def test_main(): - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - support.run_unittest(__name__) + support.run_unittest(__name__) if __name__ == "__main__": test_main() |