diff options
Diffstat (limited to 'Lib/test/pickletester.py')
-rw-r--r-- | Lib/test/pickletester.py | 97 |
1 files changed, 86 insertions, 11 deletions
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 5963175ddd1..444a1036085 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1153,16 +1153,62 @@ class AbstractPickleTests(unittest.TestCase): self.assertGreaterEqual(num_additems, 2) def test_simple_newobj(self): - x = object.__new__(SimpleNewObj) # avoid __init__ + x = SimpleNewObj.__new__(SimpleNewObj, 0xface) # avoid __init__ x.abc = 666 for proto in protocols: - s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), - 2 <= proto < 4) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), - proto >= 4) - y = self.loads(s) # will raise TypeError if __init__ called - self.assert_is_copy(x, y) + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + else: + self.assertIn(b'M\xce\xfa', s) # BININT2 + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj(self): + x = ComplexNewObj.__new__(ComplexNewObj, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + elif proto < 4: + self.assertIn(b'X\x04\x00\x00\x00FACE', s) # BINUNICODE + else: + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj_ex(self): + x = ComplexNewObjEx.__new__(ComplexNewObjEx, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + if 2 <= proto < 4: + self.assertRaises(ValueError, self.dumps, x, proto) + continue + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + else: + assert proto >= 4 + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ, s)) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + 4 <= proto) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1636,6 +1682,27 @@ class AbstractPickleTests(unittest.TestCase): unpickled = self.loads(self.dumps(method, proto)) self.assertEqual(method(*args), unpickled(*args)) + def test_local_lookup_error(self): + # Test that whichmodule() errors out cleanly when looking up + # an assumed globally-reachable object fails. + def f(): + pass + # Since the function is local, lookup will fail + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Same without a __module__ attribute (exercises a different path + # in _pickle.c). + del f.__module__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Yet a different path. + f.__name__ = f.__qualname__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + class BigmemPickleTests(unittest.TestCase): @@ -1870,12 +1937,20 @@ myclasses = [MyInt, MyFloat, class SlotList(MyList): __slots__ = ["foo"] -class SimpleNewObj(object): - def __init__(self, a, b, c): +class SimpleNewObj(int): + def __init__(self, *args, **kwargs): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") def __eq__(self, other): - return self.__dict__ == other.__dict__ + return int(self) == int(other) and self.__dict__ == other.__dict__ + +class ComplexNewObj(SimpleNewObj): + def __getnewargs__(self): + return ('%X' % self, 16) + +class ComplexNewObjEx(SimpleNewObj): + def __getnewargs_ex__(self): + return ('%X' % self,), {'base': 16} class BadGetattr: def __getattr__(self, key): |