| """Unit tests for contextlib.py, and other context managers.""" | |
| import sys | |
| import tempfile | |
| import unittest | |
| from contextlib import * # Tests __all__ | |
| from test import test_support | |
| try: | |
| import threading | |
| except ImportError: | |
| threading = None | |
| class ContextManagerTestCase(unittest.TestCase): | |
| def test_contextmanager_plain(self): | |
| state = [] | |
| @contextmanager | |
| def woohoo(): | |
| state.append(1) | |
| yield 42 | |
| state.append(999) | |
| with woohoo() as x: | |
| self.assertEqual(state, [1]) | |
| self.assertEqual(x, 42) | |
| state.append(x) | |
| self.assertEqual(state, [1, 42, 999]) | |
| def test_contextmanager_finally(self): | |
| state = [] | |
| @contextmanager | |
| def woohoo(): | |
| state.append(1) | |
| try: | |
| yield 42 | |
| finally: | |
| state.append(999) | |
| with self.assertRaises(ZeroDivisionError): | |
| with woohoo() as x: | |
| self.assertEqual(state, [1]) | |
| self.assertEqual(x, 42) | |
| state.append(x) | |
| raise ZeroDivisionError() | |
| self.assertEqual(state, [1, 42, 999]) | |
| def test_contextmanager_no_reraise(self): | |
| @contextmanager | |
| def whee(): | |
| yield | |
| ctx = whee() | |
| ctx.__enter__() | |
| # Calling __exit__ should not result in an exception | |
| self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) | |
| def test_contextmanager_trap_yield_after_throw(self): | |
| @contextmanager | |
| def whoo(): | |
| try: | |
| yield | |
| except: | |
| yield | |
| ctx = whoo() | |
| ctx.__enter__() | |
| self.assertRaises( | |
| RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None | |
| ) | |
| def test_contextmanager_except(self): | |
| state = [] | |
| @contextmanager | |
| def woohoo(): | |
| state.append(1) | |
| try: | |
| yield 42 | |
| except ZeroDivisionError, e: | |
| state.append(e.args[0]) | |
| self.assertEqual(state, [1, 42, 999]) | |
| with woohoo() as x: | |
| self.assertEqual(state, [1]) | |
| self.assertEqual(x, 42) | |
| state.append(x) | |
| raise ZeroDivisionError(999) | |
| self.assertEqual(state, [1, 42, 999]) | |
| def _create_contextmanager_attribs(self): | |
| def attribs(**kw): | |
| def decorate(func): | |
| for k,v in kw.items(): | |
| setattr(func,k,v) | |
| return func | |
| return decorate | |
| @contextmanager | |
| @attribs(foo='bar') | |
| def baz(spam): | |
| """Whee!""" | |
| return baz | |
| def test_contextmanager_attribs(self): | |
| baz = self._create_contextmanager_attribs() | |
| self.assertEqual(baz.__name__,'baz') | |
| self.assertEqual(baz.foo, 'bar') | |
| @unittest.skipIf(sys.flags.optimize >= 2, | |
| "Docstrings are omitted with -O2 and above") | |
| def test_contextmanager_doc_attrib(self): | |
| baz = self._create_contextmanager_attribs() | |
| self.assertEqual(baz.__doc__, "Whee!") | |
| class NestedTestCase(unittest.TestCase): | |
| # XXX This needs more work | |
| def test_nested(self): | |
| @contextmanager | |
| def a(): | |
| yield 1 | |
| @contextmanager | |
| def b(): | |
| yield 2 | |
| @contextmanager | |
| def c(): | |
| yield 3 | |
| with nested(a(), b(), c()) as (x, y, z): | |
| self.assertEqual(x, 1) | |
| self.assertEqual(y, 2) | |
| self.assertEqual(z, 3) | |
| def test_nested_cleanup(self): | |
| state = [] | |
| @contextmanager | |
| def a(): | |
| state.append(1) | |
| try: | |
| yield 2 | |
| finally: | |
| state.append(3) | |
| @contextmanager | |
| def b(): | |
| state.append(4) | |
| try: | |
| yield 5 | |
| finally: | |
| state.append(6) | |
| with self.assertRaises(ZeroDivisionError): | |
| with nested(a(), b()) as (x, y): | |
| state.append(x) | |
| state.append(y) | |
| 1 // 0 | |
| self.assertEqual(state, [1, 4, 2, 5, 6, 3]) | |
| def test_nested_right_exception(self): | |
| @contextmanager | |
| def a(): | |
| yield 1 | |
| class b(object): | |
| def __enter__(self): | |
| return 2 | |
| def __exit__(self, *exc_info): | |
| try: | |
| raise Exception() | |
| except: | |
| pass | |
| with self.assertRaises(ZeroDivisionError): | |
| with nested(a(), b()) as (x, y): | |
| 1 // 0 | |
| self.assertEqual((x, y), (1, 2)) | |
| def test_nested_b_swallows(self): | |
| @contextmanager | |
| def a(): | |
| yield | |
| @contextmanager | |
| def b(): | |
| try: | |
| yield | |
| except: | |
| # Swallow the exception | |
| pass | |
| try: | |
| with nested(a(), b()): | |
| 1 // 0 | |
| except ZeroDivisionError: | |
| self.fail("Didn't swallow ZeroDivisionError") | |
| def test_nested_break(self): | |
| @contextmanager | |
| def a(): | |
| yield | |
| state = 0 | |
| while True: | |
| state += 1 | |
| with nested(a(), a()): | |
| break | |
| state += 10 | |
| self.assertEqual(state, 1) | |
| def test_nested_continue(self): | |
| @contextmanager | |
| def a(): | |
| yield | |
| state = 0 | |
| while state < 3: | |
| state += 1 | |
| with nested(a(), a()): | |
| continue | |
| state += 10 | |
| self.assertEqual(state, 3) | |
| def test_nested_return(self): | |
| @contextmanager | |
| def a(): | |
| try: | |
| yield | |
| except: | |
| pass | |
| def foo(): | |
| with nested(a(), a()): | |
| return 1 | |
| return 10 | |
| self.assertEqual(foo(), 1) | |
| class ClosingTestCase(unittest.TestCase): | |
| # XXX This needs more work | |
| def test_closing(self): | |
| state = [] | |
| class C: | |
| def close(self): | |
| state.append(1) | |
| x = C() | |
| self.assertEqual(state, []) | |
| with closing(x) as y: | |
| self.assertEqual(x, y) | |
| self.assertEqual(state, [1]) | |
| def test_closing_error(self): | |
| state = [] | |
| class C: | |
| def close(self): | |
| state.append(1) | |
| x = C() | |
| self.assertEqual(state, []) | |
| with self.assertRaises(ZeroDivisionError): | |
| with closing(x) as y: | |
| self.assertEqual(x, y) | |
| 1 // 0 | |
| self.assertEqual(state, [1]) | |
| class FileContextTestCase(unittest.TestCase): | |
| def testWithOpen(self): | |
| tfn = tempfile.mktemp() | |
| try: | |
| f = None | |
| with open(tfn, "w") as f: | |
| self.assertFalse(f.closed) | |
| f.write("Booh\n") | |
| self.assertTrue(f.closed) | |
| f = None | |
| with self.assertRaises(ZeroDivisionError): | |
| with open(tfn, "r") as f: | |
| self.assertFalse(f.closed) | |
| self.assertEqual(f.read(), "Booh\n") | |
| 1 // 0 | |
| self.assertTrue(f.closed) | |
| finally: | |
| test_support.unlink(tfn) | |
| @unittest.skipUnless(threading, 'Threading required for this test.') | |
| class LockContextTestCase(unittest.TestCase): | |
| def boilerPlate(self, lock, locked): | |
| self.assertFalse(locked()) | |
| with lock: | |
| self.assertTrue(locked()) | |
| self.assertFalse(locked()) | |
| with self.assertRaises(ZeroDivisionError): | |
| with lock: | |
| self.assertTrue(locked()) | |
| 1 // 0 | |
| self.assertFalse(locked()) | |
| def testWithLock(self): | |
| lock = threading.Lock() | |
| self.boilerPlate(lock, lock.locked) | |
| def testWithRLock(self): | |
| lock = threading.RLock() | |
| self.boilerPlate(lock, lock._is_owned) | |
| def testWithCondition(self): | |
| lock = threading.Condition() | |
| def locked(): | |
| return lock._is_owned() | |
| self.boilerPlate(lock, locked) | |
| def testWithSemaphore(self): | |
| lock = threading.Semaphore() | |
| def locked(): | |
| if lock.acquire(False): | |
| lock.release() | |
| return False | |
| else: | |
| return True | |
| self.boilerPlate(lock, locked) | |
| def testWithBoundedSemaphore(self): | |
| lock = threading.BoundedSemaphore() | |
| def locked(): | |
| if lock.acquire(False): | |
| lock.release() | |
| return False | |
| else: | |
| return True | |
| self.boilerPlate(lock, locked) | |
| # This is needed to make the test actually run under regrtest.py! | |
| def test_main(): | |
| with test_support.check_warnings(("With-statements now directly support " | |
| "multiple context managers", | |
| DeprecationWarning)): | |
| test_support.run_unittest(__name__) | |
| if __name__ == "__main__": | |
| test_main() |