aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_memoryview.py
diff options
context:
space:
mode:
authorKen Jin <kenjin@python.org>2022-06-17 23:14:53 +0800
committerGitHub <noreply@github.com>2022-06-17 23:14:53 +0800
commit11190c4ad0d3722b8d263758ac802985131a5462 (patch)
tree7b14666e1fe7000cbafc13e3b884b9916baeb5d9 /Lib/test/test_memoryview.py
parenta51742ab82ad2a57841058fc9a16dac82d8337cf (diff)
downloadcpython-11190c4ad0d3722b8d263758ac802985131a5462.tar.gz
cpython-11190c4ad0d3722b8d263758ac802985131a5462.zip
gh-92888: Fix memoryview bad `__index__` use after free (GH-92946)
Co-authored-by: chilaxan <35645806+chilaxan@users.noreply.github.com> Co-authored-by: Serhiy Storchaka <3659035+serhiy-storchaka@users.noreply.github.com>
Diffstat (limited to 'Lib/test/test_memoryview.py')
-rw-r--r--Lib/test/test_memoryview.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index d7e3f0c0eff..9d1e1f3063c 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -545,6 +545,107 @@ class OtherTest(unittest.TestCase):
with self.assertRaises(TypeError):
pickle.dumps(m, proto)
+ def test_use_released_memory(self):
+ # gh-92888: Previously it was possible to use a memoryview even after
+ # backing buffer is freed in certain cases. This tests that those
+ # cases raise an exception.
+ size = 128
+ def release():
+ m.release()
+ nonlocal ba
+ ba = bytearray(size)
+ class MyIndex:
+ def __index__(self):
+ release()
+ return 4
+ class MyFloat:
+ def __float__(self):
+ release()
+ return 4.25
+ class MyBool:
+ def __bool__(self):
+ release()
+ return True
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ with self.assertRaises(ValueError):
+ m[MyIndex()]
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ self.assertEqual(list(m[:MyIndex()]), [255] * 4)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ self.assertEqual(list(m[MyIndex():8]), [255] * 4)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast('B', (64, 2))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[MyIndex(), 0]
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast('B', (2, 64))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0, MyIndex()]
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[MyIndex()] = 42
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[:MyIndex()] = b'spam'
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[MyIndex():8] = b'spam'
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast('B', (64, 2))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[MyIndex(), 0] = 42
+ self.assertEqual(ba[8:16], b'\0'*8)
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast('B', (2, 64))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0, MyIndex()] = 42
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size))
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0] = MyIndex()
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ for fmt in 'bhilqnBHILQN':
+ with self.subTest(fmt=fmt):
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast(fmt)
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0] = MyIndex()
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ for fmt in 'fd':
+ with self.subTest(fmt=fmt):
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast(fmt)
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0] = MyFloat()
+ self.assertEqual(ba[:8], b'\0'*8)
+
+ ba = None
+ m = memoryview(bytearray(b'\xff'*size)).cast('?')
+ with self.assertRaisesRegex(ValueError, "operation forbidden"):
+ m[0] = MyBool()
+ self.assertEqual(ba[:8], b'\0'*8)
if __name__ == "__main__":
unittest.main()