diff options
Diffstat (limited to 'Lib/test/pickletester.py')
-rw-r--r-- | Lib/test/pickletester.py | 659 |
1 files changed, 513 insertions, 146 deletions
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 4d59bde8511..a2bea600edb 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1,9 +1,11 @@ +import copyreg import io -import unittest import pickle import pickletools +import random +import struct import sys -import copyreg +import unittest import weakref from http.cookies import SimpleCookie @@ -93,6 +95,9 @@ class E(C): def __getinitargs__(self): return () +class H(object): + pass + import __main__ __main__.C = C C.__module__ = "__main__" @@ -100,6 +105,8 @@ __main__.D = D D.__module__ = "__main__" __main__.E = E E.__module__ = "__main__" +__main__.H = H +H.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -491,31 +498,52 @@ def create_data(): x.append(5) return x + class AbstractPickleTests(unittest.TestCase): # Subclass must define self.dumps, self.loads. + optimized = False + _testdata = create_data() def setUp(self): pass + def assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + self.assertEqual(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + def test_misc(self): # test various datatypes not tested by testdata for proto in protocols: x = myint(4) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = (1, ()) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) x = initarg(1, x) s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) # XXX test __reduce__ protocol? @@ -524,16 +552,16 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(expected, proto) got = self.loads(s) - self.assertEqual(expected, got) + self.assert_is_copy(expected, got) def test_load_from_data0(self): - self.assertEqual(self._testdata, self.loads(DATA0)) + self.assert_is_copy(self._testdata, self.loads(DATA0)) def test_load_from_data1(self): - self.assertEqual(self._testdata, self.loads(DATA1)) + self.assert_is_copy(self._testdata, self.loads(DATA1)) def test_load_from_data2(self): - self.assertEqual(self._testdata, self.loads(DATA2)) + self.assert_is_copy(self._testdata, self.loads(DATA2)) def test_load_classic_instance(self): # See issue5180. Test loading 2.x pickles that @@ -555,7 +583,7 @@ class AbstractPickleTests(unittest.TestCase): b"X\n" b"p0\n" b"(dp1\nb.").replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle0)) + self.assert_is_copy(X(*args), self.loads(pickle0)) # Protocol 1 (binary mode pickle) """ @@ -572,7 +600,7 @@ class AbstractPickleTests(unittest.TestCase): pickle1 = (b'(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle1)) + self.assert_is_copy(X(*args), self.loads(pickle1)) # Protocol 2 (pickle2 = b'\x80\x02' + pickle1) """ @@ -590,7 +618,7 @@ class AbstractPickleTests(unittest.TestCase): pickle2 = (b'\x80\x02(c__main__\n' b'X\n' b'q\x00oq\x01}q\x02b.').replace(b'X', xname) - self.assertEqual(X(*args), self.loads(pickle2)) + self.assert_is_copy(X(*args), self.loads(pickle2)) # There are gratuitous differences between pickles produced by # pickle and cPickle, largely because cPickle starts PUT indices at @@ -615,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertTrue(x is x[0]) @@ -624,6 +653,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) x = self.loads(s) + self.assertIsInstance(x, tuple) self.assertEqual(len(x), 1) self.assertEqual(len(x[0]), 1) self.assertTrue(x is x[0][0]) @@ -634,15 +664,39 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(d, proto) x = self.loads(s) + self.assertIsInstance(x, dict) self.assertEqual(list(x.keys()), [1]) self.assertTrue(x[1] is x) + def test_recursive_set(self): + h = H() + y = set({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, set) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + + def test_recursive_frozenset(self): + h = H() + y = frozenset({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, frozenset) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + def test_recursive_inst(self): i = C() i.attr = i for proto in protocols: s = self.dumps(i, proto) x = self.loads(s) + self.assertIsInstance(x, C) self.assertEqual(dir(x), dir(i)) self.assertIs(x.attr, x) @@ -655,6 +709,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIsInstance(x, list) self.assertEqual(len(x), 1) self.assertEqual(dir(x[0]), dir(i)) self.assertEqual(list(x[0].attr.keys()), [1]) @@ -662,31 +717,8 @@ class AbstractPickleTests(unittest.TestCase): def test_get(self): self.assertRaises(KeyError, self.loads, b'g0\np0') - self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)]) - - def test_insecure_strings(self): - # XXX Some of these tests are temporarily disabled - insecure = [b"abc", b"2 + 2", # not quoted - ## b"'abc' + 'def'", # not a single quoted string - b"'abc", # quote is not closed - b"'abc\"", # open quote and close quote don't match - b"'abc' ?", # junk after close quote - b"'\\'", # trailing backslash - # Variations on issue #17710 - b"'", - b'"', - b"' ", - b"' ", - b"' ", - b"' ", - b'" ', - # some tests of the quoting rules - ## b"'abc\"\''", - ## b"'\\\\a\'\'\'\\\'\\\\\''", - ] - for b in insecure: - buf = b"S" + b + b"\012p0\012." - self.assertRaises(ValueError, self.loads, buf) + self.assert_is_copy([(100,), (100,)], + self.loads(b'((Kdtp0\nh\x00l.))')) def test_unicode(self): endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', @@ -697,26 +729,26 @@ class AbstractPickleTests(unittest.TestCase): for u in endcases: p = self.dumps(u, proto) u2 = self.loads(p) - self.assertEqual(u2, u) + self.assert_is_copy(u, u2) def test_unicode_high_plane(self): t = '\U00012345' for proto in protocols: p = self.dumps(t, proto) t2 = self.loads(p) - self.assertEqual(t2, t) + self.assert_is_copy(t, t2) def test_bytes(self): for proto in protocols: for s in b'', b'xyz', b'xyz'*100: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) for s in [bytes([i, i]) for i in range(256)]: p = self.dumps(s, proto) - self.assertEqual(self.loads(p), s) + self.assert_is_copy(s, self.loads(p)) def test_ints(self): import sys @@ -726,14 +758,14 @@ class AbstractPickleTests(unittest.TestCase): for expected in (-n, n): s = self.dumps(expected, proto) n2 = self.loads(s) - self.assertEqual(expected, n2) + self.assert_is_copy(expected, n2) n = n >> 1 def test_maxint64(self): maxint64 = (1 << 63) - 1 data = b'I' + str(maxint64).encode("ascii") + b'\n.' got = self.loads(data) - self.assertEqual(got, maxint64) + self.assert_is_copy(maxint64, got) # Try too with a bogus literal. data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.' @@ -748,7 +780,7 @@ class AbstractPickleTests(unittest.TestCase): for n in npos, -npos: pickle = self.dumps(n, proto) got = self.loads(pickle) - self.assertEqual(n, got) + self.assert_is_copy(n, got) # Try a monster. This is quadratic-time in protos 0 & 1, so don't # bother with those. nbase = int("deadbeeffeedface", 16) @@ -756,6 +788,10 @@ class AbstractPickleTests(unittest.TestCase): for n in nbase, -nbase: p = self.dumps(n, 2) got = self.loads(p) + # assert_is_copy is very expensive here as it precomputes + # a failure message by computing the repr() of n and got, + # we just do the check ourselves. + self.assertIs(type(got), int) self.assertEqual(n, got) def test_float(self): @@ -766,7 +802,7 @@ class AbstractPickleTests(unittest.TestCase): for value in test_values: pickle = self.dumps(value, proto) got = self.loads(pickle) - self.assertEqual(value, got) + self.assert_is_copy(value, got) @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') def test_float_format(self): @@ -774,10 +810,18 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(self.dumps(1.2, 0)[0:3], b'F1.') def test_reduce(self): - pass + for proto in protocols: + inst = AAA() + dumped = self.dumps(inst, proto) + loaded = self.loads(dumped) + self.assertEqual(loaded, REDUCE_A) def test_getinitargs(self): - pass + for proto in protocols: + inst = initarg(1, 2) + dumped = self.dumps(inst, proto) + loaded = self.loads(dumped) + self.assert_is_copy(inst, loaded) def test_pop_empty_stack(self): # Test issue7455 @@ -798,6 +842,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(a, proto) b = self.loads(s) self.assertEqual(a, b) + self.assertIs(type(a), type(b)) def test_structseq(self): import time @@ -807,29 +852,29 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "stat"): t = os.stat(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) if hasattr(os, "statvfs"): t = os.statvfs(os.curdir) s = self.dumps(t, proto) u = self.loads(s) - self.assertEqual(t, u) + self.assert_is_copy(t, u) def test_ellipsis(self): for proto in protocols: s = self.dumps(..., proto) u = self.loads(s) - self.assertEqual(..., u) + self.assertIs(..., u) def test_notimplemented(self): for proto in protocols: s = self.dumps(NotImplemented, proto) u = self.loads(s) - self.assertEqual(NotImplemented, u) + self.assertIs(NotImplemented, u) def test_singleton_types(self): # Issue #6477: Test that types of built-in singletons can be pickled. @@ -843,21 +888,21 @@ class AbstractPickleTests(unittest.TestCase): # Tests for protocol 2 def test_proto(self): - build_none = pickle.NONE + pickle.STOP for proto in protocols: - expected = build_none + pickled = self.dumps(None, proto) if proto >= 2: - expected = pickle.PROTO + bytes([proto]) + expected - p = self.dumps(None, proto) - self.assertEqual(p, expected) + proto_header = pickle.PROTO + bytes([proto]) + self.assertTrue(pickled.startswith(proto_header)) + else: + self.assertEqual(count_opcode(pickle.PROTO, pickled), 0) oob = protocols[-1] + 1 # a future protocol + build_none = pickle.NONE + pickle.STOP badpickle = pickle.PROTO + bytes([oob]) + build_none try: self.loads(badpickle) - except ValueError as detail: - self.assertTrue(str(detail).startswith( - "unsupported pickle protocol")) + except ValueError as err: + self.assertIn("unsupported pickle protocol", str(err)) else: self.fail("expected bad protocol number to raise ValueError") @@ -866,7 +911,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) def test_long4(self): @@ -874,7 +919,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) def test_short_tuples(self): @@ -912,9 +957,9 @@ class AbstractPickleTests(unittest.TestCase): for x in a, b, c, d, e: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y, (proto, x, s, y)) - expected = expected_opcode[proto, len(x)] - self.assertEqual(opcode_in_pickle(expected, s), True) + self.assert_is_copy(x, y) + expected = expected_opcode[min(proto, 3), len(x)] + self.assertTrue(opcode_in_pickle(expected, s)) def test_singletons(self): # Map (proto, singleton) to expected opcode. @@ -938,8 +983,8 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) self.assertTrue(x is y, (proto, x, s, y)) - expected = expected_opcode[proto, x] - self.assertEqual(opcode_in_pickle(expected, s), True) + expected = expected_opcode[min(proto, 3), x] + self.assertTrue(opcode_in_pickle(expected, s)) def test_newobj_tuple(self): x = MyTuple([1, 2, 3]) @@ -948,8 +993,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(tuple(x), tuple(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list(self): x = MyList([1, 2, 3]) @@ -958,8 +1002,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_generic(self): for proto in protocols: @@ -970,6 +1013,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) y = self.loads(s) detail = (proto, C, B, x, y, type(y)) + self.assert_is_copy(x, y) # XXX revisit self.assertEqual(B(x), B(y), detail) self.assertEqual(x.__dict__, y.__dict__, detail) @@ -1008,11 +1052,10 @@ class AbstractPickleTests(unittest.TestCase): s1 = self.dumps(x, 1) self.assertIn(__name__.encode("utf-8"), s1) self.assertIn(b"MyList", s1) - self.assertEqual(opcode_in_pickle(opcode, s1), False) + self.assertFalse(opcode_in_pickle(opcode, s1)) y = self.loads(s1) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) # Dump using protocol 2 for test. s2 = self.dumps(x, 2) @@ -1021,9 +1064,7 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) y = self.loads(s2) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - + self.assert_is_copy(x, y) finally: e.restore() @@ -1047,7 +1088,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) self.assertEqual(num_appends, proto > 0) @@ -1056,7 +1097,7 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_appends = count_opcode(pickle.APPENDS, s) if proto == 0: self.assertEqual(num_appends, 0) @@ -1070,7 +1111,7 @@ class AbstractPickleTests(unittest.TestCase): s = self.dumps(x, proto) self.assertIsInstance(s, bytes_types) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) self.assertEqual(num_setitems, proto > 0) @@ -1079,22 +1120,49 @@ class AbstractPickleTests(unittest.TestCase): for proto in protocols: s = self.dumps(x, proto) y = self.loads(s) - self.assertEqual(x, y) + self.assert_is_copy(x, y) num_setitems = count_opcode(pickle.SETITEMS, s) if proto == 0: self.assertEqual(num_setitems, 0) else: self.assertTrue(num_setitems >= 2) + def test_set_chunking(self): + n = 10 # too small to chunk + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertEqual(num_additems, 1) + + n = 2500 # expect at least two chunks when proto >= 4 + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assert_is_copy(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertGreaterEqual(num_additems, 2) + def test_simple_newobj(self): x = object.__new__(SimpleNewObj) # avoid __init__ x.abc = 666 for proto in protocols: s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto < 4) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + proto >= 4) y = self.loads(s) # will raise TypeError if __init__ called - self.assertEqual(y.abc, 666) - self.assertEqual(x.__dict__, y.__dict__) + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1102,10 +1170,7 @@ class AbstractPickleTests(unittest.TestCase): x.bar = "hello" s = self.dumps(x, 2) y = self.loads(s) - self.assertEqual(list(x), list(y)) - self.assertEqual(x.__dict__, y.__dict__) - self.assertEqual(x.foo, y.foo) - self.assertEqual(x.bar, y.bar) + self.assert_is_copy(x, y) def test_reduce_overrides_default_reduce_ex(self): for proto in protocols: @@ -1154,11 +1219,10 @@ class AbstractPickleTests(unittest.TestCase): @no_tracing def test_bad_getattr(self): + # Issue #3514: crash when there is an infinite loop in __getattr__ x = BadGetattr() - for proto in 0, 1: + for proto in protocols: self.assertRaises(RuntimeError, self.dumps, x, proto) - # protocol 2 don't raise a RuntimeError. - d = self.dumps(x, 2) def test_reduce_bad_iterator(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -1191,11 +1255,10 @@ class AbstractPickleTests(unittest.TestCase): obj = [dict(large_dict), dict(large_dict), dict(large_dict)] for proto in protocols: - dumped = self.dumps(obj, proto) - loaded = self.loads(dumped) - self.assertEqual(loaded, obj, - "Failed protocol %d: %r != %r" - % (proto, obj, loaded)) + with self.subTest(proto=proto): + dumped = self.dumps(obj, proto) + loaded = self.loads(dumped) + self.assert_is_copy(obj, loaded) def test_attribute_name_interning(self): # Test that attribute names of pickled objects are interned when @@ -1248,6 +1311,35 @@ class AbstractPickleTests(unittest.TestCase): dumped = self.dumps(set([3]), 2) self.assertEqual(dumped, DATA6) + def test_load_python2_str_as_bytes(self): + # From Python 2: pickle.dumps('a\x00\xa0', protocol=0) + self.assertEqual(self.loads(b"S'a\\x00\\xa0'\n.", + encoding="bytes"), b'a\x00\xa0') + # From Python 2: pickle.dumps('a\x00\xa0', protocol=1) + self.assertEqual(self.loads(b'U\x03a\x00\xa0.', + encoding="bytes"), b'a\x00\xa0') + # From Python 2: pickle.dumps('a\x00\xa0', protocol=2) + self.assertEqual(self.loads(b'\x80\x02U\x03a\x00\xa0.', + encoding="bytes"), b'a\x00\xa0') + + def test_load_python2_unicode_as_str(self): + # From Python 2: pickle.dumps(u'π', protocol=0) + self.assertEqual(self.loads(b'V\\u03c0\n.', + encoding='bytes'), 'π') + # From Python 2: pickle.dumps(u'π', protocol=1) + self.assertEqual(self.loads(b'X\x02\x00\x00\x00\xcf\x80.', + encoding="bytes"), 'π') + # From Python 2: pickle.dumps(u'π', protocol=2) + self.assertEqual(self.loads(b'\x80\x02X\x02\x00\x00\x00\xcf\x80.', + encoding="bytes"), 'π') + + def test_load_long_python2_str_as_bytes(self): + # From Python 2: pickle.dumps('x' * 300, protocol=1) + self.assertEqual(self.loads(pickle.BINSTRING + + struct.pack("<I", 300) + + b'x' * 300 + pickle.STOP, + encoding='bytes'), b'x' * 300) + def test_large_pickles(self): # Test the correctness of internal buffering routines when handling # large data. @@ -1266,11 +1358,14 @@ class AbstractPickleTests(unittest.TestCase): def test_int_pickling_efficiency(self): # Test compacity of int representation (see issue #12744) for proto in protocols: - sizes = [len(self.dumps(2**n, proto)) for n in range(70)] - # the size function is monotonic - self.assertEqual(sorted(sizes), sizes) - if proto >= 2: - self.assertLessEqual(sizes[-1], 14) + with self.subTest(proto=proto): + pickles = [self.dumps(2**n, proto) for n in range(70)] + sizes = list(map(len, pickles)) + # the size function is monotonic + self.assertEqual(sorted(sizes), sizes) + if proto >= 2: + for p in pickles: + self.assertFalse(opcode_in_pickle(pickle.LONG, p)) def check_negative_32b_binXXX(self, dumped): if sys.maxsize > 2**32: @@ -1301,6 +1396,35 @@ class AbstractPickleTests(unittest.TestCase): dumped = b'\x80\x03X\x01\x00\x00\x00ar\xff\xff\xff\xff.' self.assertRaises(ValueError, self.loads, dumped) + def test_badly_escaped_string(self): + self.assertRaises(ValueError, self.loads, b"S'\\'\n.") + + def test_badly_quoted_string(self): + # Issue #17710 + badpickles = [b"S'\n.", + b'S"\n.', + b'S\' \n.', + b'S" \n.', + b'S\'"\n.', + b'S"\'\n.', + b"S' ' \n.", + b'S" " \n.', + b"S ''\n.", + b'S ""\n.', + b'S \n.', + b'S\n.', + b'S.'] + for p in badpickles: + self.assertRaises(pickle.UnpicklingError, self.loads, p) + + def test_correctly_quoted_string(self): + goodpickles = [(b"S''\n.", ''), + (b'S""\n.', ''), + (b'S"\\n"\n.', '\n'), + (b"S'\\n'\n.", '\n')] + for p, expected in goodpickles: + self.assertEqual(self.loads(p), expected) + def _check_pickling_with_opcode(self, obj, opcode, proto): pickled = self.dumps(obj, proto) self.assertTrue(opcode_in_pickle(opcode, pickled)) @@ -1324,6 +1448,194 @@ class AbstractPickleTests(unittest.TestCase): else: self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) + # Exercise framing (proto >= 4) for significant workloads + + FRAME_SIZE_TARGET = 64 * 1024 + + def check_frame_opcodes(self, pickled): + """ + Check the arguments of FRAME opcodes in a protocol 4+ pickle. + """ + frame_opcode_size = 9 + last_arg = last_pos = None + for op, arg, pos in pickletools.genops(pickled): + if op.name != 'FRAME': + continue + if last_pos is not None: + # The previous frame's size should be equal to the number + # of bytes up to the current frame. + frame_size = pos - last_pos - frame_opcode_size + self.assertEqual(frame_size, last_arg) + last_arg, last_pos = arg, pos + # The last frame's size should be equal to the number of bytes up + # to the pickle's end. + frame_size = len(pickled) - last_pos - frame_opcode_size + self.assertEqual(frame_size, last_arg) + + def test_framing_many_objects(self): + obj = list(range(10**5)) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + bytes_per_frame = (len(pickled) / + count_opcode(pickle.FRAME, pickled)) + self.assertGreater(bytes_per_frame, + self.FRAME_SIZE_TARGET / 2) + self.assertLessEqual(bytes_per_frame, + self.FRAME_SIZE_TARGET * 1) + self.check_frame_opcodes(pickled) + + def test_framing_large_objects(self): + N = 1024 * 1024 + obj = [b'x' * N, b'y' * N, b'z' * N] + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + n_frames = count_opcode(pickle.FRAME, pickled) + self.assertGreaterEqual(n_frames, len(obj)) + self.check_frame_opcodes(pickled) + + def test_optional_frames(self): + if pickle.HIGHEST_PROTOCOL < 4: + return + + def remove_frames(pickled, keep_frame=None): + """Remove frame opcodes from the given pickle.""" + frame_starts = [] + # 1 byte for the opcode and 8 for the argument + frame_opcode_size = 9 + for opcode, _, pos in pickletools.genops(pickled): + if opcode.name == 'FRAME': + frame_starts.append(pos) + + newpickle = bytearray() + last_frame_end = 0 + for i, pos in enumerate(frame_starts): + if keep_frame and keep_frame(i): + continue + newpickle += pickled[last_frame_end:pos] + last_frame_end = pos + frame_opcode_size + newpickle += pickled[last_frame_end:] + return newpickle + + frame_size = self.FRAME_SIZE_TARGET + num_frames = 20 + obj = [bytes([i]) * frame_size for i in range(num_frames)] + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + pickled = self.dumps(obj, proto) + + frameless_pickle = remove_frames(pickled) + self.assertEqual(count_opcode(pickle.FRAME, frameless_pickle), 0) + self.assertEqual(obj, self.loads(frameless_pickle)) + + some_frames_pickle = remove_frames(pickled, lambda i: i % 2) + self.assertLess(count_opcode(pickle.FRAME, some_frames_pickle), + count_opcode(pickle.FRAME, pickled)) + self.assertEqual(obj, self.loads(some_frames_pickle)) + + def test_nested_names(self): + global Nested + class Nested: + class A: + class B: + class C: + pass + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: + with self.subTest(proto=proto, obj=obj): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertIs(obj, unpickled) + + def test_py_methods(self): + global PyMethodsTest + class PyMethodsTest: + @staticmethod + def cheese(): + return "cheese" + @classmethod + def wine(cls): + assert cls is PyMethodsTest + return "wine" + def biscuits(self): + assert isinstance(self, PyMethodsTest) + return "biscuits" + class Nested: + "Nested class" + @staticmethod + def ketchup(): + return "ketchup" + @classmethod + def maple(cls): + assert cls is PyMethodsTest.Nested + return "maple" + def pie(self): + assert isinstance(self, PyMethodsTest.Nested) + return "pie" + + py_methods = ( + PyMethodsTest.cheese, + PyMethodsTest.wine, + PyMethodsTest().biscuits, + PyMethodsTest.Nested.ketchup, + PyMethodsTest.Nested.maple, + PyMethodsTest.Nested().pie + ) + py_unbound_methods = ( + (PyMethodsTest.biscuits, PyMethodsTest), + (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method in py_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(), unpickled()) + for method, cls in py_unbound_methods: + obj = cls() + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(obj), unpickled(obj)) + + def test_c_methods(self): + global Subclass + class Subclass(tuple): + class Nested(str): + pass + + c_methods = ( + # bound built-in method + ("abcd".index, ("c",)), + # unbound built-in method + (str.index, ("abcd", "c")), + # bound "slot" method + ([1, 2, 3].__len__, ()), + # unbound "slot" method + (list.__len__, ([1, 2, 3],)), + # bound "coexist" method + ({1, 2}.__contains__, (2,)), + # unbound "coexist" method + (set.__contains__, ({1, 2}, 2)), + # built-in class method + (dict.fromkeys, (("a", 1), ("b", 2))), + # built-in static method + (bytearray.maketrans, (b"abc", b"xyz")), + # subclass methods + (Subclass([1,2,2]).count, (2,)), + (Subclass.count, (Subclass([1,2,2]), 2)), + (Subclass.Nested("sweet").count, ("e",)), + (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method, args in c_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(*args), unpickled(*args)) + class BigmemPickleTests(unittest.TestCase): @@ -1336,8 +1648,9 @@ class BigmemPickleTests(unittest.TestCase): for proto in protocols: if proto < 2: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1352,24 +1665,44 @@ class BigmemPickleTests(unittest.TestCase): for proto in protocols: if proto < 3: continue - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINBYTES + + struct.pack("<I", len(data))) + data_start = pickled.index(data) + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + finally: + pickled = None finally: data = None @bigmemtest(size=_4G, memuse=2.5, dry_run=False) def test_huge_bytes_64b(self, size): - data = b"a" * size + data = b"acbd" * (size // 4) try: for proto in protocols: if proto < 3: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto == 3: + # Protocol 3 does not support large bytes objects. + # Verify that we do not crash when processing one. + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + continue + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINBYTES8 + + struct.pack("<Q", len(data))) + data_start = pickled.index(data) + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + finally: + pickled = None finally: data = None @@ -1381,27 +1714,52 @@ class BigmemPickleTests(unittest.TestCase): data = "abcd" * (size // 4) try: for proto in protocols: - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + if proto == 0: + continue + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINUNICODE + + struct.pack("<I", len(data))) + data_start = pickled.index(b'abcd') + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + self.assertEqual((pickled.rindex(b"abcd") + len(b"abcd") - + pickled.index(b"abcd")), len(data)) + finally: + pickled = None finally: data = None - # BINUNICODE (protocols 1, 2 and 3) cannot carry more than - # 2**32 - 1 bytes of utf-8 encoded unicode. + # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes + # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge + # unicode strings however. @bigmemtest(size=_4G, memuse=8, dry_run=False) def test_huge_str_64b(self, size): - data = "a" * size + data = "abcd" * (size // 4) try: for proto in protocols: if proto == 0: continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 4: + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + continue + try: + pickled = self.dumps(data, protocol=proto) + header = (pickle.BINUNICODE8 + + struct.pack("<Q", len(data))) + data_start = pickled.index(b'abcd') + self.assertEqual( + header, + pickled[data_start-len(header):data_start]) + self.assertEqual((pickled.rindex(b"abcd") + len(b"abcd") - + pickled.index(b"abcd")), len(data)) + finally: + pickled = None finally: data = None @@ -1445,8 +1803,8 @@ class REX_five(object): return object.__reduce__(self) class REX_six(object): - """This class is used to check the 4th argument (list iterator) of the reduce - protocol. + """This class is used to check the 4th argument (list iterator) of + the reduce protocol. """ def __init__(self, items=None): self.items = items if items is not None else [] @@ -1458,8 +1816,8 @@ class REX_six(object): return type(self), (), None, iter(self.items), None class REX_seven(object): - """This class is used to check the 5th argument (dict iterator) of the reduce - protocol. + """This class is used to check the 5th argument (dict iterator) of + the reduce protocol. """ def __init__(self, table=None): self.table = table if table is not None else {} @@ -1497,10 +1855,16 @@ class MyList(list): class MyDict(dict): sample = {"a": 1, "b": 2} +class MySet(set): + sample = {"a", "b"} + +class MyFrozenSet(frozenset): + sample = frozenset({"a", "b"}) + myclasses = [MyInt, MyFloat, MyComplex, MyStr, MyUnicode, - MyTuple, MyList, MyDict] + MyTuple, MyList, MyDict, MySet, MyFrozenSet] class SlotList(MyList): @@ -1510,6 +1874,8 @@ class SimpleNewObj(object): def __init__(self, a, b, c): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") + def __eq__(self, other): + return self.__dict__ == other.__dict__ class BadGetattr: def __getattr__(self, key): @@ -1546,7 +1912,7 @@ class AbstractPickleModuleTests(unittest.TestCase): def test_highest_protocol(self): # Of course this needs to be changed when HIGHEST_PROTOCOL changes. - self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) + self.assertEqual(pickle.HIGHEST_PROTOCOL, 4) def test_callapi(self): f = io.BytesIO() @@ -1731,22 +2097,23 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): def _check_multiple_unpicklings(self, ioclass): for proto in protocols: - data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] - f = ioclass() - pickler = self.pickler_class(f, protocol=proto) - pickler.dump(data1) - pickled = f.getvalue() - - N = 5 - f = ioclass(pickled * N) - unpickler = self.unpickler_class(f) - for i in range(N): - if f.seekable(): - pos = f.tell() - self.assertEqual(unpickler.load(), data1) - if f.seekable(): - self.assertEqual(f.tell(), pos + len(pickled)) - self.assertRaises(EOFError, unpickler.load) + with self.subTest(proto=proto): + data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] + f = ioclass() + pickler = self.pickler_class(f, protocol=proto) + pickler.dump(data1) + pickled = f.getvalue() + + N = 5 + f = ioclass(pickled * N) + unpickler = self.unpickler_class(f) + for i in range(N): + if f.seekable(): + pos = f.tell() + self.assertEqual(unpickler.load(), data1) + if f.seekable(): + self.assertEqual(f.tell(), pos + len(pickled)) + self.assertRaises(EOFError, unpickler.load) def test_multiple_unpicklings_seekable(self): self._check_multiple_unpicklings(io.BytesIO) |