aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py301
1 files changed, 261 insertions, 40 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index a713314f160..270cab00756 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,9 +1,11 @@
import functools
+import collections
import sys
import unittest
-from test import test_support
+from test import support
from weakref import proxy
import pickle
+from random import choice
@staticmethod
def PythonPartial(func, *args, **keywords):
@@ -34,7 +36,7 @@ class TestPartial(unittest.TestCase):
self.assertEqual(p(3, 4, b=30, c=40),
((1, 2, 3, 4), dict(a=10, b=30, c=40)))
p = self.thetype(map, lambda x: x*10)
- self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
+ self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
def test_attributes(self):
p = self.thetype(capture, 1, 2, a=10, b=20)
@@ -45,9 +47,9 @@ class TestPartial(unittest.TestCase):
# attributes should not be writable
if not isinstance(self.thetype, type):
return
- self.assertRaises(TypeError, setattr, p, 'func', map)
- self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
- self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
+ self.assertRaises(AttributeError, setattr, p, 'func', map)
+ self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
+ self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
p = self.thetype(hex)
try:
@@ -125,7 +127,7 @@ class TestPartial(unittest.TestCase):
def test_error_propagation(self):
def f(x, y):
- x // y
+ x / y
self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
@@ -139,12 +141,38 @@ class TestPartial(unittest.TestCase):
self.assertRaises(ReferenceError, getattr, p, 'func')
def test_with_bound_and_unbound_methods(self):
- data = map(str, range(10))
+ data = list(map(str, range(10)))
join = self.thetype(str.join, '')
self.assertEqual(join(data), '0123456789')
join = self.thetype(''.join)
self.assertEqual(join(data), '0123456789')
+ def test_repr(self):
+ args = (object(), object())
+ args_repr = ', '.join(repr(a) for a in args)
+ kwargs = {'a': object(), 'b': object()}
+ kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
+ if self.thetype is functools.partial:
+ name = 'functools.partial'
+ else:
+ name = self.thetype.__name__
+
+ f = self.thetype(capture)
+ self.assertEqual('{}({!r})'.format(name, capture),
+ repr(f))
+
+ f = self.thetype(capture, *args)
+ self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
+ repr(f))
+
+ f = self.thetype(capture, **kwargs)
+ self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
+ repr(f))
+
+ f = self.thetype(capture, *args, **kwargs)
+ self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
+ repr(f))
+
def test_pickle(self):
f = self.thetype(signature, 'asdf', bar=True)
f.add_something_to__dict__ = True
@@ -162,6 +190,9 @@ class TestPythonPartial(TestPartial):
thetype = PythonPartial
+ # the python version hasn't a nice repr
+ def test_repr(self): pass
+
# the python version isn't picklable
def test_pickle(self): pass
@@ -181,11 +212,11 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
def _default_update(self):
- def f():
+ def f(a:'This is a new annotation'):
"""This is a test"""
pass
f.attr = 'This is also a test'
- def wrapper():
+ def wrapper(b:'This is the prior annotation'):
pass
functools.update_wrapper(wrapper, f)
return wrapper, f
@@ -193,8 +224,11 @@ class TestUpdateWrapper(unittest.TestCase):
def test_default_update(self):
wrapper, f = self._default_update()
self.check_wrapper(wrapper, f)
+ self.assertIs(wrapper.__wrapped__, f)
self.assertEqual(wrapper.__name__, 'f')
self.assertEqual(wrapper.attr, 'This is also a test')
+ self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
+ self.assertNotIn('b', wrapper.__annotations__)
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
@@ -213,6 +247,7 @@ class TestUpdateWrapper(unittest.TestCase):
self.check_wrapper(wrapper, f, (), ())
self.assertEqual(wrapper.__name__, 'wrapper')
self.assertEqual(wrapper.__doc__, None)
+ self.assertEqual(wrapper.__annotations__, {})
self.assertFalse(hasattr(wrapper, 'attr'))
def test_selective_update(self):
@@ -232,6 +267,28 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
+ def test_missing_attributes(self):
+ def f():
+ pass
+ def wrapper():
+ pass
+ wrapper.dict_attr = {}
+ assign = ('attr',)
+ update = ('dict_attr',)
+ # Missing attributes on wrapped object are ignored
+ functools.update_wrapper(wrapper, f, assign, update)
+ self.assertNotIn('attr', wrapper.__dict__)
+ self.assertEqual(wrapper.dict_attr, {})
+ # Wrapper must have expected attributes for updating
+ del wrapper.dict_attr
+ with self.assertRaises(AttributeError):
+ functools.update_wrapper(wrapper, f, assign, update)
+ wrapper.dict_attr = 1
+ with self.assertRaises(AttributeError):
+ functools.update_wrapper(wrapper, f, assign, update)
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
def test_builtin_update(self):
# Test for bug #1576241
def wrapper():
@@ -239,6 +296,7 @@ class TestUpdateWrapper(unittest.TestCase):
functools.update_wrapper(wrapper, max)
self.assertEqual(wrapper.__name__, 'max')
self.assertTrue(wrapper.__doc__.startswith('max('))
+ self.assertEqual(wrapper.__annotations__, {})
class TestWraps(TestUpdateWrapper):
@@ -297,17 +355,17 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
-
class TestReduce(unittest.TestCase):
+ func = functools.reduce
def test_reduce(self):
class Squares:
-
def __init__(self, max):
self.max = max
self.sofar = []
- def __len__(self): return len(self.sofar)
+ def __len__(self):
+ return len(self.sofar)
def __getitem__(self, i):
if not 0 <= i < self.max: raise IndexError
@@ -316,27 +374,66 @@ class TestReduce(unittest.TestCase):
self.sofar.append(n*n)
n += 1
return self.sofar[i]
-
- reduce = functools.reduce
- self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
+ def add(x, y):
+ return x + y
+ self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
self.assertEqual(
- reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
+ self.func(add, [['a', 'c'], [], ['d', 'w']], []),
['a','c','d','w']
)
- self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
+ self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
self.assertEqual(
- reduce(lambda x, y: x*y, range(2,21), 1L),
- 2432902008176640000L
+ self.func(lambda x, y: x*y, range(2,21), 1),
+ 2432902008176640000
)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
- self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
- self.assertRaises(TypeError, reduce)
- self.assertRaises(TypeError, reduce, 42, 42)
- self.assertRaises(TypeError, reduce, 42, 42, 42)
- self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
- self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
- self.assertRaises(TypeError, reduce, 42, (42, 42))
+ self.assertEqual(self.func(add, Squares(10)), 285)
+ self.assertEqual(self.func(add, Squares(10), 0), 285)
+ self.assertEqual(self.func(add, Squares(0), 0), 0)
+ self.assertRaises(TypeError, self.func)
+ self.assertRaises(TypeError, self.func, 42, 42)
+ self.assertRaises(TypeError, self.func, 42, 42, 42)
+ self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
+ self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
+ self.assertRaises(TypeError, self.func, 42, (42, 42))
+ self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
+ self.assertRaises(TypeError, self.func, add, "")
+ self.assertRaises(TypeError, self.func, add, ())
+ self.assertRaises(TypeError, self.func, add, object())
+
+ class TestFailingIter:
+ def __iter__(self):
+ raise RuntimeError
+ self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
+
+ self.assertEqual(self.func(add, [], None), None)
+ self.assertEqual(self.func(add, [], 42), 42)
+
+ class BadSeq:
+ def __getitem__(self, index):
+ raise ValueError
+ self.assertRaises(ValueError, self.func, 42, BadSeq())
+
+ # Test reduce()'s use of iterators.
+ def test_iterator_usage(self):
+ class SequenceClass:
+ def __init__(self, n):
+ self.n = n
+ def __getitem__(self, i):
+ if 0 <= i < self.n:
+ return i
+ else:
+ raise IndexError
+
+ from operator import add
+ self.assertEqual(self.func(add, SequenceClass(5)), 10)
+ self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
+ self.assertRaises(TypeError, self.func, add, SequenceClass(0))
+ self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
+ self.assertEqual(self.func(add, SequenceClass(1)), 0)
+ self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
+
+ d = {"one": 1, "two": 2, "three": 3}
+ self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
def test_cmp_to_key(self):
@@ -350,7 +447,8 @@ class TestCmpToKey(unittest.TestCase):
return y - x
key = functools.cmp_to_key(mycmp)
k = key(10)
- self.assertRaises(TypeError, hash(k))
+ self.assertRaises(TypeError, hash, k)
+ self.assertFalse(isinstance(k, collections.Hashable))
class TestTotalOrdering(unittest.TestCase):
@@ -421,14 +519,14 @@ class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_no_overwrite(self):
# new methods should not overwrite existing
@functools.total_ordering
- class A(str):
+ class A(int):
pass
- self.assertTrue(A("a") < A("b"))
- self.assertTrue(A("b") > A("a"))
- self.assertTrue(A("a") <= A("b"))
- self.assertTrue(A("b") >= A("a"))
- self.assertTrue(A("b") <= A("b"))
- self.assertTrue(A("b") >= A("b"))
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
def test_no_operations_defined(self):
with self.assertRaises(ValueError):
@@ -452,6 +550,127 @@ class TestTotalOrdering(unittest.TestCase):
with self.assertRaises(TypeError):
TestTO(8) <= ()
+class TestLRU(unittest.TestCase):
+
+ def test_lru(self):
+ def orig(x, y):
+ return 3*x+y
+ f = functools.lru_cache(maxsize=20)(orig)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(maxsize, 20)
+ self.assertEqual(currsize, 0)
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 0)
+
+ domain = range(5)
+ for i in range(1000):
+ x, y = choice(domain), choice(domain)
+ actual = f(x, y)
+ expected = orig(x, y)
+ self.assertEqual(actual, expected)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertTrue(hits > misses)
+ self.assertEqual(hits + misses, 1000)
+ self.assertEqual(currsize, 20)
+
+ f.cache_clear() # test clearing
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 0)
+ self.assertEqual(currsize, 0)
+ f(x, y)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # Test bypassing the cache
+ self.assertIs(f.__wrapped__, orig)
+ f.__wrapped__(x, y)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # test size zero (which means "never-cache")
+ @functools.lru_cache(0)
+ def f():
+ nonlocal f_cnt
+ f_cnt += 1
+ return 20
+ self.assertEqual(f.cache_info().maxsize, 0)
+ f_cnt = 0
+ for i in range(5):
+ self.assertEqual(f(), 20)
+ self.assertEqual(f_cnt, 5)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 5)
+ self.assertEqual(currsize, 0)
+
+ # test size one
+ @functools.lru_cache(1)
+ def f():
+ nonlocal f_cnt
+ f_cnt += 1
+ return 20
+ self.assertEqual(f.cache_info().maxsize, 1)
+ f_cnt = 0
+ for i in range(5):
+ self.assertEqual(f(), 20)
+ self.assertEqual(f_cnt, 1)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 4)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # test size two
+ @functools.lru_cache(2)
+ def f(x):
+ nonlocal f_cnt
+ f_cnt += 1
+ return x*10
+ self.assertEqual(f.cache_info().maxsize, 2)
+ f_cnt = 0
+ for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
+ # * * * *
+ self.assertEqual(f(x), x*10)
+ self.assertEqual(f_cnt, 4)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 12)
+ self.assertEqual(misses, 4)
+ self.assertEqual(currsize, 2)
+
+ def test_lru_with_maxsize_none(self):
+ @functools.lru_cache(maxsize=None)
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n-1) + fib(n-2)
+ self.assertEqual([fib(n) for n in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
+ def test_lru_with_exceptions(self):
+ # Verify that user_function exceptions get passed through without
+ # creating a hard-to-read chained exception.
+ # http://bugs.python.org/issue13177
+ for maxsize in (None, 100):
+ @functools.lru_cache(maxsize)
+ def func(i):
+ return 'abc'[i]
+ self.assertEqual(func(0), 'a')
+ with self.assertRaises(IndexError) as cm:
+ func(15)
+ self.assertIsNone(cm.exception.__context__)
+ # Verify that the previous exception did not result in a cached entry
+ with self.assertRaises(IndexError):
+ func(15)
+
def test_main(verbose=None):
test_classes = (
TestPartial,
@@ -459,20 +678,22 @@ def test_main(verbose=None):
TestPythonPartial,
TestUpdateWrapper,
TestTotalOrdering,
+ TestCmpToKey,
TestWraps,
TestReduce,
+ TestLRU,
)
- test_support.run_unittest(*test_classes)
+ support.run_unittest(*test_classes)
# verify reference counting
if verbose and hasattr(sys, "gettotalrefcount"):
import gc
counts = [None] * 5
- for i in xrange(len(counts)):
- test_support.run_unittest(*test_classes)
+ for i in range(len(counts)):
+ support.run_unittest(*test_classes)
gc.collect()
counts[i] = sys.gettotalrefcount()
- print counts
+ print(counts)
if __name__ == '__main__':
test_main(verbose=True)