diff options
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/_test_multiprocessing.py | 22 | ||||
-rw-r--r-- | Lib/test/support/interpreters/channels.py | 2 | ||||
-rw-r--r-- | Lib/test/test__interpreters.py | 25 | ||||
-rw-r--r-- | Lib/test/test_ast/test_ast.py | 1 | ||||
-rw-r--r-- | Lib/test/test_capi/test_object.py | 1 | ||||
-rw-r--r-- | Lib/test/test_collections.py | 2 | ||||
-rw-r--r-- | Lib/test/test_httpservers.py | 70 | ||||
-rw-r--r-- | Lib/test/test_interpreters/test_api.py | 23 | ||||
-rw-r--r-- | Lib/test/test_ipaddress.py | 11 | ||||
-rw-r--r-- | Lib/test/test_json/test_recursion.py | 1 | ||||
-rw-r--r-- | Lib/test/test_memoryview.py | 20 | ||||
-rw-r--r-- | Lib/test/test_statistics.py | 1 | ||||
-rw-r--r-- | Lib/test/test_unittest/test_result.py | 37 | ||||
-rw-r--r-- | Lib/test/test_unittest/test_runner.py | 52 | ||||
-rw-r--r-- | Lib/test/test_zstd.py | 61 |
15 files changed, 275 insertions, 54 deletions
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 6a20a1eb03e..75f31d858d3 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -6844,6 +6844,28 @@ class MiscTestCase(unittest.TestCase): self.assertEqual("332833500", out.decode('utf-8').strip()) self.assertFalse(err, msg=err.decode('utf-8')) + def test_forked_thread_not_started(self): + # gh-134381: Ensure that a thread that has not been started yet in + # the parent process can be started within a forked child process. + + if multiprocessing.get_start_method() != "fork": + self.skipTest("fork specific test") + + q = multiprocessing.Queue() + t = threading.Thread(target=lambda: q.put("done"), daemon=True) + + def child(): + t.start() + t.join() + + p = multiprocessing.Process(target=child) + p.start() + p.join(support.SHORT_TIMEOUT) + + self.assertEqual(p.exitcode, 0) + self.assertEqual(q.get_nowait(), "done") + close_queue(q) + # # Mixins diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/interpreters/channels.py index 3b6e0f0effd..7a2bd7d63f8 100644 --- a/Lib/test/support/interpreters/channels.py +++ b/Lib/test/support/interpreters/channels.py @@ -69,7 +69,7 @@ def list_all(): if not hasattr(send, '_unboundop'): send._set_unbound(unboundop) else: - assert send._unbound[0] == unboundop + assert send._unbound[0] == op channels.append(chan) return channels diff --git a/Lib/test/test__interpreters.py b/Lib/test/test__interpreters.py index ad3ebbfdff6..63fdaad8de7 100644 --- a/Lib/test/test__interpreters.py +++ b/Lib/test/test__interpreters.py @@ -474,15 +474,13 @@ class CommonTests(TestBase): def test_signatures(self): # See https://github.com/python/cpython/issues/126654 - msg = r'_interpreters.exec\(\) argument 3 must be dict, not int' + msg = "expected 'shared' to be a dict" with self.assertRaisesRegex(TypeError, msg): _interpreters.exec(self.id, 'a', 1) with self.assertRaisesRegex(TypeError, msg): _interpreters.exec(self.id, 'a', shared=1) - msg = r'_interpreters.run_string\(\) argument 3 must be dict, not int' with self.assertRaisesRegex(TypeError, msg): _interpreters.run_string(self.id, 'a', shared=1) - msg = r'_interpreters.run_func\(\) argument 3 must be dict, not int' with self.assertRaisesRegex(TypeError, msg): _interpreters.run_func(self.id, lambda: None, shared=1) @@ -954,8 +952,7 @@ class RunFailedTests(TestBase): """) with self.subTest('script'): - with self.assertRaises(SyntaxError): - _interpreters.run_string(self.id, script) + self.assert_run_failed(SyntaxError, script) with self.subTest('module'): modname = 'spam_spam_spam' @@ -1022,19 +1019,12 @@ class RunFuncTests(TestBase): with open(w, 'w', encoding="utf-8") as spipe: with contextlib.redirect_stdout(spipe): print('it worked!', end='') - failed = None def f(): - nonlocal failed - try: - _interpreters.set___main___attrs(self.id, dict(w=w)) - _interpreters.run_func(self.id, script) - except Exception as exc: - failed = exc + _interpreters.set___main___attrs(self.id, dict(w=w)) + _interpreters.run_func(self.id, script) t = threading.Thread(target=f) t.start() t.join() - if failed: - raise Exception from failed with open(r, encoding="utf-8") as outfile: out = outfile.read() @@ -1063,16 +1053,19 @@ class RunFuncTests(TestBase): spam = True def script(): assert spam - with self.assertRaises(ValueError): + + with self.assertRaises(TypeError): _interpreters.run_func(self.id, script) + # XXX This hasn't been fixed yet. + @unittest.expectedFailure def test_return_value(self): def script(): return 'spam' with self.assertRaises(ValueError): _interpreters.run_func(self.id, script) -# @unittest.skip("we're not quite there yet") + @unittest.skip("we're not quite there yet") def test_args(self): with self.subTest('args'): def script(a, b=0): diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 1479a8eafd6..46745cfa8f8 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -3292,6 +3292,7 @@ class CommandLineTests(unittest.TestCase): expect = self.text_normalize(expect) self.assertEqual(res, expect) + @support.requires_resource('cpu') def test_invocation(self): # test various combinations of parameters base_flags = ( diff --git a/Lib/test/test_capi/test_object.py b/Lib/test/test_capi/test_object.py index cd772120bde..d4056727d07 100644 --- a/Lib/test/test_capi/test_object.py +++ b/Lib/test/test_capi/test_object.py @@ -221,6 +221,7 @@ class CAPITest(unittest.TestCase): """ self.check_negative_refcount(code) + @support.requires_resource('cpu') def test_decref_delayed(self): # gh-130519: Test that _PyObject_XDecRefDelayed() and QSBR code path # handles destructors that are possibly re-entrant or trigger a GC. diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 1e93530398b..d9d61e5c205 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -542,6 +542,8 @@ class TestNamedTuple(unittest.TestCase): self.assertEqual(Dot(1)._replace(d=999), (999,)) self.assertEqual(Dot(1)._fields, ('d',)) + @support.requires_resource('cpu') + def test_large_size(self): n = support.exceeds_recursion_limit() names = list(set(''.join([choice(string.ascii_letters) for j in range(10)]) for i in range(n))) diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 0af1c45ecb2..2548a7c5f29 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -21,6 +21,7 @@ import email.utils import html import http, http.client import urllib.parse +import urllib.request import tempfile import time import datetime @@ -33,6 +34,8 @@ from test import support from test.support import ( is_apple, import_helper, os_helper, threading_helper ) +from test.support.script_helper import kill_python, spawn_python +from test.support.socket_helper import find_unused_port try: import ssl @@ -1452,6 +1455,73 @@ class CommandLineTestCase(unittest.TestCase): self.assertIn('error', stderr.getvalue()) +class CommandLineRunTimeTestCase(unittest.TestCase): + served_data = os.urandom(32) + served_filename = 'served_filename' + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = b'somepass' + tls_password_file = 'ssl_key_password' + + def setUp(self): + super().setUp() + server_dir_context = os_helper.temp_cwd() + server_dir = self.enterContext(server_dir_context) + with open(self.served_filename, 'wb') as f: + f.write(self.served_data) + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password) + + def fetch_file(self, path, context=None): + req = urllib.request.Request(path, method='GET') + with urllib.request.urlopen(req, context=context) as res: + return res.read() + + def parse_cli_output(self, output): + match = re.search(r'Serving (HTTP|HTTPS) on (.+) port (\d+)', output) + if match is None: + return None, None, None + return match.group(1).lower(), match.group(2), int(match.group(3)) + + def wait_for_server(self, proc, protocol, bind, port): + """Check that the server has been successfully started.""" + line = proc.stdout.readline().strip() + if support.verbose: + print() + print('python -m http.server: ', line) + return self.parse_cli_output(line) == (protocol, bind, port) + + def test_http_client(self): + bind, port = '127.0.0.1', find_unused_port() + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'http', bind, port)) + res = self.fetch_file(f'http://{bind}:{port}/{self.served_filename}') + self.assertEqual(res, self.served_data) + + @unittest.skipIf(ssl is None, "requires ssl") + def test_https_client(self): + context = ssl.create_default_context() + # allow self-signed certificates + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + bind, port = '127.0.0.1', find_unused_port() + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + '--tls-cert', self.tls_cert, + '--tls-key', self.tls_key, + '--tls-password-file', self.tls_password_file, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'https', bind, port)) + url = f'https://{bind}:{port}/{self.served_filename}' + res = self.fetch_file(url, context=context) + self.assertEqual(res, self.served_data) + + def setUpModule(): unittest.addModuleCleanup(os.chdir, os.getcwd()) diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py index 165949167ce..1e2d572b1cb 100644 --- a/Lib/test/test_interpreters/test_api.py +++ b/Lib/test/test_interpreters/test_api.py @@ -839,16 +839,9 @@ class TestInterpreterExec(TestBase): interp.exec(10) def test_bytes_for_script(self): - r, w = self.pipe() - RAN = b'R' - DONE = b'D' interp = interpreters.create() - interp.exec(f"""if True: - import os - os.write({w}, {RAN!r}) - """) - os.write(w, DONE) - self.assertEqual(os.read(r, 1), RAN) + with self.assertRaises(TypeError): + interp.exec(b'print("spam")') def test_with_background_threads_still_running(self): r_interp, w_interp = self.pipe() @@ -1017,6 +1010,8 @@ class TestInterpreterCall(TestBase): for i, (callable, args, kwargs) in enumerate([ (call_func_noop, (), {}), + (call_func_return_shareable, (), {}), + (call_func_return_not_shareable, (), {}), (Spam.noop, (), {}), ]): with self.subTest(f'success case #{i+1}'): @@ -1041,8 +1036,6 @@ class TestInterpreterCall(TestBase): (call_func_complex, ('custom', 'spam!'), {}), (call_func_complex, ('custom-inner', 'eggs!'), {}), (call_func_complex, ('???',), {'exc': ValueError('spam')}), - (call_func_return_shareable, (), {}), - (call_func_return_not_shareable, (), {}), ]): with self.subTest(f'invalid case #{i+1}'): with self.assertRaises(Exception): @@ -1058,6 +1051,8 @@ class TestInterpreterCall(TestBase): for i, (callable, args, kwargs) in enumerate([ (call_func_noop, (), {}), + (call_func_return_shareable, (), {}), + (call_func_return_not_shareable, (), {}), (Spam.noop, (), {}), ]): with self.subTest(f'success case #{i+1}'): @@ -1084,8 +1079,6 @@ class TestInterpreterCall(TestBase): (call_func_complex, ('custom', 'spam!'), {}), (call_func_complex, ('custom-inner', 'eggs!'), {}), (call_func_complex, ('???',), {'exc': ValueError('spam')}), - (call_func_return_shareable, (), {}), - (call_func_return_not_shareable, (), {}), ]): with self.subTest(f'invalid case #{i+1}'): if args or kwargs: @@ -1625,8 +1618,8 @@ class LowLevelTests(TestBase): def test_call(self): with self.subTest('no args'): interpid = _interpreters.create() - with self.assertRaises(ValueError): - _interpreters.call(interpid, call_func_return_shareable) + exc = _interpreters.call(interpid, call_func_return_shareable) + self.assertIs(exc, None) with self.subTest('uncaught exception'): interpid = _interpreters.create() diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index a06608d0016..ee95454e64b 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -397,6 +397,17 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): # A trailing IPv4 address is two parts assertBadSplit("10:9:8:7:6:5:4:3:42.42.42.42%scope") + def test_bad_address_split_v6_too_long(self): + def assertBadSplit(addr): + msg = r"At most 39 characters expected in %s" + with self.assertAddressError(msg, repr(re.escape(addr[:14]))): + ipaddress.IPv6Address(addr) + + # Long IPv6 address + long_addr = ("0:" * 10000) + "0" + assertBadSplit(long_addr) + assertBadSplit(long_addr + "%zoneid") + def test_bad_address_split_v6_too_many_parts(self): def assertBadSplit(addr): msg = "Exactly 8 parts expected without '::' in %r" diff --git a/Lib/test/test_json/test_recursion.py b/Lib/test/test_json/test_recursion.py index 8f0e5e078ed..5d7b56ff9ad 100644 --- a/Lib/test/test_json/test_recursion.py +++ b/Lib/test/test_json/test_recursion.py @@ -86,6 +86,7 @@ class TestRecursion: @support.skip_wasi_stack_overflow() @support.skip_emscripten_stack_overflow() + @support.requires_resource('cpu') def test_highly_nested_objects_encoding(self): # See #12051 l, d = [], {} diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py index 61b068c630c..64f440f180b 100644 --- a/Lib/test/test_memoryview.py +++ b/Lib/test/test_memoryview.py @@ -743,19 +743,21 @@ class RacingTest(unittest.TestCase): from multiprocessing.managers import SharedMemoryManager except ImportError: self.skipTest("Test requires multiprocessing") - from threading import Thread + from threading import Thread, Event - n = 100 + start = Event() with SharedMemoryManager() as smm: obj = smm.ShareableList(range(100)) - threads = [] - for _ in range(n): - # Issue gh-127085, the `ShareableList.count` is just a convenient way to mess the `exports` - # counter of `memoryview`, this issue has no direct relation with `ShareableList`. - threads.append(Thread(target=obj.count, args=(1,))) - + def test(): + # Issue gh-127085, the `ShareableList.count` is just a + # convenient way to mess the `exports` counter of `memoryview`, + # this issue has no direct relation with `ShareableList`. + start.wait(support.SHORT_TIMEOUT) + for i in range(10): + obj.count(1) + threads = [Thread(target=test) for _ in range(10)] with threading_helper.start_threads(threads): - pass + start.set() del obj diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 5980f939185..0dd619dd7c8 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -2346,6 +2346,7 @@ class TestGeometricMean(unittest.TestCase): class TestKDE(unittest.TestCase): + @support.requires_resource('cpu') def test_kde(self): kde = statistics.kde StatisticsError = statistics.StatisticsError diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py index 9ac4c52449c..3f44e617303 100644 --- a/Lib/test/test_unittest/test_result.py +++ b/Lib/test/test_unittest/test_result.py @@ -1282,14 +1282,22 @@ class TestOutputBuffering(unittest.TestCase): suite(result) expected_out = '\nStdout:\ndo cleanup2\ndo cleanup1\n' self.assertEqual(stdout.getvalue(), expected_out) - self.assertEqual(len(result.errors), 1) + self.assertEqual(len(result.errors), 2) description = 'tearDownModule (Module)' test_case, formatted_exc = result.errors[0] self.assertEqual(test_case.description, description) self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) self.assertNotIn('TypeError', formatted_exc) self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + def testBufferSetUpModule_DoModuleCleanups(self): with captured_stdout() as stdout: result = unittest.TestResult() @@ -1313,22 +1321,34 @@ class TestOutputBuffering(unittest.TestCase): suite(result) expected_out = '\nStdout:\nset up module\ndo cleanup2\ndo cleanup1\n' self.assertEqual(stdout.getvalue(), expected_out) - self.assertEqual(len(result.errors), 2) + self.assertEqual(len(result.errors), 3) description = 'setUpModule (Module)' test_case, formatted_exc = result.errors[0] self.assertEqual(test_case.description, description) self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) self.assertNotIn('ValueError', formatted_exc) self.assertNotIn('TypeError', formatted_exc) self.assertIn('\nStdout:\nset up module\n', formatted_exc) + test_case, formatted_exc = result.errors[1] self.assertIn(expected_out, formatted_exc) self.assertEqual(test_case.description, description) self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) self.assertNotIn('ZeroDivisionError', formatted_exc) self.assertNotIn('TypeError', formatted_exc) self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertIn(expected_out, formatted_exc) + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + def testBufferTearDownModule_DoModuleCleanups(self): with captured_stdout() as stdout: result = unittest.TestResult() @@ -1355,21 +1375,32 @@ class TestOutputBuffering(unittest.TestCase): suite(result) expected_out = '\nStdout:\ntear down module\ndo cleanup2\ndo cleanup1\n' self.assertEqual(stdout.getvalue(), expected_out) - self.assertEqual(len(result.errors), 2) + self.assertEqual(len(result.errors), 3) description = 'tearDownModule (Module)' test_case, formatted_exc = result.errors[0] self.assertEqual(test_case.description, description) self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) self.assertNotIn('ValueError', formatted_exc) self.assertNotIn('TypeError', formatted_exc) self.assertIn('\nStdout:\ntear down module\n', formatted_exc) + test_case, formatted_exc = result.errors[1] self.assertEqual(test_case.description, description) self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) self.assertNotIn('ZeroDivisionError', formatted_exc) self.assertNotIn('TypeError', formatted_exc) self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ExceptionGroup', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_unittest/test_runner.py b/Lib/test/test_unittest/test_runner.py index 4d3cfd60b8d..a47e2ebb59d 100644 --- a/Lib/test/test_unittest/test_runner.py +++ b/Lib/test/test_unittest/test_runner.py @@ -13,6 +13,7 @@ from test.test_unittest.support import ( LoggingResult, ResultWithNoStartTestRunStopTestRun, ) +from test.support.testcase import ExceptionIsLikeMixin def resultFactory(*_): @@ -604,7 +605,7 @@ class TestClassCleanup(unittest.TestCase): @support.force_not_colorized_test_class -class TestModuleCleanUp(unittest.TestCase): +class TestModuleCleanUp(ExceptionIsLikeMixin, unittest.TestCase): def test_add_and_do_ModuleCleanup(self): module_cleanups = [] @@ -646,11 +647,50 @@ class TestModuleCleanUp(unittest.TestCase): [(module_cleanup_good, (1, 2, 3), dict(four='hello', five='goodbye')), (module_cleanup_bad, (), {})]) - with self.assertRaises(CustomError) as e: + with self.assertRaises(Exception) as e: unittest.case.doModuleCleanups() - self.assertEqual(str(e.exception), 'CleanUpExc') + self.assertExceptionIsLike(e.exception, + ExceptionGroup('module cleanup failed', + [CustomError('CleanUpExc')])) self.assertEqual(unittest.case._module_cleanups, []) + def test_doModuleCleanup_with_multiple_errors_in_addModuleCleanup(self): + def module_cleanup_bad1(): + raise TypeError('CleanUpExc1') + + def module_cleanup_bad2(): + raise ValueError('CleanUpExc2') + + class Module: + unittest.addModuleCleanup(module_cleanup_bad1) + unittest.addModuleCleanup(module_cleanup_bad2) + with self.assertRaises(ExceptionGroup) as e: + unittest.case.doModuleCleanups() + self.assertExceptionIsLike(e.exception, + ExceptionGroup('module cleanup failed', [ + ValueError('CleanUpExc2'), + TypeError('CleanUpExc1'), + ])) + + def test_doModuleCleanup_with_exception_group_in_addModuleCleanup(self): + def module_cleanup_bad(): + raise ExceptionGroup('CleanUpExc', [ + ValueError('CleanUpExc2'), + TypeError('CleanUpExc1'), + ]) + + class Module: + unittest.addModuleCleanup(module_cleanup_bad) + with self.assertRaises(ExceptionGroup) as e: + unittest.case.doModuleCleanups() + self.assertExceptionIsLike(e.exception, + ExceptionGroup('module cleanup failed', [ + ExceptionGroup('CleanUpExc', [ + ValueError('CleanUpExc2'), + TypeError('CleanUpExc1'), + ]), + ])) + def test_addModuleCleanup_arg_errors(self): cleanups = [] def cleanup(*args, **kwargs): @@ -871,9 +911,11 @@ class TestModuleCleanUp(unittest.TestCase): ordering = [] blowUp = True suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) - with self.assertRaises(CustomError) as cm: + with self.assertRaises(Exception) as cm: suite.debug() - self.assertEqual(str(cm.exception), 'CleanUpExc') + self.assertExceptionIsLike(cm.exception, + ExceptionGroup('module cleanup failed', + [CustomError('CleanUpExc')])) self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', 'tearDownClass', 'tearDownModule', 'cleanup_exc']) self.assertEqual(unittest.case._module_cleanups, []) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 53ca592ea38..34c7c721b1a 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -1424,11 +1424,12 @@ class FileTestCase(unittest.TestCase): with self.assertRaises(ValueError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") - with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"): + with self.assertRaisesRegex(TypeError, + r"NOT be a CompressionParameter"): ZstdFile(io.BytesIO(), 'rb', options={CompressionParameter.compression_level:5}) with self.assertRaisesRegex(TypeError, - r"NOT be DecompressionParameter"): + r"NOT be a DecompressionParameter"): ZstdFile(io.BytesIO(), 'wb', options={DecompressionParameter.window_log_max:21}) @@ -2430,10 +2431,8 @@ class OpenTestCase(unittest.TestCase): self.assertEqual(f.write(arr), LENGTH) self.assertEqual(f.tell(), LENGTH) -@unittest.skip("it fails for now, see gh-133885") class FreeThreadingMethodTests(unittest.TestCase): - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_compress_locking(self): @@ -2470,7 +2469,6 @@ class FreeThreadingMethodTests(unittest.TestCase): actual = b''.join(output) + rest2 self.assertEqual(expected, actual) - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_decompress_locking(self): @@ -2506,6 +2504,59 @@ class FreeThreadingMethodTests(unittest.TestCase): actual = b''.join(output) self.assertEqual(expected, actual) + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_shared_dict(self): + num_threads = 8 + + def run_method(b): + level = threading.get_ident() % 4 + # sync threads to increase chance of contention on + # capsule storing dictionary levels + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_shared_dict(self): + num_threads = 8 + + def run_method(b): + # sync threads to increase chance of contention on + # decompression dictionary + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass if __name__ == "__main__": |