diff options
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r-- | Lib/test/test_functools.py | 301 |
1 files changed, 261 insertions, 40 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index a713314f160..270cab00756 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,9 +1,11 @@ import functools +import collections import sys import unittest -from test import test_support +from test import support from weakref import proxy import pickle +from random import choice @staticmethod def PythonPartial(func, *args, **keywords): @@ -34,7 +36,7 @@ class TestPartial(unittest.TestCase): self.assertEqual(p(3, 4, b=30, c=40), ((1, 2, 3, 4), dict(a=10, b=30, c=40))) p = self.thetype(map, lambda x: x*10) - self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) + self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) def test_attributes(self): p = self.thetype(capture, 1, 2, a=10, b=20) @@ -45,9 +47,9 @@ class TestPartial(unittest.TestCase): # attributes should not be writable if not isinstance(self.thetype, type): return - self.assertRaises(TypeError, setattr, p, 'func', map) - self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) - self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) + self.assertRaises(AttributeError, setattr, p, 'func', map) + self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) + self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) p = self.thetype(hex) try: @@ -125,7 +127,7 @@ class TestPartial(unittest.TestCase): def test_error_propagation(self): def f(x, y): - x // y + x / y self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0)) self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0) self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) @@ -139,12 +141,38 @@ class TestPartial(unittest.TestCase): self.assertRaises(ReferenceError, getattr, p, 'func') def test_with_bound_and_unbound_methods(self): - data = map(str, range(10)) + data = list(map(str, range(10))) join = self.thetype(str.join, '') self.assertEqual(join(data), '0123456789') join = self.thetype(''.join) self.assertEqual(join(data), '0123456789') + def test_repr(self): + args = (object(), object()) + args_repr = ', '.join(repr(a) for a in args) + kwargs = {'a': object(), 'b': object()} + kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items()) + if self.thetype is functools.partial: + name = 'functools.partial' + else: + name = self.thetype.__name__ + + f = self.thetype(capture) + self.assertEqual('{}({!r})'.format(name, capture), + repr(f)) + + f = self.thetype(capture, *args) + self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr), + repr(f)) + + f = self.thetype(capture, **kwargs) + self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr), + repr(f)) + + f = self.thetype(capture, *args, **kwargs) + self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr), + repr(f)) + def test_pickle(self): f = self.thetype(signature, 'asdf', bar=True) f.add_something_to__dict__ = True @@ -162,6 +190,9 @@ class TestPythonPartial(TestPartial): thetype = PythonPartial + # the python version hasn't a nice repr + def test_repr(self): pass + # the python version isn't picklable def test_pickle(self): pass @@ -181,11 +212,11 @@ class TestUpdateWrapper(unittest.TestCase): self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) def _default_update(self): - def f(): + def f(a:'This is a new annotation'): """This is a test""" pass f.attr = 'This is also a test' - def wrapper(): + def wrapper(b:'This is the prior annotation'): pass functools.update_wrapper(wrapper, f) return wrapper, f @@ -193,8 +224,11 @@ class TestUpdateWrapper(unittest.TestCase): def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) + self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.attr, 'This is also a test') + self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') + self.assertNotIn('b', wrapper.__annotations__) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -213,6 +247,7 @@ class TestUpdateWrapper(unittest.TestCase): self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') self.assertEqual(wrapper.__doc__, None) + self.assertEqual(wrapper.__annotations__, {}) self.assertFalse(hasattr(wrapper, 'attr')) def test_selective_update(self): @@ -232,6 +267,28 @@ class TestUpdateWrapper(unittest.TestCase): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) + def test_missing_attributes(self): + def f(): + pass + def wrapper(): + pass + wrapper.dict_attr = {} + assign = ('attr',) + update = ('dict_attr',) + # Missing attributes on wrapped object are ignored + functools.update_wrapper(wrapper, f, assign, update) + self.assertNotIn('attr', wrapper.__dict__) + self.assertEqual(wrapper.dict_attr, {}) + # Wrapper must have expected attributes for updating + del wrapper.dict_attr + with self.assertRaises(AttributeError): + functools.update_wrapper(wrapper, f, assign, update) + wrapper.dict_attr = 1 + with self.assertRaises(AttributeError): + functools.update_wrapper(wrapper, f, assign, update) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") def test_builtin_update(self): # Test for bug #1576241 def wrapper(): @@ -239,6 +296,7 @@ class TestUpdateWrapper(unittest.TestCase): functools.update_wrapper(wrapper, max) self.assertEqual(wrapper.__name__, 'max') self.assertTrue(wrapper.__doc__.startswith('max(')) + self.assertEqual(wrapper.__annotations__, {}) class TestWraps(TestUpdateWrapper): @@ -297,17 +355,17 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) - class TestReduce(unittest.TestCase): + func = functools.reduce def test_reduce(self): class Squares: - def __init__(self, max): self.max = max self.sofar = [] - def __len__(self): return len(self.sofar) + def __len__(self): + return len(self.sofar) def __getitem__(self, i): if not 0 <= i < self.max: raise IndexError @@ -316,27 +374,66 @@ class TestReduce(unittest.TestCase): self.sofar.append(n*n) n += 1 return self.sofar[i] - - reduce = functools.reduce - self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') + def add(x, y): + return x + y + self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc') self.assertEqual( - reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), + self.func(add, [['a', 'c'], [], ['d', 'w']], []), ['a','c','d','w'] ) - self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) + self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) self.assertEqual( - reduce(lambda x, y: x*y, range(2,21), 1L), - 2432902008176640000L + self.func(lambda x, y: x*y, range(2,21), 1), + 2432902008176640000 ) - self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) - self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) - self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) - self.assertRaises(TypeError, reduce) - self.assertRaises(TypeError, reduce, 42, 42) - self.assertRaises(TypeError, reduce, 42, 42, 42) - self.assertEqual(reduce(42, "1"), "1") # func is never called with one item - self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item - self.assertRaises(TypeError, reduce, 42, (42, 42)) + self.assertEqual(self.func(add, Squares(10)), 285) + self.assertEqual(self.func(add, Squares(10), 0), 285) + self.assertEqual(self.func(add, Squares(0), 0), 0) + self.assertRaises(TypeError, self.func) + self.assertRaises(TypeError, self.func, 42, 42) + self.assertRaises(TypeError, self.func, 42, 42, 42) + self.assertEqual(self.func(42, "1"), "1") # func is never called with one item + self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item + self.assertRaises(TypeError, self.func, 42, (42, 42)) + self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value + self.assertRaises(TypeError, self.func, add, "") + self.assertRaises(TypeError, self.func, add, ()) + self.assertRaises(TypeError, self.func, add, object()) + + class TestFailingIter: + def __iter__(self): + raise RuntimeError + self.assertRaises(RuntimeError, self.func, add, TestFailingIter()) + + self.assertEqual(self.func(add, [], None), None) + self.assertEqual(self.func(add, [], 42), 42) + + class BadSeq: + def __getitem__(self, index): + raise ValueError + self.assertRaises(ValueError, self.func, 42, BadSeq()) + + # Test reduce()'s use of iterators. + def test_iterator_usage(self): + class SequenceClass: + def __init__(self, n): + self.n = n + def __getitem__(self, i): + if 0 <= i < self.n: + return i + else: + raise IndexError + + from operator import add + self.assertEqual(self.func(add, SequenceClass(5)), 10) + self.assertEqual(self.func(add, SequenceClass(5), 42), 52) + self.assertRaises(TypeError, self.func, add, SequenceClass(0)) + self.assertEqual(self.func(add, SequenceClass(0), 42), 42) + self.assertEqual(self.func(add, SequenceClass(1)), 0) + self.assertEqual(self.func(add, SequenceClass(1), 42), 42) + + d = {"one": 1, "two": 2, "three": 3} + self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey(unittest.TestCase): def test_cmp_to_key(self): @@ -350,7 +447,8 @@ class TestCmpToKey(unittest.TestCase): return y - x key = functools.cmp_to_key(mycmp) k = key(10) - self.assertRaises(TypeError, hash(k)) + self.assertRaises(TypeError, hash, k) + self.assertFalse(isinstance(k, collections.Hashable)) class TestTotalOrdering(unittest.TestCase): @@ -421,14 +519,14 @@ class TestTotalOrdering(unittest.TestCase): def test_total_ordering_no_overwrite(self): # new methods should not overwrite existing @functools.total_ordering - class A(str): + class A(int): pass - self.assertTrue(A("a") < A("b")) - self.assertTrue(A("b") > A("a")) - self.assertTrue(A("a") <= A("b")) - self.assertTrue(A("b") >= A("a")) - self.assertTrue(A("b") <= A("b")) - self.assertTrue(A("b") >= A("b")) + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) def test_no_operations_defined(self): with self.assertRaises(ValueError): @@ -452,6 +550,127 @@ class TestTotalOrdering(unittest.TestCase): with self.assertRaises(TypeError): TestTO(8) <= () +class TestLRU(unittest.TestCase): + + def test_lru(self): + def orig(x, y): + return 3*x+y + f = functools.lru_cache(maxsize=20)(orig) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(maxsize, 20) + self.assertEqual(currsize, 0) + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + + domain = range(5) + for i in range(1000): + x, y = choice(domain), choice(domain) + actual = f(x, y) + expected = orig(x, y) + self.assertEqual(actual, expected) + hits, misses, maxsize, currsize = f.cache_info() + self.assertTrue(hits > misses) + self.assertEqual(hits + misses, 1000) + self.assertEqual(currsize, 20) + + f.cache_clear() # test clearing + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + self.assertEqual(currsize, 0) + f(x, y) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # Test bypassing the cache + self.assertIs(f.__wrapped__, orig) + f.__wrapped__(x, y) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # test size zero (which means "never-cache") + @functools.lru_cache(0) + def f(): + nonlocal f_cnt + f_cnt += 1 + return 20 + self.assertEqual(f.cache_info().maxsize, 0) + f_cnt = 0 + for i in range(5): + self.assertEqual(f(), 20) + self.assertEqual(f_cnt, 5) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 5) + self.assertEqual(currsize, 0) + + # test size one + @functools.lru_cache(1) + def f(): + nonlocal f_cnt + f_cnt += 1 + return 20 + self.assertEqual(f.cache_info().maxsize, 1) + f_cnt = 0 + for i in range(5): + self.assertEqual(f(), 20) + self.assertEqual(f_cnt, 1) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 4) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # test size two + @functools.lru_cache(2) + def f(x): + nonlocal f_cnt + f_cnt += 1 + return x*10 + self.assertEqual(f.cache_info().maxsize, 2) + f_cnt = 0 + for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: + # * * * * + self.assertEqual(f(x), x*10) + self.assertEqual(f_cnt, 4) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 12) + self.assertEqual(misses, 4) + self.assertEqual(currsize, 2) + + def test_lru_with_maxsize_none(self): + @functools.lru_cache(maxsize=None) + def fib(n): + if n < 2: + return n + return fib(n-1) + fib(n-2) + self.assertEqual([fib(n) for n in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + + def test_lru_with_exceptions(self): + # Verify that user_function exceptions get passed through without + # creating a hard-to-read chained exception. + # http://bugs.python.org/issue13177 + for maxsize in (None, 100): + @functools.lru_cache(maxsize) + def func(i): + return 'abc'[i] + self.assertEqual(func(0), 'a') + with self.assertRaises(IndexError) as cm: + func(15) + self.assertIsNone(cm.exception.__context__) + # Verify that the previous exception did not result in a cached entry + with self.assertRaises(IndexError): + func(15) + def test_main(verbose=None): test_classes = ( TestPartial, @@ -459,20 +678,22 @@ def test_main(verbose=None): TestPythonPartial, TestUpdateWrapper, TestTotalOrdering, + TestCmpToKey, TestWraps, TestReduce, + TestLRU, ) - test_support.run_unittest(*test_classes) + support.run_unittest(*test_classes) # verify reference counting if verbose and hasattr(sys, "gettotalrefcount"): import gc counts = [None] * 5 - for i in xrange(len(counts)): - test_support.run_unittest(*test_classes) + for i in range(len(counts)): + support.run_unittest(*test_classes) gc.collect() counts[i] = sys.gettotalrefcount() - print counts + print(counts) if __name__ == '__main__': test_main(verbose=True) |