diff options
Diffstat (limited to 'Lib/test/test_shelve.py')
-rw-r--r-- | Lib/test/test_shelve.py | 236 |
1 files changed, 235 insertions, 1 deletions
diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 08c6562f2a2..64609ab9dd9 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -1,10 +1,11 @@ +import array import unittest import dbm import shelve import pickle import os -from test.support import os_helper +from test.support import import_helper, os_helper from collections.abc import MutableMapping from test.test_dbm import dbm_iterator @@ -165,6 +166,239 @@ class TestCase(unittest.TestCase): with shelve.Shelf({}) as s: self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL) + def test_custom_serializer_and_deserializer(self): + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol): + if isinstance(obj, (bytes, bytearray, str)): + if protocol == 5: + return obj + return type(obj).__name__ + elif isinstance(obj, array.array): + return obj.tobytes() + raise TypeError(f"Unsupported type for serialization: {type(obj)}") + + def deserializer(data): + if isinstance(data, (bytes, bytearray, str)): + return data.decode("utf-8") + elif isinstance(data, array.array): + return array.array("b", data) + raise TypeError( + f"Unsupported type for deserialization: {type(data)}" + ) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto), shelve.open( + self.fn, + protocol=proto, + serializer=serializer, + deserializer=deserializer + ) as s: + bar = "bar" + bytes_data = b"Hello, world!" + bytearray_data = bytearray(b"\x00\x01\x02\x03\x04") + array_data = array.array("i", [1, 2, 3, 4, 5]) + + s["foo"] = bar + s["bytes_data"] = bytes_data + s["bytearray_data"] = bytearray_data + s["array_data"] = array_data + + if proto == 5: + self.assertEqual(s["foo"], str(bar)) + self.assertEqual(s["bytes_data"], "Hello, world!") + self.assertEqual( + s["bytearray_data"], bytearray_data.decode() + ) + self.assertEqual( + s["array_data"], array_data.tobytes().decode() + ) + else: + self.assertEqual(s["foo"], "str") + self.assertEqual(s["bytes_data"], "bytes") + self.assertEqual(s["bytearray_data"], "bytearray") + self.assertEqual( + s["array_data"], array_data.tobytes().decode() + ) + + def test_custom_incomplete_serializer_and_deserializer(self): + dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + with self.assertRaises(dbm_sqlite3.error): + def serializer(obj, protocol=None): + pass + + def deserializer(data): + return data.decode("utf-8") + + with shelve.open(self.fn, serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + + def serializer(obj, protocol=None): + return type(obj).__name__.encode("utf-8") + + def deserializer(data): + pass + + with shelve.open(self.fn, serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertNotEqual(s["foo"], "bar") + self.assertIsNone(s["foo"]) + + def test_custom_serializer_and_deserializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol=None): + data = obj.__class__.__name__ + if protocol == 5: + data = str(len(data)) + return data.encode("utf-8") + + def deserializer(data): + return data.decode("utf-8") + + def type_name_len(obj): + return f"{(len(type(obj).__name__))}" + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto), shelve.BsdDbShelf( + berkeleydb.btopen(self.fn), + protocol=proto, + serializer=serializer, + deserializer=deserializer, + ) as s: + bar = "bar" + bytes_obj = b"Hello, world!" + bytearray_obj = bytearray(b"\x00\x01\x02\x03\x04") + arr_obj = array.array("i", [1, 2, 3, 4, 5]) + + s["foo"] = bar + s["bytes_data"] = bytes_obj + s["bytearray_data"] = bytearray_obj + s["array_data"] = arr_obj + + if proto == 5: + self.assertEqual(s["foo"], type_name_len(bar)) + self.assertEqual(s["bytes_data"], type_name_len(bytes_obj)) + self.assertEqual(s["bytearray_data"], + type_name_len(bytearray_obj)) + self.assertEqual(s["array_data"], type_name_len(arr_obj)) + + k, v = s.set_location(b"foo") + self.assertEqual(k, "foo") + self.assertEqual(v, type_name_len(bar)) + + k, v = s.previous() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, type_name_len(bytes_obj)) + + k, v = s.previous() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, type_name_len(bytearray_obj)) + + k, v = s.previous() + self.assertEqual(k, "array_data") + self.assertEqual(v, type_name_len(arr_obj)) + + k, v = s.next() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, type_name_len(bytearray_obj)) + + k, v = s.next() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, type_name_len(bytes_obj)) + + k, v = s.first() + self.assertEqual(k, "array_data") + self.assertEqual(v, type_name_len(arr_obj)) + else: + k, v = s.set_location(b"foo") + self.assertEqual(k, "foo") + self.assertEqual(v, "str") + + k, v = s.previous() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, "bytes") + + k, v = s.previous() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, "bytearray") + + k, v = s.previous() + self.assertEqual(k, "array_data") + self.assertEqual(v, "array") + + k, v = s.next() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, "bytearray") + + k, v = s.next() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, "bytes") + + k, v = s.first() + self.assertEqual(k, "array_data") + self.assertEqual(v, "array") + + self.assertEqual(s["foo"], "str") + self.assertEqual(s["bytes_data"], "bytes") + self.assertEqual(s["bytearray_data"], "bytearray") + self.assertEqual(s["array_data"], "array") + + def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol=None): + return type(obj).__name__.encode("utf-8") + + def deserializer(data): + pass + + with shelve.BsdDbShelf(berkeleydb.btopen(self.fn), + serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertIsNone(s["foo"]) + self.assertNotEqual(s["foo"], "bar") + + def serializer(obj, protocol=None): + pass + + def deserializer(data): + return data.decode("utf-8") + + with shelve.BsdDbShelf(berkeleydb.btopen(self.fn), + serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertEqual(s["foo"], "") + self.assertNotEqual(s["foo"], "bar") + + def test_missing_custom_deserializer(self): + def serializer(obj, protocol=None): + pass + + kwargs = dict(protocol=2, writeback=False, serializer=serializer) + self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs) + self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs) + + def test_missing_custom_serializer(self): + def deserializer(data): + pass + + kwargs = dict(protocol=2, writeback=False, deserializer=deserializer) + self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs) + self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs) + class TestShelveBase: type2test = shelve.Shelf |