aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_shelve.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_shelve.py')
-rw-r--r--Lib/test/test_shelve.py236
1 files changed, 235 insertions, 1 deletions
diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py
index 08c6562f2a2..64609ab9dd9 100644
--- a/Lib/test/test_shelve.py
+++ b/Lib/test/test_shelve.py
@@ -1,10 +1,11 @@
+import array
import unittest
import dbm
import shelve
import pickle
import os
-from test.support import os_helper
+from test.support import import_helper, os_helper
from collections.abc import MutableMapping
from test.test_dbm import dbm_iterator
@@ -165,6 +166,239 @@ class TestCase(unittest.TestCase):
with shelve.Shelf({}) as s:
self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL)
+ def test_custom_serializer_and_deserializer(self):
+ os.mkdir(self.dirname)
+ self.addCleanup(os_helper.rmtree, self.dirname)
+
+ def serializer(obj, protocol):
+ if isinstance(obj, (bytes, bytearray, str)):
+ if protocol == 5:
+ return obj
+ return type(obj).__name__
+ elif isinstance(obj, array.array):
+ return obj.tobytes()
+ raise TypeError(f"Unsupported type for serialization: {type(obj)}")
+
+ def deserializer(data):
+ if isinstance(data, (bytes, bytearray, str)):
+ return data.decode("utf-8")
+ elif isinstance(data, array.array):
+ return array.array("b", data)
+ raise TypeError(
+ f"Unsupported type for deserialization: {type(data)}"
+ )
+
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto), shelve.open(
+ self.fn,
+ protocol=proto,
+ serializer=serializer,
+ deserializer=deserializer
+ ) as s:
+ bar = "bar"
+ bytes_data = b"Hello, world!"
+ bytearray_data = bytearray(b"\x00\x01\x02\x03\x04")
+ array_data = array.array("i", [1, 2, 3, 4, 5])
+
+ s["foo"] = bar
+ s["bytes_data"] = bytes_data
+ s["bytearray_data"] = bytearray_data
+ s["array_data"] = array_data
+
+ if proto == 5:
+ self.assertEqual(s["foo"], str(bar))
+ self.assertEqual(s["bytes_data"], "Hello, world!")
+ self.assertEqual(
+ s["bytearray_data"], bytearray_data.decode()
+ )
+ self.assertEqual(
+ s["array_data"], array_data.tobytes().decode()
+ )
+ else:
+ self.assertEqual(s["foo"], "str")
+ self.assertEqual(s["bytes_data"], "bytes")
+ self.assertEqual(s["bytearray_data"], "bytearray")
+ self.assertEqual(
+ s["array_data"], array_data.tobytes().decode()
+ )
+
+ def test_custom_incomplete_serializer_and_deserializer(self):
+ dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
+ os.mkdir(self.dirname)
+ self.addCleanup(os_helper.rmtree, self.dirname)
+
+ with self.assertRaises(dbm_sqlite3.error):
+ def serializer(obj, protocol=None):
+ pass
+
+ def deserializer(data):
+ return data.decode("utf-8")
+
+ with shelve.open(self.fn, serializer=serializer,
+ deserializer=deserializer) as s:
+ s["foo"] = "bar"
+
+ def serializer(obj, protocol=None):
+ return type(obj).__name__.encode("utf-8")
+
+ def deserializer(data):
+ pass
+
+ with shelve.open(self.fn, serializer=serializer,
+ deserializer=deserializer) as s:
+ s["foo"] = "bar"
+ self.assertNotEqual(s["foo"], "bar")
+ self.assertIsNone(s["foo"])
+
+ def test_custom_serializer_and_deserializer_bsd_db_shelf(self):
+ berkeleydb = import_helper.import_module("berkeleydb")
+ os.mkdir(self.dirname)
+ self.addCleanup(os_helper.rmtree, self.dirname)
+
+ def serializer(obj, protocol=None):
+ data = obj.__class__.__name__
+ if protocol == 5:
+ data = str(len(data))
+ return data.encode("utf-8")
+
+ def deserializer(data):
+ return data.decode("utf-8")
+
+ def type_name_len(obj):
+ return f"{(len(type(obj).__name__))}"
+
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto), shelve.BsdDbShelf(
+ berkeleydb.btopen(self.fn),
+ protocol=proto,
+ serializer=serializer,
+ deserializer=deserializer,
+ ) as s:
+ bar = "bar"
+ bytes_obj = b"Hello, world!"
+ bytearray_obj = bytearray(b"\x00\x01\x02\x03\x04")
+ arr_obj = array.array("i", [1, 2, 3, 4, 5])
+
+ s["foo"] = bar
+ s["bytes_data"] = bytes_obj
+ s["bytearray_data"] = bytearray_obj
+ s["array_data"] = arr_obj
+
+ if proto == 5:
+ self.assertEqual(s["foo"], type_name_len(bar))
+ self.assertEqual(s["bytes_data"], type_name_len(bytes_obj))
+ self.assertEqual(s["bytearray_data"],
+ type_name_len(bytearray_obj))
+ self.assertEqual(s["array_data"], type_name_len(arr_obj))
+
+ k, v = s.set_location(b"foo")
+ self.assertEqual(k, "foo")
+ self.assertEqual(v, type_name_len(bar))
+
+ k, v = s.previous()
+ self.assertEqual(k, "bytes_data")
+ self.assertEqual(v, type_name_len(bytes_obj))
+
+ k, v = s.previous()
+ self.assertEqual(k, "bytearray_data")
+ self.assertEqual(v, type_name_len(bytearray_obj))
+
+ k, v = s.previous()
+ self.assertEqual(k, "array_data")
+ self.assertEqual(v, type_name_len(arr_obj))
+
+ k, v = s.next()
+ self.assertEqual(k, "bytearray_data")
+ self.assertEqual(v, type_name_len(bytearray_obj))
+
+ k, v = s.next()
+ self.assertEqual(k, "bytes_data")
+ self.assertEqual(v, type_name_len(bytes_obj))
+
+ k, v = s.first()
+ self.assertEqual(k, "array_data")
+ self.assertEqual(v, type_name_len(arr_obj))
+ else:
+ k, v = s.set_location(b"foo")
+ self.assertEqual(k, "foo")
+ self.assertEqual(v, "str")
+
+ k, v = s.previous()
+ self.assertEqual(k, "bytes_data")
+ self.assertEqual(v, "bytes")
+
+ k, v = s.previous()
+ self.assertEqual(k, "bytearray_data")
+ self.assertEqual(v, "bytearray")
+
+ k, v = s.previous()
+ self.assertEqual(k, "array_data")
+ self.assertEqual(v, "array")
+
+ k, v = s.next()
+ self.assertEqual(k, "bytearray_data")
+ self.assertEqual(v, "bytearray")
+
+ k, v = s.next()
+ self.assertEqual(k, "bytes_data")
+ self.assertEqual(v, "bytes")
+
+ k, v = s.first()
+ self.assertEqual(k, "array_data")
+ self.assertEqual(v, "array")
+
+ self.assertEqual(s["foo"], "str")
+ self.assertEqual(s["bytes_data"], "bytes")
+ self.assertEqual(s["bytearray_data"], "bytearray")
+ self.assertEqual(s["array_data"], "array")
+
+ def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self):
+ berkeleydb = import_helper.import_module("berkeleydb")
+ os.mkdir(self.dirname)
+ self.addCleanup(os_helper.rmtree, self.dirname)
+
+ def serializer(obj, protocol=None):
+ return type(obj).__name__.encode("utf-8")
+
+ def deserializer(data):
+ pass
+
+ with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
+ serializer=serializer,
+ deserializer=deserializer) as s:
+ s["foo"] = "bar"
+ self.assertIsNone(s["foo"])
+ self.assertNotEqual(s["foo"], "bar")
+
+ def serializer(obj, protocol=None):
+ pass
+
+ def deserializer(data):
+ return data.decode("utf-8")
+
+ with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
+ serializer=serializer,
+ deserializer=deserializer) as s:
+ s["foo"] = "bar"
+ self.assertEqual(s["foo"], "")
+ self.assertNotEqual(s["foo"], "bar")
+
+ def test_missing_custom_deserializer(self):
+ def serializer(obj, protocol=None):
+ pass
+
+ kwargs = dict(protocol=2, writeback=False, serializer=serializer)
+ self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
+ self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
+
+ def test_missing_custom_serializer(self):
+ def deserializer(data):
+ pass
+
+ kwargs = dict(protocol=2, writeback=False, deserializer=deserializer)
+ self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
+ self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
+
class TestShelveBase:
type2test = shelve.Shelf