aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/.ruff.toml14
-rw-r--r--Lib/test/_code_definitions.py160
-rw-r--r--Lib/test/_test_embed_structseq.py2
-rw-r--r--Lib/test/_test_gc_fast_cycles.py48
-rw-r--r--Lib/test/_test_multiprocessing.py84
-rw-r--r--Lib/test/audit-tests.py28
-rw-r--r--Lib/test/datetimetester.py35
-rw-r--r--Lib/test/libregrtest/main.py23
-rw-r--r--Lib/test/libregrtest/setup.py11
-rw-r--r--Lib/test/libregrtest/single.py2
-rw-r--r--Lib/test/libregrtest/tsan.py2
-rw-r--r--Lib/test/libregrtest/utils.py42
-rw-r--r--Lib/test/lock_tests.py43
-rw-r--r--Lib/test/mapping_tests.py4
-rw-r--r--Lib/test/mp_preload_flush.py15
-rw-r--r--Lib/test/pickletester.py38
-rw-r--r--Lib/test/pythoninfo.py24
-rwxr-xr-xLib/test/re_tests.py2
-rw-r--r--Lib/test/subprocessdata/fd_status.py4
-rw-r--r--Lib/test/support/__init__.py125
-rw-r--r--Lib/test/support/ast_helper.py3
-rw-r--r--Lib/test/support/channels.py (renamed from Lib/test/support/interpreters/channels.py)93
-rw-r--r--Lib/test/support/hashlib_helper.py221
-rw-r--r--Lib/test/support/import_helper.py108
-rw-r--r--Lib/test/support/interpreters/__init__.py258
-rw-r--r--Lib/test/support/interpreters/_crossinterp.py102
-rw-r--r--Lib/test/support/interpreters/queues.py313
-rw-r--r--Lib/test/support/strace_helper.py7
-rw-r--r--Lib/test/support/warnings_helper.py3
-rw-r--r--Lib/test/test__interpchannels.py8
-rw-r--r--Lib/test/test__interpreters.py37
-rw-r--r--Lib/test/test__osx_support.py4
-rw-r--r--Lib/test/test_abstract_numbers.py30
-rw-r--r--Lib/test/test_annotationlib.py381
-rw-r--r--Lib/test/test_argparse.py299
-rw-r--r--Lib/test/test_asdl_parser.py8
-rw-r--r--Lib/test/test_ast/data/ast_repr.txt7
-rw-r--r--Lib/test/test_ast/snippets.py11
-rw-r--r--Lib/test/test_ast/test_ast.py497
-rw-r--r--Lib/test/test_asyncgen.py9
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py37
-rw-r--r--Lib/test/test_asyncio/test_futures.py58
-rw-r--r--Lib/test/test_asyncio/test_locks.py2
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py16
-rw-r--r--Lib/test/test_asyncio/test_ssl.py10
-rw-r--r--Lib/test/test_asyncio/test_tasks.py73
-rw-r--r--Lib/test/test_asyncio/test_tools.py1706
-rw-r--r--Lib/test/test_audit.py10
-rw-r--r--Lib/test/test_base64.py10
-rw-r--r--Lib/test/test_baseexception.py8
-rw-r--r--Lib/test/test_binascii.py6
-rw-r--r--Lib/test/test_binop.py2
-rw-r--r--Lib/test/test_buffer.py4
-rw-r--r--Lib/test/test_bufio.py2
-rw-r--r--Lib/test/test_build_details.py16
-rw-r--r--Lib/test/test_builtin.py8
-rw-r--r--Lib/test/test_bytes.py10
-rw-r--r--Lib/test/test_bz2.py4
-rw-r--r--Lib/test/test_calendar.py5
-rw-r--r--Lib/test/test_call.py4
-rw-r--r--Lib/test/test_capi/test_bytearray.py6
-rw-r--r--Lib/test/test_capi/test_bytes.py8
-rw-r--r--Lib/test/test_capi/test_config.py3
-rw-r--r--Lib/test/test_capi/test_float.py26
-rw-r--r--Lib/test/test_capi/test_import.py2
-rw-r--r--Lib/test/test_capi/test_misc.py6
-rw-r--r--Lib/test/test_capi/test_object.py24
-rw-r--r--Lib/test/test_capi/test_opt.py547
-rw-r--r--Lib/test/test_capi/test_sys.py64
-rw-r--r--Lib/test/test_capi/test_type.py26
-rw-r--r--Lib/test/test_capi/test_unicode.py21
-rw-r--r--Lib/test/test_class.py1
-rw-r--r--Lib/test/test_clinic.py51
-rw-r--r--Lib/test/test_cmd.py32
-rw-r--r--Lib/test/test_cmd_line.py52
-rw-r--r--Lib/test/test_cmd_line_script.py11
-rw-r--r--Lib/test/test_code.py389
-rw-r--r--Lib/test/test_code_module.py2
-rw-r--r--Lib/test/test_codeccallbacks.py38
-rw-r--r--Lib/test/test_codecs.py105
-rw-r--r--Lib/test/test_codeop.py2
-rw-r--r--Lib/test/test_collections.py2
-rw-r--r--Lib/test/test_compileall.py2
-rw-r--r--Lib/test/test_compiler_assemble.py2
-rw-r--r--Lib/test/test_concurrent_futures/test_future.py57
-rw-r--r--Lib/test/test_concurrent_futures/test_init.py4
-rw-r--r--Lib/test/test_concurrent_futures/test_interpreter_pool.py265
-rw-r--r--Lib/test/test_concurrent_futures/test_shutdown.py58
-rw-r--r--Lib/test/test_configparser.py12
-rw-r--r--Lib/test/test_contextlib.py8
-rw-r--r--Lib/test/test_contextlib_async.py8
-rw-r--r--Lib/test/test_copy.py5
-rw-r--r--Lib/test/test_coroutines.py2
-rw-r--r--Lib/test/test_cprofile.py19
-rw-r--r--Lib/test/test_crossinterp.py1134
-rw-r--r--Lib/test/test_csv.py55
-rw-r--r--Lib/test/test_ctypes/_support.py1
-rw-r--r--Lib/test/test_ctypes/test_aligned_structures.py1
-rw-r--r--Lib/test/test_ctypes/test_bitfields.py5
-rw-r--r--Lib/test/test_ctypes/test_byteswap.py5
-rw-r--r--Lib/test/test_ctypes/test_c_simple_type_meta.py234
-rw-r--r--Lib/test/test_ctypes/test_generated_structs.py13
-rw-r--r--Lib/test/test_ctypes/test_incomplete.py27
-rw-r--r--Lib/test/test_ctypes/test_keeprefs.py7
-rw-r--r--Lib/test/test_ctypes/test_parameters.py4
-rw-r--r--Lib/test/test_ctypes/test_pep3118.py3
-rw-r--r--Lib/test/test_ctypes/test_pointers.py268
-rw-r--r--Lib/test/test_ctypes/test_structunion.py18
-rw-r--r--Lib/test/test_ctypes/test_structures.py55
-rw-r--r--Lib/test/test_ctypes/test_unaligned_structures.py2
-rw-r--r--Lib/test/test_ctypes/test_values.py3
-rw-r--r--Lib/test/test_ctypes/test_win32.py5
-rw-r--r--Lib/test/test_curses.py42
-rw-r--r--Lib/test/test_dataclasses/__init__.py78
-rw-r--r--Lib/test/test_dbm.py63
-rw-r--r--Lib/test/test_dbm_gnu.py27
-rw-r--r--Lib/test/test_dbm_sqlite3.py4
-rw-r--r--Lib/test/test_decimal.py3
-rw-r--r--Lib/test/test_deque.py2
-rw-r--r--Lib/test/test_descr.py51
-rw-r--r--Lib/test/test_dict.py86
-rw-r--r--Lib/test/test_difflib.py6
-rw-r--r--Lib/test/test_difflib_expect.html48
-rw-r--r--Lib/test/test_dis.py217
-rw-r--r--Lib/test/test_doctest/sample_doctest_errors.py46
-rw-r--r--Lib/test/test_doctest/test_doctest.py447
-rw-r--r--Lib/test/test_doctest/test_doctest_errors.txt14
-rw-r--r--Lib/test/test_doctest/test_doctest_skip.txt2
-rw-r--r--Lib/test/test_doctest/test_doctest_skip2.txt6
-rw-r--r--Lib/test/test_dynamicclassattribute.py4
-rw-r--r--Lib/test/test_email/test__header_value_parser.py78
-rw-r--r--Lib/test/test_email/test_email.py30
-rw-r--r--Lib/test/test_email/test_utils.py10
-rw-r--r--Lib/test/test_embed.py15
-rw-r--r--Lib/test/test_enum.py27
-rw-r--r--Lib/test/test_errno.py6
-rw-r--r--Lib/test/test_exception_group.py14
-rw-r--r--Lib/test/test_exceptions.py23
-rw-r--r--Lib/test/test_external_inspection.py664
-rw-r--r--Lib/test/test_faulthandler.py23
-rw-r--r--Lib/test/test_fcntl.py59
-rw-r--r--Lib/test/test_fileinput.py2
-rw-r--r--Lib/test/test_fileio.py12
-rw-r--r--Lib/test/test_float.py2
-rw-r--r--Lib/test/test_fnmatch.py98
-rw-r--r--Lib/test/test_format.py10
-rw-r--r--Lib/test/test_fractions.py216
-rw-r--r--Lib/test/test_free_threading/test_dict.py16
-rw-r--r--Lib/test/test_free_threading/test_functools.py75
-rw-r--r--Lib/test/test_free_threading/test_generators.py51
-rw-r--r--Lib/test/test_free_threading/test_heapq.py267
-rw-r--r--Lib/test/test_free_threading/test_io.py109
-rw-r--r--Lib/test/test_free_threading/test_itertools.py95
-rw-r--r--Lib/test/test_free_threading/test_itertools_batched.py38
-rw-r--r--Lib/test/test_free_threading/test_itertools_combinatoric.py51
-rw-r--r--Lib/test/test_fstring.py13
-rw-r--r--Lib/test/test_functools.py34
-rw-r--r--Lib/test/test_future_stmt/test_future.py5
-rw-r--r--Lib/test/test_gc.py66
-rw-r--r--Lib/test/test_generated_cases.py467
-rw-r--r--Lib/test/test_genericalias.py18
-rw-r--r--Lib/test/test_genericpath.py6
-rw-r--r--Lib/test/test_getpass.py39
-rw-r--r--Lib/test/test_gettext.py33
-rw-r--r--Lib/test/test_glob.py80
-rw-r--r--Lib/test/test_grammar.py41
-rw-r--r--Lib/test/test_gzip.py9
-rw-r--r--Lib/test/test_hashlib.py253
-rw-r--r--Lib/test/test_heapq.py197
-rw-r--r--Lib/test/test_hmac.py153
-rw-r--r--Lib/test/test_htmlparser.py213
-rw-r--r--Lib/test/test_http_cookiejar.py187
-rw-r--r--Lib/test/test_httpservers.py738
-rw-r--r--Lib/test/test_idle.py2
-rw-r--r--Lib/test/test_import/__init__.py6
-rw-r--r--Lib/test/test_importlib/import_/test_relative_imports.py15
-rw-r--r--Lib/test/test_importlib/test_locks.py1
-rw-r--r--Lib/test/test_importlib/test_threaded_import.py15
-rw-r--r--Lib/test/test_inspect/test_inspect.py122
-rw-r--r--Lib/test/test_int.py2
-rw-r--r--Lib/test/test_interpreters/test_api.py884
-rw-r--r--Lib/test/test_interpreters/test_channels.py56
-rw-r--r--Lib/test/test_interpreters/test_lifecycle.py4
-rw-r--r--Lib/test/test_interpreters/test_queues.py271
-rw-r--r--Lib/test/test_interpreters/test_stress.py32
-rw-r--r--Lib/test/test_interpreters/utils.py3
-rw-r--r--Lib/test/test_io.py103
-rw-r--r--Lib/test/test_ioctl.py16
-rw-r--r--Lib/test/test_ipaddress.py54
-rw-r--r--Lib/test/test_isinstance.py2
-rw-r--r--Lib/test/test_iter.py2
-rw-r--r--Lib/test/test_json/test_dump.py8
-rw-r--r--Lib/test/test_json/test_fail.py2
-rw-r--r--Lib/test/test_json/test_recursion.py3
-rw-r--r--Lib/test/test_json/test_tool.py87
-rw-r--r--Lib/test/test_launcher.py8
-rw-r--r--Lib/test/test_linecache.py37
-rw-r--r--Lib/test/test_list.py15
-rw-r--r--Lib/test/test_listcomps.py2
-rw-r--r--Lib/test/test_locale.py13
-rw-r--r--Lib/test/test_logging.py27
-rw-r--r--Lib/test/test_lzma.py4
-rw-r--r--Lib/test/test_math.py43
-rw-r--r--Lib/test/test_memoryio.py4
-rw-r--r--Lib/test/test_memoryview.py20
-rw-r--r--Lib/test/test_mimetypes.py8
-rw-r--r--Lib/test/test_minidom.py182
-rw-r--r--Lib/test/test_monitoring.py15
-rw-r--r--Lib/test/test_multibytecodec.py3
-rw-r--r--Lib/test/test_netrc.py13
-rw-r--r--Lib/test/test_ntpath.py360
-rw-r--r--Lib/test/test_opcache.py21
-rw-r--r--Lib/test/test_operator.py3
-rw-r--r--Lib/test/test_optparse.py11
-rw-r--r--Lib/test/test_ordered_dict.py8
-rw-r--r--Lib/test/test_os.py28
-rw-r--r--Lib/test/test_pathlib/support/lexical_path.py11
-rw-r--r--Lib/test/test_pathlib/support/local_path.py10
-rw-r--r--Lib/test/test_pathlib/support/zip_path.py45
-rw-r--r--Lib/test/test_pathlib/test_join_windows.py17
-rw-r--r--Lib/test/test_pathlib/test_pathlib.py36
-rw-r--r--Lib/test/test_pdb.py118
-rw-r--r--Lib/test/test_peepholer.py126
-rw-r--r--Lib/test/test_peg_generator/test_c_parser.py4
-rw-r--r--Lib/test/test_peg_generator/test_pegen.py6
-rw-r--r--Lib/test/test_perf_profiler.py9
-rw-r--r--Lib/test/test_pickle.py14
-rw-r--r--Lib/test/test_platform.py93
-rw-r--r--Lib/test/test_positional_only_arg.py12
-rw-r--r--Lib/test/test_posix.py57
-rw-r--r--Lib/test/test_posixpath.py388
-rw-r--r--Lib/test/test_pprint.py355
-rw-r--r--Lib/test/test_property.py4
-rw-r--r--Lib/test/test_pstats.py7
-rw-r--r--Lib/test/test_pty.py1
-rw-r--r--Lib/test/test_pulldom.py4
-rw-r--r--Lib/test/test_pyclbr.py2
-rw-r--r--Lib/test/test_pydoc/test_pydoc.py86
-rw-r--r--Lib/test/test_pyrepl/support.py3
-rw-r--r--Lib/test/test_pyrepl/test_eventqueue.py78
-rw-r--r--Lib/test/test_pyrepl/test_interact.py2
-rw-r--r--Lib/test/test_pyrepl/test_pyrepl.py245
-rw-r--r--Lib/test/test_pyrepl/test_reader.py211
-rw-r--r--Lib/test/test_pyrepl/test_unix_console.py13
-rw-r--r--Lib/test/test_pyrepl/test_utils.py37
-rw-r--r--Lib/test/test_pyrepl/test_windows_console.py242
-rw-r--r--Lib/test/test_queue.py20
-rw-r--r--Lib/test/test_random.py340
-rw-r--r--Lib/test/test_re.py19
-rw-r--r--Lib/test/test_readline.py8
-rw-r--r--Lib/test/test_regrtest.py65
-rw-r--r--Lib/test/test_remote_pdb.py929
-rw-r--r--Lib/test/test_repl.py6
-rw-r--r--Lib/test/test_reprlib.py97
-rw-r--r--Lib/test/test_rlcompleter.py27
-rw-r--r--Lib/test/test_runpy.py2
-rw-r--r--Lib/test/test_sax.py6
-rw-r--r--Lib/test/test_scope.py2
-rw-r--r--Lib/test/test_script_helper.py3
-rw-r--r--Lib/test/test_set.py6
-rw-r--r--Lib/test/test_shlex.py2
-rw-r--r--Lib/test/test_shutil.py12
-rw-r--r--Lib/test/test_site.py22
-rw-r--r--Lib/test/test_socket.py31
-rw-r--r--Lib/test/test_source_encoding.py3
-rw-r--r--Lib/test/test_sqlite3/test_cli.py160
-rw-r--r--Lib/test/test_sqlite3/test_dbapi.py14
-rw-r--r--Lib/test/test_sqlite3/test_factory.py15
-rw-r--r--Lib/test/test_sqlite3/test_hooks.py22
-rw-r--r--Lib/test/test_sqlite3/test_userfunctions.py55
-rw-r--r--Lib/test/test_ssl.py54
-rw-r--r--Lib/test/test_stable_abi_ctypes.py4
-rw-r--r--Lib/test/test_stat.py6
-rw-r--r--Lib/test/test_statistics.py25
-rw-r--r--Lib/test/test_str.py8
-rw-r--r--Lib/test/test_strftime.py16
-rw-r--r--Lib/test/test_string/__init__.py5
-rw-r--r--Lib/test/test_string/_support.py54
-rw-r--r--Lib/test/test_string/test_string.py (renamed from Lib/test/test_string.py)8
-rw-r--r--Lib/test/test_string/test_templatelib.py160
-rw-r--r--Lib/test/test_strptime.py37
-rw-r--r--Lib/test/test_strtod.py2
-rw-r--r--Lib/test/test_struct.py20
-rw-r--r--Lib/test/test_structseq.py4
-rw-r--r--Lib/test/test_subprocess.py22
-rw-r--r--Lib/test/test_super.py10
-rw-r--r--Lib/test/test_support.py5
-rw-r--r--Lib/test/test_syntax.py239
-rw-r--r--Lib/test/test_sys.py192
-rw-r--r--Lib/test/test_sysconfig.py20
-rw-r--r--Lib/test/test_tarfile.py405
-rw-r--r--Lib/test/test_tempfile.py16
-rw-r--r--Lib/test/test_termios.py4
-rw-r--r--Lib/test/test_threadedtempfile.py4
-rw-r--r--Lib/test/test_threading.py65
-rw-r--r--Lib/test/test_time.py12
-rw-r--r--Lib/test/test_timeit.py4
-rw-r--r--Lib/test/test_tkinter/support.py2
-rw-r--r--Lib/test/test_tkinter/test_misc.py24
-rw-r--r--Lib/test/test_tkinter/test_widgets.py5
-rw-r--r--Lib/test/test_tkinter/widget_tests.py2
-rw-r--r--Lib/test/test_tokenize.py78
-rw-r--r--Lib/test/test_tools/i18n_data/docstrings.py2
-rw-r--r--Lib/test/test_tools/test_i18n.py8
-rw-r--r--Lib/test/test_tools/test_msgfmt.py127
-rw-r--r--Lib/test/test_traceback.py165
-rw-r--r--Lib/test/test_tstring.py314
-rw-r--r--Lib/test/test_ttk/test_extensions.py6
-rw-r--r--Lib/test/test_ttk/test_widgets.py6
-rw-r--r--Lib/test/test_type_annotations.py41
-rw-r--r--Lib/test/test_type_comments.py2
-rw-r--r--Lib/test/test_types.py40
-rw-r--r--Lib/test/test_typing.py279
-rw-r--r--Lib/test/test_unittest/test_case.py8
-rw-r--r--Lib/test/test_unittest/test_result.py37
-rw-r--r--Lib/test/test_unittest/test_runner.py52
-rw-r--r--Lib/test/test_unittest/testmock/testhelpers.py21
-rw-r--r--Lib/test/test_unparse.py20
-rw-r--r--Lib/test/test_urllib.py19
-rw-r--r--Lib/test/test_urlparse.py398
-rw-r--r--Lib/test/test_userdict.py2
-rwxr-xr-xLib/test/test_uuid.py28
-rw-r--r--Lib/test/test_venv.py6
-rw-r--r--Lib/test/test_warnings/__init__.py70
-rw-r--r--Lib/test/test_wave.py26
-rw-r--r--Lib/test/test_weakref.py6
-rw-r--r--Lib/test/test_weakset.py4
-rw-r--r--Lib/test/test_webbrowser.py1
-rw-r--r--Lib/test/test_winconsoleio.py6
-rw-r--r--Lib/test/test_with.py2
-rw-r--r--Lib/test/test_wmi.py4
-rw-r--r--Lib/test/test_wsgiref.py14
-rw-r--r--Lib/test/test_xml_etree.py49
-rw-r--r--Lib/test/test_xxlimited.py2
-rw-r--r--Lib/test/test_zipapp.py6
-rw-r--r--Lib/test/test_zipfile/__main__.py2
-rw-r--r--Lib/test/test_zipfile/_path/_test_params.py2
-rw-r--r--Lib/test/test_zipfile/_path/test_complexity.py2
-rw-r--r--Lib/test/test_zipfile/_path/test_path.py22
-rw-r--r--Lib/test/test_zipfile/_path/write-alpharep.py1
-rw-r--r--Lib/test/test_zipfile/test_core.py64
-rw-r--r--Lib/test/test_zipimport.py4
-rw-r--r--Lib/test/test_zlib.py108
-rw-r--r--Lib/test/test_zoneinfo/test_zoneinfo.py5
-rw-r--r--Lib/test/test_zstd.py2794
345 files changed, 22582 insertions, 4807 deletions
diff --git a/Lib/test/.ruff.toml b/Lib/test/.ruff.toml
index fa8b2b42579..f1a967203ce 100644
--- a/Lib/test/.ruff.toml
+++ b/Lib/test/.ruff.toml
@@ -1,4 +1,5 @@
-fix = true
+extend = "../../.ruff.toml" # Inherit the project-wide settings
+
extend-exclude = [
# Excluded (run with the other AC files in its own separate ruff job in pre-commit)
"test_clinic.py",
@@ -7,6 +8,10 @@ extend-exclude = [
# Non UTF-8 files
"encoded_modules/module_iso_8859_1.py",
"encoded_modules/module_koi8_r.py",
+ # SyntaxError because of t-strings
+ "test_annotationlib.py",
+ "test_string/test_templatelib.py",
+ "test_tstring.py",
# New grammar constructions may not yet be recognized by Ruff,
# and tests re-use the same names as only the grammar is being checked.
"test_grammar.py",
@@ -14,5 +19,12 @@ extend-exclude = [
[lint]
select = [
+ "F401", # Unused import
"F811", # Redefinition of unused variable (useful for finding test methods with the same name)
]
+
+[lint.per-file-ignores]
+"*/**/__main__.py" = ["F401"] # Unused import
+"test_import/*.py" = ["F401"] # Unused import
+"test_importlib/*.py" = ["F401"] # Unused import
+"typinganndata/partialexecution/*.py" = ["F401"] # Unused import
diff --git a/Lib/test/_code_definitions.py b/Lib/test/_code_definitions.py
index 06cf6a10231..70c44da2ec6 100644
--- a/Lib/test/_code_definitions.py
+++ b/Lib/test/_code_definitions.py
@@ -1,4 +1,32 @@
+def simple_script():
+ assert True
+
+
+def complex_script():
+ obj = 'a string'
+ pickle = __import__('pickle')
+ def spam_minimal():
+ pass
+ spam_minimal()
+ data = pickle.dumps(obj)
+ res = pickle.loads(data)
+ assert res == obj, (res, obj)
+
+
+def script_with_globals():
+ obj1, obj2 = spam(42)
+ assert obj1 == 42
+ assert obj2 is None
+
+
+def script_with_explicit_empty_return():
+ return None
+
+
+def script_with_return():
+ return True
+
def spam_minimal():
# no arg defaults or kwarg defaults
@@ -12,6 +40,70 @@ def spam_minimal():
return
+def spam_with_builtins():
+ x = 42
+ values = (42,)
+ checks = tuple(callable(v) for v in values)
+ res = callable(values), tuple(values), list(values), checks
+ print(res)
+
+
+def spam_with_globals_and_builtins():
+ func1 = spam
+ func2 = spam_minimal
+ funcs = (func1, func2)
+ checks = tuple(callable(f) for f in funcs)
+ res = callable(funcs), tuple(funcs), list(funcs), checks
+ print(res)
+
+
+def spam_with_global_and_attr_same_name():
+ try:
+ spam_minimal.spam_minimal
+ except AttributeError:
+ pass
+
+
+def spam_full_args(a, b, /, c, d, *args, e, f, **kwargs):
+ return (a, b, c, d, e, f, args, kwargs)
+
+
+def spam_full_args_with_defaults(a=-1, b=-2, /, c=-3, d=-4, *args,
+ e=-5, f=-6, **kwargs):
+ return (a, b, c, d, e, f, args, kwargs)
+
+
+def spam_args_attrs_and_builtins(a, b, /, c, d, *args, e, f, **kwargs):
+ if args.__len__() > 2:
+ return None
+ return a, b, c, d, e, f, args, kwargs
+
+
+def spam_returns_arg(x):
+ return x
+
+
+def spam_raises():
+ raise Exception('spam!')
+
+
+def spam_with_inner_not_closure():
+ def eggs():
+ pass
+ eggs()
+
+
+def spam_with_inner_closure():
+ x = 42
+ def eggs():
+ print(x)
+ eggs()
+
+
+def spam_annotated(a: int, b: str, c: object) -> tuple:
+ return a, b, c
+
+
def spam_full(a, b, /, c, d:int=1, *args, e, f:object=None, **kwargs) -> tuple:
# arg defaults, kwarg defaults
# annotations
@@ -97,7 +189,23 @@ ham_C_closure, *_ = eggs_closure_C(2)
TOP_FUNCTIONS = [
# shallow
+ simple_script,
+ complex_script,
+ script_with_globals,
+ script_with_explicit_empty_return,
+ script_with_return,
spam_minimal,
+ spam_with_builtins,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+ spam_full_args,
+ spam_full_args_with_defaults,
+ spam_args_attrs_and_builtins,
+ spam_returns_arg,
+ spam_raises,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+ spam_annotated,
spam_full,
spam,
# outer func
@@ -127,6 +235,58 @@ FUNCTIONS = [
*NESTED_FUNCTIONS,
]
+STATELESS_FUNCTIONS = [
+ simple_script,
+ complex_script,
+ script_with_explicit_empty_return,
+ script_with_return,
+ spam,
+ spam_minimal,
+ spam_with_builtins,
+ spam_full_args,
+ spam_args_attrs_and_builtins,
+ spam_returns_arg,
+ spam_raises,
+ spam_annotated,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+ spam_N,
+ spam_C,
+ spam_NN,
+ spam_NC,
+ spam_CN,
+ spam_CC,
+ eggs_nested,
+ eggs_nested_N,
+ ham_nested,
+ ham_C_nested
+]
+STATELESS_CODE = [
+ *STATELESS_FUNCTIONS,
+ script_with_globals,
+ spam_full_args_with_defaults,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+ spam_full,
+]
+
+PURE_SCRIPT_FUNCTIONS = [
+ simple_script,
+ complex_script,
+ script_with_explicit_empty_return,
+ spam_minimal,
+ spam_with_builtins,
+ spam_raises,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+]
+SCRIPT_FUNCTIONS = [
+ *PURE_SCRIPT_FUNCTIONS,
+ script_with_globals,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+]
+
# generators
diff --git a/Lib/test/_test_embed_structseq.py b/Lib/test/_test_embed_structseq.py
index 154662efce9..4cac84d7a46 100644
--- a/Lib/test/_test_embed_structseq.py
+++ b/Lib/test/_test_embed_structseq.py
@@ -11,7 +11,7 @@ class TestStructSeq(unittest.TestCase):
# ob_refcnt
self.assertGreaterEqual(sys.getrefcount(obj_type), 1)
# tp_base
- self.assertTrue(issubclass(obj_type, tuple))
+ self.assertIsSubclass(obj_type, tuple)
# tp_bases
self.assertEqual(obj_type.__bases__, (tuple,))
# tp_dict
diff --git a/Lib/test/_test_gc_fast_cycles.py b/Lib/test/_test_gc_fast_cycles.py
new file mode 100644
index 00000000000..4e2c7d72a02
--- /dev/null
+++ b/Lib/test/_test_gc_fast_cycles.py
@@ -0,0 +1,48 @@
+# Run by test_gc.
+from test import support
+import _testinternalcapi
+import gc
+import unittest
+
+class IncrementalGCTests(unittest.TestCase):
+
+ # Use small increments to emulate longer running process in a shorter time
+ @support.gc_threshold(200, 10)
+ def test_incremental_gc_handles_fast_cycle_creation(self):
+
+ class LinkedList:
+
+ #Use slots to reduce number of implicit objects
+ __slots__ = "next", "prev", "surprise"
+
+ def __init__(self, next=None, prev=None):
+ self.next = next
+ if next is not None:
+ next.prev = self
+ self.prev = prev
+ if prev is not None:
+ prev.next = self
+
+ def make_ll(depth):
+ head = LinkedList()
+ for i in range(depth):
+ head = LinkedList(head, head.prev)
+ return head
+
+ head = make_ll(1000)
+
+ assert(gc.isenabled())
+ olds = []
+ initial_heap_size = _testinternalcapi.get_tracked_heap_size()
+ for i in range(20_000):
+ newhead = make_ll(20)
+ newhead.surprise = head
+ olds.append(newhead)
+ if len(olds) == 20:
+ new_objects = _testinternalcapi.get_tracked_heap_size() - initial_heap_size
+ self.assertLess(new_objects, 27_000, f"Heap growing. Reached limit after {i} iterations")
+ del olds[:]
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index 4dc9a31d22f..a1259ff1d63 100644
--- a/Lib/test/_test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -513,9 +513,14 @@ class _TestProcess(BaseTestCase):
time.sleep(100)
@classmethod
- def _sleep_no_int_handler(cls):
+ def _sleep_some_event(cls, event):
+ event.set()
+ time.sleep(100)
+
+ @classmethod
+ def _sleep_no_int_handler(cls, event):
signal.signal(signal.SIGINT, signal.SIG_DFL)
- cls._sleep_some()
+ cls._sleep_some_event(event)
@classmethod
def _test_sleep(cls, delay):
@@ -525,7 +530,10 @@ class _TestProcess(BaseTestCase):
if self.TYPE == 'threads':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
- p = self.Process(target=target or self._sleep_some)
+ event = self.Event()
+ if not target:
+ target = self._sleep_some_event
+ p = self.Process(target=target, args=(event,))
p.daemon = True
p.start()
@@ -543,8 +551,11 @@ class _TestProcess(BaseTestCase):
self.assertTimingAlmostEqual(join.elapsed, 0.0)
self.assertEqual(p.is_alive(), True)
- # XXX maybe terminating too soon causes the problems on Gentoo...
- time.sleep(1)
+ timeout = support.SHORT_TIMEOUT
+ if not event.wait(timeout):
+ p.terminate()
+ p.join()
+ self.fail(f"event not signaled in {timeout} seconds")
meth(p)
@@ -2463,6 +2474,12 @@ class _TestValue(BaseTestCase):
self.assertNotHasAttr(arr5, 'get_lock')
self.assertNotHasAttr(arr5, 'get_obj')
+ @unittest.skipIf(c_int is None, "requires _ctypes")
+ def test_invalid_typecode(self):
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.Value('x', None)
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.RawValue('x', None)
class _TestArray(BaseTestCase):
@@ -2543,6 +2560,12 @@ class _TestArray(BaseTestCase):
self.assertNotHasAttr(arr5, 'get_lock')
self.assertNotHasAttr(arr5, 'get_obj')
+ @unittest.skipIf(c_int is None, "requires _ctypes")
+ def test_invalid_typecode(self):
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.Array('x', [])
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.RawArray('x', [])
#
#
#
@@ -6778,6 +6801,35 @@ class _TestSpawnedSysPath(BaseTestCase):
self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")
+ def test_std_streams_flushed_after_preload(self):
+ # gh-135335: Check fork server flushes standard streams after
+ # preloading modules
+ if multiprocessing.get_start_method() != "forkserver":
+ self.skipTest("forkserver specific test")
+
+ # Create a test module in the temporary directory on the child's path
+ # TODO: This can all be simplified once gh-126631 is fixed and we can
+ # use __main__ instead of a module.
+ dirname = os.path.join(self._temp_dir, 'preloaded_module')
+ init_name = os.path.join(dirname, '__init__.py')
+ os.mkdir(dirname)
+ with open(init_name, "w") as f:
+ cmd = '''if 1:
+ import sys
+ print('stderr', end='', file=sys.stderr)
+ print('stdout', end='', file=sys.stdout)
+ '''
+ f.write(cmd)
+
+ name = os.path.join(os.path.dirname(__file__), 'mp_preload_flush.py')
+ env = {'PYTHONPATH': self._temp_dir}
+ _, out, err = test.support.script_helper.assert_python_ok(name, **env)
+
+ # Check stderr first, as it is more likely to be useful to see in the
+ # event of a failure.
+ self.assertEqual(err.decode().rstrip(), 'stderr')
+ self.assertEqual(out.decode().rstrip(), 'stdout')
+
class MiscTestCase(unittest.TestCase):
def test__all__(self):
@@ -6821,6 +6873,28 @@ class MiscTestCase(unittest.TestCase):
self.assertEqual("332833500", out.decode('utf-8').strip())
self.assertFalse(err, msg=err.decode('utf-8'))
+ def test_forked_thread_not_started(self):
+ # gh-134381: Ensure that a thread that has not been started yet in
+ # the parent process can be started within a forked child process.
+
+ if multiprocessing.get_start_method() != "fork":
+ self.skipTest("fork specific test")
+
+ q = multiprocessing.Queue()
+ t = threading.Thread(target=lambda: q.put("done"), daemon=True)
+
+ def child():
+ t.start()
+ t.join()
+
+ p = multiprocessing.Process(target=child)
+ p.start()
+ p.join(support.SHORT_TIMEOUT)
+
+ self.assertEqual(p.exitcode, 0)
+ self.assertEqual(q.get_nowait(), "done")
+ close_queue(q)
+
#
# Mixins
diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py
index 08b638e4b8d..6884ac0dbe6 100644
--- a/Lib/test/audit-tests.py
+++ b/Lib/test/audit-tests.py
@@ -643,6 +643,34 @@ def test_assert_unicode():
else:
raise RuntimeError("Expected sys.audit(9) to fail.")
+def test_sys_remote_exec():
+ import tempfile
+
+ pid = os.getpid()
+ event_pid = -1
+ event_script_path = ""
+ remote_event_script_path = ""
+ def hook(event, args):
+ if event not in ["sys.remote_exec", "cpython.remote_debugger_script"]:
+ return
+ print(event, args)
+ match event:
+ case "sys.remote_exec":
+ nonlocal event_pid, event_script_path
+ event_pid = args[0]
+ event_script_path = args[1]
+ case "cpython.remote_debugger_script":
+ nonlocal remote_event_script_path
+ remote_event_script_path = args[0]
+
+ sys.addaudithook(hook)
+ with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp_file:
+ tmp_file.write("a = 1+1\n")
+ tmp_file.flush()
+ sys.remote_exec(pid, tmp_file.name)
+ assertEqual(event_pid, pid)
+ assertEqual(event_script_path, tmp_file.name)
+ assertEqual(remote_event_script_path, tmp_file.name)
if __name__ == "__main__":
from test.support import suppress_msvcrt_asserts
diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index 55844ec35a9..93b3382b9c6 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -183,7 +183,7 @@ class TestTZInfo(unittest.TestCase):
def __init__(self, offset, name):
self.__offset = offset
self.__name = name
- self.assertTrue(issubclass(NotEnough, tzinfo))
+ self.assertIsSubclass(NotEnough, tzinfo)
ne = NotEnough(3, "NotByALongShot")
self.assertIsInstance(ne, tzinfo)
@@ -232,7 +232,7 @@ class TestTZInfo(unittest.TestCase):
self.assertIs(type(derived), otype)
self.assertEqual(derived.utcoffset(None), offset)
self.assertEqual(derived.tzname(None), oname)
- self.assertFalse(hasattr(derived, 'spam'))
+ self.assertNotHasAttr(derived, 'spam')
def test_issue23600(self):
DSTDIFF = DSTOFFSET = timedelta(hours=1)
@@ -773,6 +773,9 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
microseconds=999999)),
"999999999 days, 23:59:59.999999")
+ # test the Doc/library/datetime.rst recipe
+ eq(f'-({-td(hours=-1)!s})', "-(1:00:00)")
+
def test_repr(self):
name = 'datetime.' + self.theclass.__name__
self.assertEqual(repr(self.theclass(1)),
@@ -810,7 +813,7 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
# Verify td -> string -> td identity.
s = repr(td)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
td2 = eval(s)
self.assertEqual(td, td2)
@@ -1228,7 +1231,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
self.theclass.today()):
# Verify dt -> string -> date identity.
s = repr(dt)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
dt2 = eval(s)
self.assertEqual(dt, dt2)
@@ -2215,7 +2218,7 @@ class TestDateTime(TestDate):
self.theclass.now()):
# Verify dt -> string -> datetime identity.
s = repr(dt)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
dt2 = eval(s)
self.assertEqual(dt, dt2)
@@ -2969,6 +2972,17 @@ class TestDateTime(TestDate):
with self._assertNotWarns(DeprecationWarning):
self.theclass.strptime('02-29,2024', '%m-%d,%Y')
+ def test_strptime_z_empty(self):
+ for directive in ('z',):
+ string = '2025-04-25 11:42:47'
+ format = f'%Y-%m-%d %H:%M:%S%{directive}'
+ target = self.theclass(2025, 4, 25, 11, 42, 47)
+ with self.subTest(string=string,
+ format=format,
+ target=target):
+ result = self.theclass.strptime(string, format)
+ self.assertEqual(result, target)
+
def test_more_timetuple(self):
# This tests fields beyond those tested by the TestDate.test_timetuple.
t = self.theclass(2004, 12, 31, 6, 22, 33)
@@ -3568,6 +3582,10 @@ class TestDateTime(TestDate):
'2009-04-19T12:30:45.400 +02:30', # Space between ms and timezone (gh-130959)
'2009-04-19T12:30:45.400 ', # Trailing space (gh-130959)
'2009-04-19T12:30:45. 400', # Space before fraction (gh-130959)
+ '2009-04-19T12:30:45+00:90:00', # Time zone field out from range
+ '2009-04-19T12:30:45+00:00:90', # Time zone field out from range
+ '2009-04-19T12:30:45-00:90:00', # Time zone field out from range
+ '2009-04-19T12:30:45-00:00:90', # Time zone field out from range
]
for bad_str in bad_strs:
@@ -3669,7 +3687,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
# Verify t -> string -> time identity.
s = repr(t)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
t2 = eval(s)
self.assertEqual(t, t2)
@@ -4792,6 +4810,11 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
'12:30:45.400 +02:30', # Space between ms and timezone (gh-130959)
'12:30:45.400 ', # Trailing space (gh-130959)
'12:30:45. 400', # Space before fraction (gh-130959)
+ '24:00:00.000001', # Has non-zero microseconds on 24:00
+ '24:00:01.000000', # Has non-zero seconds on 24:00
+ '24:01:00.000000', # Has non-zero minutes on 24:00
+ '12:30:45+00:90:00', # Time zone field out from range
+ '12:30:45+00:00:90', # Time zone field out from range
]
for bad_str in bad_strs:
diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py
index 713cbedb299..a2d01b157ac 100644
--- a/Lib/test/libregrtest/main.py
+++ b/Lib/test/libregrtest/main.py
@@ -190,6 +190,12 @@ class Regrtest:
strip_py_suffix(tests)
+ exclude_tests = set()
+ if self.exclude:
+ for arg in self.cmdline_args:
+ exclude_tests.add(arg)
+ self.cmdline_args = []
+
if self.pgo:
# add default PGO tests if no tests are specified
setup_pgo_tests(self.cmdline_args, self.pgo_extended)
@@ -200,17 +206,15 @@ class Regrtest:
if self.tsan_parallel:
setup_tsan_parallel_tests(self.cmdline_args)
- exclude_tests = set()
- if self.exclude:
- for arg in self.cmdline_args:
- exclude_tests.add(arg)
- self.cmdline_args = []
-
alltests = findtests(testdir=self.test_dir,
exclude=exclude_tests)
if not self.fromfile:
selected = tests or self.cmdline_args
+ if exclude_tests:
+ # Support "--pgo/--tsan -x test_xxx" command
+ selected = [name for name in selected
+ if name not in exclude_tests]
if selected:
selected = split_test_packages(selected)
else:
@@ -543,8 +547,6 @@ class Regrtest:
self.first_runtests = runtests
self.logger.set_tests(runtests)
- setup_process()
-
if (runtests.hunt_refleak is not None) and (not self.num_workers):
# gh-109739: WindowsLoadTracker thread interferes with refleak check
use_load_tracker = False
@@ -721,10 +723,7 @@ class Regrtest:
self._execute_python(cmd, environ)
def _init(self):
- # Set sys.stdout encoder error handler to backslashreplace,
- # similar to sys.stderr error handler, to avoid UnicodeEncodeError
- # when printing a traceback or any other non-encodable character.
- sys.stdout.reconfigure(errors="backslashreplace")
+ setup_process()
if self.junit_filename and not os.path.isabs(self.junit_filename):
self.junit_filename = os.path.abspath(self.junit_filename)
diff --git a/Lib/test/libregrtest/setup.py b/Lib/test/libregrtest/setup.py
index c0346aa934d..9bfc414cd61 100644
--- a/Lib/test/libregrtest/setup.py
+++ b/Lib/test/libregrtest/setup.py
@@ -1,5 +1,6 @@
import faulthandler
import gc
+import io
import os
import random
import signal
@@ -40,7 +41,7 @@ def setup_process() -> None:
faulthandler.enable(all_threads=True, file=stderr_fd)
# Display the Python traceback on SIGALRM or SIGUSR1 signal
- signals = []
+ signals: list[signal.Signals] = []
if hasattr(signal, 'SIGALRM'):
signals.append(signal.SIGALRM)
if hasattr(signal, 'SIGUSR1'):
@@ -52,6 +53,14 @@ def setup_process() -> None:
support.record_original_stdout(sys.stdout)
+ # Set sys.stdout encoder error handler to backslashreplace,
+ # similar to sys.stderr error handler, to avoid UnicodeEncodeError
+ # when printing a traceback or any other non-encodable character.
+ #
+ # Use an assertion to fix mypy error.
+ assert isinstance(sys.stdout, io.TextIOWrapper)
+ sys.stdout.reconfigure(errors="backslashreplace")
+
# Some times __path__ and __file__ are not absolute (e.g. while running from
# Lib/) and, if we change the CWD to run the tests in a temporary dir, some
# imports might fail. This affects only the modules imported before os.chdir().
diff --git a/Lib/test/libregrtest/single.py b/Lib/test/libregrtest/single.py
index 57d7b649d2e..958a915626a 100644
--- a/Lib/test/libregrtest/single.py
+++ b/Lib/test/libregrtest/single.py
@@ -283,7 +283,7 @@ def _runtest(result: TestResult, runtests: RunTests) -> None:
try:
setup_tests(runtests)
- if output_on_failure:
+ if output_on_failure or runtests.pgo:
support.verbose = True
stream = io.StringIO()
diff --git a/Lib/test/libregrtest/tsan.py b/Lib/test/libregrtest/tsan.py
index d984a735bdf..3545c5f999f 100644
--- a/Lib/test/libregrtest/tsan.py
+++ b/Lib/test/libregrtest/tsan.py
@@ -8,7 +8,7 @@ TSAN_TESTS = [
'test_capi.test_pyatomic',
'test_code',
'test_ctypes',
- # 'test_concurrent_futures', # gh-130605: too many data races
+ 'test_concurrent_futures',
'test_enum',
'test_functools',
'test_httpservers',
diff --git a/Lib/test/libregrtest/utils.py b/Lib/test/libregrtest/utils.py
index c4a1506c9a7..72b8ea89e62 100644
--- a/Lib/test/libregrtest/utils.py
+++ b/Lib/test/libregrtest/utils.py
@@ -31,7 +31,7 @@ WORKER_WORK_DIR_PREFIX = WORK_DIR_PREFIX + 'worker_'
EXIT_TIMEOUT = 120.0
-ALL_RESOURCES = ('audio', 'curses', 'largefile', 'network',
+ALL_RESOURCES = ('audio', 'console', 'curses', 'largefile', 'network',
'decimal', 'cpu', 'subprocess', 'urlfetch', 'gui', 'walltime')
# Other resources excluded from --use=all:
@@ -335,43 +335,11 @@ def get_build_info():
build.append('with_assert')
# --enable-experimental-jit
- tier2 = re.search('-D_Py_TIER2=([0-9]+)', cflags)
- if tier2:
- tier2 = int(tier2.group(1))
-
- if not sys.flags.ignore_environment:
- PYTHON_JIT = os.environ.get('PYTHON_JIT', None)
- if PYTHON_JIT:
- PYTHON_JIT = (PYTHON_JIT != '0')
- else:
- PYTHON_JIT = None
-
- if tier2 == 1: # =yes
- if PYTHON_JIT == False:
- jit = 'JIT=off'
- else:
- jit = 'JIT'
- elif tier2 == 3: # =yes-off
- if PYTHON_JIT:
- jit = 'JIT'
+ if sys._jit.is_available():
+ if sys._jit.is_enabled():
+ build.append("JIT")
else:
- jit = 'JIT=off'
- elif tier2 == 4: # =interpreter
- if PYTHON_JIT == False:
- jit = 'JIT-interpreter=off'
- else:
- jit = 'JIT-interpreter'
- elif tier2 == 6: # =interpreter-off (Secret option!)
- if PYTHON_JIT:
- jit = 'JIT-interpreter'
- else:
- jit = 'JIT-interpreter=off'
- elif '-D_Py_JIT' in cflags:
- jit = 'JIT'
- else:
- jit = None
- if jit:
- build.append(jit)
+ build.append("JIT (disabled)")
# --enable-framework=name
framework = sysconfig.get_config_var('PYTHONFRAMEWORK')
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index 009e04e9c0b..691029a1a54 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -124,6 +124,11 @@ class BaseLockTests(BaseTestCase):
lock = self.locktype()
del lock
+ def test_constructor_noargs(self):
+ self.assertRaises(TypeError, self.locktype, 1)
+ self.assertRaises(TypeError, self.locktype, x=1)
+ self.assertRaises(TypeError, self.locktype, 1, x=2)
+
def test_repr(self):
lock = self.locktype()
self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
@@ -332,6 +337,26 @@ class RLockTests(BaseLockTests):
"""
Tests for recursive locks.
"""
+ def test_repr_count(self):
+ # see gh-134322: check that count values are correct:
+ # when a rlock is just created,
+ # in a second thread when rlock is acquired in the main thread.
+ lock = self.locktype()
+ self.assertIn("count=0", repr(lock))
+ self.assertIn("<unlocked", repr(lock))
+ lock.acquire()
+ lock.acquire()
+ self.assertIn("count=2", repr(lock))
+ self.assertIn("<locked", repr(lock))
+
+ result = []
+ def call_repr():
+ result.append(repr(lock))
+ with Bunch(call_repr, 1):
+ pass
+ self.assertIn("count=2", result[0])
+ self.assertIn("<locked", result[0])
+
def test_reacquire(self):
lock = self.locktype()
lock.acquire()
@@ -365,6 +390,24 @@ class RLockTests(BaseLockTests):
lock.release()
self.assertFalse(lock.locked())
+ def test_locked_with_2threads(self):
+ # see gh-134323: check that a rlock which
+ # is acquired in a different thread,
+ # is still locked in the main thread.
+ result = []
+ rlock = self.locktype()
+ self.assertFalse(rlock.locked())
+ def acquire():
+ result.append(rlock.locked())
+ rlock.acquire()
+ result.append(rlock.locked())
+
+ with Bunch(acquire, 1):
+ pass
+ self.assertTrue(rlock.locked())
+ self.assertFalse(result[0])
+ self.assertTrue(result[1])
+
def test_release_save_unacquired(self):
# Cannot _release_save an unacquired lock
lock = self.locktype()
diff --git a/Lib/test/mapping_tests.py b/Lib/test/mapping_tests.py
index 9d38da5a86e..20306e1526d 100644
--- a/Lib/test/mapping_tests.py
+++ b/Lib/test/mapping_tests.py
@@ -70,8 +70,8 @@ class BasicTestMappingProtocol(unittest.TestCase):
if not d: self.fail("Full mapping must compare to True")
# keys(), items(), iterkeys() ...
def check_iterandlist(iter, lst, ref):
- self.assertTrue(hasattr(iter, '__next__'))
- self.assertTrue(hasattr(iter, '__iter__'))
+ self.assertHasAttr(iter, '__next__')
+ self.assertHasAttr(iter, '__iter__')
x = list(iter)
self.assertTrue(set(x)==set(lst)==set(ref))
check_iterandlist(iter(d.keys()), list(d.keys()),
diff --git a/Lib/test/mp_preload_flush.py b/Lib/test/mp_preload_flush.py
new file mode 100644
index 00000000000..3501554d366
--- /dev/null
+++ b/Lib/test/mp_preload_flush.py
@@ -0,0 +1,15 @@
+import multiprocessing
+import sys
+
+modname = 'preloaded_module'
+if __name__ == '__main__':
+ if modname in sys.modules:
+ raise AssertionError(f'{modname!r} is not in sys.modules')
+ multiprocessing.set_start_method('forkserver')
+ multiprocessing.set_forkserver_preload([modname])
+ for _ in range(2):
+ p = multiprocessing.Process()
+ p.start()
+ p.join()
+elif modname not in sys.modules:
+ raise AssertionError(f'{modname!r} is not in sys.modules')
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index bdc7ef62943..9a3a26a8400 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1100,6 +1100,11 @@ class AbstractUnpickleTests:
self.check_unpickling_error((pickle.UnpicklingError, OverflowError),
dumped)
+ def test_large_binstring(self):
+ errmsg = 'BINSTRING pickle has negative byte count'
+ with self.assertRaisesRegex(pickle.UnpicklingError, errmsg):
+ self.loads(b'T\0\0\0\x80')
+
def test_get(self):
pickled = b'((lp100000\ng100000\nt.'
unpickled = self.loads(pickled)
@@ -2272,7 +2277,11 @@ class AbstractPicklingErrorTests:
def test_nested_lookup_error(self):
# Nested name does not exist
- obj = REX('AbstractPickleTests.spam')
+ global TestGlobal
+ class TestGlobal:
+ class A:
+ pass
+ obj = REX('TestGlobal.A.B.C')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
@@ -2280,9 +2289,9 @@ class AbstractPicklingErrorTests:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
f"Can't pickle {obj!r}: "
- f"it's not found as {__name__}.AbstractPickleTests.spam")
+ f"it's not found as {__name__}.TestGlobal.A.B.C")
self.assertEqual(str(cm.exception.__context__),
- "type object 'AbstractPickleTests' has no attribute 'spam'")
+ "type object 'A' has no attribute 'B'")
obj.__module__ = None
for proto in protocols:
@@ -2290,21 +2299,25 @@ class AbstractPicklingErrorTests:
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests.spam")
+ f"Can't pickle {obj!r}: "
+ f"it's not found as __main__.TestGlobal.A.B.C")
self.assertEqual(str(cm.exception.__context__),
- "module '__main__' has no attribute 'AbstractPickleTests'")
+ "module '__main__' has no attribute 'TestGlobal'")
def test_wrong_object_lookup_error(self):
# Name is bound to different object
- obj = REX('AbstractPickleTests')
+ global TestGlobal
+ class TestGlobal:
+ pass
+ obj = REX('TestGlobal')
obj.__module__ = __name__
- AbstractPickleTests.ham = []
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not the same object as {__name__}.AbstractPickleTests")
+ f"Can't pickle {obj!r}: "
+ f"it's not the same object as {__name__}.TestGlobal")
self.assertIsNone(cm.exception.__context__)
obj.__module__ = None
@@ -2313,9 +2326,10 @@ class AbstractPicklingErrorTests:
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests")
+ f"Can't pickle {obj!r}: "
+ f"it's not found as __main__.TestGlobal")
self.assertEqual(str(cm.exception.__context__),
- "module '__main__' has no attribute 'AbstractPickleTests'")
+ "module '__main__' has no attribute 'TestGlobal'")
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
@@ -3059,7 +3073,7 @@ class AbstractPickleTests:
pickled = self.dumps(None, proto)
if proto >= 2:
proto_header = pickle.PROTO + bytes([proto])
- self.assertTrue(pickled.startswith(proto_header))
+ self.assertStartsWith(pickled, proto_header)
else:
self.assertEqual(count_opcode(pickle.PROTO, pickled), 0)
@@ -4998,7 +5012,7 @@ class AbstractDispatchTableTests:
p = self.pickler_class(f, 0)
with self.assertRaises(AttributeError):
p.dispatch_table
- self.assertFalse(hasattr(p, 'dispatch_table'))
+ self.assertNotHasAttr(p, 'dispatch_table')
def test_class_dispatch_table(self):
# A dispatch_table attribute can be specified class-wide
diff --git a/Lib/test/pythoninfo.py b/Lib/test/pythoninfo.py
index 682815c3fdd..80a262c18a5 100644
--- a/Lib/test/pythoninfo.py
+++ b/Lib/test/pythoninfo.py
@@ -658,6 +658,16 @@ def collect_zlib(info_add):
copy_attributes(info_add, zlib, 'zlib.%s', attributes)
+def collect_zstd(info_add):
+ try:
+ import _zstd
+ except ImportError:
+ return
+
+ attributes = ('zstd_version',)
+ copy_attributes(info_add, _zstd, 'zstd.%s', attributes)
+
+
def collect_expat(info_add):
try:
from xml.parsers import expat
@@ -910,10 +920,17 @@ def collect_windows(info_add):
try:
import _winapi
- dll_path = _winapi.GetModuleFileName(sys.dllhandle)
- info_add('windows.dll_path', dll_path)
- except (ImportError, AttributeError):
+ except ImportError:
pass
+ else:
+ try:
+ dll_path = _winapi.GetModuleFileName(sys.dllhandle)
+ info_add('windows.dll_path', dll_path)
+ except AttributeError:
+ pass
+
+ call_func(info_add, 'windows.ansi_code_page', _winapi, 'GetACP')
+ call_func(info_add, 'windows.oem_code_page', _winapi, 'GetOEMCP')
# windows.version_caption: "wmic os get Caption,Version /value" command
import subprocess
@@ -1051,6 +1068,7 @@ def collect_info(info):
collect_tkinter,
collect_windows,
collect_zlib,
+ collect_zstd,
collect_libregrtest_utils,
# Collecting from tests should be last as they have side effects.
diff --git a/Lib/test/re_tests.py b/Lib/test/re_tests.py
index 85b026736ca..e50f5d52bbd 100755
--- a/Lib/test/re_tests.py
+++ b/Lib/test/re_tests.py
@@ -531,7 +531,7 @@ xyzabc
(r'a[ ]*?\ (\d+).*', 'a 10', SUCCEED, 'found', 'a 10'),
(r'a[ ]*?\ (\d+).*', 'a 10', SUCCEED, 'found', 'a 10'),
# bug 127259: \Z shouldn't depend on multiline mode
- (r'(?ms).*?x\s*\Z(.*)','xx\nx\n', SUCCEED, 'g1', ''),
+ (r'(?ms).*?x\s*\z(.*)','xx\nx\n', SUCCEED, 'g1', ''),
# bug 128899: uppercase literals under the ignorecase flag
(r'(?i)M+', 'MMM', SUCCEED, 'found', 'MMM'),
(r'(?i)m+', 'MMM', SUCCEED, 'found', 'MMM'),
diff --git a/Lib/test/subprocessdata/fd_status.py b/Lib/test/subprocessdata/fd_status.py
index d12bd95abee..90e785981ae 100644
--- a/Lib/test/subprocessdata/fd_status.py
+++ b/Lib/test/subprocessdata/fd_status.py
@@ -2,7 +2,7 @@
file descriptors on stdout.
Usage:
-fd_stats.py: check all file descriptors
+fd_status.py: check all file descriptors (up to 255)
fd_status.py fd1 fd2 ...: check only specified file descriptors
"""
@@ -18,7 +18,7 @@ if __name__ == "__main__":
_MAXFD = os.sysconf("SC_OPEN_MAX")
except:
_MAXFD = 256
- test_fds = range(0, _MAXFD)
+ test_fds = range(0, min(_MAXFD, 256))
else:
test_fds = map(int, sys.argv[1:])
for fd in test_fds:
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 82f88109498..fd39d3f7c95 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -33,7 +33,7 @@ __all__ = [
"is_resource_enabled", "requires", "requires_freebsd_version",
"requires_gil_enabled", "requires_linux_version", "requires_mac_ver",
"check_syntax_error",
- "requires_gzip", "requires_bz2", "requires_lzma",
+ "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd",
"bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
"requires_IEEE_754", "requires_zlib",
"has_fork_support", "requires_fork",
@@ -46,6 +46,7 @@ __all__ = [
# sys
"MS_WINDOWS", "is_jython", "is_android", "is_emscripten", "is_wasi",
"is_apple_mobile", "check_impl_detail", "unix_shell", "setswitchinterval",
+ "support_remote_exec_only",
# os
"get_pagesize",
# network
@@ -527,6 +528,13 @@ def requires_lzma(reason='requires lzma'):
lzma = None
return unittest.skipUnless(lzma, reason)
+def requires_zstd(reason='requires zstd'):
+ try:
+ from compression import zstd
+ except ImportError:
+ zstd = None
+ return unittest.skipUnless(zstd, reason)
+
def has_no_debug_ranges():
try:
import _testcapi
@@ -689,9 +697,11 @@ def sortdict(dict):
return "{%s}" % withcommas
-def run_code(code: str) -> dict[str, object]:
+def run_code(code: str, extra_names: dict[str, object] | None = None) -> dict[str, object]:
"""Run a piece of code after dedenting it, and return its global namespace."""
ns = {}
+ if extra_names:
+ ns.update(extra_names)
exec(textwrap.dedent(code), ns)
return ns
@@ -936,6 +946,31 @@ def check_sizeof(test, o, size):
% (type(o), result, size)
test.assertEqual(result, size, msg)
+def subTests(arg_names, arg_values, /, *, _do_cleanups=False):
+ """Run multiple subtests with different parameters.
+ """
+ single_param = False
+ if isinstance(arg_names, str):
+ arg_names = arg_names.replace(',',' ').split()
+ if len(arg_names) == 1:
+ single_param = True
+ arg_values = tuple(arg_values)
+ def decorator(func):
+ if isinstance(func, type):
+ raise TypeError('subTests() can only decorate methods, not classes')
+ @functools.wraps(func)
+ def wrapper(self, /, *args, **kwargs):
+ for values in arg_values:
+ if single_param:
+ values = (values,)
+ subtest_kwargs = dict(zip(arg_names, values))
+ with self.subTest(**subtest_kwargs):
+ func(self, *args, **kwargs, **subtest_kwargs)
+ if _do_cleanups:
+ self.doCleanups()
+ return wrapper
+ return decorator
+
#=======================================================================
# Decorator/context manager for running a code in a different locale,
# correctly resetting it afterwards.
@@ -1075,7 +1110,7 @@ def set_memlimit(limit: str) -> None:
global real_max_memuse
memlimit = _parse_memlimit(limit)
if memlimit < _2G - 1:
- raise ValueError('Memory limit {limit!r} too low to be useful')
+ raise ValueError(f'Memory limit {limit!r} too low to be useful')
real_max_memuse = memlimit
memlimit = min(memlimit, MAX_Py_ssize_t)
@@ -1092,7 +1127,6 @@ class _MemoryWatchdog:
self.started = False
def start(self):
- import warnings
try:
f = open(self.procfile, 'r')
except OSError as e:
@@ -1332,8 +1366,8 @@ MISSING_C_DOCSTRINGS = (check_impl_detail() and
sys.platform != 'win32' and
not sysconfig.get_config_var('WITH_DOC_STRINGS'))
-HAVE_DOCSTRINGS = (_check_docstrings.__doc__ is not None and
- not MISSING_C_DOCSTRINGS)
+HAVE_PY_DOCSTRINGS = _check_docstrings.__doc__ is not None
+HAVE_DOCSTRINGS = (HAVE_PY_DOCSTRINGS and not MISSING_C_DOCSTRINGS)
requires_docstrings = unittest.skipUnless(HAVE_DOCSTRINGS,
"test requires docstrings")
@@ -2299,6 +2333,7 @@ def check_disallow_instantiation(testcase, tp, *args, **kwds):
qualname = f"{name}"
msg = f"cannot create '{re.escape(qualname)}' instances"
testcase.assertRaisesRegex(TypeError, msg, tp, *args, **kwds)
+ testcase.assertRaisesRegex(TypeError, msg, tp.__new__, tp, *args, **kwds)
def get_recursion_depth():
"""Get the recursion depth of the caller function.
@@ -2350,7 +2385,7 @@ def infinite_recursion(max_depth=None):
# very deep recursion.
max_depth = 20_000
elif max_depth < 3:
- raise ValueError("max_depth must be at least 3, got {max_depth}")
+ raise ValueError(f"max_depth must be at least 3, got {max_depth}")
depth = get_recursion_depth()
depth = max(depth - 1, 1) # Ignore infinite_recursion() frame.
limit = depth + max_depth
@@ -2648,13 +2683,9 @@ skip_on_s390x = unittest.skipIf(is_s390x, 'skipped on s390x')
Py_TRACE_REFS = hasattr(sys, 'getobjects')
-try:
- from _testinternalcapi import jit_enabled
-except ImportError:
- requires_jit_enabled = requires_jit_disabled = unittest.skip("requires _testinternalcapi")
-else:
- requires_jit_enabled = unittest.skipUnless(jit_enabled(), "requires JIT enabled")
- requires_jit_disabled = unittest.skipIf(jit_enabled(), "requires JIT disabled")
+_JIT_ENABLED = sys._jit.is_enabled()
+requires_jit_enabled = unittest.skipUnless(_JIT_ENABLED, "requires JIT enabled")
+requires_jit_disabled = unittest.skipIf(_JIT_ENABLED, "requires JIT disabled")
_BASE_COPY_SRC_DIR_IGNORED_NAMES = frozenset({
@@ -2723,7 +2754,7 @@ def iter_builtin_types():
# Fall back to making a best-effort guess.
if hasattr(object, '__flags__'):
# Look for any type object with the Py_TPFLAGS_STATIC_BUILTIN flag set.
- import datetime
+ import datetime # noqa: F401
seen = set()
for cls, subs in walk_class_hierarchy(object):
if cls in seen:
@@ -2855,36 +2886,59 @@ def iter_slot_wrappers(cls):
@contextlib.contextmanager
-def no_color():
+def force_color(color: bool):
import _colorize
from .os_helper import EnvironmentVarGuard
with (
- swap_attr(_colorize, "can_colorize", lambda file=None: False),
+ swap_attr(_colorize, "can_colorize", lambda file=None: color),
EnvironmentVarGuard() as env,
):
env.unset("FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS")
- env.set("NO_COLOR", "1")
+ env.set("FORCE_COLOR" if color else "NO_COLOR", "1")
yield
+def force_colorized(func):
+ """Force the terminal to be colorized."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ with force_color(True):
+ return func(*args, **kwargs)
+ return wrapper
+
+
def force_not_colorized(func):
- """Force the terminal not to be colorized."""
+ """Force the terminal NOT to be colorized."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
- with no_color():
+ with force_color(False):
return func(*args, **kwargs)
return wrapper
+def force_colorized_test_class(cls):
+ """Force the terminal to be colorized for the entire test class."""
+ original_setUpClass = cls.setUpClass
+
+ @classmethod
+ @functools.wraps(cls.setUpClass)
+ def new_setUpClass(cls):
+ cls.enterClassContext(force_color(True))
+ original_setUpClass()
+
+ cls.setUpClass = new_setUpClass
+ return cls
+
+
def force_not_colorized_test_class(cls):
- """Force the terminal not to be colorized for the entire test class."""
+ """Force the terminal NOT to be colorized for the entire test class."""
original_setUpClass = cls.setUpClass
@classmethod
@functools.wraps(cls.setUpClass)
def new_setUpClass(cls):
- cls.enterClassContext(no_color())
+ cls.enterClassContext(force_color(False))
original_setUpClass()
cls.setUpClass = new_setUpClass
@@ -2901,12 +2955,6 @@ def make_clean_env() -> dict[str, str]:
return clean_env
-def initialized_with_pyrepl():
- """Detect whether PyREPL was used during Python initialization."""
- # If the main module has a __file__ attribute it's a Python module, which means PyREPL.
- return hasattr(sys.modules["__main__"], "__file__")
-
-
WINDOWS_STATUS = {
0xC0000005: "STATUS_ACCESS_VIOLATION",
0xC00000FD: "STATUS_STACK_OVERFLOW",
@@ -3023,6 +3071,27 @@ def is_libssl_fips_mode():
return False # more of a maybe, unless we add this to the _ssl module.
return get_fips_mode() != 0
+def _supports_remote_attaching():
+ PROCESS_VM_READV_SUPPORTED = False
+
+ try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ except ImportError:
+ pass
+
+ return PROCESS_VM_READV_SUPPORTED
+
+def _support_remote_exec_only_impl():
+ if not sys.is_remote_debug_enabled():
+ return unittest.skip("Remote debugging is not enabled")
+ if sys.platform not in ("darwin", "linux", "win32"):
+ return unittest.skip("Test only runs on Linux, Windows and macOS")
+ if sys.platform == "linux" and not _supports_remote_attaching():
+ return unittest.skip("Test only runs on Linux with process_vm_readv support")
+ return _id
+
+def support_remote_exec_only(test):
+ return _support_remote_exec_only_impl()(test)
class EqualToForwardRef:
"""Helper to ease use of annotationlib.ForwardRef in tests.
diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py
index 8a0415b6aae..173d299afee 100644
--- a/Lib/test/support/ast_helper.py
+++ b/Lib/test/support/ast_helper.py
@@ -16,6 +16,9 @@ class ASTTestMixin:
self.fail(f"{type(a)!r} is not {type(b)!r}")
if isinstance(a, ast.AST):
for field in a._fields:
+ if isinstance(a, ast.Constant) and field == "kind":
+ # Skip the 'kind' field for ast.Constant
+ continue
value1 = getattr(a, field, missing)
value2 = getattr(b, field, missing)
# Singletons are equal by definition, so further
diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/channels.py
index d2bd93d77f7..b2de24d9d3e 100644
--- a/Lib/test/support/interpreters/channels.py
+++ b/Lib/test/support/channels.py
@@ -2,14 +2,14 @@
import time
import _interpchannels as _channels
-from . import _crossinterp
+from concurrent.interpreters import _crossinterp
# aliases:
from _interpchannels import (
- ChannelError, ChannelNotFoundError, ChannelClosedError,
- ChannelEmptyError, ChannelNotEmptyError,
+ ChannelError, ChannelNotFoundError, ChannelClosedError, # noqa: F401
+ ChannelEmptyError, ChannelNotEmptyError, # noqa: F401
)
-from ._crossinterp import (
+from concurrent.interpreters._crossinterp import (
UNBOUND_ERROR, UNBOUND_REMOVE,
)
@@ -55,15 +55,23 @@ def create(*, unbounditems=UNBOUND):
"""
unbound = _serialize_unbound(unbounditems)
unboundop, = unbound
- cid = _channels.create(unboundop)
- recv, send = RecvChannel(cid), SendChannel(cid, _unbound=unbound)
+ cid = _channels.create(unboundop, -1)
+ recv, send = RecvChannel(cid), SendChannel(cid)
+ send._set_unbound(unboundop, unbounditems)
return recv, send
def list_all():
"""Return a list of (recv, send) for all open channels."""
- return [(RecvChannel(cid), SendChannel(cid, _unbound=unbound))
- for cid, unbound in _channels.list_all()]
+ channels = []
+ for cid, unboundop, _ in _channels.list_all():
+ chan = _, send = RecvChannel(cid), SendChannel(cid)
+ if not hasattr(send, '_unboundop'):
+ send._set_unbound(unboundop)
+ else:
+ assert send._unbound[0] == unboundop
+ channels.append(chan)
+ return channels
class _ChannelEnd:
@@ -175,16 +183,33 @@ class SendChannel(_ChannelEnd):
_end = 'send'
- def __new__(cls, cid, *, _unbound=None):
- if _unbound is None:
- try:
- op = _channels.get_channel_defaults(cid)
- _unbound = (op,)
- except ChannelNotFoundError:
- _unbound = _serialize_unbound(UNBOUND)
- self = super().__new__(cls, cid)
- self._unbound = _unbound
- return self
+# def __new__(cls, cid, *, _unbound=None):
+# if _unbound is None:
+# try:
+# op = _channels.get_channel_defaults(cid)
+# _unbound = (op,)
+# except ChannelNotFoundError:
+# _unbound = _serialize_unbound(UNBOUND)
+# self = super().__new__(cls, cid)
+# self._unbound = _unbound
+# return self
+
+ def _set_unbound(self, op, items=None):
+ assert not hasattr(self, '_unbound')
+ if items is None:
+ items = _resolve_unbound(op)
+ unbound = (op, items)
+ self._unbound = unbound
+ return unbound
+
+ @property
+ def unbounditems(self):
+ try:
+ _, items = self._unbound
+ except AttributeError:
+ op, _ = _channels.get_queue_defaults(self._id)
+ _, items = self._set_unbound(op)
+ return items
@property
def is_closed(self):
@@ -192,61 +217,61 @@ class SendChannel(_ChannelEnd):
return info.closed or info.closing
def send(self, obj, timeout=None, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object (i.e. its data) to the channel's receiving end.
This blocks until the object is received.
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
_channels.send(self._id, obj, unboundop, timeout=timeout, blocking=True)
def send_nowait(self, obj, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object to the channel's receiving end.
If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
# XXX Note that at the moment channel_send() only ever returns
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj, unboundop, blocking=False)
def send_buffer(self, obj, timeout=None, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object's buffer to the channel's receiving end.
This blocks until the object is received.
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
_channels.send_buffer(self._id, obj, unboundop,
timeout=timeout, blocking=True)
def send_buffer_nowait(self, obj, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object's buffer to the channel's receiving end.
If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
return _channels.send_buffer(self._id, obj, unboundop, blocking=False)
def close(self):
diff --git a/Lib/test/support/hashlib_helper.py b/Lib/test/support/hashlib_helper.py
index 5043f08dd93..7032257b068 100644
--- a/Lib/test/support/hashlib_helper.py
+++ b/Lib/test/support/hashlib_helper.py
@@ -23,6 +23,22 @@ def requires_builtin_hmac():
return unittest.skipIf(_hmac is None, "requires _hmac")
+def _missing_hash(digestname, implementation=None, *, exc=None):
+ parts = ["missing", implementation, f"hash algorithm: {digestname!r}"]
+ msg = " ".join(filter(None, parts))
+ raise unittest.SkipTest(msg) from exc
+
+
+def _openssl_availabillity(digestname, *, usedforsecurity):
+ try:
+ _hashlib.new(digestname, usedforsecurity=usedforsecurity)
+ except AttributeError:
+ assert _hashlib is None
+ _missing_hash(digestname, "OpenSSL")
+ except ValueError as exc:
+ _missing_hash(digestname, "OpenSSL", exc=exc)
+
+
def _decorate_func_or_class(func_or_class, decorator_func):
if not isinstance(func_or_class, type):
return decorator_func(func_or_class)
@@ -71,8 +87,7 @@ def requires_hashdigest(digestname, openssl=None, usedforsecurity=True):
try:
test_availability()
except ValueError as exc:
- msg = f"missing hash algorithm: {digestname!r}"
- raise unittest.SkipTest(msg) from exc
+ _missing_hash(digestname, exc=exc)
return func(*args, **kwargs)
return wrapper
@@ -87,14 +102,44 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True):
The hashing algorithm may be missing or blocked by a strict crypto policy.
"""
def decorator_func(func):
- @requires_hashlib()
+ @requires_hashlib() # avoid checking at each call
@functools.wraps(func)
def wrapper(*args, **kwargs):
+ _openssl_availabillity(digestname, usedforsecurity=usedforsecurity)
+ return func(*args, **kwargs)
+ return wrapper
+
+ def decorator(func_or_class):
+ return _decorate_func_or_class(func_or_class, decorator_func)
+ return decorator
+
+
+def find_openssl_hashdigest_constructor(digestname, *, usedforsecurity=True):
+ """Find the OpenSSL hash function constructor by its name."""
+ assert isinstance(digestname, str), digestname
+ _openssl_availabillity(digestname, usedforsecurity=usedforsecurity)
+ # This returns a function of the form _hashlib.openssl_<name> and
+ # not a lambda function as it is rejected by _hashlib.hmac_new().
+ return getattr(_hashlib, f"openssl_{digestname}")
+
+
+def requires_builtin_hashdigest(
+ module_name, digestname, *, usedforsecurity=True
+):
+ """Decorator raising SkipTest if a HACL* hashing algorithm is missing.
+
+ - The *module_name* is the C extension module name based on HACL*.
+ - The *digestname* is one of its member, e.g., 'md5'.
+ """
+ def decorator_func(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ module = import_module(module_name)
try:
- _hashlib.new(digestname, usedforsecurity=usedforsecurity)
- except ValueError:
- msg = f"missing OpenSSL hash algorithm: {digestname!r}"
- raise unittest.SkipTest(msg)
+ getattr(module, digestname)
+ except AttributeError:
+ fullname = f'{module_name}.{digestname}'
+ _missing_hash(fullname, implementation="HACL")
return func(*args, **kwargs)
return wrapper
@@ -103,6 +148,168 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True):
return decorator
+def find_builtin_hashdigest_constructor(
+ module_name, digestname, *, usedforsecurity=True
+):
+ """Find the HACL* hash function constructor.
+
+ - The *module_name* is the C extension module name based on HACL*.
+ - The *digestname* is one of its member, e.g., 'md5'.
+ """
+ module = import_module(module_name)
+ try:
+ constructor = getattr(module, digestname)
+ constructor(b'', usedforsecurity=usedforsecurity)
+ except (AttributeError, TypeError, ValueError):
+ _missing_hash(f'{module_name}.{digestname}', implementation="HACL")
+ return constructor
+
+
+class HashFunctionsTrait:
+ """Mixin trait class containing hash functions.
+
+ This class is assumed to have all unitest.TestCase methods but should
+ not directly inherit from it to prevent the test suite being run on it.
+
+ Subclasses should implement the hash functions by returning an object
+ that can be recognized as a valid digestmod parameter for both hashlib
+ and HMAC. In particular, it cannot be a lambda function as it will not
+ be recognized by hashlib (it will still be accepted by the pure Python
+ implementation of HMAC).
+ """
+
+ ALGORITHMS = [
+ 'md5', 'sha1',
+ 'sha224', 'sha256', 'sha384', 'sha512',
+ 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512',
+ ]
+
+ # Default 'usedforsecurity' to use when looking up a hash function.
+ usedforsecurity = True
+
+ def _find_constructor(self, name):
+ # By default, a missing algorithm skips the test that uses it.
+ self.assertIn(name, self.ALGORITHMS)
+ self.skipTest(f"missing hash function: {name}")
+
+ @property
+ def md5(self):
+ return self._find_constructor("md5")
+
+ @property
+ def sha1(self):
+ return self._find_constructor("sha1")
+
+ @property
+ def sha224(self):
+ return self._find_constructor("sha224")
+
+ @property
+ def sha256(self):
+ return self._find_constructor("sha256")
+
+ @property
+ def sha384(self):
+ return self._find_constructor("sha384")
+
+ @property
+ def sha512(self):
+ return self._find_constructor("sha512")
+
+ @property
+ def sha3_224(self):
+ return self._find_constructor("sha3_224")
+
+ @property
+ def sha3_256(self):
+ return self._find_constructor("sha3_256")
+
+ @property
+ def sha3_384(self):
+ return self._find_constructor("sha3_384")
+
+ @property
+ def sha3_512(self):
+ return self._find_constructor("sha3_512")
+
+
+class NamedHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing named hash functions.
+
+ Hash functions are available if and only if they are available in hashlib.
+ """
+
+ def _find_constructor(self, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return name
+
+
+class OpenSSLHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing OpenSSL hash functions.
+
+ Hash functions are available if and only if they are available in _hashlib.
+ """
+
+ def _find_constructor(self, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return find_openssl_hashdigest_constructor(
+ name, usedforsecurity=self.usedforsecurity
+ )
+
+
+class BuiltinHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing HACL* hash functions.
+
+ Hash functions are available if and only if they are available in C.
+ In particular, HACL* HMAC-MD5 may be available even though HACL* md5
+ is not since the former is unconditionally built.
+ """
+
+ def _find_constructor_in(self, module, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return find_builtin_hashdigest_constructor(module, name)
+
+ @property
+ def md5(self):
+ return self._find_constructor_in("_md5", "md5")
+
+ @property
+ def sha1(self):
+ return self._find_constructor_in("_sha1", "sha1")
+
+ @property
+ def sha224(self):
+ return self._find_constructor_in("_sha2", "sha224")
+
+ @property
+ def sha256(self):
+ return self._find_constructor_in("_sha2", "sha256")
+
+ @property
+ def sha384(self):
+ return self._find_constructor_in("_sha2", "sha384")
+
+ @property
+ def sha512(self):
+ return self._find_constructor_in("_sha2", "sha512")
+
+ @property
+ def sha3_224(self):
+ return self._find_constructor_in("_sha3", "sha3_224")
+
+ @property
+ def sha3_256(self):
+ return self._find_constructor_in("_sha3","sha3_256")
+
+ @property
+ def sha3_384(self):
+ return self._find_constructor_in("_sha3","sha3_384")
+
+ @property
+ def sha3_512(self):
+ return self._find_constructor_in("_sha3","sha3_512")
+
+
def find_gil_minsize(modules_names, default=2048):
"""Get the largest GIL_MINSIZE value for the given cryptographic modules.
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 42cfe9cfa8c..0af63501f93 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -1,6 +1,7 @@
import contextlib
import _imp
import importlib
+import importlib.machinery
import importlib.util
import os
import shutil
@@ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block):
)
from .script_helper import assert_python_ok
assert_python_ok("-S", "-c", script)
+
+
+@contextlib.contextmanager
+def module_restored(name):
+ """A context manager that restores a module to the original state."""
+ missing = object()
+ orig = sys.modules.get(name, missing)
+ if orig is None:
+ mod = importlib.import_module(name)
+ else:
+ mod = type(sys)(name)
+ mod.__dict__.update(orig.__dict__)
+ sys.modules[name] = mod
+ try:
+ yield mod
+ finally:
+ if orig is missing:
+ sys.modules.pop(name, None)
+ else:
+ sys.modules[name] = orig
+
+
+def create_module(name, loader=None, *, ispkg=False):
+ """Return a new, empty module."""
+ spec = importlib.machinery.ModuleSpec(
+ name,
+ loader,
+ origin='<import_helper>',
+ is_package=ispkg,
+ )
+ return importlib.util.module_from_spec(spec)
+
+
+def _ensure_module(name, ispkg, addparent, clearnone):
+ try:
+ mod = orig = sys.modules[name]
+ except KeyError:
+ mod = orig = None
+ missing = True
+ else:
+ missing = False
+ if mod is not None:
+ # It was already imported.
+ return mod, orig, missing
+ # Otherwise, None means it was explicitly disabled.
+
+ assert name != '__main__'
+ if not missing:
+ assert orig is None, (name, sys.modules[name])
+ if not clearnone:
+ raise ModuleNotFoundError(name)
+ del sys.modules[name]
+ # Try normal import, then fall back to adding the module.
+ try:
+ mod = importlib.import_module(name)
+ except ModuleNotFoundError:
+ if addparent and not clearnone:
+ addparent = None
+ mod = _add_module(name, ispkg, addparent)
+ return mod, orig, missing
+
+
+def _add_module(spec, ispkg, addparent):
+ if isinstance(spec, str):
+ name = spec
+ mod = create_module(name, ispkg=ispkg)
+ spec = mod.__spec__
+ else:
+ name = spec.name
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules[name] = mod
+ if addparent is not False and spec.parent:
+ _ensure_module(spec.parent, True, addparent, bool(addparent))
+ return mod
+
+
+def add_module(spec, *, parents=True):
+ """Return the module after creating it and adding it to sys.modules.
+
+ If parents is True then also create any missing parents.
+ """
+ return _add_module(spec, False, parents)
+
+
+def add_package(spec, *, parents=True):
+ """Return the module after creating it and adding it to sys.modules.
+
+ If parents is True then also create any missing parents.
+ """
+ return _add_module(spec, True, parents)
+
+
+def ensure_module_imported(name, *, clearnone=True):
+ """Return the corresponding module.
+
+ If it was already imported then return that. Otherwise, try
+ importing it (optionally clear it first if None). If that fails
+ then create a new empty module.
+
+ It can be helpful to combine this with ready_to_import() and/or
+ isolated_modules().
+ """
+ if sys.modules.get(name) is not None:
+ mod = sys.modules[name]
+ else:
+ mod, _, _ = _ensure_module(name, False, True, clearnone)
+ return mod
diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/test/support/interpreters/__init__.py
deleted file mode 100644
index e067f259364..00000000000
--- a/Lib/test/support/interpreters/__init__.py
+++ /dev/null
@@ -1,258 +0,0 @@
-"""Subinterpreters High Level Module."""
-
-import threading
-import weakref
-import _interpreters
-
-# aliases:
-from _interpreters import (
- InterpreterError, InterpreterNotFoundError, NotShareableError,
- is_shareable,
-)
-
-
-__all__ = [
- 'get_current', 'get_main', 'create', 'list_all', 'is_shareable',
- 'Interpreter',
- 'InterpreterError', 'InterpreterNotFoundError', 'ExecutionFailed',
- 'NotShareableError',
- 'create_queue', 'Queue', 'QueueEmpty', 'QueueFull',
-]
-
-
-_queuemod = None
-
-def __getattr__(name):
- if name in ('Queue', 'QueueEmpty', 'QueueFull', 'create_queue'):
- global create_queue, Queue, QueueEmpty, QueueFull
- ns = globals()
- from .queues import (
- create as create_queue,
- Queue, QueueEmpty, QueueFull,
- )
- return ns[name]
- else:
- raise AttributeError(name)
-
-
-_EXEC_FAILURE_STR = """
-{superstr}
-
-Uncaught in the interpreter:
-
-{formatted}
-""".strip()
-
-class ExecutionFailed(InterpreterError):
- """An unhandled exception happened during execution.
-
- This is raised from Interpreter.exec() and Interpreter.call().
- """
-
- def __init__(self, excinfo):
- msg = excinfo.formatted
- if not msg:
- if excinfo.type and excinfo.msg:
- msg = f'{excinfo.type.__name__}: {excinfo.msg}'
- else:
- msg = excinfo.type.__name__ or excinfo.msg
- super().__init__(msg)
- self.excinfo = excinfo
-
- def __str__(self):
- try:
- formatted = self.excinfo.errdisplay
- except Exception:
- return super().__str__()
- else:
- return _EXEC_FAILURE_STR.format(
- superstr=super().__str__(),
- formatted=formatted,
- )
-
-
-def create():
- """Return a new (idle) Python interpreter."""
- id = _interpreters.create(reqrefs=True)
- return Interpreter(id, _ownsref=True)
-
-
-def list_all():
- """Return all existing interpreters."""
- return [Interpreter(id, _whence=whence)
- for id, whence in _interpreters.list_all(require_ready=True)]
-
-
-def get_current():
- """Return the currently running interpreter."""
- id, whence = _interpreters.get_current()
- return Interpreter(id, _whence=whence)
-
-
-def get_main():
- """Return the main interpreter."""
- id, whence = _interpreters.get_main()
- assert whence == _interpreters.WHENCE_RUNTIME, repr(whence)
- return Interpreter(id, _whence=whence)
-
-
-_known = weakref.WeakValueDictionary()
-
-class Interpreter:
- """A single Python interpreter.
-
- Attributes:
-
- "id" - the unique process-global ID number for the interpreter
- "whence" - indicates where the interpreter was created
-
- If the interpreter wasn't created by this module
- then any method that modifies the interpreter will fail,
- i.e. .close(), .prepare_main(), .exec(), and .call()
- """
-
- _WHENCE_TO_STR = {
- _interpreters.WHENCE_UNKNOWN: 'unknown',
- _interpreters.WHENCE_RUNTIME: 'runtime init',
- _interpreters.WHENCE_LEGACY_CAPI: 'legacy C-API',
- _interpreters.WHENCE_CAPI: 'C-API',
- _interpreters.WHENCE_XI: 'cross-interpreter C-API',
- _interpreters.WHENCE_STDLIB: '_interpreters module',
- }
-
- def __new__(cls, id, /, _whence=None, _ownsref=None):
- # There is only one instance for any given ID.
- if not isinstance(id, int):
- raise TypeError(f'id must be an int, got {id!r}')
- id = int(id)
- if _whence is None:
- if _ownsref:
- _whence = _interpreters.WHENCE_STDLIB
- else:
- _whence = _interpreters.whence(id)
- assert _whence in cls._WHENCE_TO_STR, repr(_whence)
- if _ownsref is None:
- _ownsref = (_whence == _interpreters.WHENCE_STDLIB)
- try:
- self = _known[id]
- assert hasattr(self, '_ownsref')
- except KeyError:
- self = super().__new__(cls)
- _known[id] = self
- self._id = id
- self._whence = _whence
- self._ownsref = _ownsref
- if _ownsref:
- # This may raise InterpreterNotFoundError:
- _interpreters.incref(id)
- return self
-
- def __repr__(self):
- return f'{type(self).__name__}({self.id})'
-
- def __hash__(self):
- return hash(self._id)
-
- def __del__(self):
- self._decref()
-
- # for pickling:
- def __getnewargs__(self):
- return (self._id,)
-
- # for pickling:
- def __getstate__(self):
- return None
-
- def _decref(self):
- if not self._ownsref:
- return
- self._ownsref = False
- try:
- _interpreters.decref(self._id)
- except InterpreterNotFoundError:
- pass
-
- @property
- def id(self):
- return self._id
-
- @property
- def whence(self):
- return self._WHENCE_TO_STR[self._whence]
-
- def is_running(self):
- """Return whether or not the identified interpreter is running."""
- return _interpreters.is_running(self._id)
-
- # Everything past here is available only to interpreters created by
- # interpreters.create().
-
- def close(self):
- """Finalize and destroy the interpreter.
-
- Attempting to destroy the current interpreter results
- in an InterpreterError.
- """
- return _interpreters.destroy(self._id, restrict=True)
-
- def prepare_main(self, ns=None, /, **kwargs):
- """Bind the given values into the interpreter's __main__.
-
- The values must be shareable.
- """
- ns = dict(ns, **kwargs) if ns is not None else kwargs
- _interpreters.set___main___attrs(self._id, ns, restrict=True)
-
- def exec(self, code, /):
- """Run the given source code in the interpreter.
-
- This is essentially the same as calling the builtin "exec"
- with this interpreter, using the __dict__ of its __main__
- module as both globals and locals.
-
- There is no return value.
-
- If the code raises an unhandled exception then an ExecutionFailed
- exception is raised, which summarizes the unhandled exception.
- The actual exception is discarded because objects cannot be
- shared between interpreters.
-
- This blocks the current Python thread until done. During
- that time, the previous interpreter is allowed to run
- in other threads.
- """
- excinfo = _interpreters.exec(self._id, code, restrict=True)
- if excinfo is not None:
- raise ExecutionFailed(excinfo)
-
- def call(self, callable, /):
- """Call the object in the interpreter with given args/kwargs.
-
- Only functions that take no arguments and have no closure
- are supported.
-
- The return value is discarded.
-
- If the callable raises an exception then the error display
- (including full traceback) is send back between the interpreters
- and an ExecutionFailed exception is raised, much like what
- happens with Interpreter.exec().
- """
- # XXX Support args and kwargs.
- # XXX Support arbitrary callables.
- # XXX Support returning the return value (e.g. via pickle).
- excinfo = _interpreters.call(self._id, callable, restrict=True)
- if excinfo is not None:
- raise ExecutionFailed(excinfo)
-
- def call_in_thread(self, callable, /):
- """Return a new thread that calls the object in the interpreter.
-
- The return value and any raised exception are discarded.
- """
- def task():
- self.call(callable)
- t = threading.Thread(target=task)
- t.start()
- return t
diff --git a/Lib/test/support/interpreters/_crossinterp.py b/Lib/test/support/interpreters/_crossinterp.py
deleted file mode 100644
index 544e197ba4c..00000000000
--- a/Lib/test/support/interpreters/_crossinterp.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""Common code between queues and channels."""
-
-
-class ItemInterpreterDestroyed(Exception):
- """Raised when trying to get an item whose interpreter was destroyed."""
-
-
-class classonly:
- """A non-data descriptor that makes a value only visible on the class.
-
- This is like the "classmethod" builtin, but does not show up on
- instances of the class. It may be used as a decorator.
- """
-
- def __init__(self, value):
- self.value = value
- self.getter = classmethod(value).__get__
- self.name = None
-
- def __set_name__(self, cls, name):
- if self.name is not None:
- raise TypeError('already used')
- self.name = name
-
- def __get__(self, obj, cls):
- if obj is not None:
- raise AttributeError(self.name)
- # called on the class
- return self.getter(None, cls)
-
-
-class UnboundItem:
- """Represents a cross-interpreter item no longer bound to an interpreter.
-
- An item is unbound when the interpreter that added it to the
- cross-interpreter container is destroyed.
- """
-
- __slots__ = ()
-
- @classonly
- def singleton(cls, kind, module, name='UNBOUND'):
- doc = cls.__doc__.replace('cross-interpreter container', kind)
- doc = doc.replace('cross-interpreter', kind)
- subclass = type(
- f'Unbound{kind.capitalize()}Item',
- (cls,),
- dict(
- _MODULE=module,
- _NAME=name,
- __doc__=doc,
- ),
- )
- return object.__new__(subclass)
-
- _MODULE = __name__
- _NAME = 'UNBOUND'
-
- def __new__(cls):
- raise Exception(f'use {cls._MODULE}.{cls._NAME}')
-
- def __repr__(self):
- return f'{self._MODULE}.{self._NAME}'
-# return f'interpreters.queues.UNBOUND'
-
-
-UNBOUND = object.__new__(UnboundItem)
-UNBOUND_ERROR = object()
-UNBOUND_REMOVE = object()
-
-_UNBOUND_CONSTANT_TO_FLAG = {
- UNBOUND_REMOVE: 1,
- UNBOUND_ERROR: 2,
- UNBOUND: 3,
-}
-_UNBOUND_FLAG_TO_CONSTANT = {v: k
- for k, v in _UNBOUND_CONSTANT_TO_FLAG.items()}
-
-
-def serialize_unbound(unbound):
- op = unbound
- try:
- flag = _UNBOUND_CONSTANT_TO_FLAG[op]
- except KeyError:
- raise NotImplementedError(f'unsupported unbound replacement op {op!r}')
- return flag,
-
-
-def resolve_unbound(flag, exctype_destroyed):
- try:
- op = _UNBOUND_FLAG_TO_CONSTANT[flag]
- except KeyError:
- raise NotImplementedError(f'unsupported unbound replacement op {flag!r}')
- if op is UNBOUND_REMOVE:
- # "remove" not possible here
- raise NotImplementedError
- elif op is UNBOUND_ERROR:
- raise exctype_destroyed("item's original interpreter destroyed")
- elif op is UNBOUND:
- return UNBOUND
- else:
- raise NotImplementedError(repr(op))
diff --git a/Lib/test/support/interpreters/queues.py b/Lib/test/support/interpreters/queues.py
deleted file mode 100644
index deb8e8613af..00000000000
--- a/Lib/test/support/interpreters/queues.py
+++ /dev/null
@@ -1,313 +0,0 @@
-"""Cross-interpreter Queues High Level Module."""
-
-import pickle
-import queue
-import time
-import weakref
-import _interpqueues as _queues
-from . import _crossinterp
-
-# aliases:
-from _interpqueues import (
- QueueError, QueueNotFoundError,
-)
-from ._crossinterp import (
- UNBOUND_ERROR, UNBOUND_REMOVE,
-)
-
-__all__ = [
- 'UNBOUND', 'UNBOUND_ERROR', 'UNBOUND_REMOVE',
- 'create', 'list_all',
- 'Queue',
- 'QueueError', 'QueueNotFoundError', 'QueueEmpty', 'QueueFull',
- 'ItemInterpreterDestroyed',
-]
-
-
-class QueueEmpty(QueueError, queue.Empty):
- """Raised from get_nowait() when the queue is empty.
-
- It is also raised from get() if it times out.
- """
-
-
-class QueueFull(QueueError, queue.Full):
- """Raised from put_nowait() when the queue is full.
-
- It is also raised from put() if it times out.
- """
-
-
-class ItemInterpreterDestroyed(QueueError,
- _crossinterp.ItemInterpreterDestroyed):
- """Raised from get() and get_nowait()."""
-
-
-_SHARED_ONLY = 0
-_PICKLED = 1
-
-
-UNBOUND = _crossinterp.UnboundItem.singleton('queue', __name__)
-
-
-def _serialize_unbound(unbound):
- if unbound is UNBOUND:
- unbound = _crossinterp.UNBOUND
- return _crossinterp.serialize_unbound(unbound)
-
-
-def _resolve_unbound(flag):
- resolved = _crossinterp.resolve_unbound(flag, ItemInterpreterDestroyed)
- if resolved is _crossinterp.UNBOUND:
- resolved = UNBOUND
- return resolved
-
-
-def create(maxsize=0, *, syncobj=False, unbounditems=UNBOUND):
- """Return a new cross-interpreter queue.
-
- The queue may be used to pass data safely between interpreters.
-
- "syncobj" sets the default for Queue.put()
- and Queue.put_nowait().
-
- "unbounditems" likewise sets the default. See Queue.put() for
- supported values. The default value is UNBOUND, which replaces
- the unbound item.
- """
- fmt = _SHARED_ONLY if syncobj else _PICKLED
- unbound = _serialize_unbound(unbounditems)
- unboundop, = unbound
- qid = _queues.create(maxsize, fmt, unboundop)
- return Queue(qid, _fmt=fmt, _unbound=unbound)
-
-
-def list_all():
- """Return a list of all open queues."""
- return [Queue(qid, _fmt=fmt, _unbound=(unboundop,))
- for qid, fmt, unboundop in _queues.list_all()]
-
-
-_known_queues = weakref.WeakValueDictionary()
-
-class Queue:
- """A cross-interpreter queue."""
-
- def __new__(cls, id, /, *, _fmt=None, _unbound=None):
- # There is only one instance for any given ID.
- if isinstance(id, int):
- id = int(id)
- else:
- raise TypeError(f'id must be an int, got {id!r}')
- if _fmt is None:
- if _unbound is None:
- _fmt, op = _queues.get_queue_defaults(id)
- _unbound = (op,)
- else:
- _fmt, _ = _queues.get_queue_defaults(id)
- elif _unbound is None:
- _, op = _queues.get_queue_defaults(id)
- _unbound = (op,)
- try:
- self = _known_queues[id]
- except KeyError:
- self = super().__new__(cls)
- self._id = id
- self._fmt = _fmt
- self._unbound = _unbound
- _known_queues[id] = self
- _queues.bind(id)
- return self
-
- def __del__(self):
- try:
- _queues.release(self._id)
- except QueueNotFoundError:
- pass
- try:
- del _known_queues[self._id]
- except KeyError:
- pass
-
- def __repr__(self):
- return f'{type(self).__name__}({self.id})'
-
- def __hash__(self):
- return hash(self._id)
-
- # for pickling:
- def __getnewargs__(self):
- return (self._id,)
-
- # for pickling:
- def __getstate__(self):
- return None
-
- @property
- def id(self):
- return self._id
-
- @property
- def maxsize(self):
- try:
- return self._maxsize
- except AttributeError:
- self._maxsize = _queues.get_maxsize(self._id)
- return self._maxsize
-
- def empty(self):
- return self.qsize() == 0
-
- def full(self):
- return _queues.is_full(self._id)
-
- def qsize(self):
- return _queues.get_count(self._id)
-
- def put(self, obj, timeout=None, *,
- syncobj=None,
- unbound=None,
- _delay=10 / 1000, # 10 milliseconds
- ):
- """Add the object to the queue.
-
- This blocks while the queue is full.
-
- If "syncobj" is None (the default) then it uses the
- queue's default, set with create_queue().
-
- If "syncobj" is false then all objects are supported,
- at the expense of worse performance.
-
- If "syncobj" is true then the object must be "shareable".
- Examples of "shareable" objects include the builtin singletons,
- str, and memoryview. One benefit is that such objects are
- passed through the queue efficiently.
-
- The key difference, though, is conceptual: the corresponding
- object returned from Queue.get() will be strictly equivalent
- to the given obj. In other words, the two objects will be
- effectively indistinguishable from each other, even if the
- object is mutable. The received object may actually be the
- same object, or a copy (immutable values only), or a proxy.
- Regardless, the received object should be treated as though
- the original has been shared directly, whether or not it
- actually is. That's a slightly different and stronger promise
- than just (initial) equality, which is all "syncobj=False"
- can promise.
-
- "unbound" controls the behavior of Queue.get() for the given
- object if the current interpreter (calling put()) is later
- destroyed.
-
- If "unbound" is None (the default) then it uses the
- queue's default, set with create_queue(),
- which is usually UNBOUND.
-
- If "unbound" is UNBOUND_ERROR then get() will raise an
- ItemInterpreterDestroyed exception if the original interpreter
- has been destroyed. This does not otherwise affect the queue;
- the next call to put() will work like normal, returning the next
- item in the queue.
-
- If "unbound" is UNBOUND_REMOVE then the item will be removed
- from the queue as soon as the original interpreter is destroyed.
- Be aware that this will introduce an imbalance between put()
- and get() calls.
-
- If "unbound" is UNBOUND then it is returned by get() in place
- of the unbound item.
- """
- if syncobj is None:
- fmt = self._fmt
- else:
- fmt = _SHARED_ONLY if syncobj else _PICKLED
- if unbound is None:
- unboundop, = self._unbound
- else:
- unboundop, = _serialize_unbound(unbound)
- if timeout is not None:
- timeout = int(timeout)
- if timeout < 0:
- raise ValueError(f'timeout value must be non-negative')
- end = time.time() + timeout
- if fmt is _PICKLED:
- obj = pickle.dumps(obj)
- while True:
- try:
- _queues.put(self._id, obj, fmt, unboundop)
- except QueueFull as exc:
- if timeout is not None and time.time() >= end:
- raise # re-raise
- time.sleep(_delay)
- else:
- break
-
- def put_nowait(self, obj, *, syncobj=None, unbound=None):
- if syncobj is None:
- fmt = self._fmt
- else:
- fmt = _SHARED_ONLY if syncobj else _PICKLED
- if unbound is None:
- unboundop, = self._unbound
- else:
- unboundop, = _serialize_unbound(unbound)
- if fmt is _PICKLED:
- obj = pickle.dumps(obj)
- _queues.put(self._id, obj, fmt, unboundop)
-
- def get(self, timeout=None, *,
- _delay=10 / 1000, # 10 milliseconds
- ):
- """Return the next object from the queue.
-
- This blocks while the queue is empty.
-
- If the next item's original interpreter has been destroyed
- then the "next object" is determined by the value of the
- "unbound" argument to put().
- """
- if timeout is not None:
- timeout = int(timeout)
- if timeout < 0:
- raise ValueError(f'timeout value must be non-negative')
- end = time.time() + timeout
- while True:
- try:
- obj, fmt, unboundop = _queues.get(self._id)
- except QueueEmpty as exc:
- if timeout is not None and time.time() >= end:
- raise # re-raise
- time.sleep(_delay)
- else:
- break
- if unboundop is not None:
- assert obj is None, repr(obj)
- return _resolve_unbound(unboundop)
- if fmt == _PICKLED:
- obj = pickle.loads(obj)
- else:
- assert fmt == _SHARED_ONLY
- return obj
-
- def get_nowait(self):
- """Return the next object from the channel.
-
- If the queue is empty then raise QueueEmpty. Otherwise this
- is the same as get().
- """
- try:
- obj, fmt, unboundop = _queues.get(self._id)
- except QueueEmpty as exc:
- raise # re-raise
- if unboundop is not None:
- assert obj is None, repr(obj)
- return _resolve_unbound(unboundop)
- if fmt == _PICKLED:
- obj = pickle.loads(obj)
- else:
- assert fmt == _SHARED_ONLY
- return obj
-
-
-_queues._register_heap_types(Queue, QueueEmpty, QueueFull)
diff --git a/Lib/test/support/strace_helper.py b/Lib/test/support/strace_helper.py
index 798d6c68869..cf95f7bdc7d 100644
--- a/Lib/test/support/strace_helper.py
+++ b/Lib/test/support/strace_helper.py
@@ -38,7 +38,7 @@ class StraceResult:
This assumes the program under inspection doesn't print any non-utf8
strings which would mix into the strace output."""
- decoded_events = self.event_bytes.decode('utf-8')
+ decoded_events = self.event_bytes.decode('utf-8', 'surrogateescape')
matches = [
_syscall_regex.match(event)
for event in decoded_events.splitlines()
@@ -178,7 +178,10 @@ def get_syscalls(code, strace_flags, prelude="", cleanup="",
# Moderately expensive (spawns a subprocess), so share results when possible.
@cache
def _can_strace():
- res = strace_python("import sys; sys.exit(0)", [], check=False)
+ res = strace_python("import sys; sys.exit(0)",
+ # --trace option needs strace 5.5 (gh-133741)
+ ["--trace=%process"],
+ check=False)
if res.strace_returncode == 0 and res.python_returncode == 0:
assert res.events(), "Should have parsed multiple calls"
return True
diff --git a/Lib/test/support/warnings_helper.py b/Lib/test/support/warnings_helper.py
index a6e43dff200..5f6f14afd74 100644
--- a/Lib/test/support/warnings_helper.py
+++ b/Lib/test/support/warnings_helper.py
@@ -23,8 +23,7 @@ def check_syntax_warning(testcase, statement, errtext='',
testcase.assertEqual(len(warns), 1, warns)
warn, = warns
- testcase.assertTrue(issubclass(warn.category, SyntaxWarning),
- warn.category)
+ testcase.assertIsSubclass(warn.category, SyntaxWarning)
if errtext:
testcase.assertRegex(str(warn.message), errtext)
testcase.assertEqual(warn.filename, '<testcase>')
diff --git a/Lib/test/test__interpchannels.py b/Lib/test/test__interpchannels.py
index e4c1ad85451..858d31a73cf 100644
--- a/Lib/test/test__interpchannels.py
+++ b/Lib/test/test__interpchannels.py
@@ -9,7 +9,7 @@ import unittest
from test.support import import_helper, skip_if_sanitizer
_channels = import_helper.import_module('_interpchannels')
-from test.support.interpreters import _crossinterp
+from concurrent.interpreters import _crossinterp
from test.test__interpreters import (
_interpreters,
_run_output,
@@ -247,7 +247,7 @@ def _run_action(cid, action, end, state):
def clean_up_channels():
- for cid, _ in _channels.list_all():
+ for cid, _, _ in _channels.list_all():
try:
_channels.destroy(cid)
except _channels.ChannelNotFoundError:
@@ -373,11 +373,11 @@ class ChannelTests(TestBase):
self.assertIsInstance(cid, _channels.ChannelID)
def test_sequential_ids(self):
- before = [cid for cid, _ in _channels.list_all()]
+ before = [cid for cid, _, _ in _channels.list_all()]
id1 = _channels.create(REPLACE)
id2 = _channels.create(REPLACE)
id3 = _channels.create(REPLACE)
- after = [cid for cid, _ in _channels.list_all()]
+ after = [cid for cid, _, _ in _channels.list_all()]
self.assertEqual(id2, int(id1) + 1)
self.assertEqual(id3, int(id2) + 1)
diff --git a/Lib/test/test__interpreters.py b/Lib/test/test__interpreters.py
index 0c43f46300f..a32d5d81d2b 100644
--- a/Lib/test/test__interpreters.py
+++ b/Lib/test/test__interpreters.py
@@ -474,15 +474,32 @@ class CommonTests(TestBase):
def test_signatures(self):
# See https://github.com/python/cpython/issues/126654
- msg = "expected 'shared' to be a dict"
+ msg = r'_interpreters.exec\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.exec(self.id, 'a', 1)
with self.assertRaisesRegex(TypeError, msg):
_interpreters.exec(self.id, 'a', shared=1)
+ msg = r'_interpreters.run_string\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.run_string(self.id, 'a', shared=1)
+ msg = r'_interpreters.run_func\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.run_func(self.id, lambda: None, shared=1)
+ # See https://github.com/python/cpython/issues/135855
+ msg = r'_interpreters.set___main___attrs\(\) argument 2 must be dict, not int'
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.set___main___attrs(self.id, 1)
+
+ def test_invalid_shared_none(self):
+ msg = r'must be dict, not None'
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.exec(self.id, 'a', shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.run_string(self.id, 'a', shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.run_func(self.id, lambda: None, shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.set___main___attrs(self.id, None)
def test_invalid_shared_encoding(self):
# See https://github.com/python/cpython/issues/127196
@@ -952,7 +969,8 @@ class RunFailedTests(TestBase):
""")
with self.subTest('script'):
- self.assert_run_failed(SyntaxError, script)
+ with self.assertRaises(SyntaxError):
+ _interpreters.run_string(self.id, script)
with self.subTest('module'):
modname = 'spam_spam_spam'
@@ -1019,12 +1037,19 @@ class RunFuncTests(TestBase):
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
+ failed = None
def f():
- _interpreters.set___main___attrs(self.id, dict(w=w))
- _interpreters.run_func(self.id, script)
+ nonlocal failed
+ try:
+ _interpreters.set___main___attrs(self.id, dict(w=w))
+ _interpreters.run_func(self.id, script)
+ except Exception as exc:
+ failed = exc
t = threading.Thread(target=f)
t.start()
t.join()
+ if failed:
+ raise Exception from failed
with open(r, encoding="utf-8") as outfile:
out = outfile.read()
@@ -1053,18 +1078,16 @@ class RunFuncTests(TestBase):
spam = True
def script():
assert spam
-
with self.assertRaises(ValueError):
_interpreters.run_func(self.id, script)
- # XXX This hasn't been fixed yet.
- @unittest.expectedFailure
def test_return_value(self):
def script():
return 'spam'
with self.assertRaises(ValueError):
_interpreters.run_func(self.id, script)
+# @unittest.skip("we're not quite there yet")
def test_args(self):
with self.subTest('args'):
def script(a, b=0):
diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py
index 53aa26620a6..0813c4804c1 100644
--- a/Lib/test/test__osx_support.py
+++ b/Lib/test/test__osx_support.py
@@ -66,8 +66,8 @@ class Test_OSXSupport(unittest.TestCase):
'cc not found - check xcode-select')
def test__get_system_version(self):
- self.assertTrue(platform.mac_ver()[0].startswith(
- _osx_support._get_system_version()))
+ self.assertStartsWith(platform.mac_ver()[0],
+ _osx_support._get_system_version())
def test__remove_original_values(self):
config_vars = {
diff --git a/Lib/test/test_abstract_numbers.py b/Lib/test/test_abstract_numbers.py
index 72232b670cd..cf071d2c933 100644
--- a/Lib/test/test_abstract_numbers.py
+++ b/Lib/test/test_abstract_numbers.py
@@ -24,11 +24,11 @@ def concretize(cls):
class TestNumbers(unittest.TestCase):
def test_int(self):
- self.assertTrue(issubclass(int, Integral))
- self.assertTrue(issubclass(int, Rational))
- self.assertTrue(issubclass(int, Real))
- self.assertTrue(issubclass(int, Complex))
- self.assertTrue(issubclass(int, Number))
+ self.assertIsSubclass(int, Integral)
+ self.assertIsSubclass(int, Rational)
+ self.assertIsSubclass(int, Real)
+ self.assertIsSubclass(int, Complex)
+ self.assertIsSubclass(int, Number)
self.assertEqual(7, int(7).real)
self.assertEqual(0, int(7).imag)
@@ -38,11 +38,11 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(1, int(7).denominator)
def test_float(self):
- self.assertFalse(issubclass(float, Integral))
- self.assertFalse(issubclass(float, Rational))
- self.assertTrue(issubclass(float, Real))
- self.assertTrue(issubclass(float, Complex))
- self.assertTrue(issubclass(float, Number))
+ self.assertNotIsSubclass(float, Integral)
+ self.assertNotIsSubclass(float, Rational)
+ self.assertIsSubclass(float, Real)
+ self.assertIsSubclass(float, Complex)
+ self.assertIsSubclass(float, Number)
self.assertEqual(7.3, float(7.3).real)
self.assertEqual(0, float(7.3).imag)
@@ -50,11 +50,11 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(-7.3, float(-7.3).conjugate())
def test_complex(self):
- self.assertFalse(issubclass(complex, Integral))
- self.assertFalse(issubclass(complex, Rational))
- self.assertFalse(issubclass(complex, Real))
- self.assertTrue(issubclass(complex, Complex))
- self.assertTrue(issubclass(complex, Number))
+ self.assertNotIsSubclass(complex, Integral)
+ self.assertNotIsSubclass(complex, Rational)
+ self.assertNotIsSubclass(complex, Real)
+ self.assertIsSubclass(complex, Complex)
+ self.assertIsSubclass(complex, Number)
c1, c2 = complex(3, 2), complex(4,1)
# XXX: This is not ideal, but see the comment in math_trunc().
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index be55f044b15..ae0e73f08c5 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -1,18 +1,19 @@
"""Tests for the annotations module."""
+import textwrap
import annotationlib
import builtins
import collections
import functools
import itertools
import pickle
+from string.templatelib import Template
import typing
import unittest
from annotationlib import (
Format,
ForwardRef,
get_annotations,
- get_annotate_function,
annotations_to_string,
type_repr,
)
@@ -121,6 +122,28 @@ class TestForwardRefFormat(unittest.TestCase):
self.assertIsInstance(gamma_anno, ForwardRef)
self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f))
+ def test_partially_nonexistent_union(self):
+ # Test unions with '|' syntax equal unions with typing.Union[] with some forwardrefs
+ class UnionForwardrefs:
+ pipe: str | undefined
+ union: Union[str, undefined]
+
+ annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF)
+
+ pipe = annos["pipe"]
+ self.assertIsInstance(pipe, ForwardRef)
+ self.assertEqual(
+ pipe.evaluate(globals={"undefined": int}),
+ str | int,
+ )
+ union = annos["union"]
+ self.assertIsInstance(union, Union)
+ arg1, arg2 = typing.get_args(union)
+ self.assertIs(arg1, str)
+ self.assertEqual(
+ arg2, support.EqualToForwardRef("undefined", is_class=True, owner=UnionForwardrefs)
+ )
+
class TestStringFormat(unittest.TestCase):
def test_closure(self):
@@ -251,6 +274,126 @@ class TestStringFormat(unittest.TestCase):
},
)
+ def test_template_str(self):
+ def f(
+ x: t"{a}",
+ y: list[t"{a}"],
+ z: t"{a:b} {c!r} {d!s:t}",
+ a: t"a{b}c{d}e{f}g",
+ b: t"{a:{1}}",
+ c: t"{a | b * c}",
+ ): pass
+
+ annos = get_annotations(f, format=Format.STRING)
+ self.assertEqual(annos, {
+ "x": "t'{a}'",
+ "y": "list[t'{a}']",
+ "z": "t'{a:b} {c!r} {d!s:t}'",
+ "a": "t'a{b}c{d}e{f}g'",
+ # interpolations in the format spec are eagerly evaluated so we can't recover the source
+ "b": "t'{a:1}'",
+ "c": "t'{a | b * c}'",
+ })
+
+ def g(
+ x: t"{a}",
+ ): ...
+
+ annos = get_annotations(g, format=Format.FORWARDREF)
+ templ = annos["x"]
+ # Template and Interpolation don't have __eq__ so we have to compare manually
+ self.assertIsInstance(templ, Template)
+ self.assertEqual(templ.strings, ("", ""))
+ self.assertEqual(len(templ.interpolations), 1)
+ interp = templ.interpolations[0]
+ self.assertEqual(interp.value, support.EqualToForwardRef("a", owner=g))
+ self.assertEqual(interp.expression, "a")
+ self.assertIsNone(interp.conversion)
+ self.assertEqual(interp.format_spec, "")
+
+ def test_getitem(self):
+ def f(x: undef1[str, undef2]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "undef1[str, undef2]"})
+
+ anno = get_annotations(f, format=Format.FORWARDREF)
+ fwdref = anno["x"]
+ self.assertIsInstance(fwdref, ForwardRef)
+ self.assertEqual(
+ fwdref.evaluate(globals={"undef1": dict, "undef2": float}), dict[str, float]
+ )
+
+ def test_slice(self):
+ def f(x: a[b:c]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[b:c]"})
+
+ def f(x: a[b:c, d:e]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[b:c, d:e]"})
+
+ obj = slice(1, 1, 1)
+ def f(x: obj):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "obj"})
+
+ def test_literals(self):
+ def f(
+ a: 1,
+ b: 1.0,
+ c: "hello",
+ d: b"hello",
+ e: True,
+ f: None,
+ g: ...,
+ h: 1j,
+ ):
+ pass
+
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(
+ anno,
+ {
+ "a": "1",
+ "b": "1.0",
+ "c": 'hello',
+ "d": "b'hello'",
+ "e": "True",
+ "f": "None",
+ "g": "...",
+ "h": "1j",
+ },
+ )
+
+ def test_displays(self):
+ # Simple case first
+ def f(x: a[[int, str], float]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[[int, str], float]"})
+
+ def g(
+ w: a[[int, str], float],
+ x: a[{int}, 3],
+ y: a[{int: str}, 4],
+ z: a[(int, str), 5],
+ ):
+ pass
+ anno = get_annotations(g, format=Format.STRING)
+ self.assertEqual(
+ anno,
+ {
+ "w": "a[[int, str], float]",
+ "x": "a[{int}, 3]",
+ "y": "a[{int: str}, 4]",
+ "z": "a[(int, str), 5]",
+ },
+ )
+
def test_nested_expressions(self):
def f(
nested: list[Annotated[set[int], "set of ints", 4j]],
@@ -296,6 +439,17 @@ class TestStringFormat(unittest.TestCase):
with self.assertRaisesRegex(TypeError, format_msg):
get_annotations(f, format=Format.STRING)
+ def test_shenanigans(self):
+ # In cases like this we can't reconstruct the source; test that we do something
+ # halfway reasonable.
+ def f(x: x | (1).__class__, y: (1).__class__):
+ pass
+
+ self.assertEqual(
+ get_annotations(f, format=Format.STRING),
+ {"x": "x | <class 'int'>", "y": "<class 'int'>"},
+ )
+
class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
@@ -661,6 +815,70 @@ class TestGetAnnotations(unittest.TestCase):
{"x": int},
)
+ def test_stringized_annotation_permutations(self):
+ def define_class(name, has_future, has_annos, base_text, extra_names=None):
+ lines = []
+ if has_future:
+ lines.append("from __future__ import annotations")
+ lines.append(f"class {name}({base_text}):")
+ if has_annos:
+ lines.append(f" {name}_attr: int")
+ else:
+ lines.append(" pass")
+ code = "\n".join(lines)
+ ns = support.run_code(code, extra_names=extra_names)
+ return ns[name]
+
+ def check_annotations(cls, has_future, has_annos):
+ if has_annos:
+ if has_future:
+ anno = "int"
+ else:
+ anno = int
+ self.assertEqual(get_annotations(cls), {f"{cls.__name__}_attr": anno})
+ else:
+ self.assertEqual(get_annotations(cls), {})
+
+ for meta_future, base_future, child_future, meta_has_annos, base_has_annos, child_has_annos in itertools.product(
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ ):
+ with self.subTest(
+ meta_future=meta_future,
+ base_future=base_future,
+ child_future=child_future,
+ meta_has_annos=meta_has_annos,
+ base_has_annos=base_has_annos,
+ child_has_annos=child_has_annos,
+ ):
+ meta = define_class(
+ "Meta",
+ has_future=meta_future,
+ has_annos=meta_has_annos,
+ base_text="type",
+ )
+ base = define_class(
+ "Base",
+ has_future=base_future,
+ has_annos=base_has_annos,
+ base_text="metaclass=Meta",
+ extra_names={"Meta": meta},
+ )
+ child = define_class(
+ "Child",
+ has_future=child_future,
+ has_annos=child_has_annos,
+ base_text="Base",
+ extra_names={"Base": base},
+ )
+ check_annotations(meta, meta_future, meta_has_annos)
+ check_annotations(base, base_future, base_has_annos)
+ check_annotations(child, child_future, child_has_annos)
+
def test_modify_annotations(self):
def f(x: int):
pass
@@ -901,6 +1119,73 @@ class TestGetAnnotations(unittest.TestCase):
set(results.generic_func.__type_params__),
)
+ def test_partial_evaluation(self):
+ def f(
+ x: builtins.undef,
+ y: list[int],
+ z: 1 + int,
+ a: builtins.int,
+ b: [builtins.undef, builtins.int],
+ ):
+ pass
+
+ self.assertEqual(
+ get_annotations(f, format=Format.FORWARDREF),
+ {
+ "x": support.EqualToForwardRef("builtins.undef", owner=f),
+ "y": list[int],
+ "z": support.EqualToForwardRef("1 + int", owner=f),
+ "a": int,
+ "b": [
+ support.EqualToForwardRef("builtins.undef", owner=f),
+ # We can't resolve this because we have to evaluate the whole annotation
+ support.EqualToForwardRef("builtins.int", owner=f),
+ ],
+ },
+ )
+
+ self.assertEqual(
+ get_annotations(f, format=Format.STRING),
+ {
+ "x": "builtins.undef",
+ "y": "list[int]",
+ "z": "1 + int",
+ "a": "builtins.int",
+ "b": "[builtins.undef, builtins.int]",
+ },
+ )
+
+ def test_partial_evaluation_error(self):
+ def f(x: range[1]):
+ pass
+ with self.assertRaisesRegex(
+ TypeError, "type 'range' is not subscriptable"
+ ):
+ f.__annotations__
+
+ self.assertEqual(
+ get_annotations(f, format=Format.FORWARDREF),
+ {
+ "x": support.EqualToForwardRef("range[1]", owner=f),
+ },
+ )
+
+ def test_partial_evaluation_cell(self):
+ obj = object()
+
+ class RaisesAttributeError:
+ attriberr: obj.missing
+
+ anno = get_annotations(RaisesAttributeError, format=Format.FORWARDREF)
+ self.assertEqual(
+ anno,
+ {
+ "attriberr": support.EqualToForwardRef(
+ "obj.missing", is_class=True, owner=RaisesAttributeError
+ )
+ },
+ )
+
class TestCallEvaluateFunction(unittest.TestCase):
def test_evaluation(self):
@@ -933,13 +1218,13 @@ class MetaclassTests(unittest.TestCase):
b: float
self.assertEqual(get_annotations(Meta), {"a": int})
- self.assertEqual(get_annotate_function(Meta)(Format.VALUE), {"a": int})
+ self.assertEqual(Meta.__annotate__(Format.VALUE), {"a": int})
self.assertEqual(get_annotations(X), {})
- self.assertIs(get_annotate_function(X), None)
+ self.assertIs(X.__annotate__, None)
self.assertEqual(get_annotations(Y), {"b": float})
- self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float})
+ self.assertEqual(Y.__annotate__(Format.VALUE), {"b": float})
def test_unannotated_meta(self):
class Meta(type):
@@ -952,13 +1237,13 @@ class MetaclassTests(unittest.TestCase):
pass
self.assertEqual(get_annotations(Meta), {})
- self.assertIs(get_annotate_function(Meta), None)
+ self.assertIs(Meta.__annotate__, None)
self.assertEqual(get_annotations(Y), {})
- self.assertIs(get_annotate_function(Y), None)
+ self.assertIs(Y.__annotate__, None)
self.assertEqual(get_annotations(X), {"a": str})
- self.assertEqual(get_annotate_function(X)(Format.VALUE), {"a": str})
+ self.assertEqual(X.__annotate__(Format.VALUE), {"a": str})
def test_ordering(self):
# Based on a sample by David Ellis
@@ -996,7 +1281,7 @@ class MetaclassTests(unittest.TestCase):
for c in classes:
with self.subTest(c=c):
self.assertEqual(get_annotations(c), c.expected_annotations)
- annotate_func = get_annotate_function(c)
+ annotate_func = getattr(c, "__annotate__", None)
if c.expected_annotations:
self.assertEqual(
annotate_func(Format.VALUE), c.expected_annotations
@@ -1005,25 +1290,39 @@ class MetaclassTests(unittest.TestCase):
self.assertIs(annotate_func, None)
-class TestGetAnnotateFunction(unittest.TestCase):
- def test_static_class(self):
- self.assertIsNone(get_annotate_function(object))
- self.assertIsNone(get_annotate_function(int))
-
- def test_unannotated_class(self):
- class C:
- pass
+class TestGetAnnotateFromClassNamespace(unittest.TestCase):
+ def test_with_metaclass(self):
+ class Meta(type):
+ def __new__(mcls, name, bases, ns):
+ annotate = annotationlib.get_annotate_from_class_namespace(ns)
+ expected = ns["expected_annotate"]
+ with self.subTest(name=name):
+ if expected:
+ self.assertIsNotNone(annotate)
+ else:
+ self.assertIsNone(annotate)
+ return super().__new__(mcls, name, bases, ns)
+
+ class HasAnnotations(metaclass=Meta):
+ expected_annotate = True
+ a: int
- self.assertIsNone(get_annotate_function(C))
+ class NoAnnotations(metaclass=Meta):
+ expected_annotate = False
- D = type("D", (), {})
- self.assertIsNone(get_annotate_function(D))
+ class CustomAnnotate(metaclass=Meta):
+ expected_annotate = True
+ def __annotate__(format):
+ return {}
- def test_annotated_class(self):
- class C:
- a: int
+ code = """
+ from __future__ import annotations
- self.assertEqual(get_annotate_function(C)(Format.VALUE), {"a": int})
+ class HasFutureAnnotations(metaclass=Meta):
+ expected_annotate = False
+ a: int
+ """
+ exec(textwrap.dedent(code), {"Meta": Meta})
class TestTypeRepr(unittest.TestCase):
@@ -1240,6 +1539,38 @@ class TestForwardRefClass(unittest.TestCase):
with self.assertRaises(TypeError):
pickle.dumps(fr, proto)
+ def test_evaluate_string_format(self):
+ fr = ForwardRef("set[Any]")
+ self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")
+
+ def test_evaluate_forwardref_format(self):
+ fr = ForwardRef("undef")
+ evaluated = fr.evaluate(format=Format.FORWARDREF)
+ self.assertIs(fr, evaluated)
+
+ fr = ForwardRef("set[undefined]")
+ evaluated = fr.evaluate(format=Format.FORWARDREF)
+ self.assertEqual(
+ evaluated,
+ set[support.EqualToForwardRef("undefined")],
+ )
+
+ fr = ForwardRef("a + b")
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF),
+ support.EqualToForwardRef("a + b"),
+ )
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF, locals={"a": 1, "b": 2}),
+ 3,
+ )
+
+ fr = ForwardRef('"a" + 1')
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF),
+ support.EqualToForwardRef('"a" + 1'),
+ )
+
def test_evaluate_with_type_params(self):
class Gen[T]:
alias = int
@@ -1319,9 +1650,11 @@ class TestForwardRefClass(unittest.TestCase):
with support.swap_attr(builtins, "int", dict):
self.assertIs(ForwardRef("int").evaluate(), dict)
- with self.assertRaises(NameError):
+ with self.assertRaises(NameError, msg="name 'doesntexist' is not defined") as exc:
ForwardRef("doesntexist").evaluate()
+ self.assertEqual(exc.exception.name, "doesntexist")
+
def test_fwdref_invalid_syntax(self):
fr = ForwardRef("if")
with self.assertRaises(SyntaxError):
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index 488a3a4ed20..08ff41368d9 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -1,5 +1,6 @@
# Author: Steven J. Bethard <steven.bethard@gmail.com>.
+import _colorize
import contextlib
import functools
import inspect
@@ -5468,11 +5469,60 @@ class TestHelpMetavarTypeFormatter(HelpTestCase):
version = ''
-class TestHelpUsageLongSubparserCommand(TestCase):
- """Test that subparser commands are formatted correctly in help"""
+class TestHelpCustomHelpFormatter(TestCase):
maxDiff = None
- def test_parent_help(self):
+ def test_custom_formatter_function(self):
+ def custom_formatter(prog):
+ return argparse.RawTextHelpFormatter(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog='PROG',
+ prefix_chars='-+',
+ formatter_class=custom_formatter
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent('''\
+ usage: PROG [-h] [+f FOO] spam
+
+ positional arguments:
+ spam spam help
+
+ options:
+ -h, --help show this help message and exit
+ +f, ++foo FOO foo help
+ '''))
+
+ def test_custom_formatter_class(self):
+ class CustomFormatter(argparse.RawTextHelpFormatter):
+ def __init__(self, prog):
+ super().__init__(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog='PROG',
+ prefix_chars='-+',
+ formatter_class=CustomFormatter
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent('''\
+ usage: PROG [-h] [+f FOO] spam
+
+ positional arguments:
+ spam spam help
+
+ options:
+ -h, --help show this help message and exit
+ +f, ++foo FOO foo help
+ '''))
+
+ def test_usage_long_subparser_command(self):
+ """Test that subparser commands are formatted correctly in help"""
def custom_formatter(prog):
return argparse.RawTextHelpFormatter(prog, max_help_position=50)
@@ -6755,7 +6805,7 @@ class TestImportStar(TestCase):
def test(self):
for name in argparse.__all__:
- self.assertTrue(hasattr(argparse, name))
+ self.assertHasAttr(argparse, name)
def test_all_exports_everything_but_modules(self):
items = [
@@ -6972,7 +7022,7 @@ class TestProgName(TestCase):
def check_usage(self, expected, *args, **kwargs):
res = script_helper.assert_python_ok('-Xutf8', *args, '-h', **kwargs)
- self.assertEqual(res.out.splitlines()[0].decode(),
+ self.assertEqual(os.fsdecode(res.out.splitlines()[0]),
f'usage: {expected} [-h]')
def test_script(self, compiled=False):
@@ -7046,6 +7096,245 @@ class TestTranslations(TestTranslationsBase):
self.assertMsgidsEqual(argparse)
+# ===========
+# Color tests
+# ===========
+
+
+class TestColorized(TestCase):
+ maxDiff = None
+
+ def setUp(self):
+ super().setUp()
+ # Ensure color even if ran with NO_COLOR=1
+ _colorize.can_colorize = lambda *args, **kwargs: True
+ self.theme = _colorize.get_theme(force_color=True).argparse
+
+ def test_argparse_color(self):
+ # Arrange: create a parser with a bit of everything
+ parser = argparse.ArgumentParser(
+ color=True,
+ description="Colorful help",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ prefix_chars="-+",
+ prog="PROG",
+ )
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument(
+ "-v", "--verbose", action="store_true", help="more spam"
+ )
+ group.add_argument(
+ "-q", "--quiet", action="store_true", help="less spam"
+ )
+ parser.add_argument("x", type=int, help="the base")
+ parser.add_argument(
+ "y", type=int, help="the exponent", deprecated=True
+ )
+ parser.add_argument(
+ "this_indeed_is_a_very_long_action_name",
+ type=int,
+ help="the exponent",
+ )
+ parser.add_argument(
+ "-o", "--optional1", action="store_true", deprecated=True
+ )
+ parser.add_argument("--optional2", help="pick one")
+ parser.add_argument("--optional3", choices=("X", "Y", "Z"))
+ parser.add_argument(
+ "--optional4", choices=("X", "Y", "Z"), help="pick one"
+ )
+ parser.add_argument(
+ "--optional5", choices=("X", "Y", "Z"), help="pick one"
+ )
+ parser.add_argument(
+ "--optional6", choices=("X", "Y", "Z"), help="pick one"
+ )
+ parser.add_argument(
+ "-p",
+ "--optional7",
+ choices=("Aaaaa", "Bbbbb", "Ccccc", "Ddddd"),
+ help="pick one",
+ )
+
+ parser.add_argument("+f")
+ parser.add_argument("++bar")
+ parser.add_argument("-+baz")
+ parser.add_argument("-c", "--count")
+
+ subparsers = parser.add_subparsers(
+ title="subcommands",
+ description="valid subcommands",
+ help="additional help",
+ )
+ subparsers.add_parser("sub1", deprecated=True, help="sub1 help")
+ sub2 = subparsers.add_parser("sub2", deprecated=True, help="sub2 help")
+ sub2.add_argument("--baz", choices=("X", "Y", "Z"), help="baz help")
+
+ prog = self.theme.prog
+ heading = self.theme.heading
+ long = self.theme.summary_long_option
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
+
+ # Act
+ help_text = parser.format_help()
+
+ # Assert
+ self.assertEqual(
+ help_text,
+ textwrap.dedent(
+ f"""\
+ {heading}usage: {reset}{prog}PROG{reset} [{short}-h{reset}] [{short}-v{reset} | {short}-q{reset}] [{short}-o{reset}] [{long}--optional2 {label}OPTIONAL2{reset}] [{long}--optional3 {label}{{X,Y,Z}}{reset}]
+ [{long}--optional4 {label}{{X,Y,Z}}{reset}] [{long}--optional5 {label}{{X,Y,Z}}{reset}] [{long}--optional6 {label}{{X,Y,Z}}{reset}]
+ [{short}-p {label}{{Aaaaa,Bbbbb,Ccccc,Ddddd}}{reset}] [{short}+f {label}F{reset}] [{long}++bar {label}BAR{reset}] [{long}-+baz {label}BAZ{reset}]
+ [{short}-c {label}COUNT{reset}]
+ {pos}x{reset} {pos}y{reset} {pos}this_indeed_is_a_very_long_action_name{reset} {pos}{{sub1,sub2}} ...{reset}
+
+ Colorful help
+
+ {heading}positional arguments:{reset}
+ {pos_b}x{reset} the base
+ {pos_b}y{reset} the exponent
+ {pos_b}this_indeed_is_a_very_long_action_name{reset}
+ the exponent
+
+ {heading}options:{reset}
+ {short_b}-h{reset}, {long_b}--help{reset} show this help message and exit
+ {short_b}-v{reset}, {long_b}--verbose{reset} more spam (default: False)
+ {short_b}-q{reset}, {long_b}--quiet{reset} less spam (default: False)
+ {short_b}-o{reset}, {long_b}--optional1{reset}
+ {long_b}--optional2{reset} {label_b}OPTIONAL2{reset}
+ pick one (default: None)
+ {long_b}--optional3{reset} {label_b}{{X,Y,Z}}{reset}
+ {long_b}--optional4{reset} {label_b}{{X,Y,Z}}{reset} pick one (default: None)
+ {long_b}--optional5{reset} {label_b}{{X,Y,Z}}{reset} pick one (default: None)
+ {long_b}--optional6{reset} {label_b}{{X,Y,Z}}{reset} pick one (default: None)
+ {short_b}-p{reset}, {long_b}--optional7{reset} {label_b}{{Aaaaa,Bbbbb,Ccccc,Ddddd}}{reset}
+ pick one (default: None)
+ {short_b}+f{reset} {label_b}F{reset}
+ {long_b}++bar{reset} {label_b}BAR{reset}
+ {long_b}-+baz{reset} {label_b}BAZ{reset}
+ {short_b}-c{reset}, {long_b}--count{reset} {label_b}COUNT{reset}
+
+ {heading}subcommands:{reset}
+ valid subcommands
+
+ {pos_b}{{sub1,sub2}}{reset} additional help
+ {pos_b}sub1{reset} sub1 help
+ {pos_b}sub2{reset} sub2 help
+ """
+ ),
+ )
+
+ def test_argparse_color_usage(self):
+ # Arrange
+ parser = argparse.ArgumentParser(
+ add_help=False,
+ color=True,
+ description="Test prog and usage colors",
+ prog="PROG",
+ usage="[prefix] %(prog)s [suffix]",
+ )
+ heading = self.theme.heading
+ prog = self.theme.prog
+ reset = self.theme.reset
+ usage = self.theme.prog_extra
+
+ # Act
+ help_text = parser.format_help()
+
+ # Assert
+ self.assertEqual(
+ help_text,
+ textwrap.dedent(
+ f"""\
+ {heading}usage: {reset}{usage}[prefix] {prog}PROG{reset}{usage} [suffix]{reset}
+
+ Test prog and usage colors
+ """
+ ),
+ )
+
+ def test_custom_formatter_function(self):
+ def custom_formatter(prog):
+ return argparse.RawTextHelpFormatter(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog="PROG",
+ prefix_chars="-+",
+ formatter_class=custom_formatter,
+ color=True,
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ prog = self.theme.prog
+ heading = self.theme.heading
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent(f'''\
+ {heading}usage: {reset}{prog}PROG{reset} [{short}-h{reset}] [{short}+f {label}FOO{reset}] {pos}spam{reset}
+
+ {heading}positional arguments:{reset}
+ {pos_b}spam{reset} spam help
+
+ {heading}options:{reset}
+ {short_b}-h{reset}, {long_b}--help{reset} show this help message and exit
+ {short_b}+f{reset}, {long_b}++foo{reset} {label_b}FOO{reset} foo help
+ '''))
+
+ def test_custom_formatter_class(self):
+ class CustomFormatter(argparse.RawTextHelpFormatter):
+ def __init__(self, prog):
+ super().__init__(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog="PROG",
+ prefix_chars="-+",
+ formatter_class=CustomFormatter,
+ color=True,
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ prog = self.theme.prog
+ heading = self.theme.heading
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent(f'''\
+ {heading}usage: {reset}{prog}PROG{reset} [{short}-h{reset}] [{short}+f {label}FOO{reset}] {pos}spam{reset}
+
+ {heading}positional arguments:{reset}
+ {pos_b}spam{reset} spam help
+
+ {heading}options:{reset}
+ {short_b}-h{reset}, {long_b}--help{reset} show this help message and exit
+ {short_b}+f{reset}, {long_b}++foo{reset} {label_b}FOO{reset} foo help
+ '''))
+
+
def tearDownModule():
# Remove global references to avoid looking like we have refleaks.
RFile.seen = {}
diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py
index 2c198a6b8b2..b9df6568123 100644
--- a/Lib/test/test_asdl_parser.py
+++ b/Lib/test/test_asdl_parser.py
@@ -62,17 +62,17 @@ class TestAsdlParser(unittest.TestCase):
alias = self.types['alias']
self.assertEqual(
str(alias),
- 'Product([Field(identifier, name), Field(identifier, asname, opt=True)], '
+ 'Product([Field(identifier, name), Field(identifier, asname, quantifiers=[OPTIONAL])], '
'[Field(int, lineno), Field(int, col_offset), '
- 'Field(int, end_lineno, opt=True), Field(int, end_col_offset, opt=True)])')
+ 'Field(int, end_lineno, quantifiers=[OPTIONAL]), Field(int, end_col_offset, quantifiers=[OPTIONAL])])')
def test_attributes(self):
stmt = self.types['stmt']
self.assertEqual(len(stmt.attributes), 4)
self.assertEqual(repr(stmt.attributes[0]), 'Field(int, lineno)')
self.assertEqual(repr(stmt.attributes[1]), 'Field(int, col_offset)')
- self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, opt=True)')
- self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, opt=True)')
+ self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, quantifiers=[OPTIONAL])')
+ self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, quantifiers=[OPTIONAL])')
def test_constructor_fields(self):
ehandler = self.types['excepthandler']
diff --git a/Lib/test/test_ast/data/ast_repr.txt b/Lib/test/test_ast/data/ast_repr.txt
index 3778b9e70a4..1c1985519cd 100644
--- a/Lib/test/test_ast/data/ast_repr.txt
+++ b/Lib/test/test_ast/data/ast_repr.txt
@@ -206,4 +206,9 @@ Module(body=[Expr(value=IfExp(test=Name(...), body=Call(...), orelse=Call(...)))
Module(body=[Expr(value=JoinedStr(values=[FormattedValue(...)]))], type_ignores=[])
Module(body=[Expr(value=JoinedStr(values=[FormattedValue(...)]))], type_ignores=[])
Module(body=[Expr(value=JoinedStr(values=[FormattedValue(...)]))], type_ignores=[])
-Module(body=[Expr(value=JoinedStr(values=[Constant(...), ..., Constant(...)]))], type_ignores=[]) \ No newline at end of file
+Module(body=[Expr(value=JoinedStr(values=[Constant(...), ..., Constant(...)]))], type_ignores=[])
+Module(body=[Expr(value=TemplateStr(values=[Interpolation(...)]))], type_ignores=[])
+Module(body=[Expr(value=TemplateStr(values=[Interpolation(...)]))], type_ignores=[])
+Module(body=[Expr(value=TemplateStr(values=[Interpolation(...)]))], type_ignores=[])
+Module(body=[Expr(value=TemplateStr(values=[Interpolation(...)]))], type_ignores=[])
+Module(body=[Expr(value=TemplateStr(values=[Constant(...), ..., Constant(...)]))], type_ignores=[]) \ No newline at end of file
diff --git a/Lib/test/test_ast/snippets.py b/Lib/test/test_ast/snippets.py
index 28d32b2941f..b76f98901d2 100644
--- a/Lib/test/test_ast/snippets.py
+++ b/Lib/test/test_ast/snippets.py
@@ -364,6 +364,12 @@ eval_tests = [
"f'{a:.2f}'",
"f'{a!r}'",
"f'foo({a})'",
+ # TemplateStr and Interpolation
+ "t'{a}'",
+ "t'{a:.2f}'",
+ "t'{a!r}'",
+ "t'{a!r:.2f}'",
+ "t'foo({a})'",
]
@@ -597,5 +603,10 @@ eval_results = [
('Expression', ('JoinedStr', (1, 0, 1, 10), [('FormattedValue', (1, 2, 1, 9), ('Name', (1, 3, 1, 4), 'a', ('Load',)), -1, ('JoinedStr', (1, 4, 1, 8), [('Constant', (1, 5, 1, 8), '.2f', None)]))])),
('Expression', ('JoinedStr', (1, 0, 1, 8), [('FormattedValue', (1, 2, 1, 7), ('Name', (1, 3, 1, 4), 'a', ('Load',)), 114, None)])),
('Expression', ('JoinedStr', (1, 0, 1, 11), [('Constant', (1, 2, 1, 6), 'foo(', None), ('FormattedValue', (1, 6, 1, 9), ('Name', (1, 7, 1, 8), 'a', ('Load',)), -1, None), ('Constant', (1, 9, 1, 10), ')', None)])),
+('Expression', ('TemplateStr', (1, 0, 1, 6), [('Interpolation', (1, 2, 1, 5), ('Name', (1, 3, 1, 4), 'a', ('Load',)), 'a', -1, None)])),
+('Expression', ('TemplateStr', (1, 0, 1, 10), [('Interpolation', (1, 2, 1, 9), ('Name', (1, 3, 1, 4), 'a', ('Load',)), 'a', -1, ('JoinedStr', (1, 4, 1, 8), [('Constant', (1, 5, 1, 8), '.2f', None)]))])),
+('Expression', ('TemplateStr', (1, 0, 1, 8), [('Interpolation', (1, 2, 1, 7), ('Name', (1, 3, 1, 4), 'a', ('Load',)), 'a', 114, None)])),
+('Expression', ('TemplateStr', (1, 0, 1, 12), [('Interpolation', (1, 2, 1, 11), ('Name', (1, 3, 1, 4), 'a', ('Load',)), 'a', 114, ('JoinedStr', (1, 6, 1, 10), [('Constant', (1, 7, 1, 10), '.2f', None)]))])),
+('Expression', ('TemplateStr', (1, 0, 1, 11), [('Constant', (1, 2, 1, 6), 'foo(', None), ('Interpolation', (1, 6, 1, 9), ('Name', (1, 7, 1, 8), 'a', ('Load',)), 'a', -1, None), ('Constant', (1, 9, 1, 10), ')', None)])),
]
main()
diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py
index dd459487afe..cc46529c0ef 100644
--- a/Lib/test/test_ast/test_ast.py
+++ b/Lib/test/test_ast/test_ast.py
@@ -1,16 +1,20 @@
import _ast_unparse
import ast
import builtins
+import contextlib
import copy
import dis
import enum
+import itertools
import os
import re
import sys
+import tempfile
import textwrap
import types
import unittest
import weakref
+from io import StringIO
from pathlib import Path
from textwrap import dedent
try:
@@ -19,9 +23,10 @@ except ImportError:
_testinternalcapi = None
from test import support
-from test.support import os_helper, script_helper
+from test.support import os_helper
from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow
from test.support.ast_helper import ASTTestMixin
+from test.support.import_helper import ensure_lazy_imports
from test.test_ast.utils import to_tuple
from test.test_ast.snippets import (
eval_tests, eval_results, exec_tests, exec_results, single_tests, single_results
@@ -43,6 +48,12 @@ def ast_repr_update_snapshots() -> None:
AST_REPR_DATA_FILE.write_text("\n".join(data))
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("ast", {"contextlib", "enum", "inspect", "re", "collections", "argparse"})
+
+
class AST_Tests(unittest.TestCase):
maxDiff = None
@@ -264,12 +275,12 @@ class AST_Tests(unittest.TestCase):
self.assertEqual(alias.end_col_offset, 17)
def test_base_classes(self):
- self.assertTrue(issubclass(ast.For, ast.stmt))
- self.assertTrue(issubclass(ast.Name, ast.expr))
- self.assertTrue(issubclass(ast.stmt, ast.AST))
- self.assertTrue(issubclass(ast.expr, ast.AST))
- self.assertTrue(issubclass(ast.comprehension, ast.AST))
- self.assertTrue(issubclass(ast.Gt, ast.AST))
+ self.assertIsSubclass(ast.For, ast.stmt)
+ self.assertIsSubclass(ast.Name, ast.expr)
+ self.assertIsSubclass(ast.stmt, ast.AST)
+ self.assertIsSubclass(ast.expr, ast.AST)
+ self.assertIsSubclass(ast.comprehension, ast.AST)
+ self.assertIsSubclass(ast.Gt, ast.AST)
def test_field_attr_existence(self):
for name, item in ast.__dict__.items():
@@ -675,6 +686,91 @@ class AST_Tests(unittest.TestCase):
with self.assertRaises(SyntaxError):
ast.parse('(x := 0)', feature_version=(3, 7))
+ def test_pep750_tstring(self):
+ code = 't""'
+ ast.parse(code, feature_version=(3, 14))
+ with self.assertRaises(SyntaxError):
+ ast.parse(code, feature_version=(3, 13))
+
+ def test_pep758_except_without_parens(self):
+ code = textwrap.dedent("""
+ try:
+ ...
+ except ValueError, TypeError:
+ ...
+ """)
+ ast.parse(code, feature_version=(3, 14))
+ with self.assertRaises(SyntaxError):
+ ast.parse(code, feature_version=(3, 13))
+
+ def test_pep758_except_with_single_expr(self):
+ single_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} TypeError:
+ ...
+ """)
+
+ single_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} TypeError as exc:
+ ...
+ """)
+
+ single_tuple_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError,):
+ ...
+ """)
+
+ single_tuple_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError,) as exc:
+ ...
+ """)
+
+ single_parens_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError):
+ ...
+ """)
+
+ single_parens_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError) as exc:
+ ...
+ """)
+
+ for code in [
+ single_expr,
+ single_expr_with_as,
+ single_tuple_expr,
+ single_tuple_expr_with_as,
+ single_parens_expr,
+ single_parens_expr_with_as,
+ ]:
+ for star in [True, False]:
+ code = code.format('*' if star else '')
+ with self.subTest(code=code, star=star):
+ ast.parse(code, feature_version=(3, 14))
+ ast.parse(code, feature_version=(3, 13))
+
+ def test_pep758_except_star_without_parens(self):
+ code = textwrap.dedent("""
+ try:
+ ...
+ except* ValueError, TypeError:
+ ...
+ """)
+ ast.parse(code, feature_version=(3, 14))
+ with self.assertRaises(SyntaxError):
+ ast.parse(code, feature_version=(3, 13))
+
def test_conditional_context_managers_parse_with_low_feature_version(self):
# regression test for gh-115881
ast.parse('with (x() if y else z()): ...', feature_version=(3, 8))
@@ -725,6 +821,17 @@ class AST_Tests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
compile(expr, "<test>", "eval")
+ def test_constant_as_unicode_name(self):
+ constants = [
+ ("True", b"Tru\xe1\xb5\x89"),
+ ("False", b"Fal\xc5\xbfe"),
+ ("None", b"N\xc2\xbane"),
+ ]
+ for constant in constants:
+ with self.assertRaisesRegex(ValueError,
+ f"identifier field can't represent '{constant[0]}' constant"):
+ ast.parse(constant[1], mode="eval")
+
def test_precedence_enum(self):
class _Precedence(enum.IntEnum):
"""Precedence table that originated from python grammar."""
@@ -880,6 +987,25 @@ class AST_Tests(unittest.TestCase):
for src in srcs:
ast.parse(src)
+ def test_tstring(self):
+ # Test AST structure for simple t-string
+ tree = ast.parse('t"Hello"')
+ self.assertIsInstance(tree.body[0].value, ast.TemplateStr)
+ self.assertIsInstance(tree.body[0].value.values[0], ast.Constant)
+
+ # Test AST for t-string with interpolation
+ tree = ast.parse('t"Hello {name}"')
+ self.assertIsInstance(tree.body[0].value, ast.TemplateStr)
+ self.assertIsInstance(tree.body[0].value.values[0], ast.Constant)
+ self.assertIsInstance(tree.body[0].value.values[1], ast.Interpolation)
+
+ # Test AST for implicit concat of t-string with f-string
+ tree = ast.parse('t"Hello {name}" f"{name}"')
+ self.assertIsInstance(tree.body[0].value, ast.TemplateStr)
+ self.assertIsInstance(tree.body[0].value.values[0], ast.Constant)
+ self.assertIsInstance(tree.body[0].value.values[1], ast.Interpolation)
+ self.assertIsInstance(tree.body[0].value.values[2], ast.FormattedValue)
+
class CopyTests(unittest.TestCase):
"""Test copying and pickling AST nodes."""
@@ -975,7 +1101,7 @@ class CopyTests(unittest.TestCase):
def test_replace_interface(self):
for klass in self.iter_ast_classes():
with self.subTest(klass=klass):
- self.assertTrue(hasattr(klass, '__replace__'))
+ self.assertHasAttr(klass, '__replace__')
fields = set(klass._fields)
with self.subTest(klass=klass, fields=fields):
@@ -1189,13 +1315,22 @@ class CopyTests(unittest.TestCase):
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)
+ def test_replace_accept_missing_field_with_default(self):
+ node = ast.FunctionDef(name="foo", args=ast.arguments())
+ self.assertIs(node.returns, None)
+ self.assertEqual(node.decorator_list, [])
+ node2 = copy.replace(node, name="bar")
+ self.assertEqual(node2.name, "bar")
+ self.assertIs(node2.returns, None)
+ self.assertEqual(node2.decorator_list, [])
+
def test_replace_reject_known_custom_instance_fields_commits(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx
# explicit rejection of known instance fields
- self.assertTrue(hasattr(node, 'extra'))
+ self.assertHasAttr(node, 'extra')
msg = "Name.__replace__ got an unexpected keyword argument 'extra'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, extra=1)
@@ -1237,17 +1372,17 @@ class ASTHelpers_Test(unittest.TestCase):
def test_dump(self):
node = ast.parse('spam(eggs, "and cheese")')
self.assertEqual(ast.dump(node),
- "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), "
- "args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])"
+ "Module(body=[Expr(value=Call(func=Name(id='spam'), "
+ "args=[Name(id='eggs'), Constant(value='and cheese')]))])"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Module([Expr(Call(Name('spam', Load()), [Name('eggs', Load()), "
+ "Module([Expr(Call(Name('spam'), [Name('eggs'), "
"Constant('and cheese')]))])"
)
self.assertEqual(ast.dump(node, include_attributes=True),
- "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load(), "
+ "Module(body=[Expr(value=Call(func=Name(id='spam', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=4), "
- "args=[Name(id='eggs', ctx=Load(), lineno=1, col_offset=5, "
+ "args=[Name(id='eggs', lineno=1, col_offset=5, "
"end_lineno=1, end_col_offset=9), Constant(value='and cheese', "
"lineno=1, col_offset=11, end_lineno=1, end_col_offset=23)], "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=24), "
@@ -1261,18 +1396,18 @@ Module(
body=[
Expr(
value=Call(
- func=Name(id='spam', ctx=Load()),
+ func=Name(id='spam'),
args=[
- Name(id='eggs', ctx=Load()),
+ Name(id='eggs'),
Constant(value='and cheese')]))])""")
self.assertEqual(ast.dump(node, annotate_fields=False, indent='\t'), """\
Module(
\t[
\t\tExpr(
\t\t\tCall(
-\t\t\t\tName('spam', Load()),
+\t\t\t\tName('spam'),
\t\t\t\t[
-\t\t\t\t\tName('eggs', Load()),
+\t\t\t\t\tName('eggs'),
\t\t\t\t\tConstant('and cheese')]))])""")
self.assertEqual(ast.dump(node, include_attributes=True, indent=3), """\
Module(
@@ -1281,7 +1416,6 @@ Module(
value=Call(
func=Name(
id='spam',
- ctx=Load(),
lineno=1,
col_offset=0,
end_lineno=1,
@@ -1289,7 +1423,6 @@ Module(
args=[
Name(
id='eggs',
- ctx=Load(),
lineno=1,
col_offset=5,
end_lineno=1,
@@ -1319,23 +1452,23 @@ Module(
)
node = ast.Raise(exc=ast.Name(id='e', ctx=ast.Load()), lineno=3, col_offset=4)
self.assertEqual(ast.dump(node),
- "Raise(exc=Name(id='e', ctx=Load()))"
+ "Raise(exc=Name(id='e'))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Raise(Name('e', Load()))"
+ "Raise(Name('e'))"
)
self.assertEqual(ast.dump(node, include_attributes=True),
- "Raise(exc=Name(id='e', ctx=Load()), lineno=3, col_offset=4)"
+ "Raise(exc=Name(id='e'), lineno=3, col_offset=4)"
)
self.assertEqual(ast.dump(node, annotate_fields=False, include_attributes=True),
- "Raise(Name('e', Load()), lineno=3, col_offset=4)"
+ "Raise(Name('e'), lineno=3, col_offset=4)"
)
node = ast.Raise(cause=ast.Name(id='e', ctx=ast.Load()))
self.assertEqual(ast.dump(node),
- "Raise(cause=Name(id='e', ctx=Load()))"
+ "Raise(cause=Name(id='e'))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Raise(cause=Name('e', Load()))"
+ "Raise(cause=Name('e'))"
)
# Arguments:
node = ast.arguments(args=[ast.arg("x")])
@@ -1367,10 +1500,10 @@ Module(
[ast.Name('dataclass', ctx=ast.Load())],
)
self.assertEqual(ast.dump(node),
- "ClassDef(name='T', keywords=[keyword(arg='a', value=Constant(value=None))], decorator_list=[Name(id='dataclass', ctx=Load())])",
+ "ClassDef(name='T', keywords=[keyword(arg='a', value=Constant(value=None))], decorator_list=[Name(id='dataclass')])",
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass', Load())])",
+ "ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass')])",
)
def test_dump_show_empty(self):
@@ -1398,7 +1531,7 @@ Module(
check_node(
# Corner case: there are no real `Name` instances with `id=''`:
ast.Name(id='', ctx=ast.Load()),
- empty="Name(id='', ctx=Load())",
+ empty="Name(id='')",
full="Name(id='', ctx=Load())",
)
@@ -1409,39 +1542,63 @@ Module(
)
check_node(
+ ast.MatchSingleton(value=[]),
+ empty="MatchSingleton(value=[])",
+ full="MatchSingleton(value=[])",
+ )
+
+ check_node(
ast.Constant(value=None),
empty="Constant(value=None)",
full="Constant(value=None)",
)
check_node(
+ ast.Constant(value=[]),
+ empty="Constant(value=[])",
+ full="Constant(value=[])",
+ )
+
+ check_node(
ast.Constant(value=''),
empty="Constant(value='')",
full="Constant(value='')",
)
+ check_node(
+ ast.Interpolation(value=ast.Constant(42), str=None, conversion=-1),
+ empty="Interpolation(value=Constant(value=42), str=None, conversion=-1)",
+ full="Interpolation(value=Constant(value=42), str=None, conversion=-1)",
+ )
+
+ check_node(
+ ast.Interpolation(value=ast.Constant(42), str=[], conversion=-1),
+ empty="Interpolation(value=Constant(value=42), str=[], conversion=-1)",
+ full="Interpolation(value=Constant(value=42), str=[], conversion=-1)",
+ )
+
check_text(
"def a(b: int = 0, *, c): ...",
- empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))])])",
+ empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int'))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))])])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))], decorator_list=[], type_params=[])], type_ignores=[])",
)
check_text(
"def a(b: int = 0, *, c): ...",
- empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)])",
+ empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], decorator_list=[], type_params=[], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)], type_ignores=[])",
include_attributes=True,
)
check_text(
'spam(eggs, "and cheese")',
- empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])",
+ empty="Module(body=[Expr(value=Call(func=Name(id='spam'), args=[Name(id='eggs'), Constant(value='and cheese')]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')], keywords=[]))], type_ignores=[])",
)
check_text(
'spam(eggs, text="and cheese")',
- empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))])",
+ empty="Module(body=[Expr(value=Call(func=Name(id='spam'), args=[Name(id='eggs')], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))], type_ignores=[])",
)
@@ -1475,12 +1632,12 @@ Module(
self.assertEqual(src, ast.fix_missing_locations(src))
self.maxDiff = None
self.assertEqual(ast.dump(src, include_attributes=True),
- "Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), "
+ "Module(body=[Expr(value=Call(func=Name(id='write', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=5), "
"args=[Constant(value='spam', lineno=1, col_offset=6, end_lineno=1, "
"end_col_offset=12)], lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=13), lineno=1, col_offset=0, end_lineno=1, "
- "end_col_offset=13), Expr(value=Call(func=Name(id='spam', ctx=Load(), "
+ "end_col_offset=13), Expr(value=Call(func=Name(id='spam', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=0), "
"args=[Constant(value='eggs', lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=0)], lineno=1, col_offset=0, end_lineno=1, "
@@ -2936,7 +3093,7 @@ class ASTConstructorTests(unittest.TestCase):
with self.assertWarnsRegex(DeprecationWarning,
r"FunctionDef\.__init__ missing 1 required positional argument: 'name'"):
node = ast.FunctionDef(args=args)
- self.assertFalse(hasattr(node, "name"))
+ self.assertNotHasAttr(node, "name")
self.assertEqual(node.decorator_list, [])
node = ast.FunctionDef(name='foo', args=args)
self.assertEqual(node.name, 'foo')
@@ -3128,23 +3285,263 @@ class ModuleStateTests(unittest.TestCase):
self.assertEqual(res, 0)
-class ASTMainTests(unittest.TestCase):
- # Tests `ast.main()` function.
+class CommandLineTests(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+ self.addCleanup(os_helper.unlink, self.filename)
+
+ @staticmethod
+ def text_normalize(string):
+ return textwrap.dedent(string).strip()
+
+ def set_source(self, content):
+ Path(self.filename).write_text(self.text_normalize(content))
+
+ def invoke_ast(self, *flags):
+ stderr = StringIO()
+ stdout = StringIO()
+ with (
+ contextlib.redirect_stdout(stdout),
+ contextlib.redirect_stderr(stderr),
+ ):
+ ast.main(args=[*flags, self.filename])
+ self.assertEqual(stderr.getvalue(), '')
+ return stdout.getvalue().strip()
- def test_cli_file_input(self):
- code = "print(1, 2, 3)"
- expected = ast.dump(ast.parse(code), indent=3)
+ def check_output(self, source, expect, *flags):
+ self.set_source(source)
+ res = self.invoke_ast(*flags)
+ expect = self.text_normalize(expect)
+ self.assertEqual(res, expect)
- with os_helper.temp_dir() as tmp_dir:
- filename = os.path.join(tmp_dir, "test_module.py")
- with open(filename, 'w', encoding='utf-8') as f:
- f.write(code)
- res, _ = script_helper.run_python_until_end("-m", "ast", filename)
+ @support.requires_resource('cpu')
+ def test_invocation(self):
+ # test various combinations of parameters
+ base_flags = (
+ ('-m=exec', '--mode=exec'),
+ ('--no-type-comments', '--no-type-comments'),
+ ('-a', '--include-attributes'),
+ ('-i=4', '--indent=4'),
+ ('--feature-version=3.13', '--feature-version=3.13'),
+ ('-O=-1', '--optimize=-1'),
+ ('--show-empty', '--show-empty'),
+ )
+ self.set_source('''
+ print(1, 2, 3)
+ def f(x: int) -> int:
+ x -= 1
+ return x
+ ''')
+
+ for r in range(1, len(base_flags) + 1):
+ for choices in itertools.combinations(base_flags, r=r):
+ for args in itertools.product(*choices):
+ with self.subTest(flags=args):
+ self.invoke_ast(*args)
+
+ @support.force_not_colorized
+ def test_help_message(self):
+ for flag in ('-h', '--help', '--unknown'):
+ with self.subTest(flag=flag):
+ output = StringIO()
+ with self.assertRaises(SystemExit):
+ with contextlib.redirect_stderr(output):
+ ast.main(args=flag)
+ self.assertStartsWith(output.getvalue(), 'usage: ')
+
+ def test_exec_mode_flag(self):
+ # test 'python -m ast -m/--mode exec'
+ source = 'x: bool = 1 # type: ignore[assignment]'
+ expect = '''
+ Module(
+ body=[
+ AnnAssign(
+ target=Name(id='x', ctx=Store()),
+ annotation=Name(id='bool'),
+ value=Constant(value=1),
+ simple=1)],
+ type_ignores=[
+ TypeIgnore(lineno=1, tag='[assignment]')])
+ '''
+ for flag in ('-m=exec', '--mode=exec'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_single_mode_flag(self):
+ # test 'python -m ast -m/--mode single'
+ source = 'pass'
+ expect = '''
+ Interactive(
+ body=[
+ Pass()])
+ '''
+ for flag in ('-m=single', '--mode=single'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_eval_mode_flag(self):
+ # test 'python -m ast -m/--mode eval'
+ source = 'print(1, 2, 3)'
+ expect = '''
+ Expression(
+ body=Call(
+ func=Name(id='print'),
+ args=[
+ Constant(value=1),
+ Constant(value=2),
+ Constant(value=3)]))
+ '''
+ for flag in ('-m=eval', '--mode=eval'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_func_type_mode_flag(self):
+ # test 'python -m ast -m/--mode func_type'
+ source = '(int, str) -> list[int]'
+ expect = '''
+ FunctionType(
+ argtypes=[
+ Name(id='int'),
+ Name(id='str')],
+ returns=Subscript(
+ value=Name(id='list'),
+ slice=Name(id='int')))
+ '''
+ for flag in ('-m=func_type', '--mode=func_type'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_no_type_comments_flag(self):
+ # test 'python -m ast --no-type-comments'
+ source = 'x: bool = 1 # type: ignore[assignment]'
+ expect = '''
+ Module(
+ body=[
+ AnnAssign(
+ target=Name(id='x', ctx=Store()),
+ annotation=Name(id='bool'),
+ value=Constant(value=1),
+ simple=1)])
+ '''
+ self.check_output(source, expect, '--no-type-comments')
+
+ def test_include_attributes_flag(self):
+ # test 'python -m ast -a/--include-attributes'
+ source = 'pass'
+ expect = '''
+ Module(
+ body=[
+ Pass(
+ lineno=1,
+ col_offset=0,
+ end_lineno=1,
+ end_col_offset=4)])
+ '''
+ for flag in ('-a', '--include-attributes'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_indent_flag(self):
+ # test 'python -m ast -i/--indent 0'
+ source = 'pass'
+ expect = '''
+ Module(
+ body=[
+ Pass()])
+ '''
+ for flag in ('-i=0', '--indent=0'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_feature_version_flag(self):
+ # test 'python -m ast --feature-version 3.9/3.10'
+ source = '''
+ match x:
+ case 1:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='x'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=1)),
+ body=[
+ Pass()])])])
+ '''
+ self.check_output(source, expect, '--feature-version=3.10')
+ with self.assertRaises(SyntaxError):
+ self.invoke_ast('--feature-version=3.9')
- self.assertEqual(res.err, b"")
- self.assertEqual(expected.splitlines(),
- res.out.decode("utf8").splitlines())
- self.assertEqual(res.rc, 0)
+ def test_no_optimize_flag(self):
+ # test 'python -m ast -O/--optimize -1/0'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=BinOp(
+ left=Constant(value=1),
+ op=Add(),
+ right=Constant(value=2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=-1', '--optimize=-1', '-O=0', '--optimize=0'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_optimize_flag(self):
+ # test 'python -m ast -O/--optimize 1/2'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=(1+2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=1', '--optimize=1', '-O=2', '--optimize=2'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_show_empty_flag(self):
+ # test 'python -m ast --show-empty'
+ source = 'print(1, 2, 3)'
+ expect = '''
+ Module(
+ body=[
+ Expr(
+ value=Call(
+ func=Name(id='print', ctx=Load()),
+ args=[
+ Constant(value=1),
+ Constant(value=2),
+ Constant(value=3)],
+ keywords=[]))],
+ type_ignores=[])
+ '''
+ self.check_output(source, expect, '--show-empty')
class ASTOptimiziationTests(unittest.TestCase):
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 2c44647bf3e..636cb33dd98 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -2021,6 +2021,15 @@ class TestUnawaitedWarnings(unittest.TestCase):
g.athrow(RuntimeError)
gc_collect()
+ def test_athrow_throws_immediately(self):
+ async def gen():
+ yield 1
+
+ g = gen()
+ msg = "athrow expected at least 1 argument, got 0"
+ with self.assertRaisesRegex(TypeError, msg):
+ g.athrow()
+
def test_aclose(self):
async def gen():
yield 1
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
index a2fb1022ae4..9f3b6f9acef 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -263,6 +263,24 @@ class EagerTaskFactoryLoopTests:
self.run_coro(run())
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=False, name="example"
+ )
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ self.run_coro(main())
+
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
@@ -505,5 +523,24 @@ class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
+
+class DefaultTaskFactoryEagerStart(test_utils.TestCase):
+ def test_eager_start_true_with_default_factory(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=True, name="example"
+ )
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
index 8b51522278a..39bef465bdb 100644
--- a/Lib/test/test_asyncio/test_futures.py
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -413,7 +413,7 @@ class BaseFutureTests:
def test_copy_state(self):
from asyncio.futures import _copy_future_state
- f = self._new_future(loop=self.loop)
+ f = concurrent.futures.Future()
f.set_result(10)
newf = self._new_future(loop=self.loop)
@@ -421,7 +421,7 @@ class BaseFutureTests:
self.assertTrue(newf.done())
self.assertEqual(newf.result(), 10)
- f_exception = self._new_future(loop=self.loop)
+ f_exception = concurrent.futures.Future()
f_exception.set_exception(RuntimeError())
newf_exception = self._new_future(loop=self.loop)
@@ -429,7 +429,7 @@ class BaseFutureTests:
self.assertTrue(newf_exception.done())
self.assertRaises(RuntimeError, newf_exception.result)
- f_cancelled = self._new_future(loop=self.loop)
+ f_cancelled = concurrent.futures.Future()
f_cancelled.cancel()
newf_cancelled = self._new_future(loop=self.loop)
@@ -441,7 +441,7 @@ class BaseFutureTests:
except BaseException as e:
f_exc = e
- f_conexc = self._new_future(loop=self.loop)
+ f_conexc = concurrent.futures.Future()
f_conexc.set_exception(f_exc)
newf_conexc = self._new_future(loop=self.loop)
@@ -454,6 +454,56 @@ class BaseFutureTests:
newf_tb = ''.join(traceback.format_tb(newf_exc.__traceback__))
self.assertEqual(newf_tb.count('raise concurrent.futures.InvalidStateError'), 1)
+ def test_copy_state_from_concurrent_futures(self):
+ """Test _copy_future_state from concurrent.futures.Future.
+
+ This tests the optimized path using _get_snapshot when available.
+ """
+ from asyncio.futures import _copy_future_state
+
+ # Test with a result
+ f_concurrent = concurrent.futures.Future()
+ f_concurrent.set_result(42)
+ f_asyncio = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent, f_asyncio)
+ self.assertTrue(f_asyncio.done())
+ self.assertEqual(f_asyncio.result(), 42)
+
+ # Test with an exception
+ f_concurrent_exc = concurrent.futures.Future()
+ f_concurrent_exc.set_exception(ValueError("test exception"))
+ f_asyncio_exc = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_exc, f_asyncio_exc)
+ self.assertTrue(f_asyncio_exc.done())
+ with self.assertRaises(ValueError) as cm:
+ f_asyncio_exc.result()
+ self.assertEqual(str(cm.exception), "test exception")
+
+ # Test with cancelled state
+ f_concurrent_cancelled = concurrent.futures.Future()
+ f_concurrent_cancelled.cancel()
+ f_asyncio_cancelled = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_cancelled, f_asyncio_cancelled)
+ self.assertTrue(f_asyncio_cancelled.cancelled())
+
+ # Test that destination already cancelled prevents copy
+ f_concurrent_result = concurrent.futures.Future()
+ f_concurrent_result.set_result(10)
+ f_asyncio_precancelled = self._new_future(loop=self.loop)
+ f_asyncio_precancelled.cancel()
+ _copy_future_state(f_concurrent_result, f_asyncio_precancelled)
+ self.assertTrue(f_asyncio_precancelled.cancelled())
+
+ # Test exception type conversion
+ f_concurrent_invalid = concurrent.futures.Future()
+ f_concurrent_invalid.set_exception(concurrent.futures.InvalidStateError("invalid"))
+ f_asyncio_invalid = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_invalid, f_asyncio_invalid)
+ self.assertTrue(f_asyncio_invalid.done())
+ with self.assertRaises(asyncio.exceptions.InvalidStateError) as cm:
+ f_asyncio_invalid.result()
+ self.assertEqual(str(cm.exception), "invalid")
+
def test_iter(self):
fut = self._new_future(loop=self.loop)
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 3bb3e5c4ca0..047f03cbb14 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -14,7 +14,7 @@ STR_RGX_REPR = (
r'(, value:\d)?'
r'(, waiters:\d+)?'
r'(, waiters:\d+\/\d+)?' # barrier
- r')\]>\Z'
+ r')\]>\z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index de81936b745..aab6a779170 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -347,6 +347,18 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
selectors.EVENT_WRITE)])
self.loop._remove_writer.assert_called_with(1)
+ def test_accept_connection_zero_one(self):
+ for backlog in [0, 1]:
+ sock = mock.Mock()
+ sock.accept.return_value = (mock.Mock(), mock.Mock())
+ with self.subTest(backlog):
+ mock_obj = mock.patch.object
+ with mock_obj(self.loop, '_accept_connection2') as accept2_mock:
+ self.loop._accept_connection(
+ mock.Mock(), sock, backlog=backlog)
+ self.loop.run_until_complete(asyncio.sleep(0))
+ self.assertEqual(sock.accept.call_count, backlog + 1)
+
def test_accept_connection_multiple(self):
sock = mock.Mock()
sock.accept.return_value = (mock.Mock(), mock.Mock())
@@ -362,7 +374,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
self.loop._accept_connection(
mock.Mock(), sock, backlog=backlog)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
def test_accept_connection_skip_connectionabortederror(self):
sock = mock.Mock()
@@ -388,7 +400,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
# as in test_accept_connection_multiple avoid task pending
# warnings by using asyncio.sleep(0)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
class SelectorTransportTests(test_utils.TestCase):
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
index 986ecc2c5a9..3a7185cd897 100644
--- a/Lib/test/test_asyncio/test_ssl.py
+++ b/Lib/test/test_asyncio/test_ssl.py
@@ -195,9 +195,10 @@ class TestSSL(test_utils.TestCase):
except (BrokenPipeError, ConnectionError):
pass
- def test_create_server_ssl_1(self):
+ @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False)
+ def test_create_server_ssl_1(self, size):
CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
+ TOTAL_CNT = size # total number of clients that test will create
TIMEOUT = support.LONG_TIMEOUT # timeout for this test
A_DATA = b'A' * 1024 * BUF_MULTIPLIER
@@ -1038,9 +1039,10 @@ class TestSSL(test_utils.TestCase):
self.loop.run_until_complete(run_main())
- def test_create_server_ssl_over_ssl(self):
+ @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False)
+ def test_create_server_ssl_over_ssl(self, size):
CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
+ TOTAL_CNT = size # total number of clients that test will create
TIMEOUT = support.LONG_TIMEOUT # timeout for this test
A_DATA = b'A' * 1024 * BUF_MULTIPLIER
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 8d7f1733454..f6f976f213a 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -89,8 +89,8 @@ class BaseTaskTests:
Future = None
all_tasks = None
- def new_task(self, loop, coro, name='TestTask', context=None):
- return self.__class__.Task(coro, loop=loop, name=name, context=context)
+ def new_task(self, loop, coro, name='TestTask', context=None, eager_start=None):
+ return self.__class__.Task(coro, loop=loop, name=name, context=context, eager_start=eager_start)
def new_future(self, loop):
return self.__class__.Future(loop=loop)
@@ -2116,6 +2116,46 @@ class BaseTaskTests:
self.assertTrue(outer.cancelled())
self.assertEqual(0, 0 if outer._callbacks is None else len(outer._callbacks))
+ def test_shield_cancel_outer_result(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_result(1)
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_not_called()
+
+ def test_shield_cancel_outer_exception(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_exception(Exception('foo'))
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_called_once()
+
+ def test_shield_duplicate_log_once(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_exception(Exception('foo'))
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_called_once()
+
def test_shield_shortcut(self):
fut = self.new_future(self.loop)
fut.set_result(42)
@@ -2686,6 +2726,35 @@ class BaseTaskTests:
self.assertEqual([None, 1, 2], ret)
+ def test_eager_start_true(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=True, name="example")
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=False, name="example")
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
diff --git a/Lib/test/test_asyncio/test_tools.py b/Lib/test/test_asyncio/test_tools.py
new file mode 100644
index 00000000000..34e94830204
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tools.py
@@ -0,0 +1,1706 @@
+import unittest
+
+from asyncio import tools
+
+from collections import namedtuple
+
+FrameInfo = namedtuple('FrameInfo', ['funcname', 'filename', 'lineno'])
+CoroInfo = namedtuple('CoroInfo', ['call_stack', 'task_name'])
+TaskInfo = namedtuple('TaskInfo', ['task_id', 'task_name', 'coroutine_stack', 'awaited_by'])
+AwaitedInfo = namedtuple('AwaitedInfo', ['thread_id', 'awaited_by'])
+
+
+# mock output of get_all_awaited_by function.
+TEST_INPUTS_TREE = [
+ [
+ # test case containing a task called timer being awaited in two
+ # different subtasks part of a TaskGroup (root1 and root2) which call
+ # awaiter functions.
+ (
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="timer",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "/path/to/app.py", 130),
+ FrameInfo("awaiter2", "/path/to/app.py", 120),
+ FrameInfo("awaiter", "/path/to/app.py", 110)
+ ],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiterB3", "/path/to/app.py", 190),
+ FrameInfo("awaiterB2", "/path/to/app.py", 180),
+ FrameInfo("awaiterB", "/path/to/app.py", 170)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiterB3", "/path/to/app.py", 190),
+ FrameInfo("awaiterB2", "/path/to/app.py", 180),
+ FrameInfo("awaiterB", "/path/to/app.py", 170)
+ ],
+ task_name=6
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "/path/to/app.py", 130),
+ FrameInfo("awaiter2", "/path/to/app.py", 120),
+ FrameInfo("awaiter", "/path/to/app.py", 110)
+ ],
+ task_name=7
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="root1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=9,
+ task_name="root2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="child1_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="child2_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="child1_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="child2_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " └── __aexit__",
+ " └── _aexit",
+ " ├── (T) root1",
+ " │ └── bloch",
+ " │ └── blocho_caller",
+ " │ └── __aexit__",
+ " │ └── _aexit",
+ " │ ├── (T) child1_1",
+ " │ │ └── awaiter /path/to/app.py:110",
+ " │ │ └── awaiter2 /path/to/app.py:120",
+ " │ │ └── awaiter3 /path/to/app.py:130",
+ " │ │ └── (T) timer",
+ " │ └── (T) child2_1",
+ " │ └── awaiterB /path/to/app.py:170",
+ " │ └── awaiterB2 /path/to/app.py:180",
+ " │ └── awaiterB3 /path/to/app.py:190",
+ " │ └── (T) timer",
+ " └── (T) root2",
+ " └── bloch",
+ " └── blocho_caller",
+ " └── __aexit__",
+ " └── _aexit",
+ " ├── (T) child1_2",
+ " │ └── awaiter /path/to/app.py:110",
+ " │ └── awaiter2 /path/to/app.py:120",
+ " │ └── awaiter3 /path/to/app.py:130",
+ " │ └── (T) timer",
+ " └── (T) child2_2",
+ " └── awaiterB /path/to/app.py:170",
+ " └── awaiterB2 /path/to/app.py:180",
+ " └── awaiterB3 /path/to/app.py:190",
+ " └── (T) timer",
+ ]
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots
+ (
+ AwaitedInfo(
+ thread_id=9,
+ awaited_by=[
+ TaskInfo(
+ task_id=5,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-6",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-7",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="Task-8",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=10,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=11, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [
+ "└── (T) Task-5",
+ " └── main2",
+ " ├── (T) Task-6",
+ " ├── (T) Task-7",
+ " └── (T) Task-8",
+ ],
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " ├── (T) Task-2",
+ " ├── (T) Task-3",
+ " └── (T) Task-4",
+ ],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots, one of them without subtasks
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=3,
+ awaited_by=[
+ TaskInfo(
+ task_id=4,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=8, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ ["└── (T) Task-5"],
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " ├── (T) Task-2",
+ " ├── (T) Task-3",
+ " └── (T) Task-4",
+ ],
+ ]
+ ),
+ ],
+]
+
+TEST_INPUTS_CYCLES_TREE = [
+ [
+ # this test case contains a cycle: two tasks awaiting each other.
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="a",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter2", "", 0)],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="b",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ ([[4, 3, 4]]),
+ ],
+ [
+ # this test case contains two cycles
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_c", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_a", "", 0)
+ ],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0)
+ ],
+ task_name=6
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ ([[4, 3, 4], [4, 6, 5, 4]]),
+ ],
+]
+
+TEST_INPUTS_TABLE = [
+ [
+ # test case containing a task called timer being awaited in two
+ # different subtasks part of a TaskGroup (root1 and root2) which call
+ # awaiter functions.
+ (
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="timer",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "", 0),
+ FrameInfo("awaiter2", "", 0),
+ FrameInfo("awaiter", "", 0)
+ ],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter1_3", "", 0),
+ FrameInfo("awaiter1_2", "", 0),
+ FrameInfo("awaiter1", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter1_3", "", 0),
+ FrameInfo("awaiter1_2", "", 0),
+ FrameInfo("awaiter1", "", 0)
+ ],
+ task_name=6
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "", 0),
+ FrameInfo("awaiter2", "", 0),
+ FrameInfo("awaiter", "", 0)
+ ],
+ task_name=7
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="root1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=9,
+ task_name="root2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="child1_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="child2_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="child1_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="child2_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter3 -> awaiter2 -> awaiter",
+ "child1_1",
+ "0x4",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter1_3 -> awaiter1_2 -> awaiter1",
+ "child2_2",
+ "0x5",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter1_3 -> awaiter1_2 -> awaiter1",
+ "child2_1",
+ "0x6",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter3 -> awaiter2 -> awaiter",
+ "child1_2",
+ "0x7",
+ ],
+ [
+ 1,
+ "0x8",
+ "root1",
+ "",
+ "_aexit -> __aexit__ -> main",
+ "Task-1",
+ "0x2",
+ ],
+ [
+ 1,
+ "0x9",
+ "root2",
+ "",
+ "_aexit -> __aexit__ -> main",
+ "Task-1",
+ "0x2",
+ ],
+ [
+ 1,
+ "0x4",
+ "child1_1",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root1",
+ "0x8",
+ ],
+ [
+ 1,
+ "0x6",
+ "child2_1",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root1",
+ "0x8",
+ ],
+ [
+ 1,
+ "0x7",
+ "child1_2",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root2",
+ "0x9",
+ ],
+ [
+ 1,
+ "0x5",
+ "child2_2",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root2",
+ "0x9",
+ ],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots
+ (
+ AwaitedInfo(
+ thread_id=9,
+ awaited_by=[
+ TaskInfo(
+ task_id=5,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-6",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-7",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="Task-8",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=10,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=11, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [9, "0x5", "Task-5", "", "", "", "0x0"],
+ [9, "0x6", "Task-6", "", "main2", "Task-5", "0x5"],
+ [9, "0x7", "Task-7", "", "main2", "Task-5", "0x5"],
+ [9, "0x8", "Task-8", "", "main2", "Task-5", "0x5"],
+ [10, "0x1", "Task-1", "", "", "", "0x0"],
+ [10, "0x2", "Task-2", "", "main", "Task-1", "0x1"],
+ [10, "0x3", "Task-3", "", "main", "Task-1", "0x1"],
+ [10, "0x4", "Task-4", "", "main", "Task-1", "0x1"],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots, one of them without subtasks
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=3,
+ awaited_by=[
+ TaskInfo(
+ task_id=4,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=8, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-5", "", "", "", "0x0"],
+ [3, "0x4", "Task-1", "", "", "", "0x0"],
+ [3, "0x5", "Task-2", "", "main", "Task-1", "0x4"],
+ [3, "0x6", "Task-3", "", "main", "Task-1", "0x4"],
+ [3, "0x7", "Task-4", "", "main", "Task-1", "0x4"],
+ ]
+ ),
+ ],
+ # CASES WITH CYCLES
+ [
+ # this test case contains a cycle: two tasks awaiting each other.
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="a",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter2", "", 0)],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="b",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [1, "0x3", "a", "", "awaiter2", "b", "0x4"],
+ [1, "0x3", "a", "", "main", "Task-1", "0x2"],
+ [1, "0x4", "b", "", "awaiter", "a", "0x3"],
+ ]
+ ),
+ ],
+ [
+ # this test case contains two cycles
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_c", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_a", "", 0)
+ ],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0)
+ ],
+ task_name=6
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [
+ 1,
+ "0x3",
+ "A",
+ "",
+ "nested -> nested -> task_b",
+ "B",
+ "0x4",
+ ],
+ [
+ 1,
+ "0x4",
+ "B",
+ "",
+ "nested -> nested -> task_c",
+ "C",
+ "0x5",
+ ],
+ [
+ 1,
+ "0x4",
+ "B",
+ "",
+ "nested -> nested -> task_a",
+ "A",
+ "0x3",
+ ],
+ [
+ 1,
+ "0x5",
+ "C",
+ "",
+ "nested -> nested",
+ "Task-2",
+ "0x6",
+ ],
+ [
+ 1,
+ "0x6",
+ "Task-2",
+ "",
+ "nested -> nested -> task_b",
+ "B",
+ "0x4",
+ ],
+ ]
+ ),
+ ],
+]
+
+
+class TestAsyncioToolsTree(unittest.TestCase):
+ def test_asyncio_utils(self):
+ for input_, tree in TEST_INPUTS_TREE:
+ with self.subTest(input_):
+ result = tools.build_async_tree(input_)
+ self.assertEqual(result, tree)
+
+ def test_asyncio_utils_cycles(self):
+ for input_, cycles in TEST_INPUTS_CYCLES_TREE:
+ with self.subTest(input_):
+ try:
+ tools.build_async_tree(input_)
+ except tools.CycleFoundException as e:
+ self.assertEqual(e.cycles, cycles)
+
+
+class TestAsyncioToolsTable(unittest.TestCase):
+ def test_asyncio_utils(self):
+ for input_, table in TEST_INPUTS_TABLE:
+ with self.subTest(input_):
+ result = tools.build_task_table(input_)
+ self.assertEqual(result, table)
+
+
+class TestAsyncioToolsBasic(unittest.TestCase):
+ def test_empty_input_tree(self):
+ """Test build_async_tree with empty input."""
+ result = []
+ expected_output = []
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_empty_input_table(self):
+ """Test build_task_table with empty input."""
+ result = []
+ expected_output = []
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_only_independent_tasks_tree(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="taskA",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="taskB",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected = [["└── (T) taskA"], ["└── (T) taskB"]]
+ result = tools.build_async_tree(input_)
+ self.assertEqual(sorted(result), sorted(expected))
+
+ def test_only_independent_tasks_table(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="taskA",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="taskB",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ self.assertEqual(
+ tools.build_task_table(input_),
+ [[1, '0xa', 'taskA', '', '', '', '0x0'], [1, '0xb', 'taskB', '', '', '', '0x0']]
+ )
+
+ def test_single_task_tree(self):
+ """Test build_async_tree with a single task and no awaits."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [
+ "└── (T) Task-1",
+ ]
+ ]
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_single_task_table(self):
+ """Test build_task_table with a single task and no awaits."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected_output = [[1, '0x2', 'Task-1', '', '', '', '0x0']]
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_cycle_detection(self):
+ """Test build_async_tree raises CycleFoundException for cyclic input."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as context:
+ tools.build_async_tree(result)
+ self.assertEqual(context.exception.cycles, [[3, 2, 3]])
+
+ def test_complex_tree(self):
+ """Test build_async_tree with a more complex tree structure."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " └── (T) Task-2",
+ " └── main",
+ " └── (T) Task-3",
+ ]
+ ]
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_complex_table(self):
+ """Test build_task_table with a more complex tree structure."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [1, '0x2', 'Task-1', '', '', '', '0x0'],
+ [1, '0x3', 'Task-2', '', 'main', 'Task-1', '0x2'],
+ [1, '0x4', 'Task-3', '', 'main', 'Task-2', '0x3']
+ ]
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_deep_coroutine_chain(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="leaf",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("c1", "", 0),
+ FrameInfo("c2", "", 0),
+ FrameInfo("c3", "", 0),
+ FrameInfo("c4", "", 0),
+ FrameInfo("c5", "", 0)
+ ],
+ task_name=11
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="root",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected = [
+ [
+ "└── (T) root",
+ " └── c5",
+ " └── c4",
+ " └── c3",
+ " └── c2",
+ " └── c1",
+ " └── (T) leaf",
+ ]
+ ]
+ result = tools.build_async_tree(input_)
+ self.assertEqual(result, expected)
+
+ def test_multiple_cycles_same_node(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call1", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call2", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call3", "", 0)],
+ task_name=1
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("call4", "", 0)],
+ task_name=2
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as ctx:
+ tools.build_async_tree(input_)
+ cycles = ctx.exception.cycles
+ self.assertTrue(any(set(c) == {1, 2, 3} for c in cycles))
+
+ def test_table_output_format(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("foo", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-B",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ table = tools.build_task_table(input_)
+ for row in table:
+ self.assertEqual(len(row), 7)
+ self.assertIsInstance(row[0], int) # thread ID
+ self.assertTrue(
+ isinstance(row[1], str) and row[1].startswith("0x")
+ ) # hex task ID
+ self.assertIsInstance(row[2], str) # task name
+ self.assertIsInstance(row[3], str) # coroutine stack
+ self.assertIsInstance(row[4], str) # coroutine chain
+ self.assertIsInstance(row[5], str) # awaiter name
+ self.assertTrue(
+ isinstance(row[6], str) and row[6].startswith("0x")
+ ) # hex awaiter ID
+
+
+class TestAsyncioToolsEdgeCases(unittest.TestCase):
+
+ def test_task_awaits_self(self):
+ """A task directly awaits itself - should raise a cycle."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Self-Awaiter",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("loopback", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as ctx:
+ tools.build_async_tree(input_)
+ self.assertIn([1, 1], ctx.exception.cycles)
+
+ def test_task_with_missing_awaiter_id(self):
+ """Awaiter ID not in task list - should not crash, just show 'Unknown'."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("coro", "", 0)],
+ task_name=999
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ table = tools.build_task_table(input_)
+ self.assertEqual(len(table), 1)
+ self.assertEqual(table[0][5], "Unknown")
+
+ def test_duplicate_coroutine_frames(self):
+ """Same coroutine frame repeated under a parent - should deduplicate."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("frameA", "", 0)],
+ task_name=2
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("frameA", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ tree = tools.build_async_tree(input_)
+ # Both children should be under the same coroutine node
+ flat = "\n".join(tree[0])
+ self.assertIn("frameA", flat)
+ self.assertIn("Task-2", flat)
+ self.assertIn("Task-1", flat)
+
+ flat = "\n".join(tree[1])
+ self.assertIn("frameA", flat)
+ self.assertIn("Task-3", flat)
+ self.assertIn("Task-1", flat)
+
+ def test_task_with_no_name(self):
+ """Task with no name in id2name - should still render with fallback."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="root",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("f1", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name=None,
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ # If name is None, fallback to string should not crash
+ tree = tools.build_async_tree(input_)
+ self.assertIn("(T) None", "\n".join(tree[0]))
+
+ def test_tree_rendering_with_custom_emojis(self):
+ """Pass custom emojis to the tree renderer."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="MainTask",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("f1", "", 0),
+ FrameInfo("f2", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="SubTask",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ tree = tools.build_async_tree(input_, task_emoji="🧵", cor_emoji="🔁")
+ flat = "\n".join(tree[0])
+ self.assertIn("🧵 MainTask", flat)
+ self.assertIn("🔁 f1", flat)
+ self.assertIn("🔁 f2", flat)
+ self.assertIn("🧵 SubTask", flat)
diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py
index 2b24b5d7927..077765fcda2 100644
--- a/Lib/test/test_audit.py
+++ b/Lib/test/test_audit.py
@@ -134,7 +134,7 @@ class AuditTest(unittest.TestCase):
self.assertEqual(events[0][0], "socket.gethostname")
self.assertEqual(events[1][0], "socket.__new__")
self.assertEqual(events[2][0], "socket.bind")
- self.assertTrue(events[2][2].endswith("('127.0.0.1', 8080)"))
+ self.assertEndsWith(events[2][2], "('127.0.0.1', 8080)")
def test_gc(self):
returncode, events, stderr = self.run_python("test_gc")
@@ -322,6 +322,14 @@ class AuditTest(unittest.TestCase):
if returncode:
self.fail(stderr)
+ @support.support_remote_exec_only
+ @support.cpython_only
+ def test_sys_remote_exec(self):
+ returncode, events, stderr = self.run_python("test_sys_remote_exec")
+ self.assertTrue(any(["sys.remote_exec" in event for event in events]))
+ self.assertTrue(any(["cpython.remote_debugger_script" in event for event in events]))
+ if returncode:
+ self.fail(stderr)
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index 409c8c109e8..ce2e3e3726f 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -3,8 +3,16 @@ import base64
import binascii
import os
from array import array
+from test.support import cpython_only
from test.support import os_helper
from test.support import script_helper
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "getopt"})
class LegacyBase64TestCase(unittest.TestCase):
@@ -804,7 +812,7 @@ class BaseXYTestCase(unittest.TestCase):
self.assertRaises(ValueError, f, 'with non-ascii \xcb')
def test_ErrorHeritage(self):
- self.assertTrue(issubclass(binascii.Error, ValueError))
+ self.assertIsSubclass(binascii.Error, ValueError)
def test_RFC4648_test_cases(self):
# test cases from RFC 4648 section 10
diff --git a/Lib/test/test_baseexception.py b/Lib/test/test_baseexception.py
index e599b02c17d..12d4088842b 100644
--- a/Lib/test/test_baseexception.py
+++ b/Lib/test/test_baseexception.py
@@ -10,13 +10,11 @@ class ExceptionClassTests(unittest.TestCase):
inheritance hierarchy)"""
def test_builtins_new_style(self):
- self.assertTrue(issubclass(Exception, object))
+ self.assertIsSubclass(Exception, object)
def verify_instance_interface(self, ins):
for attr in ("args", "__str__", "__repr__"):
- self.assertTrue(hasattr(ins, attr),
- "%s missing %s attribute" %
- (ins.__class__.__name__, attr))
+ self.assertHasAttr(ins, attr)
def test_inheritance(self):
# Make sure the inheritance hierarchy matches the documentation
@@ -65,7 +63,7 @@ class ExceptionClassTests(unittest.TestCase):
elif last_depth > depth:
while superclasses[-1][0] >= depth:
superclasses.pop()
- self.assertTrue(issubclass(exc, superclasses[-1][1]),
+ self.assertIsSubclass(exc, superclasses[-1][1],
"%s is not a subclass of %s" % (exc.__name__,
superclasses[-1][1].__name__))
try: # Some exceptions require arguments; just skip them
diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py
index 1f3b6746ce4..7ed7d7c47b6 100644
--- a/Lib/test/test_binascii.py
+++ b/Lib/test/test_binascii.py
@@ -38,13 +38,13 @@ class BinASCIITest(unittest.TestCase):
def test_exceptions(self):
# Check module exceptions
- self.assertTrue(issubclass(binascii.Error, Exception))
- self.assertTrue(issubclass(binascii.Incomplete, Exception))
+ self.assertIsSubclass(binascii.Error, Exception)
+ self.assertIsSubclass(binascii.Incomplete, Exception)
def test_functions(self):
# Check presence of all functions
for name in all_functions:
- self.assertTrue(hasattr(getattr(binascii, name), '__call__'))
+ self.assertHasAttr(getattr(binascii, name), '__call__')
self.assertRaises(TypeError, getattr(binascii, name))
def test_returned_value(self):
diff --git a/Lib/test/test_binop.py b/Lib/test/test_binop.py
index 299af09c498..b224c3d4e60 100644
--- a/Lib/test/test_binop.py
+++ b/Lib/test/test_binop.py
@@ -383,7 +383,7 @@ class OperationOrderTests(unittest.TestCase):
self.assertEqual(op_sequence(le, B, C), ['C.__ge__', 'B.__le__'])
self.assertEqual(op_sequence(le, C, B), ['C.__le__', 'B.__ge__'])
- self.assertTrue(issubclass(V, B))
+ self.assertIsSubclass(V, B)
self.assertEqual(op_sequence(eq, B, V), ['B.__eq__', 'V.__eq__'])
self.assertEqual(op_sequence(le, B, V), ['B.__le__', 'V.__ge__'])
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
index 61921e93e85..19582e75716 100644
--- a/Lib/test/test_buffer.py
+++ b/Lib/test/test_buffer.py
@@ -2879,11 +2879,11 @@ class TestBufferProtocol(unittest.TestCase):
def test_memoryview_repr(self):
m = memoryview(bytearray(9))
r = m.__repr__()
- self.assertTrue(r.startswith("<memory"))
+ self.assertStartsWith(r, "<memory")
m.release()
r = m.__repr__()
- self.assertTrue(r.startswith("<released"))
+ self.assertStartsWith(r, "<released")
def test_memoryview_sequence(self):
diff --git a/Lib/test/test_bufio.py b/Lib/test/test_bufio.py
index dc9a82dc635..cb9cb4d0bc7 100644
--- a/Lib/test/test_bufio.py
+++ b/Lib/test/test_bufio.py
@@ -28,7 +28,7 @@ class BufferSizeTest:
f.write(b"\n")
f.write(s)
f.close()
- f = open(os_helper.TESTFN, "rb")
+ f = self.open(os_helper.TESTFN, "rb")
line = f.readline()
self.assertEqual(line, s + b"\n")
line = f.readline()
diff --git a/Lib/test/test_build_details.py b/Lib/test/test_build_details.py
index 05ce163a337..ba4b8c5aa9b 100644
--- a/Lib/test/test_build_details.py
+++ b/Lib/test/test_build_details.py
@@ -117,12 +117,26 @@ class CPythonBuildDetailsTests(unittest.TestCase, FormatTestsBase):
# Override generic format tests with tests for our specific implemenation.
@needs_installed_python
- @unittest.skipIf(is_android or is_apple_mobile, 'Android and iOS run tests via a custom testbed method that changes sys.executable')
+ @unittest.skipIf(
+ is_android or is_apple_mobile,
+ 'Android and iOS run tests via a custom testbed method that changes sys.executable'
+ )
def test_base_interpreter(self):
value = self.key('base_interpreter')
self.assertEqual(os.path.realpath(value), os.path.realpath(sys.executable))
+ @needs_installed_python
+ @unittest.skipIf(
+ is_android or is_apple_mobile,
+ "Android and iOS run tests via a custom testbed method that doesn't ship headers"
+ )
+ def test_c_api(self):
+ value = self.key('c_api')
+ self.assertTrue(os.path.exists(os.path.join(value['headers'], 'Python.h')))
+ version = sysconfig.get_config_var('VERSION')
+ self.assertTrue(os.path.exists(os.path.join(value['pkgconfig_path'], f'python-{version}.pc')))
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 31597a320d4..14fe3355239 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -393,7 +393,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(ValueError, chr, -2**1000)
def test_cmp(self):
- self.assertTrue(not hasattr(builtins, "cmp"))
+ self.assertNotHasAttr(builtins, "cmp")
def test_compile(self):
compile('print(1)\n', '', 'exec')
@@ -1120,6 +1120,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.check_iter_pickle(f1, list(f2), proto)
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@support.requires_resource('cpu')
def test_filter_dealloc(self):
# Tests recursive deallocation of nested filter objects using the
@@ -2303,7 +2304,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# tests for object.__format__ really belong elsewhere, but
# there's no good place to put them
x = object().__format__('')
- self.assertTrue(x.startswith('<object object at'))
+ self.assertStartsWith(x, '<object object at')
# first argument to object.__format__ must be string
self.assertRaises(TypeError, object().__format__, 3)
@@ -2990,7 +2991,8 @@ class TestType(unittest.TestCase):
def load_tests(loader, tests, pattern):
from doctest import DocTestSuite
- tests.addTest(DocTestSuite(builtins))
+ if sys.float_repr_style == 'short':
+ tests.addTest(DocTestSuite(builtins))
return tests
if __name__ == "__main__":
diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
index 82d9916e38d..bb0f8aa99da 100644
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -1974,9 +1974,9 @@ class AssortedBytesTest(unittest.TestCase):
@test.support.requires_docstrings
def test_doc(self):
self.assertIsNotNone(bytearray.__doc__)
- self.assertTrue(bytearray.__doc__.startswith("bytearray("), bytearray.__doc__)
+ self.assertStartsWith(bytearray.__doc__, "bytearray(")
self.assertIsNotNone(bytes.__doc__)
- self.assertTrue(bytes.__doc__.startswith("bytes("), bytes.__doc__)
+ self.assertStartsWith(bytes.__doc__, "bytes(")
def test_from_bytearray(self):
sample = bytes(b"Hello world\n\x80\x81\xfe\xff")
@@ -2107,7 +2107,7 @@ class BytesAsStringTest(FixedStringTest, unittest.TestCase):
class SubclassTest:
def test_basic(self):
- self.assertTrue(issubclass(self.type2test, self.basetype))
+ self.assertIsSubclass(self.type2test, self.basetype)
self.assertIsInstance(self.type2test(), self.basetype)
a, b = b"abcd", b"efgh"
@@ -2155,7 +2155,7 @@ class SubclassTest:
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.z), type(b.z))
- self.assertFalse(hasattr(b, 'y'))
+ self.assertNotHasAttr(b, 'y')
def test_copy(self):
a = self.type2test(b"abcd")
@@ -2169,7 +2169,7 @@ class SubclassTest:
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.z), type(b.z))
- self.assertFalse(hasattr(b, 'y'))
+ self.assertNotHasAttr(b, 'y')
def test_fromhex(self):
b = self.type2test.fromhex('1a2B30')
diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py
index f32b24b39ba..3b7897b8a88 100644
--- a/Lib/test/test_bz2.py
+++ b/Lib/test/test_bz2.py
@@ -184,7 +184,7 @@ class BZ2FileTest(BaseTest):
with BZ2File(self.filename) as bz2f:
pdata = bz2f.peek()
self.assertNotEqual(len(pdata), 0)
- self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertStartsWith(self.TEXT, pdata)
self.assertEqual(bz2f.read(), self.TEXT)
def testReadInto(self):
@@ -768,7 +768,7 @@ class BZ2FileTest(BaseTest):
with BZ2File(bio) as bz2f:
pdata = bz2f.peek()
self.assertNotEqual(len(pdata), 0)
- self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertStartsWith(self.TEXT, pdata)
self.assertEqual(bz2f.read(), self.TEXT)
def testWriteBytesIO(self):
diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py
index 073df310bb4..bc39c86b8cf 100644
--- a/Lib/test/test_calendar.py
+++ b/Lib/test/test_calendar.py
@@ -417,7 +417,7 @@ class OutputTestCase(unittest.TestCase):
self.check_htmlcalendar_encoding('utf-8', 'utf-8')
def test_output_htmlcalendar_encoding_default(self):
- self.check_htmlcalendar_encoding(None, sys.getdefaultencoding())
+ self.check_htmlcalendar_encoding(None, 'utf-8')
def test_yeardatescalendar(self):
def shrink(cal):
@@ -987,6 +987,7 @@ class CommandLineTestCase(unittest.TestCase):
self.assertCLIFails(*args)
self.assertCmdFails(*args)
+ @support.force_not_colorized
def test_help(self):
stdout = self.run_cmd_ok('-h')
self.assertIn(b'usage:', stdout)
@@ -1097,7 +1098,7 @@ class CommandLineTestCase(unittest.TestCase):
output = run('--type', 'text', '2004')
self.assertEqual(output, conv(result_2004_text))
output = run('--type', 'html', '2004')
- self.assertEqual(output[:6], b'<?xml ')
+ self.assertStartsWith(output, b'<?xml ')
self.assertIn(b'<title>Calendar for 2004</title>', output)
def test_html_output_current_year(self):
diff --git a/Lib/test/test_call.py b/Lib/test/test_call.py
index 185ae84dc4d..1c73aaafb71 100644
--- a/Lib/test/test_call.py
+++ b/Lib/test/test_call.py
@@ -695,8 +695,8 @@ class TestPEP590(unittest.TestCase):
UnaffectedType2 = _testcapi.make_vectorcall_class(SuperType)
# Aside: Quickly check that the C helper actually made derived types
- self.assertTrue(issubclass(UnaffectedType1, DerivedType))
- self.assertTrue(issubclass(UnaffectedType2, SuperType))
+ self.assertIsSubclass(UnaffectedType1, DerivedType)
+ self.assertIsSubclass(UnaffectedType2, SuperType)
# Initial state: tp_call
self.assertEqual(instance(), "tp_call")
diff --git a/Lib/test/test_capi/test_bytearray.py b/Lib/test/test_capi/test_bytearray.py
index dfa98de9f00..52565ea34c6 100644
--- a/Lib/test/test_capi/test_bytearray.py
+++ b/Lib/test/test_capi/test_bytearray.py
@@ -66,6 +66,7 @@ class CAPITest(unittest.TestCase):
# Test PyByteArray_FromObject()
fromobject = _testlimitedcapi.bytearray_fromobject
+ self.assertEqual(fromobject(b''), bytearray(b''))
self.assertEqual(fromobject(b'abc'), bytearray(b'abc'))
self.assertEqual(fromobject(bytearray(b'abc')), bytearray(b'abc'))
self.assertEqual(fromobject(ByteArraySubclass(b'abc')), bytearray(b'abc'))
@@ -115,6 +116,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(concat(b'abc', bytearray(b'def')), bytearray(b'abcdef'))
self.assertEqual(concat(bytearray(b'abc'), b''), bytearray(b'abc'))
self.assertEqual(concat(b'', bytearray(b'def')), bytearray(b'def'))
+ self.assertEqual(concat(bytearray(b''), bytearray(b'')), bytearray(b''))
self.assertEqual(concat(memoryview(b'xabcy')[1:4], b'def'),
bytearray(b'abcdef'))
self.assertEqual(concat(b'abc', memoryview(b'xdefy')[1:4]),
@@ -150,6 +152,10 @@ class CAPITest(unittest.TestCase):
self.assertEqual(resize(ba, 0), 0)
self.assertEqual(ba, bytearray())
+ ba = bytearray(b'')
+ self.assertEqual(resize(ba, 0), 0)
+ self.assertEqual(ba, bytearray())
+
ba = ByteArraySubclass(b'abcdef')
self.assertEqual(resize(ba, 3), 0)
self.assertEqual(ba, bytearray(b'abc'))
diff --git a/Lib/test/test_capi/test_bytes.py b/Lib/test/test_capi/test_bytes.py
index 5b61c733815..bc820bd68d9 100644
--- a/Lib/test/test_capi/test_bytes.py
+++ b/Lib/test/test_capi/test_bytes.py
@@ -22,6 +22,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_Check()
check = _testlimitedcapi.bytes_check
self.assertTrue(check(b'abc'))
+ self.assertTrue(check(b''))
self.assertFalse(check('abc'))
self.assertFalse(check(bytearray(b'abc')))
self.assertTrue(check(BytesSubclass(b'abc')))
@@ -36,6 +37,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_CheckExact()
check = _testlimitedcapi.bytes_checkexact
self.assertTrue(check(b'abc'))
+ self.assertTrue(check(b''))
self.assertFalse(check('abc'))
self.assertFalse(check(bytearray(b'abc')))
self.assertFalse(check(BytesSubclass(b'abc')))
@@ -79,6 +81,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_FromObject()
fromobject = _testlimitedcapi.bytes_fromobject
+ self.assertEqual(fromobject(b''), b'')
self.assertEqual(fromobject(b'abc'), b'abc')
self.assertEqual(fromobject(bytearray(b'abc')), b'abc')
self.assertEqual(fromobject(BytesSubclass(b'abc')), b'abc')
@@ -108,6 +111,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(asstring(b'abc', 4), b'abc\0')
self.assertEqual(asstring(b'abc\0def', 8), b'abc\0def\0')
+ self.assertEqual(asstring(b'', 1), b'\0')
self.assertRaises(TypeError, asstring, 'abc', 0)
self.assertRaises(TypeError, asstring, object(), 0)
@@ -120,6 +124,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(asstringandsize(b'abc', 4), (b'abc\0', 3))
self.assertEqual(asstringandsize(b'abc\0def', 8), (b'abc\0def\0', 7))
+ self.assertEqual(asstringandsize(b'', 1), (b'\0', 0))
self.assertEqual(asstringandsize_null(b'abc', 4), b'abc\0')
self.assertRaises(ValueError, asstringandsize_null, b'abc\0def', 8)
self.assertRaises(TypeError, asstringandsize, 'abc', 0)
@@ -134,6 +139,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_Repr()
bytes_repr = _testlimitedcapi.bytes_repr
+ self.assertEqual(bytes_repr(b'', 0), r"""b''""")
self.assertEqual(bytes_repr(b'''abc''', 0), r"""b'abc'""")
self.assertEqual(bytes_repr(b'''abc''', 1), r"""b'abc'""")
self.assertEqual(bytes_repr(b'''a'b"c"d''', 0), r"""b'a\'b"c"d'""")
@@ -163,6 +169,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(concat(b'', bytearray(b'def')), b'def')
self.assertEqual(concat(memoryview(b'xabcy')[1:4], b'def'), b'abcdef')
self.assertEqual(concat(b'abc', memoryview(b'xdefy')[1:4]), b'abcdef')
+ self.assertEqual(concat(b'', b''), b'')
self.assertEqual(concat(b'abc', b'def', True), b'abcdef')
self.assertEqual(concat(b'abc', bytearray(b'def'), True), b'abcdef')
@@ -192,6 +199,7 @@ class CAPITest(unittest.TestCase):
"""Test PyBytes_DecodeEscape()"""
decodeescape = _testlimitedcapi.bytes_decodeescape
+ self.assertEqual(decodeescape(b''), b'')
self.assertEqual(decodeescape(b'abc'), b'abc')
self.assertEqual(decodeescape(br'\t\n\r\x0b\x0c\x00\\\'\"'),
b'''\t\n\r\v\f\0\\'"''')
diff --git a/Lib/test/test_capi/test_config.py b/Lib/test/test_capi/test_config.py
index bf351c4defa..04a27de8d84 100644
--- a/Lib/test/test_capi/test_config.py
+++ b/Lib/test/test_capi/test_config.py
@@ -3,7 +3,6 @@ Tests PyConfig_Get() and PyConfig_Set() C API (PEP 741).
"""
import os
import sys
-import sysconfig
import types
import unittest
from test import support
@@ -57,7 +56,7 @@ class CAPITests(unittest.TestCase):
("home", str | None, None),
("thread_inherit_context", int, None),
("context_aware_warnings", int, None),
- ("import_time", bool, None),
+ ("import_time", int, None),
("inspect", bool, None),
("install_signal_handlers", bool, None),
("int_max_str_digits", int, None),
diff --git a/Lib/test/test_capi/test_float.py b/Lib/test/test_capi/test_float.py
index c857959d569..f7efe0d0254 100644
--- a/Lib/test/test_capi/test_float.py
+++ b/Lib/test/test_capi/test_float.py
@@ -183,31 +183,35 @@ class CAPIFloatTest(unittest.TestCase):
def test_pack_unpack_roundtrip_for_nans(self):
pack = _testcapi.float_pack
unpack = _testcapi.float_unpack
- for _ in range(1000):
+
+ for _ in range(10):
for size in (2, 4, 8):
sign = random.randint(0, 1)
- signaling = random.randint(0, 1)
+ if sys.maxsize != 2147483647: # not it 32-bit mode
+ signaling = random.randint(0, 1)
+ else:
+ # Skip sNaN's on x86 (32-bit). The problem is that sNaN
+ # doubles become qNaN doubles just by the C calling
+ # convention, there is no way to preserve sNaN doubles
+ # between C function calls with the current
+ # PyFloat_Pack/Unpack*() API. See also gh-130317 and
+ # e.g. https://developercommunity.visualstudio.com/t/155064
+ signaling = 0
quiet = int(not signaling)
if size == 8:
- payload = random.randint(signaling, 1 << 50)
+ payload = random.randint(signaling, 0x7ffffffffffff)
i = (sign << 63) + (0x7ff << 52) + (quiet << 51) + payload
elif size == 4:
- payload = random.randint(signaling, 1 << 21)
+ payload = random.randint(signaling, 0x3fffff)
i = (sign << 31) + (0xff << 23) + (quiet << 22) + payload
elif size == 2:
- payload = random.randint(signaling, 1 << 8)
+ payload = random.randint(signaling, 0x1ff)
i = (sign << 15) + (0x1f << 10) + (quiet << 9) + payload
data = bytes.fromhex(f'{i:x}')
for endian in (BIG_ENDIAN, LITTLE_ENDIAN):
with self.subTest(data=data, size=size, endian=endian):
data1 = data if endian == BIG_ENDIAN else data[::-1]
value = unpack(data1, endian)
- if signaling and sys.platform == 'win32':
- # On this platform sNaN becomes qNaN when returned
- # from function. That's a known bug, e.g.
- # https://developercommunity.visualstudio.com/t/155064
- # (see also gh-130317).
- value = _testcapi.float_set_snan(value)
data2 = pack(size, value, endian)
self.assertTrue(math.isnan(value))
self.assertEqual(data1, data2)
diff --git a/Lib/test/test_capi/test_import.py b/Lib/test/test_capi/test_import.py
index 25136624ca4..57e0316fda8 100644
--- a/Lib/test/test_capi/test_import.py
+++ b/Lib/test/test_capi/test_import.py
@@ -134,7 +134,7 @@ class ImportTests(unittest.TestCase):
# CRASHES importmodule(NULL)
def test_importmodulenoblock(self):
- # Test deprecated PyImport_ImportModuleNoBlock()
+ # Test deprecated (stable ABI only) PyImport_ImportModuleNoBlock()
importmodulenoblock = _testlimitedcapi.PyImport_ImportModuleNoBlock
with check_warnings(('', DeprecationWarning)):
self.check_import_func(importmodulenoblock)
diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py
index 98dc3b42ef0..ef950f5df04 100644
--- a/Lib/test/test_capi/test_misc.py
+++ b/Lib/test/test_capi/test_misc.py
@@ -306,7 +306,7 @@ class CAPITest(unittest.TestCase):
CURRENT_THREAD_REGEX +
r' File .*, line 6 in <module>\n'
r'\n'
- r'Extension modules: _testcapi, _testinternalcapi \(total: 2\)\n')
+ r'Extension modules: _testcapi \(total: 1\)\n')
else:
# Python built with NDEBUG macro defined:
# test _Py_CheckFunctionResult() instead.
@@ -412,10 +412,14 @@ class CAPITest(unittest.TestCase):
L = MyList((L,))
@support.requires_resource('cpu')
+ @support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_trashcan_python_class1(self):
self.do_test_trashcan_python_class(list)
@support.requires_resource('cpu')
+ @support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_trashcan_python_class2(self):
from _testcapi import MyList
self.do_test_trashcan_python_class(MyList)
diff --git a/Lib/test/test_capi/test_object.py b/Lib/test/test_capi/test_object.py
index 3e8fd91b9a6..d4056727d07 100644
--- a/Lib/test/test_capi/test_object.py
+++ b/Lib/test/test_capi/test_object.py
@@ -1,4 +1,5 @@
import enum
+import sys
import textwrap
import unittest
from test import support
@@ -173,6 +174,16 @@ class EnableDeferredRefcountingTest(unittest.TestCase):
self.assertTrue(_testinternalcapi.has_deferred_refcount(silly_list))
+class IsUniquelyReferencedTest(unittest.TestCase):
+ """Test PyUnstable_Object_IsUniquelyReferenced"""
+ def test_is_uniquely_referenced(self):
+ self.assertTrue(_testcapi.is_uniquely_referenced(object()))
+ self.assertTrue(_testcapi.is_uniquely_referenced([]))
+ # Immortals
+ self.assertFalse(_testcapi.is_uniquely_referenced(()))
+ self.assertFalse(_testcapi.is_uniquely_referenced(42))
+ # CRASHES is_uniquely_referenced(NULL)
+
class CAPITest(unittest.TestCase):
def check_negative_refcount(self, code):
# bpo-35059: Check that Py_DECREF() reports the correct filename
@@ -210,6 +221,7 @@ class CAPITest(unittest.TestCase):
"""
self.check_negative_refcount(code)
+ @support.requires_resource('cpu')
def test_decref_delayed(self):
# gh-130519: Test that _PyObject_XDecRefDelayed() and QSBR code path
# handles destructors that are possibly re-entrant or trigger a GC.
@@ -223,5 +235,17 @@ class CAPITest(unittest.TestCase):
obj = MyObj()
_testinternalcapi.incref_decref_delayed(obj)
+ def test_is_unique_temporary(self):
+ self.assertTrue(_testcapi.pyobject_is_unique_temporary(object()))
+ obj = object()
+ self.assertFalse(_testcapi.pyobject_is_unique_temporary(obj))
+
+ def func(x):
+ # This relies on the LOAD_FAST_BORROW optimization (gh-130704)
+ self.assertEqual(sys.getrefcount(x), 1)
+ self.assertFalse(_testcapi.pyobject_is_unique_temporary(x))
+
+ func(object())
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index 7e0c60d5522..e4c9a463855 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -407,12 +407,12 @@ class TestUops(unittest.TestCase):
x = 0
for i in range(m):
for j in MyIter(n):
- x += 1000*i + j
+ x += j
return x
- x = testfunc(TIER2_THRESHOLD, TIER2_THRESHOLD)
+ x = testfunc(TIER2_THRESHOLD, 2)
- self.assertEqual(x, sum(range(TIER2_THRESHOLD)) * TIER2_THRESHOLD * 1001)
+ self.assertEqual(x, sum(range(TIER2_THRESHOLD)) * 2)
ex = get_first_executor(testfunc)
self.assertIsNotNone(ex)
@@ -678,7 +678,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_ADD_FLOAT", uops)
+ self.assertIn("_BINARY_OP_ADD_FLOAT__NO_DECREF_INPUTS", uops)
def test_float_subtract_constant_propagation(self):
def testfunc(n):
@@ -700,7 +700,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_SUBTRACT_FLOAT", uops)
+ self.assertIn("_BINARY_OP_SUBTRACT_FLOAT__NO_DECREF_INPUTS", uops)
def test_float_multiply_constant_propagation(self):
def testfunc(n):
@@ -722,7 +722,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_MULTIPLY_FLOAT", uops)
+ self.assertIn("_BINARY_OP_MULTIPLY_FLOAT__NO_DECREF_INPUTS", uops)
def test_add_unicode_propagation(self):
def testfunc(n):
@@ -1183,6 +1183,17 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertIn("_RETURN_GENERATOR", get_opnames(ex))
+ def test_for_iter(self):
+ def testfunc(n):
+ t = 0
+ for i in set(range(n)):
+ t += i
+ return t
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD * (TIER2_THRESHOLD - 1) // 2)
+ self.assertIsNotNone(ex)
+ self.assertIn("_FOR_ITER_TIER_TWO", get_opnames(ex))
+
@unittest.skip("Tracing into generators currently isn't supported.")
def test_for_iter_gen(self):
def gen(n):
@@ -1280,8 +1291,8 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertEqual(res, TIER2_THRESHOLD * 6 + 1)
call = opnames.index("_CALL_BUILTIN_FAST")
- load_attr_top = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", 0, call)
- load_attr_bottom = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", call)
+ load_attr_top = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", 0, call)
+ load_attr_bottom = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", call)
self.assertEqual(opnames[:load_attr_top].count("_GUARD_TYPE_VERSION"), 1)
self.assertEqual(opnames[call:load_attr_bottom].count("_CHECK_VALIDITY"), 2)
@@ -1303,8 +1314,8 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertEqual(res, TIER2_THRESHOLD * 2)
call = opnames.index("_CALL_BUILTIN_FAST_WITH_KEYWORDS")
- load_attr_top = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", 0, call)
- load_attr_bottom = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", call)
+ load_attr_top = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", 0, call)
+ load_attr_bottom = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", call)
self.assertEqual(opnames[:load_attr_top].count("_GUARD_TYPE_VERSION"), 1)
self.assertEqual(opnames[call:load_attr_bottom].count("_CHECK_VALIDITY"), 2)
@@ -1370,6 +1381,21 @@ class TestUopsOptimization(unittest.TestCase):
# Removed guard
self.assertNotIn("_CHECK_FUNCTION_EXACT_ARGS", uops)
+ def test_method_guards_removed_or_reduced(self):
+ def testfunc(n):
+ result = 0
+ for i in range(n):
+ result += test_bound_method(i)
+ return result
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, sum(range(TIER2_THRESHOLD)))
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_PUSH_FRAME", uops)
+ # Strength reduced version
+ self.assertIn("_CHECK_FUNCTION_VERSION_INLINE", uops)
+ self.assertNotIn("_CHECK_METHOD_VERSION", uops)
+
def test_jit_error_pops(self):
"""
Tests that the correct number of pops are inserted into the
@@ -1655,13 +1681,11 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIn("_CONTAINS_OP_DICT", uops)
self.assertNotIn("_TO_BOOL_BOOL", uops)
-
def test_remove_guard_for_known_type_str(self):
def f(n):
for i in range(n):
false = i == TIER2_THRESHOLD
empty = "X"[:false]
- empty += "" # Make JIT realize this is a string.
if empty:
return 1
return 0
@@ -1767,11 +1791,12 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_GUARD_TOS_UNICODE", uops)
self.assertIn("_BINARY_OP_ADD_UNICODE", uops)
- def test_call_type_1(self):
+ def test_call_type_1_guards_removed(self):
def testfunc(n):
x = 0
for _ in range(n):
- x += type(42) is int
+ foo = eval('42')
+ x += type(foo) is int
return x
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
@@ -1782,6 +1807,25 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_GUARD_NOS_NULL", uops)
self.assertNotIn("_GUARD_CALLABLE_TYPE_1", uops)
+ def test_call_type_1_known_type(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ x += type(42) is int
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the result of type(...) is known, _CALL_TYPE_1 is replaced with
+ # _POP_CALL_ONE_LOAD_CONST_INLINE_BORROW which is optimized away in
+ # remove_unneeded_uops.
+ self.assertNotIn("_CALL_TYPE_1", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+
def test_call_type_1_result_is_const(self):
def testfunc(n):
x = 0
@@ -1795,7 +1839,6 @@ class TestUopsOptimization(unittest.TestCase):
self.assertEqual(res, TIER2_THRESHOLD)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
- self.assertIn("_CALL_TYPE_1", uops)
self.assertNotIn("_GUARD_IS_NOT_NONE_POP", uops)
def test_call_str_1(self):
@@ -1919,9 +1962,98 @@ class TestUopsOptimization(unittest.TestCase):
_, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_NULL", uops)
+ self.assertNotIn("_GUARD_CALLABLE_LEN", uops)
+ self.assertIn("_CALL_LEN", uops)
self.assertNotIn("_GUARD_NOS_INT", uops)
self.assertNotIn("_GUARD_TOS_INT", uops)
+
+ def test_call_len_known_length_small_int(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ t = (1, 2, 3, 4, 5)
+ if len(t) == 5:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the length is < _PY_NSMALLPOSINTS, the len() call is replaced
+ # with just an inline load.
+ self.assertNotIn("_CALL_LEN", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_len_known_length(self):
+ def testfunc(n):
+ class C:
+ t = tuple(range(300))
+
+ x = 0
+ for _ in range(n):
+ if len(C.t) == 300: # comparison + guard removed
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the length is >= _PY_NSMALLPOSINTS, we cannot replace
+ # the len() call with an inline load, but knowing the exact
+ # length allows us to optimize more code, such as conditionals
+ # in this case
self.assertIn("_CALL_LEN", uops)
+ self.assertNotIn("_COMPARE_OP_INT", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_get_len_with_const_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match (1, 2, 3, 4):
+ case [_, _, _, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GET_LEN", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_get_len_with_non_const_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match object(), object():
+ case [_, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GET_LEN", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_get_len_with_non_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match [1, 2, 3, 4]:
+ case [_, _, _, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertIn("_GET_LEN", uops)
def test_binary_op_subscr_tuple_int(self):
def testfunc(n):
@@ -1940,9 +2072,394 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_COMPARE_OP_INT", uops)
self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ def test_call_isinstance_guards_removed(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_GUARD_THIRD_NULL", uops)
+ self.assertNotIn("_GUARD_CALLABLE_ISINSTANCE", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_list_append(self):
+ def testfunc(n):
+ a = []
+ for i in range(n):
+ a.append(i)
+ return sum(a)
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, sum(range(TIER2_THRESHOLD)))
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_LIST_APPEND", uops)
+ # We should remove these in the future
+ self.assertIn("_GUARD_NOS_LIST", uops)
+ self.assertIn("_GUARD_CALLABLE_LIST_APPEND", uops)
+
+ def test_call_isinstance_is_true(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_is_false(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, str)
+ if not y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_FALSE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_subclass(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(True, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_unknown_object(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # The optimizer doesn't know the return type here:
+ bar = eval("42")
+ # This will only narrow to bool:
+ y = isinstance(bar, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_call_isinstance_tuple_of_classes(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # A tuple of classes is currently not optimized,
+ # so this is only narrowed to bool:
+ y = isinstance(42, (int, str))
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_call_isinstance_metaclass(self):
+ class EvenNumberMeta(type):
+ def __instancecheck__(self, number):
+ return number % 2 == 0
+
+ class EvenNumber(metaclass=EvenNumberMeta):
+ pass
+
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # Only narrowed to bool
+ y = isinstance(42, EvenNumber)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_set_type_version_sets_type(self):
+ class C:
+ A = 1
+
+ def testfunc(n):
+ x = 0
+ c = C()
+ for _ in range(n):
+ x += c.A # Guarded.
+ x += type(c).A # Unguarded!
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, 2 * TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_GUARD_TYPE_VERSION", uops)
+ self.assertNotIn("_CHECK_ATTR_CLASS", uops)
+
+ def test_load_small_int(self):
+ def testfunc(n):
+ x = 0
+ for i in range(n):
+ x += 1
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_LOAD_SMALL_INT", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_cached_attributes(self):
+ class C:
+ A = 1
+ def m(self):
+ return 1
+ class D:
+ __slots__ = ()
+ A = 1
+ def m(self):
+ return 1
+ class E(Exception):
+ def m(self):
+ return 1
+ def f(n):
+ x = 0
+ c = C()
+ d = D()
+ e = E()
+ for _ in range(n):
+ x += C.A # _LOAD_ATTR_CLASS
+ x += c.A # _LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES
+ x += d.A # _LOAD_ATTR_NONDESCRIPTOR_NO_DICT
+ x += c.m() # _LOAD_ATTR_METHOD_WITH_VALUES
+ x += d.m() # _LOAD_ATTR_METHOD_NO_DICT
+ x += e.m() # _LOAD_ATTR_METHOD_LAZY_DICT
+ return x
+
+ res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertEqual(res, 6 * TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_LOAD_ATTR_CLASS", uops)
+ self.assertNotIn("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", uops)
+ self.assertNotIn("_LOAD_ATTR_NONDESCRIPTOR_NO_DICT", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_WITH_VALUES", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_NO_DICT", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_LAZY_DICT", uops)
+
+ def test_float_op_refcount_elimination(self):
+ def testfunc(args):
+ a, b, n = args
+ c = 0.0
+ for _ in range(n):
+ c += a + b
+ return c
+
+ res, ex = self._run_with_optimizer(testfunc, (0.1, 0.1, TIER2_THRESHOLD))
+ self.assertAlmostEqual(res, TIER2_THRESHOLD * (0.1 + 0.1))
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_BINARY_OP_ADD_FLOAT__NO_DECREF_INPUTS", uops)
+
+ def test_remove_guard_for_slice_list(self):
+ def f(n):
+ for i in range(n):
+ false = i == TIER2_THRESHOLD
+ sliced = [1, 2, 3][:false]
+ if sliced:
+ return 1
+ return 0
+
+ res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertEqual(res, 0)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_TO_BOOL_LIST", uops)
+ self.assertNotIn("_GUARD_TOS_LIST", uops)
+
+ def test_remove_guard_for_slice_tuple(self):
+ def f(n):
+ for i in range(n):
+ false = i == TIER2_THRESHOLD
+ a, b = (1, 2, 3)[: false + 2]
+
+ _, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_UNPACK_SEQUENCE_TWO_TUPLE", uops)
+ self.assertNotIn("_GUARD_TOS_TUPLE", uops)
+
+ def test_unary_invert_long_type(self):
+ def testfunc(n):
+ for _ in range(n):
+ a = 9397
+ x = ~a + ~a
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertNotIn("_GUARD_TOS_INT", uops)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+
+ def test_attr_promotion_failure(self):
+ # We're not testing for any specific uops here, just
+ # testing it doesn't crash.
+ script_helper.assert_python_ok('-c', textwrap.dedent("""
+ import _testinternalcapi
+ import _opcode
+ import email
+
+ def get_first_executor(func):
+ code = func.__code__
+ co_code = code.co_code
+ for i in range(0, len(co_code), 2):
+ try:
+ return _opcode.get_executor(code, i)
+ except ValueError:
+ pass
+ return None
+
+ def testfunc(n):
+ for _ in range(n):
+ email.jit_testing = None
+ prompt = email.jit_testing
+ del email.jit_testing
+
+
+ testfunc(_testinternalcapi.TIER2_THRESHOLD)
+ ex = get_first_executor(testfunc)
+ assert ex is not None
+ """))
+
+ def test_pop_top_specialize_none(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(None)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_NOP", uops)
+
+ def test_pop_top_specialize_int(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(100000)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_INT", uops)
+
+ def test_pop_top_specialize_float(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(1e6)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_FLOAT", uops)
+
+
+ def test_unary_negative_long_float_type(self):
+ def testfunc(n):
+ for _ in range(n):
+ a = 9397
+ f = 9397.0
+ x = -a + -a
+ y = -f + -f
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertNotIn("_GUARD_TOS_INT", uops)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GUARD_TOS_FLOAT", uops)
+ self.assertNotIn("_GUARD_NOS_FLOAT", uops)
def global_identity(x):
return x
+class TestObject:
+ def test(self, *args, **kwargs):
+ return args[0]
+
+test_object = TestObject()
+test_bound_method = TestObject.test.__get__(test_object)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_capi/test_sys.py b/Lib/test/test_capi/test_sys.py
index d3a9b378e77..3793ce2461e 100644
--- a/Lib/test/test_capi/test_sys.py
+++ b/Lib/test/test_capi/test_sys.py
@@ -19,6 +19,68 @@ class CAPITest(unittest.TestCase):
maxDiff = None
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getattr(self):
+ # Test PySys_GetAttr()
+ sys_getattr = _testlimitedcapi.sys_getattr
+
+ self.assertIs(sys_getattr('stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(sys_getattr('\U0001f40d'), 42)
+
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.nonexistent'):
+ sys_getattr('nonexistent')
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.\U0001f40d'):
+ sys_getattr('\U0001f40d')
+ self.assertRaises(TypeError, sys_getattr, 1)
+ self.assertRaises(TypeError, sys_getattr, [])
+ # CRASHES sys_getattr(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getattrstring(self):
+ # Test PySys_GetAttrString()
+ getattrstring = _testlimitedcapi.sys_getattrstring
+
+ self.assertIs(getattrstring(b'stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getattrstring('\U0001f40d'.encode()), 42)
+
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.nonexistent'):
+ getattrstring(b'nonexistent')
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.\U0001f40d'):
+ getattrstring('\U0001f40d'.encode())
+ self.assertRaises(UnicodeDecodeError, getattrstring, b'\xff')
+ # CRASHES getattrstring(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getoptionalattr(self):
+ # Test PySys_GetOptionalAttr()
+ getoptionalattr = _testlimitedcapi.sys_getoptionalattr
+
+ self.assertIs(getoptionalattr('stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getoptionalattr('\U0001f40d'), 42)
+
+ self.assertIs(getoptionalattr('nonexistent'), AttributeError)
+ self.assertIs(getoptionalattr('\U0001f40d'), AttributeError)
+ self.assertRaises(TypeError, getoptionalattr, 1)
+ self.assertRaises(TypeError, getoptionalattr, [])
+ # CRASHES getoptionalattr(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getoptionalattrstring(self):
+ # Test PySys_GetOptionalAttrString()
+ getoptionalattrstring = _testlimitedcapi.sys_getoptionalattrstring
+
+ self.assertIs(getoptionalattrstring(b'stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getoptionalattrstring('\U0001f40d'.encode()), 42)
+
+ self.assertIs(getoptionalattrstring(b'nonexistent'), AttributeError)
+ self.assertIs(getoptionalattrstring('\U0001f40d'.encode()), AttributeError)
+ self.assertRaises(UnicodeDecodeError, getoptionalattrstring, b'\xff')
+ # CRASHES getoptionalattrstring(NULL)
+
@support.cpython_only
@unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
def test_sys_getobject(self):
@@ -29,7 +91,7 @@ class CAPITest(unittest.TestCase):
with support.swap_attr(sys, '\U0001f40d', 42):
self.assertEqual(getobject('\U0001f40d'.encode()), 42)
- self.assertIs(getobject(b'nonexisting'), AttributeError)
+ self.assertIs(getobject(b'nonexistent'), AttributeError)
with support.catch_unraisable_exception() as cm:
self.assertIs(getobject(b'\xff'), AttributeError)
self.assertEqual(cm.unraisable.exc_type, UnicodeDecodeError)
diff --git a/Lib/test/test_capi/test_type.py b/Lib/test/test_capi/test_type.py
index 7e5d013d737..15fb4a93e2a 100644
--- a/Lib/test/test_capi/test_type.py
+++ b/Lib/test/test_capi/test_type.py
@@ -179,6 +179,22 @@ class TypeTests(unittest.TestCase):
_testcapi.pytype_getbasebytoken(
'not a type', id(self), True, False)
+ def test_get_module_by_def(self):
+ heaptype = _testcapi.create_type_with_token('_testcapi.H', 0)
+ mod = _testcapi.pytype_getmodulebydef(heaptype)
+ self.assertIs(mod, _testcapi)
+
+ class H1(heaptype): pass
+ mod = _testcapi.pytype_getmodulebydef(H1)
+ self.assertIs(mod, _testcapi)
+
+ with self.assertRaises(TypeError):
+ _testcapi.pytype_getmodulebydef(int)
+
+ class H2(int): pass
+ with self.assertRaises(TypeError):
+ _testcapi.pytype_getmodulebydef(H2)
+
def test_freeze(self):
# test PyType_Freeze()
type_freeze = _testcapi.type_freeze
@@ -248,3 +264,13 @@ class TypeTests(unittest.TestCase):
ManualHeapType = _testcapi.ManualHeapType
for i in range(100):
self.assertIsInstance(ManualHeapType(), ManualHeapType)
+
+ def test_extension_managed_dict_type(self):
+ ManagedDictType = _testcapi.ManagedDictType
+ obj = ManagedDictType()
+ obj.foo = 42
+ self.assertEqual(obj.foo, 42)
+ self.assertEqual(obj.__dict__, {'foo': 42})
+ obj.__dict__ = {'bar': 3}
+ self.assertEqual(obj.__dict__, {'bar': 3})
+ self.assertEqual(obj.bar, 3)
diff --git a/Lib/test/test_capi/test_unicode.py b/Lib/test/test_capi/test_unicode.py
index 3408c10f426..6a9c60f3a6d 100644
--- a/Lib/test/test_capi/test_unicode.py
+++ b/Lib/test/test_capi/test_unicode.py
@@ -1739,6 +1739,20 @@ class CAPITest(unittest.TestCase):
# Check that the second call returns the same result
self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1))
+ @support.cpython_only
+ @unittest.skipIf(_testcapi is None, 'need _testcapi module')
+ def test_GET_CACHED_HASH(self):
+ from _testcapi import unicode_GET_CACHED_HASH
+ content_bytes = b'some new string'
+ # avoid parser interning & constant folding
+ obj = str(content_bytes, 'ascii')
+ # impl detail: fresh strings do not have cached hash
+ self.assertEqual(unicode_GET_CACHED_HASH(obj), -1)
+ # impl detail: adding string to a dict caches its hash
+ {obj: obj}
+ # impl detail: ASCII string hashes are equal to bytes ones
+ self.assertEqual(unicode_GET_CACHED_HASH(obj), hash(content_bytes))
+
class PyUnicodeWriterTest(unittest.TestCase):
def create_writer(self, size):
@@ -1776,6 +1790,13 @@ class PyUnicodeWriterTest(unittest.TestCase):
self.assertEqual(writer.finish(),
"ascii-latin1=\xE9-euro=\u20AC.")
+ def test_ascii(self):
+ writer = self.create_writer(0)
+ writer.write_ascii(b"Hello ", -1)
+ writer.write_ascii(b"", 0)
+ writer.write_ascii(b"Python! <truncated>", 6)
+ self.assertEqual(writer.finish(), "Hello Python")
+
def test_invalid_utf8(self):
writer = self.create_writer(0)
with self.assertRaises(UnicodeDecodeError):
diff --git a/Lib/test/test_class.py b/Lib/test/test_class.py
index 4c12d43556f..8c7a62a74ba 100644
--- a/Lib/test/test_class.py
+++ b/Lib/test/test_class.py
@@ -652,6 +652,7 @@ class ClassTests(unittest.TestCase):
a = A(hash(A.f)^(-1))
hash(a.f)
+ @cpython_only
def testSetattrWrapperNameIntern(self):
# Issue #25794: __setattr__ should intern the attribute name
class A:
diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py
index 0c99620e27c..580d54e0eb0 100644
--- a/Lib/test/test_clinic.py
+++ b/Lib/test/test_clinic.py
@@ -238,11 +238,11 @@ class ClinicWholeFileTest(TestCase):
# The generated output will differ for every run, but we can check that
# it starts with the clinic block, we check that it contains all the
# expected fields, and we check that it contains the checksum line.
- self.assertTrue(out.startswith(dedent("""
+ self.assertStartsWith(out, dedent("""
/*[clinic input]
output print 'I told you once.'
[clinic start generated code]*/
- """)))
+ """))
fields = {
"cpp_endif",
"cpp_if",
@@ -259,9 +259,7 @@ class ClinicWholeFileTest(TestCase):
with self.subTest(field=field):
self.assertIn(field, out)
last_line = out.rstrip().split("\n")[-1]
- self.assertTrue(
- last_line.startswith("/*[clinic end generated code: output=")
- )
+ self.assertStartsWith(last_line, "/*[clinic end generated code: output=")
def test_directive_wrong_arg_number(self):
raw = dedent("""
@@ -2705,8 +2703,7 @@ class ClinicExternalTest(TestCase):
# Note, we cannot check the entire fail msg, because the path to
# the tmp file will change for every run.
_, err = self.expect_failure(fn)
- self.assertTrue(err.endswith(fail_msg),
- f"{err!r} does not end with {fail_msg!r}")
+ self.assertEndsWith(err, fail_msg)
# Then, force regeneration; success expected.
out = self.expect_success("-f", fn)
self.assertEqual(out, "")
@@ -2717,8 +2714,7 @@ class ClinicExternalTest(TestCase):
)
with open(fn, encoding='utf-8') as f:
generated = f.read()
- self.assertTrue(generated.endswith(checksum),
- (generated, checksum))
+ self.assertEndsWith(generated, checksum)
def test_cli_make(self):
c_code = dedent("""
@@ -2835,6 +2831,10 @@ class ClinicExternalTest(TestCase):
"size_t",
"slice_index",
"str",
+ "uint16",
+ "uint32",
+ "uint64",
+ "uint8",
"unicode",
"unsigned_char",
"unsigned_int",
@@ -2863,8 +2863,8 @@ class ClinicExternalTest(TestCase):
# param may change (it's a set, thus unordered). So, let's compare the
# start and end of the expected output, and then assert that the
# converters appear lined up in alphabetical order.
- self.assertTrue(out.startswith(prelude), out)
- self.assertTrue(out.endswith(finale), out)
+ self.assertStartsWith(out, prelude)
+ self.assertEndsWith(out, finale)
out = out.removeprefix(prelude)
out = out.removesuffix(finale)
@@ -2872,10 +2872,7 @@ class ClinicExternalTest(TestCase):
for converter, line in zip(expected_converters, lines):
line = line.lstrip()
with self.subTest(converter=converter):
- self.assertTrue(
- line.startswith(converter),
- f"expected converter {converter!r}, got {line!r}"
- )
+ self.assertStartsWith(line, converter)
def test_cli_fail_converters_and_filename(self):
_, err = self.expect_failure("--converters", "test.c")
@@ -2985,7 +2982,7 @@ class ClinicFunctionalTest(unittest.TestCase):
regex = (
fr"Passing( more than)?( [0-9]+)? positional argument(s)? to "
fr"{re.escape(name)}\(\) is deprecated. Parameters? {pnames} will "
- fr"become( a)? keyword-only parameters? in Python 3\.14"
+ fr"become( a)? keyword-only parameters? in Python 3\.37"
)
self.check_depr(regex, fn, *args, **kwds)
@@ -2998,7 +2995,7 @@ class ClinicFunctionalTest(unittest.TestCase):
regex = (
fr"Passing keyword argument{pl} {pnames} to "
fr"{re.escape(name)}\(\) is deprecated. Parameter{pl} {pnames} "
- fr"will become positional-only in Python 3\.14."
+ fr"will become positional-only in Python 3\.37."
)
self.check_depr(regex, fn, *args, **kwds)
@@ -3782,9 +3779,9 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", b="b", c="c", d="d", e="e", f="f", g="g", h="h")
errmsg = (
"Passing more than 1 positional argument to depr_star_multi() is deprecated. "
- "Parameter 'b' will become a keyword-only parameter in Python 3.16. "
- "Parameters 'c' and 'd' will become keyword-only parameters in Python 3.15. "
- "Parameters 'e', 'f' and 'g' will become keyword-only parameters in Python 3.14.")
+ "Parameter 'b' will become a keyword-only parameter in Python 3.39. "
+ "Parameters 'c' and 'd' will become keyword-only parameters in Python 3.38. "
+ "Parameters 'e', 'f' and 'g' will become keyword-only parameters in Python 3.37.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", c="c", d="d", e="e", f="f", g="g", h="h")
check("a", "b", "c", d="d", e="e", f="f", g="g", h="h")
@@ -3883,9 +3880,9 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", "b", "c", "d", "e", "f", "g", h="h")
errmsg = (
"Passing keyword arguments 'b', 'c', 'd', 'e', 'f' and 'g' to depr_kwd_multi() is deprecated. "
- "Parameter 'b' will become positional-only in Python 3.14. "
- "Parameters 'c' and 'd' will become positional-only in Python 3.15. "
- "Parameters 'e', 'f' and 'g' will become positional-only in Python 3.16.")
+ "Parameter 'b' will become positional-only in Python 3.37. "
+ "Parameters 'c' and 'd' will become positional-only in Python 3.38. "
+ "Parameters 'e', 'f' and 'g' will become positional-only in Python 3.39.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", "c", "d", "e", "f", g="g", h="h")
check("a", "b", "c", "d", "e", f="f", g="g", h="h")
@@ -3900,8 +3897,8 @@ class ClinicFunctionalTest(unittest.TestCase):
self.assertRaises(TypeError, fn, "a", "b", "c", "d", "e", "f", "g")
errmsg = (
"Passing more than 4 positional arguments to depr_multi() is deprecated. "
- "Parameter 'e' will become a keyword-only parameter in Python 3.15. "
- "Parameter 'f' will become a keyword-only parameter in Python 3.14.")
+ "Parameter 'e' will become a keyword-only parameter in Python 3.38. "
+ "Parameter 'f' will become a keyword-only parameter in Python 3.37.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", "c", "d", "e", "f", g="g")
check("a", "b", "c", "d", "e", f="f", g="g")
@@ -3909,8 +3906,8 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", "b", "c", d="d", e="e", f="f", g="g")
errmsg = (
"Passing keyword arguments 'b' and 'c' to depr_multi() is deprecated. "
- "Parameter 'b' will become positional-only in Python 3.14. "
- "Parameter 'c' will become positional-only in Python 3.15.")
+ "Parameter 'b' will become positional-only in Python 3.37. "
+ "Parameter 'c' will become positional-only in Python 3.38.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", c="c", d="d", e="e", f="f", g="g")
check("a", b="b", c="c", d="d", e="e", f="f", g="g")
diff --git a/Lib/test/test_cmd.py b/Lib/test/test_cmd.py
index 46ec82b7049..dbfec42fc21 100644
--- a/Lib/test/test_cmd.py
+++ b/Lib/test/test_cmd.py
@@ -11,9 +11,15 @@ import unittest
import io
import textwrap
from test import support
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.pty_helper import run_pty
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("cmd", {"inspect", "string"})
+
+
class samplecmdclass(cmd.Cmd):
"""
Instance the sampleclass:
@@ -289,6 +295,30 @@ class CmdTestReadline(unittest.TestCase):
self.assertIn(b'ab_completion_test', output)
self.assertIn(b'tab completion success', output)
+ def test_bang_completion_without_do_shell(self):
+ script = textwrap.dedent("""
+ import cmd
+ class simplecmd(cmd.Cmd):
+ def completedefault(self, text, line, begidx, endidx):
+ return ["hello"]
+
+ def default(self, line):
+ if line.replace(" ", "") == "!hello":
+ print('tab completion success')
+ else:
+ print('tab completion failure')
+ return True
+
+ simplecmd().cmdloop()
+ """)
+
+ # '! h' or '!h' and complete 'ello' to 'hello'
+ for input in [b"! h\t\n", b"!h\t\n"]:
+ with self.subTest(input=input):
+ output = run_pty(script, input)
+ self.assertIn(b'hello', output)
+ self.assertIn(b'tab completion success', output)
+
def load_tests(loader, tests, pattern):
tests.addTest(doctest.DocTestSuite())
return tests
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 36f87e259e7..c17d749d4a1 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -39,7 +39,8 @@ class CmdLineTest(unittest.TestCase):
def verify_valid_flag(self, cmd_line):
rc, out, err = assert_python_ok(cmd_line)
- self.assertTrue(out == b'' or out.endswith(b'\n'))
+ if out != b'':
+ self.assertEndsWith(out, b'\n')
self.assertNotIn(b'Traceback', out)
self.assertNotIn(b'Traceback', err)
return out
@@ -89,8 +90,8 @@ class CmdLineTest(unittest.TestCase):
version = ('Python %d.%d' % sys.version_info[:2]).encode("ascii")
for switch in '-V', '--version', '-VV':
rc, out, err = assert_python_ok(switch)
- self.assertFalse(err.startswith(version))
- self.assertTrue(out.startswith(version))
+ self.assertNotStartsWith(err, version)
+ self.assertStartsWith(out, version)
def test_verbose(self):
# -v causes imports to write to stderr. If the write to
@@ -380,7 +381,7 @@ class CmdLineTest(unittest.TestCase):
p.stdin.flush()
data, rc = _kill_python_and_exit_code(p)
self.assertEqual(rc, 0)
- self.assertTrue(data.startswith(b'x'), data)
+ self.assertStartsWith(data, b'x')
def test_large_PYTHONPATH(self):
path1 = "ABCDE" * 100
@@ -972,10 +973,25 @@ class CmdLineTest(unittest.TestCase):
@unittest.skipUnless(support.MS_WINDOWS, 'Test only applicable on Windows')
def test_python_legacy_windows_stdio(self):
- code = "import sys; print(sys.stdin.encoding, sys.stdout.encoding)"
- expected = 'cp'
- rc, out, err = assert_python_ok('-c', code, PYTHONLEGACYWINDOWSSTDIO='1')
- self.assertIn(expected.encode(), out)
+ # Test that _WindowsConsoleIO is used when PYTHONLEGACYWINDOWSSTDIO
+ # is not set.
+ # We cannot use PIPE becase it prevents creating new console.
+ # So we use exit code.
+ code = "import sys; sys.exit(type(sys.stdout.buffer.raw).__name__ != '_WindowsConsoleIO')"
+ env = os.environ.copy()
+ env["PYTHONLEGACYWINDOWSSTDIO"] = ""
+ p = subprocess.run([sys.executable, "-c", code],
+ creationflags=subprocess.CREATE_NEW_CONSOLE,
+ env=env)
+ self.assertEqual(p.returncode, 0)
+
+ # Then test that FIleIO is used when PYTHONLEGACYWINDOWSSTDIO is set.
+ code = "import sys; sys.exit(type(sys.stdout.buffer.raw).__name__ != 'FileIO')"
+ env["PYTHONLEGACYWINDOWSSTDIO"] = "1"
+ p = subprocess.run([sys.executable, "-c", code],
+ creationflags=subprocess.CREATE_NEW_CONSOLE,
+ env=env)
+ self.assertEqual(p.returncode, 0)
@unittest.skipIf("-fsanitize" in sysconfig.get_config_vars().get('PY_CFLAGS', ()),
"PYTHONMALLOCSTATS doesn't work with ASAN")
@@ -1024,7 +1040,7 @@ class CmdLineTest(unittest.TestCase):
stderr=subprocess.PIPE,
text=True)
err_msg = "Unknown option: --unknown-option\nusage: "
- self.assertTrue(proc.stderr.startswith(err_msg), proc.stderr)
+ self.assertStartsWith(proc.stderr, err_msg)
self.assertNotEqual(proc.returncode, 0)
def test_int_max_str_digits(self):
@@ -1158,6 +1174,24 @@ class CmdLineTest(unittest.TestCase):
res = assert_python_ok('-c', code, PYTHON_CPU_COUNT='default')
self.assertEqual(self.res2int(res), (os.cpu_count(), os.process_cpu_count()))
+ def test_import_time(self):
+ # os is not imported at startup
+ code = 'import os; import os'
+
+ for case in 'importtime', 'importtime=1', 'importtime=true':
+ res = assert_python_ok('-X', case, '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertNotRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ res = assert_python_ok('-X', 'importtime=2', '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ assert_python_failure('-X', 'importtime=-1', '-c', code)
+ assert_python_failure('-X', 'importtime=3', '-c', code)
+
def res2int(self, res):
out = res.out.strip().decode("utf-8")
return tuple(int(i) for i in out.split())
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
index 53dc9b1a7ef..784c45aa96f 100644
--- a/Lib/test/test_cmd_line_script.py
+++ b/Lib/test/test_cmd_line_script.py
@@ -553,9 +553,9 @@ class CmdLineTest(unittest.TestCase):
exitcode, stdout, stderr = assert_python_failure(script_name)
text = stderr.decode('ascii').split('\n')
self.assertEqual(len(text), 5)
- self.assertTrue(text[0].startswith('Traceback'))
- self.assertTrue(text[1].startswith(' File '))
- self.assertTrue(text[3].startswith('NameError'))
+ self.assertStartsWith(text[0], 'Traceback')
+ self.assertStartsWith(text[1], ' File ')
+ self.assertStartsWith(text[3], 'NameError')
def test_non_ascii(self):
# Apple platforms deny the creation of a file with an invalid UTF-8 name.
@@ -708,9 +708,8 @@ class CmdLineTest(unittest.TestCase):
exitcode, stdout, stderr = assert_python_failure(script_name)
text = io.TextIOWrapper(io.BytesIO(stderr), 'ascii').read()
# It used to crash in https://github.com/python/cpython/issues/111132
- self.assertTrue(text.endswith(
- 'SyntaxError: nonlocal declaration not allowed at module level\n',
- ), text)
+ self.assertEndsWith(text,
+ 'SyntaxError: nonlocal declaration not allowed at module level\n')
def test_consistent_sys_path_for_direct_execution(self):
# This test case ensures that the following all give the same
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index 7cf09ee7847..655f5a9be7f 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -220,6 +220,7 @@ try:
import _testinternalcapi
except ModuleNotFoundError:
_testinternalcapi = None
+import test._code_definitions as defs
COPY_FREE_VARS = opmap['COPY_FREE_VARS']
@@ -671,9 +672,82 @@ class CodeTest(unittest.TestCase):
VARARGS = CO_FAST_LOCAL | CO_FAST_ARG_VAR | CO_FAST_ARG_POS
VARKWARGS = CO_FAST_LOCAL | CO_FAST_ARG_VAR | CO_FAST_ARG_KW
- import test._code_definitions as defs
funcs = {
+ defs.simple_script: {},
+ defs.complex_script: {
+ 'obj': CO_FAST_LOCAL,
+ 'pickle': CO_FAST_LOCAL,
+ 'spam_minimal': CO_FAST_LOCAL,
+ 'data': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.script_with_globals: {
+ 'obj1': CO_FAST_LOCAL,
+ 'obj2': CO_FAST_LOCAL,
+ },
+ defs.script_with_explicit_empty_return: {},
+ defs.script_with_return: {},
defs.spam_minimal: {},
+ defs.spam_with_builtins: {
+ 'x': CO_FAST_LOCAL,
+ 'values': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_with_globals_and_builtins: {
+ 'func1': CO_FAST_LOCAL,
+ 'func2': CO_FAST_LOCAL,
+ 'funcs': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_with_global_and_attr_same_name: {},
+ defs.spam_full_args: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_full_args_with_defaults: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_args_attrs_and_builtins: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_returns_arg: {
+ 'x': POSORKW,
+ },
+ defs.spam_raises: {},
+ defs.spam_with_inner_not_closure: {
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_with_inner_closure: {
+ 'x': CO_FAST_CELL,
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_annotated: {
+ 'a': POSORKW,
+ 'b': POSORKW,
+ 'c': POSORKW,
+ },
defs.spam_full: {
'a': POSONLY,
'b': POSONLY,
@@ -777,6 +851,319 @@ class CodeTest(unittest.TestCase):
kinds = _testinternalcapi.get_co_localskinds(func.__code__)
self.assertEqual(kinds, expected)
+ @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi")
+ def test_var_counts(self):
+ self.maxDiff = None
+ def new_var_counts(*,
+ posonly=0,
+ posorkw=0,
+ kwonly=0,
+ varargs=0,
+ varkwargs=0,
+ purelocals=0,
+ argcells=0,
+ othercells=0,
+ freevars=0,
+ globalvars=0,
+ attrs=0,
+ unknown=0,
+ ):
+ nargvars = posonly + posorkw + kwonly + varargs + varkwargs
+ nlocals = nargvars + purelocals + othercells
+ if isinstance(globalvars, int):
+ globalvars = {
+ 'total': globalvars,
+ 'numglobal': 0,
+ 'numbuiltin': 0,
+ 'numunknown': globalvars,
+ }
+ else:
+ g_numunknown = 0
+ if isinstance(globalvars, dict):
+ numglobal = globalvars['numglobal']
+ numbuiltin = globalvars['numbuiltin']
+ size = 2
+ if 'numunknown' in globalvars:
+ g_numunknown = globalvars['numunknown']
+ size += 1
+ assert len(globalvars) == size, globalvars
+ else:
+ assert not isinstance(globalvars, str), repr(globalvars)
+ try:
+ numglobal, numbuiltin = globalvars
+ except ValueError:
+ numglobal, numbuiltin, g_numunknown = globalvars
+ globalvars = {
+ 'total': numglobal + numbuiltin + g_numunknown,
+ 'numglobal': numglobal,
+ 'numbuiltin': numbuiltin,
+ 'numunknown': g_numunknown,
+ }
+ unbound = globalvars['total'] + attrs + unknown
+ return {
+ 'total': nlocals + freevars + unbound,
+ 'locals': {
+ 'total': nlocals,
+ 'args': {
+ 'total': nargvars,
+ 'numposonly': posonly,
+ 'numposorkw': posorkw,
+ 'numkwonly': kwonly,
+ 'varargs': varargs,
+ 'varkwargs': varkwargs,
+ },
+ 'numpure': purelocals,
+ 'cells': {
+ 'total': argcells + othercells,
+ 'numargs': argcells,
+ 'numothers': othercells,
+ },
+ 'hidden': {
+ 'total': 0,
+ 'numpure': 0,
+ 'numcells': 0,
+ },
+ },
+ 'numfree': freevars,
+ 'unbound': {
+ 'total': unbound,
+ 'globals': globalvars,
+ 'numattrs': attrs,
+ 'numunknown': unknown,
+ },
+ }
+
+ funcs = {
+ defs.simple_script: new_var_counts(),
+ defs.complex_script: new_var_counts(
+ purelocals=5,
+ globalvars=1,
+ attrs=2,
+ ),
+ defs.script_with_globals: new_var_counts(
+ purelocals=2,
+ globalvars=1,
+ ),
+ defs.script_with_explicit_empty_return: new_var_counts(),
+ defs.script_with_return: new_var_counts(),
+ defs.spam_minimal: new_var_counts(),
+ defs.spam_minimal: new_var_counts(),
+ defs.spam_with_builtins: new_var_counts(
+ purelocals=4,
+ globalvars=4,
+ ),
+ defs.spam_with_globals_and_builtins: new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ ),
+ defs.spam_with_global_and_attr_same_name: new_var_counts(
+ globalvars=2,
+ attrs=1,
+ ),
+ defs.spam_full_args: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ ),
+ defs.spam_full_args_with_defaults: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ ),
+ defs.spam_args_attrs_and_builtins: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ attrs=1,
+ ),
+ defs.spam_returns_arg: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_raises: new_var_counts(
+ globalvars=1,
+ ),
+ defs.spam_with_inner_not_closure: new_var_counts(
+ purelocals=1,
+ ),
+ defs.spam_with_inner_closure: new_var_counts(
+ othercells=1,
+ purelocals=1,
+ ),
+ defs.spam_annotated: new_var_counts(
+ posorkw=3,
+ ),
+ defs.spam_full: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ purelocals=4,
+ globalvars=3,
+ attrs=1,
+ ),
+ defs.spam: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_NN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_NC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.eggs_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.eggs_closure: new_var_counts(
+ posorkw=1,
+ freevars=2,
+ ),
+ defs.eggs_nested_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.eggs_nested_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ freevars=2,
+ ),
+ defs.ham_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_closure: new_var_counts(
+ posorkw=1,
+ freevars=3,
+ ),
+ defs.ham_C_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_C_closure: new_var_counts(
+ posorkw=1,
+ freevars=4,
+ ),
+ }
+ assert len(funcs) == len(defs.FUNCTIONS), (len(funcs), len(defs.FUNCTIONS))
+ for func in defs.FUNCTIONS:
+ with self.subTest(func):
+ expected = funcs[func]
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ func = defs.spam_with_globals_and_builtins
+ with self.subTest(f'{func} code'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} with own globals and builtins'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(2, 4),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without globals'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(0, 4, 2),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without both'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={},
+ builtinsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without builtins'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(2, 0, 4),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, builtinsns={})
+ self.assertEqual(counts, expected)
+
+ @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi")
+ def test_stateless(self):
+ self.maxDiff = None
+
+ STATELESS_FUNCTIONS = [
+ *defs.STATELESS_FUNCTIONS,
+ # stateless with defaults
+ defs.spam_full_args_with_defaults,
+ ]
+
+ for func in defs.STATELESS_CODE:
+ with self.subTest((func, '(code)')):
+ _testinternalcapi.verify_stateless_code(func.__code__)
+ for func in STATELESS_FUNCTIONS:
+ with self.subTest((func, '(func)')):
+ _testinternalcapi.verify_stateless_code(func)
+
+ for func in defs.FUNCTIONS:
+ if func not in defs.STATELESS_CODE:
+ with self.subTest((func, '(code)')):
+ with self.assertRaises(Exception):
+ _testinternalcapi.verify_stateless_code(func.__code__)
+
+ if func not in STATELESS_FUNCTIONS:
+ with self.subTest((func, '(func)')):
+ with self.assertRaises(Exception):
+ _testinternalcapi.verify_stateless_code(func)
+
def isinterned(s):
return s is sys.intern(('_' + s + '_')[1:-1])
diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py
index 57fb130070b..3642b47c2c1 100644
--- a/Lib/test/test_code_module.py
+++ b/Lib/test/test_code_module.py
@@ -133,7 +133,7 @@ class TestInteractiveConsole(unittest.TestCase, MockSys):
output = ''.join(''.join(call[1]) for call in self.stderr.method_calls)
output = output[output.index('(InteractiveConsole)'):]
output = output[output.index('\n') + 1:]
- self.assertTrue(output.startswith('UnicodeEncodeError: '), output)
+ self.assertStartsWith(output, 'UnicodeEncodeError: ')
self.assertIs(self.sysmod.last_type, UnicodeEncodeError)
self.assertIs(type(self.sysmod.last_value), UnicodeEncodeError)
self.assertIsNone(self.sysmod.last_traceback)
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
index 86e5e5c1474..65d54d1004d 100644
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -1125,7 +1125,7 @@ class CodecCallbackTest(unittest.TestCase):
text = 'abc<def>ghi'*n
text.translate(charmap)
- def test_mutatingdecodehandler(self):
+ def test_mutating_decode_handler(self):
baddata = [
("ascii", b"\xff"),
("utf-7", b"++"),
@@ -1160,6 +1160,42 @@ class CodecCallbackTest(unittest.TestCase):
for (encoding, data) in baddata:
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
+ def test_mutating_decode_handler_unicode_escape(self):
+ decode = codecs.unicode_escape_decode
+ def mutating(exc):
+ if isinstance(exc, UnicodeDecodeError):
+ r = data.get(exc.object[:exc.end])
+ if r is not None:
+ exc.object = r[0] + exc.object[exc.end:]
+ return ('\u0404', r[1])
+ raise AssertionError("don't know how to handle %r" % exc)
+
+ codecs.register_error('test.mutating2', mutating)
+ data = {
+ br'\x0': (b'\\', 0),
+ br'\x3': (b'xxx\\', 3),
+ br'\x5': (b'x\\', 1),
+ }
+ def check(input, expected, msg):
+ with self.assertWarns(DeprecationWarning) as cm:
+ self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
+ self.assertIn(msg, str(cm.warning))
+
+ check(br'\x0n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
+ check(br'\x0n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
+ check(br'\x0z', '\u0404\\z', r'"\z" is an invalid escape sequence')
+
+ check(br'\x3n\zr', '\u0404\n\\zr', r'"\z" is an invalid escape sequence')
+ check(br'\x3zr', '\u0404\\zr', r'"\z" is an invalid escape sequence')
+ check(br'\x3z5', '\u0404\\z5', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'"\z" is an invalid escape sequence')
+
+ check(br'\x5n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
+ check(br'\x5n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
+ check(br'\x5z', '\u0404\\z', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'"\z" is an invalid escape sequence')
+
# issue32583
def test_crashing_decode_handler(self):
# better generating one more character to fill the extra space slot
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index e51f7e0ee12..d8666f7290e 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -1,12 +1,15 @@
import codecs
import contextlib
import copy
+import importlib
import io
import pickle
+import os
import sys
import unittest
import encodings
from unittest import mock
+import warnings
from test import support
from test.support import os_helper
@@ -20,13 +23,12 @@ try:
except ImportError:
_testinternalcapi = None
-try:
- import ctypes
-except ImportError:
- ctypes = None
- SIZEOF_WCHAR_T = -1
-else:
- SIZEOF_WCHAR_T = ctypes.sizeof(ctypes.c_wchar)
+
+def codecs_open_no_warn(*args, **kwargs):
+ """Call codecs.open(*args, **kwargs) ignoring DeprecationWarning."""
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ return codecs.open(*args, **kwargs)
def coding_checker(self, coder):
def check(input, expect):
@@ -35,13 +37,13 @@ def coding_checker(self, coder):
# On small versions of Windows like Windows IoT or Windows Nano Server not all codepages are present
def is_code_page_present(cp):
- from ctypes import POINTER, WINFUNCTYPE, WinDLL
+ from ctypes import POINTER, WINFUNCTYPE, WinDLL, Structure
from ctypes.wintypes import BOOL, BYTE, WCHAR, UINT, DWORD
MAX_LEADBYTES = 12 # 5 ranges, 2 bytes ea., 0 term.
MAX_DEFAULTCHAR = 2 # single or double byte
MAX_PATH = 260
- class CPINFOEXW(ctypes.Structure):
+ class CPINFOEXW(Structure):
_fields_ = [("MaxCharSize", UINT),
("DefaultChar", BYTE*MAX_DEFAULTCHAR),
("LeadByte", BYTE*MAX_LEADBYTES),
@@ -719,19 +721,19 @@ class UTF16Test(ReadTest, unittest.TestCase):
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
with open(os_helper.TESTFN, 'wb') as fp:
fp.write(s)
- with codecs.open(os_helper.TESTFN, 'r',
+ with codecs_open_no_warn(os_helper.TESTFN, 'r',
encoding=self.encoding) as reader:
self.assertEqual(reader.read(), s1)
def test_invalid_modes(self):
for mode in ('U', 'rU', 'r+U'):
with self.assertRaises(ValueError) as cm:
- codecs.open(os_helper.TESTFN, mode, encoding=self.encoding)
+ codecs_open_no_warn(os_helper.TESTFN, mode, encoding=self.encoding)
self.assertIn('invalid mode', str(cm.exception))
for mode in ('rt', 'wt', 'at', 'r+t'):
with self.assertRaises(ValueError) as cm:
- codecs.open(os_helper.TESTFN, mode, encoding=self.encoding)
+ codecs_open_no_warn(os_helper.TESTFN, mode, encoding=self.encoding)
self.assertIn("can't have text and binary mode at once",
str(cm.exception))
@@ -1196,23 +1198,39 @@ class EscapeDecodeTest(unittest.TestCase):
check(br"[\1010]", b"[A0]")
check(br"[\x41]", b"[A]")
check(br"[\x410]", b"[A0]")
+
+ def test_warnings(self):
+ decode = codecs.escape_decode
+ check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtvx':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, b"\\" + b)
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), b"\\" + b.upper())
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\8" is an invalid escape sequence'):
check(br"\8", b"\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", b"\\9")
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", b"\\\xfa")
for i in range(0o400, 0o1000):
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, bytes([i & 0o377]))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\z" is an invalid escape sequence'):
+ self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\501" is an invalid octal escape sequence'):
+ self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
+
def test_errors(self):
decode = codecs.escape_decode
self.assertRaises(ValueError, decode, br"\x")
@@ -1844,9 +1862,9 @@ class CodecsModuleTest(unittest.TestCase):
def test_open(self):
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
for mode in ('w', 'r', 'r+', 'w+', 'a', 'a+'):
- with self.subTest(mode), \
- codecs.open(os_helper.TESTFN, mode, 'ascii') as file:
- self.assertIsInstance(file, codecs.StreamReaderWriter)
+ with self.subTest(mode), self.assertWarns(DeprecationWarning):
+ with codecs.open(os_helper.TESTFN, mode, 'ascii') as file:
+ self.assertIsInstance(file, codecs.StreamReaderWriter)
def test_undefined(self):
self.assertRaises(UnicodeError, codecs.encode, 'abc', 'undefined')
@@ -1863,7 +1881,7 @@ class CodecsModuleTest(unittest.TestCase):
mock_open = mock.mock_open()
with mock.patch('builtins.open', mock_open) as file:
with self.assertRaises(LookupError):
- codecs.open(os_helper.TESTFN, 'wt', 'invalid-encoding')
+ codecs_open_no_warn(os_helper.TESTFN, 'wt', 'invalid-encoding')
file().close.assert_called()
@@ -2661,24 +2679,40 @@ class UnicodeEscapeTest(ReadTest, unittest.TestCase):
check(br"[\x410]", "[A0]")
check(br"\u20ac", "\u20ac")
check(br"\U0001d120", "\U0001d120")
+
+ def test_decode_warnings(self):
+ decode = codecs.unicode_escape_decode
+ check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtuvx':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, "\\" + chr(i))
if b.upper() not in b'UN':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), "\\" + chr(i-32))
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\8" is an invalid escape sequence'):
check(br"\8", "\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", "\\9")
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", "\\\xfa")
for i in range(0o400, 0o1000):
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, chr(i))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\z" is an invalid escape sequence'):
+ self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\501" is an invalid octal escape sequence'):
+ self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
+
def test_decode_errors(self):
decode = codecs.unicode_escape_decode
for c, d in (b'x', 2), (b'u', 4), (b'U', 4):
@@ -2883,7 +2917,7 @@ class BomTest(unittest.TestCase):
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
for encoding in tests:
# Check if the BOM is written only once
- with codecs.open(os_helper.TESTFN, 'w+', encoding=encoding) as f:
+ with codecs_open_no_warn(os_helper.TESTFN, 'w+', encoding=encoding) as f:
f.write(data)
f.write(data)
f.seek(0)
@@ -2892,7 +2926,7 @@ class BomTest(unittest.TestCase):
self.assertEqual(f.read(), data * 2)
# Check that the BOM is written after a seek(0)
- with codecs.open(os_helper.TESTFN, 'w+', encoding=encoding) as f:
+ with codecs_open_no_warn(os_helper.TESTFN, 'w+', encoding=encoding) as f:
f.write(data[0])
self.assertNotEqual(f.tell(), 0)
f.seek(0)
@@ -2901,7 +2935,7 @@ class BomTest(unittest.TestCase):
self.assertEqual(f.read(), data)
# (StreamWriter) Check that the BOM is written after a seek(0)
- with codecs.open(os_helper.TESTFN, 'w+', encoding=encoding) as f:
+ with codecs_open_no_warn(os_helper.TESTFN, 'w+', encoding=encoding) as f:
f.writer.write(data[0])
self.assertNotEqual(f.writer.tell(), 0)
f.writer.seek(0)
@@ -2911,7 +2945,7 @@ class BomTest(unittest.TestCase):
# Check that the BOM is not written after a seek() at a position
# different than the start
- with codecs.open(os_helper.TESTFN, 'w+', encoding=encoding) as f:
+ with codecs_open_no_warn(os_helper.TESTFN, 'w+', encoding=encoding) as f:
f.write(data)
f.seek(f.tell())
f.write(data)
@@ -2920,7 +2954,7 @@ class BomTest(unittest.TestCase):
# (StreamWriter) Check that the BOM is not written after a seek()
# at a position different than the start
- with codecs.open(os_helper.TESTFN, 'w+', encoding=encoding) as f:
+ with codecs_open_no_warn(os_helper.TESTFN, 'w+', encoding=encoding) as f:
f.writer.write(data)
f.writer.seek(f.writer.tell())
f.writer.write(data)
@@ -3075,6 +3109,13 @@ class TransformCodecTest(unittest.TestCase):
info = codecs.lookup(alias)
self.assertEqual(info.name, expected_name)
+ def test_alias_modules_exist(self):
+ encodings_dir = os.path.dirname(encodings.__file__)
+ for value in encodings.aliases.aliases.values():
+ codec_mod = f"encodings.{value}"
+ self.assertIsNotNone(importlib.util.find_spec(codec_mod),
+ f"Codec module not found: {codec_mod}")
+
def test_quopri_stateless(self):
# Should encode with quotetabs=True
encoded = codecs.encode(b"space tab\teol \n", "quopri-codec")
@@ -3762,7 +3803,7 @@ class LocaleCodecTest(unittest.TestCase):
with self.assertRaises(RuntimeError) as cm:
self.decode(encoded, errors)
errmsg = str(cm.exception)
- self.assertTrue(errmsg.startswith("decode error: "), errmsg)
+ self.assertStartsWith(errmsg, "decode error: ")
else:
decoded = self.decode(encoded, errors)
self.assertEqual(decoded, expected)
diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py
index 0eefc22d11b..ed10bd3dcb6 100644
--- a/Lib/test/test_codeop.py
+++ b/Lib/test/test_codeop.py
@@ -322,7 +322,7 @@ class CodeopTests(unittest.TestCase):
dedent("""\
def foo(x,x):
pass
- """), "duplicate argument 'x' in function definition")
+ """), "duplicate parameter 'x' in function definition")
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 1e93530398b..d9d61e5c205 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -542,6 +542,8 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(Dot(1)._replace(d=999), (999,))
self.assertEqual(Dot(1)._fields, ('d',))
+ @support.requires_resource('cpu')
+ def test_large_size(self):
n = support.exceeds_recursion_limit()
names = list(set(''.join([choice(string.ascii_letters)
for j in range(10)]) for i in range(n)))
diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py
index a580a240d9f..8384c183dd9 100644
--- a/Lib/test/test_compileall.py
+++ b/Lib/test/test_compileall.py
@@ -316,7 +316,7 @@ class CompileallTestsBase:
self.assertTrue(mods)
for mod in mods:
- self.assertTrue(mod.startswith(self.directory), mod)
+ self.assertStartsWith(mod, self.directory)
modcode = importlib.util.cache_from_source(mod)
modpath = mod[len(self.directory+os.sep):]
_, _, err = script_helper.assert_python_failure(modcode)
diff --git a/Lib/test/test_compiler_assemble.py b/Lib/test/test_compiler_assemble.py
index c4962e35999..99a11e99d56 100644
--- a/Lib/test/test_compiler_assemble.py
+++ b/Lib/test/test_compiler_assemble.py
@@ -146,4 +146,4 @@ class IsolatedAssembleTests(AssemblerTestCase):
L1 to L2 -> L2 [0]
L2 to L3 -> L3 [1] lasti
""")
- self.assertTrue(output.getvalue().endswith(exc_table))
+ self.assertEndsWith(output.getvalue(), exc_table)
diff --git a/Lib/test/test_concurrent_futures/test_future.py b/Lib/test/test_concurrent_futures/test_future.py
index 4066ea1ee4b..06b11a3bacf 100644
--- a/Lib/test/test_concurrent_futures/test_future.py
+++ b/Lib/test/test_concurrent_futures/test_future.py
@@ -6,6 +6,7 @@ from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
from test import support
+from test.support import threading_helper
from .util import (
PENDING_FUTURE, RUNNING_FUTURE, CANCELLED_FUTURE,
@@ -282,6 +283,62 @@ class FutureTests(BaseTestCase):
self.assertEqual(f.exception(), e)
+ def test_get_snapshot(self):
+ """Test the _get_snapshot method for atomic state retrieval."""
+ # Test with a pending future
+ f = Future()
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertFalse(done)
+ self.assertFalse(cancelled)
+ self.assertIsNone(result)
+ self.assertIsNone(exception)
+
+ # Test with a finished future (successful result)
+ f = Future()
+ f.set_result(42)
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertFalse(cancelled)
+ self.assertEqual(result, 42)
+ self.assertIsNone(exception)
+
+ # Test with a finished future (exception)
+ f = Future()
+ exc = ValueError("test error")
+ f.set_exception(exc)
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertFalse(cancelled)
+ self.assertIsNone(result)
+ self.assertIs(exception, exc)
+
+ # Test with a cancelled future
+ f = Future()
+ f.cancel()
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertTrue(cancelled)
+ self.assertIsNone(result)
+ self.assertIsNone(exception)
+
+ # Test concurrent access (basic thread safety check)
+ f = Future()
+ f.set_result(100)
+ results = []
+
+ def get_snapshot():
+ for _ in range(1000):
+ snapshot = f._get_snapshot()
+ results.append(snapshot)
+
+ threads = [threading.Thread(target=get_snapshot) for _ in range(4)]
+ with threading_helper.start_threads(threads):
+ pass
+ # All snapshots should be identical for a finished future
+ expected = (True, False, 100, None)
+ for result in results:
+ self.assertEqual(result, expected)
+
def setUpModule():
setup_module()
diff --git a/Lib/test/test_concurrent_futures/test_init.py b/Lib/test/test_concurrent_futures/test_init.py
index df640929309..6b8484c0d5f 100644
--- a/Lib/test/test_concurrent_futures/test_init.py
+++ b/Lib/test/test_concurrent_futures/test_init.py
@@ -20,6 +20,10 @@ INITIALIZER_STATUS = 'uninitialized'
def init(x):
global INITIALIZER_STATUS
INITIALIZER_STATUS = x
+ # InterpreterPoolInitializerTest.test_initializer fails
+ # if we don't have a LOAD_GLOBAL. (It could be any global.)
+ # We will address this separately.
+ INITIALIZER_STATUS
def get_init_status():
return INITIALIZER_STATUS
diff --git a/Lib/test/test_concurrent_futures/test_interpreter_pool.py b/Lib/test/test_concurrent_futures/test_interpreter_pool.py
index f6c62ae4b20..844dfdd6fc9 100644
--- a/Lib/test/test_concurrent_futures/test_interpreter_pool.py
+++ b/Lib/test/test_concurrent_futures/test_interpreter_pool.py
@@ -2,35 +2,78 @@ import asyncio
import contextlib
import io
import os
-import pickle
+import sys
import time
import unittest
-from concurrent.futures.interpreter import (
- ExecutionFailed, BrokenInterpreterPool,
-)
+from concurrent.futures.interpreter import BrokenInterpreterPool
+from concurrent import interpreters
+from concurrent.interpreters import _queues as queues
import _interpreters
from test import support
+from test.support import os_helper
+from test.support import script_helper
import test.test_asyncio.utils as testasyncio_utils
-from test.support.interpreters import queues
from .executor import ExecutorTest, mul
from .util import BaseTestCase, InterpreterPoolMixin, setup_module
+WINDOWS = sys.platform.startswith('win')
+
+
+@contextlib.contextmanager
+def nonblocking(fd):
+ blocking = os.get_blocking(fd)
+ if blocking:
+ os.set_blocking(fd, False)
+ try:
+ yield
+ finally:
+ if blocking:
+ os.set_blocking(fd, blocking)
+
+
+def read_file_with_timeout(fd, nbytes, timeout):
+ with nonblocking(fd):
+ end = time.time() + timeout
+ try:
+ return os.read(fd, nbytes)
+ except BlockingIOError:
+ pass
+ while time.time() < end:
+ try:
+ return os.read(fd, nbytes)
+ except BlockingIOError:
+ continue
+ else:
+ raise TimeoutError('nothing to read')
+
+
+if not WINDOWS:
+ import select
+ def read_file_with_timeout(fd, nbytes, timeout):
+ r, _, _ = select.select([fd], [], [], timeout)
+ if fd not in r:
+ raise TimeoutError('nothing to read')
+ return os.read(fd, nbytes)
+
+
def noop():
pass
def write_msg(fd, msg):
+ import os
os.write(fd, msg + b'\0')
-def read_msg(fd):
+def read_msg(fd, timeout=10.0):
msg = b''
- while ch := os.read(fd, 1):
- if ch == b'\0':
- return msg
+ ch = read_file_with_timeout(fd, 1, timeout)
+ while ch != b'\0':
msg += ch
+ ch = os.read(fd, 1)
+ return msg
def get_current_name():
@@ -113,6 +156,38 @@ class InterpreterPoolExecutorTest(
self.assertEqual(before, b'\0')
self.assertEqual(after, msg)
+ def test_init_with___main___global(self):
+ # See https://github.com/python/cpython/pull/133957#issuecomment-2927415311.
+ text = """if True:
+ from concurrent.futures import InterpreterPoolExecutor
+
+ INITIALIZER_STATUS = 'uninitialized'
+
+ def init(x):
+ global INITIALIZER_STATUS
+ INITIALIZER_STATUS = x
+ INITIALIZER_STATUS
+
+ def get_init_status():
+ return INITIALIZER_STATUS
+
+ if __name__ == "__main__":
+ exe = InterpreterPoolExecutor(initializer=init,
+ initargs=('initialized',))
+ fut = exe.submit(get_init_status)
+ print(fut.result()) # 'initialized'
+ exe.shutdown(wait=True)
+ print(INITIALIZER_STATUS) # 'uninitialized'
+ """
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ self.assertEqual(stdout.splitlines(), [
+ 'initialized',
+ 'uninitialized',
+ ])
+
def test_init_closure(self):
count = 0
def init1():
@@ -121,10 +196,19 @@ class InterpreterPoolExecutorTest(
nonlocal count
count += 1
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=init1)
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=init2)
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=init1) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
+
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=init2) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
def test_init_instance_method(self):
class Spam:
@@ -132,26 +216,12 @@ class InterpreterPoolExecutorTest(
raise NotImplementedError
spam = Spam()
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=spam.initializer)
-
- def test_init_shared(self):
- msg = b'eggs'
- r, w = self.pipe()
- script = f"""if True:
- import os
- if __name__ != '__main__':
- import __main__
- spam = __main__.spam
- os.write({w}, spam + b'\\0')
- """
-
- executor = self.executor_type(shared={'spam': msg})
- fut = executor.submit(exec, script)
- fut.result()
- after = read_msg(r)
-
- self.assertEqual(after, msg)
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=spam.initializer) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
@unittest.expectedFailure
def test_init_exception_in_script(self):
@@ -178,8 +248,6 @@ class InterpreterPoolExecutorTest(
stderr = stderr.getvalue()
self.assertIn('ExecutionFailed: Exception: spam', stderr)
self.assertIn('Uncaught in the interpreter:', stderr)
- self.assertIn('The above exception was the direct cause of the following exception:',
- stderr)
@unittest.expectedFailure
def test_submit_script(self):
@@ -208,10 +276,14 @@ class InterpreterPoolExecutorTest(
return spam
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(task1)
- with self.assertRaises(pickle.PicklingError):
- executor.submit(task2)
+
+ fut = executor.submit(task1)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
+
+ fut = executor.submit(task2)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_local_instance(self):
class Spam:
@@ -219,8 +291,9 @@ class InterpreterPoolExecutorTest(
self.value = True
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(Spam)
+ fut = executor.submit(Spam)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_instance_method(self):
class Spam:
@@ -229,8 +302,9 @@ class InterpreterPoolExecutorTest(
spam = Spam()
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(spam.run)
+ fut = executor.submit(spam.run)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_func_globals(self):
executor = self.executor_type()
@@ -242,13 +316,14 @@ class InterpreterPoolExecutorTest(
@unittest.expectedFailure
def test_submit_exception_in_script(self):
+ # Scripts are not supported currently.
fut = self.executor.submit('raise Exception("spam")')
with self.assertRaises(Exception) as captured:
fut.result()
self.assertIs(type(captured.exception), Exception)
self.assertEqual(str(captured.exception), 'spam')
cause = captured.exception.__cause__
- self.assertIs(type(cause), ExecutionFailed)
+ self.assertIs(type(cause), interpreters.ExecutionFailed)
for attr in ('__name__', '__qualname__', '__module__'):
self.assertEqual(getattr(cause.excinfo.type, attr),
getattr(Exception, attr))
@@ -261,7 +336,7 @@ class InterpreterPoolExecutorTest(
self.assertIs(type(captured.exception), Exception)
self.assertEqual(str(captured.exception), 'spam')
cause = captured.exception.__cause__
- self.assertIs(type(cause), ExecutionFailed)
+ self.assertIs(type(cause), interpreters.ExecutionFailed)
for attr in ('__name__', '__qualname__', '__module__'):
self.assertEqual(getattr(cause.excinfo.type, attr),
getattr(Exception, attr))
@@ -269,16 +344,93 @@ class InterpreterPoolExecutorTest(
def test_saturation(self):
blocker = queues.create()
- executor = self.executor_type(4, shared=dict(blocker=blocker))
+ executor = self.executor_type(4)
for i in range(15 * executor._max_workers):
- executor.submit(exec, 'import __main__; __main__.blocker.get()')
- #executor.submit('blocker.get()')
+ executor.submit(blocker.get)
self.assertEqual(len(executor._threads), executor._max_workers)
for i in range(15 * executor._max_workers):
blocker.put_nowait(None)
executor.shutdown(wait=True)
+ def test_blocking(self):
+ # There is no guarantee that a worker will be created for every
+ # submitted task. That's because there's a race between:
+ #
+ # * a new worker thread, created when task A was just submitted,
+ # becoming non-idle when it picks up task A
+ # * after task B is added to the queue, a new worker thread
+ # is started only if there are no idle workers
+ # (the check in ThreadPoolExecutor._adjust_thread_count())
+ #
+ # That means we must not block waiting for *all* tasks to report
+ # "ready" before we unblock the known-ready workers.
+ ready = queues.create()
+ blocker = queues.create()
+
+ def run(taskid, ready, blocker):
+ # There can't be any globals here.
+ ready.put_nowait(taskid)
+ blocker.get() # blocking
+
+ numtasks = 10
+ futures = []
+ with self.executor_type() as executor:
+ # Request the jobs.
+ for i in range(numtasks):
+ fut = executor.submit(run, i, ready, blocker)
+ futures.append(fut)
+ pending = numtasks
+ while pending > 0:
+ # Wait for any to be ready.
+ done = 0
+ for _ in range(pending):
+ try:
+ ready.get(timeout=1) # blocking
+ except interpreters.QueueEmpty:
+ pass
+ else:
+ done += 1
+ pending -= done
+ # Unblock the workers.
+ for _ in range(done):
+ blocker.put_nowait(None)
+
+ def test_blocking_with_limited_workers(self):
+ # This is essentially the same as test_blocking,
+ # but we explicitly force a limited number of workers,
+ # instead of it happening implicitly sometimes due to a race.
+ ready = queues.create()
+ blocker = queues.create()
+
+ def run(taskid, ready, blocker):
+ # There can't be any globals here.
+ ready.put_nowait(taskid)
+ blocker.get() # blocking
+
+ numtasks = 10
+ futures = []
+ with self.executor_type(4) as executor:
+ # Request the jobs.
+ for i in range(numtasks):
+ fut = executor.submit(run, i, ready, blocker)
+ futures.append(fut)
+ pending = numtasks
+ while pending > 0:
+ # Wait for any to be ready.
+ done = 0
+ for _ in range(pending):
+ try:
+ ready.get(timeout=1) # blocking
+ except interpreters.QueueEmpty:
+ pass
+ else:
+ done += 1
+ pending -= done
+ # Unblock the workers.
+ for _ in range(done):
+ blocker.put_nowait(None)
+
@support.requires_gil_enabled("gh-117344: test is flaky without the GIL")
def test_idle_thread_reuse(self):
executor = self.executor_type()
@@ -289,12 +441,21 @@ class InterpreterPoolExecutorTest(
executor.shutdown(wait=True)
def test_pickle_errors_propagate(self):
- # GH-125864: Pickle errors happen before the script tries to execute, so the
- # queue used to wait infinitely.
-
+ # GH-125864: Pickle errors happen before the script tries to execute,
+ # so the queue used to wait infinitely.
fut = self.executor.submit(PickleShenanigans(0))
- with self.assertRaisesRegex(RuntimeError, "gotcha"):
+ expected = interpreters.NotShareableError
+ with self.assertRaisesRegex(expected, 'args not shareable') as cm:
fut.result()
+ self.assertRegex(str(cm.exception.__cause__), 'unpickled')
+
+ def test_no_stale_references(self):
+ # Weak references don't cross between interpreters.
+ raise unittest.SkipTest('not applicable')
+
+ def test_free_reference(self):
+ # Weak references don't cross between interpreters.
+ raise unittest.SkipTest('not applicable')
class AsyncioTest(InterpretersMixin, testasyncio_utils.TestCase):
diff --git a/Lib/test/test_concurrent_futures/test_shutdown.py b/Lib/test/test_concurrent_futures/test_shutdown.py
index 7a4065afd46..99b315b47e2 100644
--- a/Lib/test/test_concurrent_futures/test_shutdown.py
+++ b/Lib/test/test_concurrent_futures/test_shutdown.py
@@ -330,6 +330,64 @@ class ProcessPoolShutdownTest(ExecutorShutdownTest):
# shutdown.
assert all([r == abs(v) for r, v in zip(res, range(-5, 5))])
+ @classmethod
+ def _failing_task_gh_132969(cls, n):
+ raise ValueError("failing task")
+
+ @classmethod
+ def _good_task_gh_132969(cls, n):
+ time.sleep(0.1 * n)
+ return n
+
+ def _run_test_issue_gh_132969(self, max_workers):
+ # max_workers=2 will repro exception
+ # max_workers=4 will repro exception and then hang
+
+ # Repro conditions
+ # max_tasks_per_child=1
+ # a task ends abnormally
+ # shutdown(wait=False) is called
+ start_method = self.get_context().get_start_method()
+ if (start_method == "fork" or
+ (start_method == "forkserver" and sys.platform.startswith("win"))):
+ self.skipTest(f"Skipping test for {start_method = }")
+ executor = futures.ProcessPoolExecutor(
+ max_workers=max_workers,
+ max_tasks_per_child=1,
+ mp_context=self.get_context())
+ f1 = executor.submit(ProcessPoolShutdownTest._good_task_gh_132969, 1)
+ f2 = executor.submit(ProcessPoolShutdownTest._failing_task_gh_132969, 2)
+ f3 = executor.submit(ProcessPoolShutdownTest._good_task_gh_132969, 3)
+ result = 0
+ try:
+ result += f1.result()
+ result += f2.result()
+ result += f3.result()
+ except ValueError:
+ # stop processing results upon first exception
+ pass
+
+ # Ensure that the executor cleans up after called
+ # shutdown with wait=False
+ executor_manager_thread = executor._executor_manager_thread
+ executor.shutdown(wait=False)
+ time.sleep(0.2)
+ executor_manager_thread.join()
+ return result
+
+ def test_shutdown_gh_132969_case_1(self):
+ # gh-132969: test that exception "object of type 'NoneType' has no len()"
+ # is not raised when shutdown(wait=False) is called.
+ result = self._run_test_issue_gh_132969(2)
+ self.assertEqual(result, 1)
+
+ def test_shutdown_gh_132969_case_2(self):
+ # gh-132969: test that process does not hang and
+ # exception "object of type 'NoneType' has no len()" is not raised
+ # when shutdown(wait=False) is called.
+ result = self._run_test_issue_gh_132969(4)
+ self.assertEqual(result, 1)
+
create_executor_tests(globals(), ProcessPoolShutdownTest,
executor_mixins=(ProcessPoolForkMixin,
diff --git a/Lib/test/test_configparser.py b/Lib/test/test_configparser.py
index 23904d17d32..e7364e18742 100644
--- a/Lib/test/test_configparser.py
+++ b/Lib/test/test_configparser.py
@@ -986,12 +986,12 @@ class ConfigParserTestCase(BasicTestCase, unittest.TestCase):
def test_defaults_keyword(self):
"""bpo-23835 fix for ConfigParser"""
- cf = self.newconfig(defaults={1: 2.4})
- self.assertEqual(cf[self.default_section]['1'], '2.4')
- self.assertAlmostEqual(cf[self.default_section].getfloat('1'), 2.4)
- cf = self.newconfig(defaults={"A": 5.2})
- self.assertEqual(cf[self.default_section]['a'], '5.2')
- self.assertAlmostEqual(cf[self.default_section].getfloat('a'), 5.2)
+ cf = self.newconfig(defaults={1: 2.5})
+ self.assertEqual(cf[self.default_section]['1'], '2.5')
+ self.assertAlmostEqual(cf[self.default_section].getfloat('1'), 2.5)
+ cf = self.newconfig(defaults={"A": 5.25})
+ self.assertEqual(cf[self.default_section]['a'], '5.25')
+ self.assertAlmostEqual(cf[self.default_section].getfloat('a'), 5.25)
class ConfigParserTestCaseNoInterpolation(BasicTestCase, unittest.TestCase):
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index cf651959803..6a3329fa5aa 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -48,23 +48,23 @@ class TestAbstractContextManager(unittest.TestCase):
def __exit__(self, exc_type, exc_value, traceback):
return None
- self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
+ self.assertIsSubclass(ManagerFromScratch, AbstractContextManager)
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
- self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
+ self.assertIsSubclass(DefaultEnter, AbstractContextManager)
class NoEnter(ManagerFromScratch):
__enter__ = None
- self.assertFalse(issubclass(NoEnter, AbstractContextManager))
+ self.assertNotIsSubclass(NoEnter, AbstractContextManager)
class NoExit(ManagerFromScratch):
__exit__ = None
- self.assertFalse(issubclass(NoExit, AbstractContextManager))
+ self.assertNotIsSubclass(NoExit, AbstractContextManager)
class ContextManagerTestCase(unittest.TestCase):
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index 7750186e56a..dcd00720379 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -77,23 +77,23 @@ class TestAbstractAsyncContextManager(unittest.TestCase):
async def __aexit__(self, exc_type, exc_value, traceback):
return None
- self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
+ self.assertIsSubclass(ManagerFromScratch, AbstractAsyncContextManager)
class DefaultEnter(AbstractAsyncContextManager):
async def __aexit__(self, *args):
await super().__aexit__(*args)
- self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
+ self.assertIsSubclass(DefaultEnter, AbstractAsyncContextManager)
class NoneAenter(ManagerFromScratch):
__aenter__ = None
- self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
+ self.assertNotIsSubclass(NoneAenter, AbstractAsyncContextManager)
class NoneAexit(ManagerFromScratch):
__aexit__ = None
- self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
+ self.assertNotIsSubclass(NoneAexit, AbstractAsyncContextManager)
class AsyncContextManagerTestCase(unittest.TestCase):
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index d76341417e9..467ec09d99e 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -19,7 +19,7 @@ class TestCopy(unittest.TestCase):
def test_exceptions(self):
self.assertIs(copy.Error, copy.error)
- self.assertTrue(issubclass(copy.Error, Exception))
+ self.assertIsSubclass(copy.Error, Exception)
# The copy() method
@@ -372,6 +372,7 @@ class TestCopy(unittest.TestCase):
self.assertIsNot(x[0], y[0])
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_list(self):
x = []
x.append(x)
@@ -400,6 +401,7 @@ class TestCopy(unittest.TestCase):
self.assertIs(x, y)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_tuple(self):
x = ([],)
x[0].append(x)
@@ -418,6 +420,7 @@ class TestCopy(unittest.TestCase):
self.assertIsNot(x["foo"], y["foo"])
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_dict(self):
x = {}
x['foo'] = x
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index 761cb230277..4755046fe19 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -527,7 +527,7 @@ class CoroutineTest(unittest.TestCase):
def test_gen_1(self):
def gen(): yield
- self.assertFalse(hasattr(gen, '__await__'))
+ self.assertNotHasAttr(gen, '__await__')
def test_func_1(self):
async def foo():
diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py
index 192c8eab26e..57e818b1c68 100644
--- a/Lib/test/test_cprofile.py
+++ b/Lib/test/test_cprofile.py
@@ -125,21 +125,22 @@ class CProfileTest(ProfileTest):
"""
gh-106152
generator.throw() should trigger a call in cProfile
- In the any() call below, there should be two entries for the generator:
- * one for the call to __next__ which gets a True and terminates any
- * one when the generator is garbage collected which will effectively
- do a throw.
"""
+
+ def gen():
+ yield
+
pr = self.profilerclass()
pr.enable()
- any(a == 1 for a in (1, 2))
+ g = gen()
+ try:
+ g.throw(SyntaxError)
+ except SyntaxError:
+ pass
pr.disable()
pr.create_stats()
- for func, (cc, nc, _, _, _) in pr.stats.items():
- if func[2] == "<genexpr>":
- self.assertEqual(cc, 1)
- self.assertEqual(nc, 1)
+ self.assertTrue(any("throw" in func[2] for func in pr.stats.keys())),
def test_bad_descriptor(self):
# gh-132250
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 5ebb78b0ea9..2fa0077a09b 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -1,7 +1,9 @@
+import contextlib
import itertools
import sys
import types
import unittest
+import warnings
from test.support import import_helper
@@ -9,80 +11,500 @@ _testinternalcapi = import_helper.import_module('_testinternalcapi')
_interpreters = import_helper.import_module('_interpreters')
from _interpreters import NotShareableError
-
+from test import _code_definitions as code_defs
from test import _crossinterp_definitions as defs
-BUILTIN_TYPES = [o for _, o in __builtins__.items()
- if isinstance(o, type)]
-EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES
+@contextlib.contextmanager
+def ignore_byteswarning():
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=BytesWarning)
+ yield
+
+
+# builtin types
+
+BUILTINS_TYPES = [o for _, o in __builtins__.items() if isinstance(o, type)]
+EXCEPTION_TYPES = [cls for cls in BUILTINS_TYPES
if issubclass(cls, BaseException)]
OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
- n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+ n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+BUILTIN_TYPES = [
+ *BUILTINS_TYPES,
+ *OTHER_TYPES,
+]
+
+# builtin exceptions
+
+try:
+ raise Exception
+except Exception as exc:
+ CAUGHT = exc
+EXCEPTIONS_WITH_SPECIAL_SIG = {
+ BaseExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ ExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ UnicodeError: (lambda msg: (None, msg, None, None, None)),
+ UnicodeEncodeError: (lambda msg: ('utf-8', '', 1, 3, msg)),
+ UnicodeDecodeError: (lambda msg: ('utf-8', b'', 1, 3, msg)),
+ UnicodeTranslateError: (lambda msg: ('', 1, 3, msg)),
+}
+BUILTIN_EXCEPTIONS = [
+ *(cls(*sig('error!')) for cls, sig in EXCEPTIONS_WITH_SPECIAL_SIG.items()),
+ *(cls('error!') for cls in EXCEPTION_TYPES
+ if cls not in EXCEPTIONS_WITH_SPECIAL_SIG),
+]
+
+# other builtin objects
+
+METHOD = defs.SpamOkay().okay
+BUILTIN_METHOD = [].append
+METHOD_DESCRIPTOR_WRAPPER = str.join
+METHOD_WRAPPER = object().__str__
+WRAPPER_DESCRIPTOR = object.__init__
+BUILTIN_WRAPPERS = {
+ METHOD: types.MethodType,
+ BUILTIN_METHOD: types.BuiltinMethodType,
+ dict.__dict__['fromkeys']: types.ClassMethodDescriptorType,
+ types.FunctionType.__code__: types.GetSetDescriptorType,
+ types.FunctionType.__globals__: types.MemberDescriptorType,
+ METHOD_DESCRIPTOR_WRAPPER: types.MethodDescriptorType,
+ METHOD_WRAPPER: types.MethodWrapperType,
+ WRAPPER_DESCRIPTOR: types.WrapperDescriptorType,
+ staticmethod(defs.SpamOkay.okay): None,
+ classmethod(defs.SpamOkay.okay): None,
+ property(defs.SpamOkay.okay): None,
+}
+BUILTIN_FUNCTIONS = [
+ # types.BuiltinFunctionType
+ len,
+ sys.is_finalizing,
+ sys.exit,
+ _testinternalcapi.get_crossinterp_data,
+]
+assert 'emptymod' not in sys.modules
+with import_helper.ready_to_import('emptymod', ''):
+ import emptymod as EMPTYMOD
+MODULES = [
+ sys,
+ defs,
+ unittest,
+ EMPTYMOD,
+]
+OBJECT = object()
+EXCEPTION = Exception()
+LAMBDA = (lambda: None)
+BUILTIN_SIMPLE = [
+ OBJECT,
+ # singletons
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ # bytes
+ *(i.to_bytes(2, 'little', signed=True)
+ for i in range(-1, 258)),
+ # str
+ 'hello world',
+ '你好世界',
+ '',
+ # int
+ sys.maxsize + 1,
+ sys.maxsize,
+ -sys.maxsize - 1,
+ -sys.maxsize - 2,
+ *range(-1, 258),
+ 2**1000,
+ # float
+ 0.0,
+ 1.1,
+ -1.0,
+ 0.12345678,
+ -0.12345678,
+]
+TUPLE_EXCEPTION = (0, 1.0, EXCEPTION)
+TUPLE_OBJECT = (0, 1.0, OBJECT)
+TUPLE_NESTED_EXCEPTION = (0, 1.0, (EXCEPTION,))
+TUPLE_NESTED_OBJECT = (0, 1.0, (OBJECT,))
+MEMORYVIEW_EMPTY = memoryview(b'')
+MEMORYVIEW_NOT_EMPTY = memoryview(b'spam'*42)
+MAPPING_PROXY_EMPTY = types.MappingProxyType({})
+BUILTIN_CONTAINERS = [
+ # tuple (flat)
+ (),
+ (1,),
+ ("hello", "world", ),
+ (1, True, "hello"),
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ # tuple (nested)
+ ((1,),),
+ ((1, 2), (3, 4)),
+ ((1, 2), (3, 4), (5, 6)),
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+ # buffer
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ # list
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ # dict
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ # set
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ # frozenset
+ frozenset([]),
+ frozenset({frozenset({1}), (2,)}),
+ # bytearray
+ bytearray(b''),
+ # other
+ MAPPING_PROXY_EMPTY,
+ types.SimpleNamespace(),
+]
+ns = {}
+exec("""
+try:
+ raise Exception
+except Exception as exc:
+ TRACEBACK = exc.__traceback__
+ FRAME = TRACEBACK.tb_frame
+""", ns, ns)
+BUILTIN_OTHER = [
+ # types.CellType
+ types.CellType(),
+ # types.FrameType
+ ns['FRAME'],
+ # types.TracebackType
+ ns['TRACEBACK'],
+]
+del ns
+
+# user-defined objects
+
+USER_TOP_INSTANCES = [c(*a) for c, a in defs.TOP_CLASSES.items()]
+USER_NESTED_INSTANCES = [c(*a) for c, a in defs.NESTED_CLASSES.items()]
+USER_INSTANCES = [
+ *USER_TOP_INSTANCES,
+ *USER_NESTED_INSTANCES,
+]
+USER_EXCEPTIONS = [
+ defs.MimimalError('error!'),
+]
+
+# shareable objects
+
+TUPLES_WITHOUT_EQUALITY = [
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+]
+_UNSHAREABLE_SIMPLE = [
+ Ellipsis,
+ NotImplemented,
+ OBJECT,
+ sys.maxsize + 1,
+ -sys.maxsize - 2,
+ 2**1000,
+]
+with ignore_byteswarning():
+ _SHAREABLE_SIMPLE = [o for o in BUILTIN_SIMPLE
+ if o not in _UNSHAREABLE_SIMPLE]
+ _SHAREABLE_CONTAINERS = [
+ *(o for o in BUILTIN_CONTAINERS if type(o) is memoryview),
+ *(o for o in BUILTIN_CONTAINERS
+ if type(o) is tuple and o not in TUPLES_WITHOUT_EQUALITY),
+ ]
+ _UNSHAREABLE_CONTAINERS = [o for o in BUILTIN_CONTAINERS
+ if o not in _SHAREABLE_CONTAINERS]
+SHAREABLE = [
+ *_SHAREABLE_SIMPLE,
+ *_SHAREABLE_CONTAINERS,
+]
+NOT_SHAREABLE = [
+ *_UNSHAREABLE_SIMPLE,
+ *_UNSHAREABLE_CONTAINERS,
+ *BUILTIN_TYPES,
+ *BUILTIN_WRAPPERS,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *MODULES,
+ *BUILTIN_OTHER,
+ # types.CodeType
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ # types.FunctionType
+ *defs.FUNCTIONS,
+ defs.SpamOkay.okay,
+ LAMBDA,
+ *defs.FUNCTION_LIKE,
+ # coroutines and generators
+ *defs.FUNCTION_LIKE_APPLIED,
+ # user classes
+ *defs.CLASSES,
+ *USER_INSTANCES,
+ # user exceptions
+ *USER_EXCEPTIONS,
+]
+
+# pickleable objects
+
+PICKLEABLE = [
+ *BUILTIN_SIMPLE,
+ *(o for o in BUILTIN_CONTAINERS if o not in [
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ MAPPING_PROXY_EMPTY,
+ ] or type(o) is dict),
+ *BUILTINS_TYPES,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *defs.TOP_FUNCTIONS,
+ defs.SpamOkay.okay,
+ *defs.FUNCTION_LIKE,
+ *defs.TOP_CLASSES,
+ *USER_TOP_INSTANCES,
+ *USER_EXCEPTIONS,
+ # from OTHER_TYPES
+ types.NoneType,
+ types.EllipsisType,
+ types.NotImplementedType,
+ types.GenericAlias,
+ types.UnionType,
+ types.SimpleNamespace,
+ # from BUILTIN_WRAPPERS
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_DESCRIPTOR_WRAPPER,
+ METHOD_WRAPPER,
+ WRAPPER_DESCRIPTOR,
+]
+assert not any(isinstance(o, types.MappingProxyType) for o in PICKLEABLE)
+
+
+# helpers
+
+DEFS = defs
+with open(code_defs.__file__) as infile:
+ _code_defs_text = infile.read()
+with open(DEFS.__file__) as infile:
+ _defs_text = infile.read()
+ _defs_text = _defs_text.replace('from ', '# from ')
+DEFS_TEXT = f"""
+#######################################
+# from {code_defs.__file__}
+
+{_code_defs_text}
+
+#######################################
+# from {defs.__file__}
+
+{_defs_text}
+"""
+del infile, _code_defs_text, _defs_text
+
+
+def load_defs(module=None):
+ """Return a new copy of the test._crossinterp_definitions module.
+
+ The module's __name__ matches the "module" arg, which is either
+ a str or a module.
+
+ If the "module" arg is a module then the just-loaded defs are also
+ copied into that module.
+
+ Note that the new module is not added to sys.modules.
+ """
+ if module is None:
+ modname = DEFS.__name__
+ elif isinstance(module, str):
+ modname = module
+ module = None
+ else:
+ modname = module.__name__
+ # Create the new module and populate it.
+ defs = import_helper.create_module(modname)
+ defs.__file__ = DEFS.__file__
+ exec(DEFS_TEXT, defs.__dict__)
+ # Copy the defs into the module arg, if any.
+ if module is not None:
+ for name, value in defs.__dict__.items():
+ if name.startswith('_'):
+ continue
+ assert not hasattr(module, name), (name, getattr(module, name))
+ setattr(module, name, value)
+ return defs
+
+
+@contextlib.contextmanager
+def using___main__():
+ """Make sure __main__ module exists (and clean up after)."""
+ modname = '__main__'
+ if modname not in sys.modules:
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+ else:
+ with import_helper.module_restored(modname) as mod:
+ yield mod
+
+
+@contextlib.contextmanager
+def temp_module(modname):
+ """Create the module and add to sys.modules, then remove it after."""
+ assert modname not in sys.modules, (modname,)
+ with import_helper.isolated_modules():
+ yield import_helper.add_module(modname)
+
+
+@contextlib.contextmanager
+def missing_defs_module(modname, *, prep=False):
+ assert modname not in sys.modules, (modname,)
+ if prep:
+ with import_helper.ready_to_import(modname, DEFS_TEXT):
+ yield modname
+ else:
+ with import_helper.isolated_modules():
+ yield modname
class _GetXIDataTests(unittest.TestCase):
MODE = None
+ def assert_functions_equal(self, func1, func2):
+ assert type(func1) is types.FunctionType, repr(func1)
+ assert type(func2) is types.FunctionType, repr(func2)
+ self.assertEqual(func1.__name__, func2.__name__)
+ self.assertEqual(func1.__code__, func2.__code__)
+ self.assertEqual(func1.__defaults__, func2.__defaults__)
+ self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__)
+ # We don't worry about __globals__ for now.
+
+ def assert_exc_args_equal(self, exc1, exc2):
+ args1 = exc1.args
+ args2 = exc2.args
+ if isinstance(exc1, ExceptionGroup):
+ self.assertIs(type(args1), type(args2))
+ self.assertEqual(len(args1), 2)
+ self.assertEqual(len(args1), len(args2))
+ self.assertEqual(args1[0], args2[0])
+ group1 = args1[1]
+ group2 = args2[1]
+ self.assertEqual(len(group1), len(group2))
+ for grouped1, grouped2 in zip(group1, group2):
+ # Currently the "extra" attrs are not preserved
+ # (via __reduce__).
+ self.assertIs(type(exc1), type(exc2))
+ self.assert_exc_equal(grouped1, grouped2)
+ else:
+ self.assertEqual(args1, args2)
+
+ def assert_exc_equal(self, exc1, exc2):
+ self.assertIs(type(exc1), type(exc2))
+
+ if type(exc1).__eq__ is not object.__eq__:
+ self.assertEqual(exc1, exc2)
+
+ self.assert_exc_args_equal(exc1, exc2)
+ # XXX For now we do not preserve tracebacks.
+ if exc1.__traceback__ is not None:
+ self.assertEqual(exc1.__traceback__, exc2.__traceback__)
+ self.assertEqual(
+ getattr(exc1, '__notes__', None),
+ getattr(exc2, '__notes__', None),
+ )
+ # We assume there are no cycles.
+ if exc1.__cause__ is None:
+ self.assertIs(exc1.__cause__, exc2.__cause__)
+ else:
+ self.assert_exc_equal(exc1.__cause__, exc2.__cause__)
+ if exc1.__context__ is None:
+ self.assertIs(exc1.__context__, exc2.__context__)
+ else:
+ self.assert_exc_equal(exc1.__context__, exc2.__context__)
+
+ def assert_equal_or_equalish(self, obj, expected):
+ cls = type(expected)
+ if cls.__eq__ is not object.__eq__:
+ self.assertEqual(obj, expected)
+ elif cls is types.FunctionType:
+ self.assert_functions_equal(obj, expected)
+ elif isinstance(expected, BaseException):
+ self.assert_exc_equal(obj, expected)
+ elif cls is types.MethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.BuiltinMethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.MethodWrapperType:
+ raise NotImplementedError(cls)
+ elif cls.__bases__ == (object,):
+ self.assertEqual(obj.__dict__, expected.__dict__)
+ else:
+ raise NotImplementedError(cls)
+
def get_xidata(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
return _testinternalcapi.get_crossinterp_data(obj, mode)
def get_roundtrip(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
- xid =_testinternalcapi.get_crossinterp_data(obj, mode)
+ return self._get_roundtrip(obj, mode)
+
+ def _get_roundtrip(self, obj, mode):
+ xid = _testinternalcapi.get_crossinterp_data(obj, mode)
return _testinternalcapi.restore_crossinterp_data(xid)
- def iter_roundtrip_values(self, values, *, mode=None):
+ def assert_roundtrip_identical(self, values, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
- xid = _testinternalcapi.get_crossinterp_data(obj, mode)
- got = _testinternalcapi.restore_crossinterp_data(xid)
- yield obj, got
-
- def assert_roundtrip_identical(self, values, *, mode=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- # XXX What about between interpreters?
- self.assertIs(got, obj)
+ with self.subTest(repr(obj)):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
- for obj, got in self.iter_roundtrip_values(values, mode=mode):
- self.assertEqual(got, obj)
- self.assertIs(type(got),
- type(obj) if expecttype is None else expecttype)
-
-# def assert_roundtrip_equal_not_identical(self, values, *,
-# mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
-#
-# def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
-# mode = self._resolve_mode(mode)
-# for obj in values:
-# cls = type(obj)
-# with self.subTest(obj):
-# got = self._get_roundtrip(obj, mode)
-# self.assertIsNot(got, obj)
-# self.assertIs(type(got), type(obj))
-# self.assertNotEqual(got, obj)
-# self.assertIs(type(got),
-# cls if expecttype is None else expecttype)
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(repr(obj)):
+ got = self._get_roundtrip(obj, mode)
+ if got is obj:
+ continue
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assert_equal_or_equalish(got, obj)
+
+ def assert_roundtrip_equal_not_identical(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(repr(obj)):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assert_equal_or_equalish(got, obj)
+
+ def assert_roundtrip_not_equal(self, values, *,
+ mode=None, expecttype=None):
+ mode = self._resolve_mode(mode)
+ for obj in values:
+ with self.subTest(repr(obj)):
+ got = self._get_roundtrip(obj, mode)
+ self.assertIsNot(got, obj)
+ self.assertIs(type(got),
+ type(obj) if expecttype is None else expecttype)
+ self.assertNotEqual(got, obj)
def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
with self.assertRaises(NotShareableError) as cm:
_testinternalcapi.get_crossinterp_data(obj, mode)
if exctype is not None:
@@ -95,6 +517,340 @@ class _GetXIDataTests(unittest.TestCase):
return mode
+class PickleTests(_GetXIDataTests):
+
+ MODE = 'pickle'
+
+ def test_shareable(self):
+ with ignore_byteswarning():
+ for obj in SHAREABLE:
+ if obj in PICKLEABLE:
+ self.assert_roundtrip_equal([obj])
+ else:
+ self.assert_not_shareable([obj])
+
+ def test_not_shareable(self):
+ with ignore_byteswarning():
+ for obj in NOT_SHAREABLE:
+ if type(obj) is types.MappingProxyType:
+ self.assert_not_shareable([obj])
+ elif obj in PICKLEABLE:
+ with self.subTest(repr(obj)):
+ # We don't worry about checking the actual value.
+ # The other tests should cover that well enough.
+ got = self.get_roundtrip(obj)
+ self.assertIs(type(got), type(obj))
+ else:
+ self.assert_not_shareable([obj])
+
+ def test_list(self):
+ self.assert_roundtrip_equal_not_identical([
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ ])
+
+ def test_dict(self):
+ self.assert_roundtrip_equal_not_identical([
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ ])
+
+ def test_set(self):
+ self.assert_roundtrip_equal_not_identical([
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ ])
+
+ # classes
+
+ def assert_class_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_equal_not_identical(instances)
+
+ # these don't compare equal
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ if cls not in defs.CLASSES_WITHOUT_EQUALITY:
+ continue
+ instances.append(cls(*args))
+ self.assert_roundtrip_equal(instances)
+
+ def assert_class_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for cls in defs.TOP_CLASSES:
+ assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__))
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ with self.subTest(repr(cls)):
+ setattr(mod, cls.__name__, cls)
+ xid = self.get_xidata(cls)
+ inst = cls(*args)
+ instxid = self.get_xidata(inst)
+ instances.append(
+ (cls, xid, inst, instxid))
+
+ for cls, xid, inst, instxid in instances:
+ with self.subTest(repr(cls)):
+ delattr(mod, cls.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, cls)
+ self.assertNotEqual(got, cls)
+
+ gotcls = got
+ got = _testinternalcapi.restore_crossinterp_data(instxid)
+ self.assertIsNot(got, inst)
+ self.assertIs(type(got), gotcls)
+ if cls in defs.CLASSES_WITHOUT_EQUALITY:
+ self.assertNotEqual(got, inst)
+ elif cls in defs.BUILTIN_SUBCLASSES:
+ self.assertEqual(got, inst)
+ else:
+ self.assertNotEqual(got, inst)
+
+ def assert_class_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_CLASSES)
+
+ instances = []
+ for cls, args in defs.TOP_CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
+
+ def test_user_class_normal(self):
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_class_defs_other_unpickle(defs, mod)
+
+ def test_user_class_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_class_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_class_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_class_defs_same(defs)
+
+ def test_user_class_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_user_class_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_class_defs_not_shareable(defs)
+
+ def test_nested_class(self):
+ eggs = defs.EggsNested()
+ with self.assertRaises(NotShareableError):
+ self.get_roundtrip(eggs)
+
+ # functions
+
+ def assert_func_defs_same(self, defs):
+ # Unpickle relative to the unchanged original module.
+ self.assert_roundtrip_identical(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_pickle(self, defs, mod):
+ # Pickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
+ # Unpickle relative to a different module than the original.
+ for func in defs.TOP_FUNCTIONS:
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
+
+ captured = []
+ for func in defs.TOP_FUNCTIONS:
+ with self.subTest(func):
+ setattr(mod, func.__name__, func)
+ xid = self.get_xidata(func)
+ captured.append(
+ (func, xid))
+
+ for func, xid in captured:
+ with self.subTest(func):
+ delattr(mod, func.__name__)
+ if fail:
+ with self.assertRaises(NotShareableError):
+ _testinternalcapi.restore_crossinterp_data(xid)
+ continue
+ got = _testinternalcapi.restore_crossinterp_data(xid)
+ self.assertIsNot(got, func)
+ self.assertNotEqual(got, func)
+
+ def assert_func_defs_not_shareable(self, defs):
+ self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+ def test_user_function_normal(self):
+ self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_in___main__(self):
+ with using___main__() as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in___main___with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in___main___unpickle_with_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ assert defs.__file__
+ mod.__file__ = defs.__file__
+ self.assert_func_defs_other_unpickle(defs, mod)
+
+ def test_user_func_not_in___main___unpickle_without_filename(self):
+ with using___main__() as mod:
+ defs = load_defs('__main__')
+ defs.__file__ = None
+ mod.__file__ = None
+ self.assert_func_defs_other_unpickle(defs, mod, fail=True)
+
+ def test_user_func_in_module(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod)
+ self.assert_func_defs_same(defs)
+
+ def test_user_func_not_in_module_with_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ assert defs.__file__
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_not_in_module_without_filename(self):
+ with temp_module('__spam__') as mod:
+ defs = load_defs(mod.__name__)
+ defs.__file__ = None
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_then_imported(self):
+ with missing_defs_module('__spam__', prep=True) as modname:
+ defs = load_defs(modname)
+ # For now, we only address this case for __main__.
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_user_func_module_missing_not_available(self):
+ with missing_defs_module('__spam__') as modname:
+ defs = load_defs(modname)
+ self.assert_func_defs_not_shareable(defs)
+
+ def test_nested_function(self):
+ self.assert_not_shareable(defs.NESTED_FUNCTIONS)
+
+ # exceptions
+
+ def test_user_exception_normal(self):
+ self.assert_roundtrip_equal([
+ defs.MimimalError('error!'),
+ ])
+ self.assert_roundtrip_equal_not_identical([
+ defs.RichError('error!', 42),
+ ])
+
+ def test_builtin_exception(self):
+ msg = 'error!'
+ try:
+ raise Exception
+ except Exception as exc:
+ caught = exc
+ special = {
+ BaseExceptionGroup: (msg, [caught]),
+ ExceptionGroup: (msg, [caught]),
+ UnicodeError: (None, msg, None, None, None),
+ UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
+ UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
+ UnicodeTranslateError: ('', 1, 3, msg),
+ }
+ exceptions = []
+ for cls in EXCEPTION_TYPES:
+ args = special.get(cls) or (msg,)
+ exceptions.append(cls(*args))
+
+ self.assert_roundtrip_equal(exceptions)
+
+
class MarshalTests(_GetXIDataTests):
MODE = 'marshal'
@@ -137,7 +893,7 @@ class MarshalTests(_GetXIDataTests):
'',
])
self.assert_not_shareable([
- object(),
+ OBJECT,
types.SimpleNamespace(),
])
@@ -208,10 +964,7 @@ class MarshalTests(_GetXIDataTests):
shareable = [
StopIteration,
]
- types = [
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ]
+ types = BUILTIN_TYPES
self.assert_not_shareable(cls for cls in types
if cls not in shareable)
self.assert_roundtrip_identical(cls for cls in types
@@ -286,10 +1039,236 @@ class MarshalTests(_GetXIDataTests):
])
+class CodeTests(_GetXIDataTests):
+
+ MODE = 'code'
+
+ def test_function_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
+ def test_functions(self):
+ self.assert_not_shareable([
+ *defs.FUNCTIONS,
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ 9999,
+ 'spam',
+ b'spam',
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class ShareableFuncTests(_GetXIDataTests):
+
+ MODE = 'func'
+
+ def test_stateless(self):
+ self.assert_roundtrip_equal([
+ *defs.STATELESS_FUNCTIONS,
+ # Generators can be stateless too.
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_not_stateless(self):
+ self.assert_not_shareable([
+ *(f for f in defs.FUNCTIONS
+ if f not in defs.STATELESS_FUNCTIONS),
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ 9999,
+ 'spam',
+ b'spam',
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class PureShareableScriptTests(_GetXIDataTests):
+
+ MODE = 'script-pure'
+
+ VALID_SCRIPTS = [
+ '',
+ 'spam',
+ '# a comment',
+ 'print("spam")',
+ 'raise Exception("spam")',
+ """if True:
+ do_something()
+ """,
+ """if True:
+ def spam(x):
+ return x
+ class Spam:
+ def eggs(self):
+ return 42
+ x = Spam().eggs()
+ raise ValueError(spam(x))
+ """,
+ ]
+ INVALID_SCRIPTS = [
+ ' pass', # IndentationError
+ '----', # SyntaxError
+ """if True:
+ def spam():
+ # no body
+ spam()
+ """, # IndentationError
+ ]
+
+ def test_valid_str(self):
+ self.assert_roundtrip_not_equal([
+ *self.VALID_SCRIPTS,
+ ], expecttype=types.CodeType)
+
+ def test_invalid_str(self):
+ self.assert_not_shareable([
+ *self.INVALID_SCRIPTS,
+ ])
+
+ def test_valid_bytes(self):
+ self.assert_roundtrip_not_equal([
+ *(s.encode('utf8') for s in self.VALID_SCRIPTS),
+ ], expecttype=types.CodeType)
+
+ def test_invalid_bytes(self):
+ self.assert_not_shareable([
+ *(s.encode('utf8') for s in self.INVALID_SCRIPTS),
+ ])
+
+ def test_pure_script_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_impure_script_code(self):
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_other_code(self):
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.FUNCTIONS
+ if f not in defs.SCRIPT_FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
+ def test_pure_script_function(self):
+ self.assert_roundtrip_not_equal([
+ *defs.PURE_SCRIPT_FUNCTIONS,
+ ], expecttype=types.CodeType)
+
+ def test_impure_script_function(self):
+ self.assert_not_shareable([
+ *(f for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_other_function(self):
+ self.assert_not_shareable([
+ *(f for f in defs.FUNCTIONS
+ if f not in defs.SCRIPT_FUNCTIONS),
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class ShareableScriptTests(PureShareableScriptTests):
+
+ MODE = 'script'
+
+ def test_impure_script_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_impure_script_function(self):
+ self.assert_roundtrip_not_equal([
+ *(f for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ], expecttype=types.CodeType)
+
+
+class ShareableFallbackTests(_GetXIDataTests):
+
+ MODE = 'fallback'
+
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
+ def test_not_shareable(self):
+ okay = [
+ *PICKLEABLE,
+ *defs.STATELESS_FUNCTIONS,
+ LAMBDA,
+ ]
+ ignored = [
+ *TUPLES_WITHOUT_EQUALITY,
+ OBJECT,
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_WRAPPER,
+ ]
+ with ignore_byteswarning():
+ self.assert_roundtrip_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in okay and o not in ignored
+ and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_roundtrip_not_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in ignored and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_not_shareable([
+ *(o for o in NOT_SHAREABLE if o not in okay),
+ MAPPING_PROXY_EMPTY,
+ ])
+
+
class ShareableTypeTests(_GetXIDataTests):
MODE = 'xidata'
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
def test_singletons(self):
self.assert_roundtrip_identical([
None,
@@ -357,8 +1336,8 @@ class ShareableTypeTests(_GetXIDataTests):
def test_tuples_containing_non_shareable_types(self):
non_shareables = [
- Exception(),
- object(),
+ EXCEPTION,
+ OBJECT,
]
for s in non_shareables:
value = tuple([0, 1.0, s])
@@ -373,21 +1352,31 @@ class ShareableTypeTests(_GetXIDataTests):
# The rest are not shareable.
+ def test_not_shareable(self):
+ self.assert_not_shareable(NOT_SHAREABLE)
+
def test_object(self):
self.assert_not_shareable([
object(),
])
+ def test_code(self):
+ # types.CodeType
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
def test_function_object(self):
for func in defs.FUNCTIONS:
assert type(func) is types.FunctionType, func
assert type(defs.SpamOkay.okay) is types.FunctionType, func
- assert type(lambda: None) is types.LambdaType
+ assert type(LAMBDA) is types.LambdaType
self.assert_not_shareable([
*defs.FUNCTIONS,
defs.SpamOkay.okay,
- (lambda: None),
+ LAMBDA,
])
def test_builtin_function(self):
@@ -444,28 +1433,15 @@ class ShareableTypeTests(_GetXIDataTests):
])
def test_class(self):
- self.assert_not_shareable([
- defs.Spam,
- defs.SpamOkay,
- defs.SpamFull,
- defs.SubSpamFull,
- defs.SubTuple,
- defs.EggsNested,
- ])
- self.assert_not_shareable([
- defs.Spam(),
- defs.SpamOkay(),
- defs.SpamFull(1, 2, 3),
- defs.SubSpamFull(1, 2, 3),
- defs.SubTuple([1, 2, 3]),
- defs.EggsNested(),
- ])
+ self.assert_not_shareable(defs.CLASSES)
+
+ instances = []
+ for cls, args in defs.CLASSES.items():
+ instances.append(cls(*args))
+ self.assert_not_shareable(instances)
def test_builtin_type(self):
- self.assert_not_shareable([
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ])
+ self.assert_not_shareable(BUILTIN_TYPES)
def test_exception(self):
self.assert_not_shareable([
@@ -504,14 +1480,8 @@ class ShareableTypeTests(_GetXIDataTests):
""", ns, ns)
self.assert_not_shareable([
- types.MappingProxyType({}),
+ MAPPING_PROXY_EMPTY,
types.SimpleNamespace(),
- # types.CodeType
- defs.spam_minimal.__code__,
- defs.spam_full.__code__,
- defs.spam_CC.__code__,
- defs.eggs_closure_C.__code__,
- defs.ham_C_closure.__code__,
# types.CellType
types.CellType(),
# types.FrameType
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 4af8f7f480e..60feab225a1 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -10,7 +10,8 @@ import csv
import gc
import pickle
from test import support
-from test.support import import_helper, check_disallow_instantiation
+from test.support import cpython_only, import_helper, check_disallow_instantiation
+from test.support.import_helper import ensure_lazy_imports
from itertools import permutations
from textwrap import dedent
from collections import OrderedDict
@@ -1121,19 +1122,22 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be a 1-character string')
+ '"quotechar" must be a unicode character or None, '
+ 'not a string of length 0')
mydialect.quotechar = "''"
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be a 1-character string')
+ '"quotechar" must be a unicode character or None, '
+ 'not a string of length 2')
mydialect.quotechar = 4
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be string or None, not int')
+ '"quotechar" must be a unicode character or None, '
+ 'not int')
def test_delimiter(self):
class mydialect(csv.Dialect):
@@ -1150,31 +1154,32 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be a 1-character string')
+ '"delimiter" must be a unicode character, '
+ 'not a string of length 3')
mydialect.delimiter = ""
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be a 1-character string')
+ '"delimiter" must be a unicode character, not a string of length 0')
mydialect.delimiter = b","
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not bytes')
+ '"delimiter" must be a unicode character, not bytes')
mydialect.delimiter = 4
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not int')
+ '"delimiter" must be a unicode character, not int')
mydialect.delimiter = None
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not NoneType')
+ '"delimiter" must be a unicode character, not NoneType')
def test_escapechar(self):
class mydialect(csv.Dialect):
@@ -1188,20 +1193,32 @@ class TestDialectValidity(unittest.TestCase):
self.assertEqual(d.escapechar, "\\")
mydialect.escapechar = ""
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not a string of length 0')
mydialect.escapechar = "**"
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not a string of length 2')
mydialect.escapechar = b"*"
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not bytes'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not bytes')
mydialect.escapechar = 4
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not int'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not int')
def test_lineterminator(self):
class mydialect(csv.Dialect):
@@ -1222,7 +1239,13 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"lineterminator" must be a string')
+ '"lineterminator" must be a string, not int')
+
+ mydialect.lineterminator = None
+ with self.assertRaises(csv.Error) as cm:
+ mydialect()
+ self.assertEqual(str(cm.exception),
+ '"lineterminator" must be a string, not NoneType')
def test_invalid_chars(self):
def create_invalid(field_name, value, **kwargs):
@@ -1565,6 +1588,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, csv, ('csv', '_csv'))
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("csv", {"re"})
+
def test_subclassable(self):
# issue 44089
class Foo(csv.Error): ...
diff --git a/Lib/test/test_ctypes/_support.py b/Lib/test/test_ctypes/_support.py
index 946d654a19a..700657a4e41 100644
--- a/Lib/test/test_ctypes/_support.py
+++ b/Lib/test/test_ctypes/_support.py
@@ -3,7 +3,6 @@
import ctypes
from _ctypes import Structure, Union, _Pointer, Array, _SimpleCData, CFuncPtr
import sys
-from test import support
_CData = Structure.__base__
diff --git a/Lib/test/test_ctypes/test_aligned_structures.py b/Lib/test/test_ctypes/test_aligned_structures.py
index 0c563ab8055..50b4d729b9d 100644
--- a/Lib/test/test_ctypes/test_aligned_structures.py
+++ b/Lib/test/test_ctypes/test_aligned_structures.py
@@ -316,6 +316,7 @@ class TestAlignedStructures(unittest.TestCase, StructCheckMixin):
class Main(sbase):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", c_ubyte),
("b", Inner),
diff --git a/Lib/test/test_ctypes/test_bitfields.py b/Lib/test/test_ctypes/test_bitfields.py
index dc81e752567..518f838219e 100644
--- a/Lib/test/test_ctypes/test_bitfields.py
+++ b/Lib/test/test_ctypes/test_bitfields.py
@@ -430,6 +430,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_84039(self):
class Bad(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -443,9 +444,9 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
("b1", c_uint16, 12),
]
-
class GoodA(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -460,6 +461,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
class Good(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", GoodA),
("b0", c_uint16, 4),
@@ -475,6 +477,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_73939(self):
class MyStructure(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("P", c_uint16),
("L", c_uint16, 9),
diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py
index 072c60d53dd..f14e1aa32e1 100644
--- a/Lib/test/test_ctypes/test_byteswap.py
+++ b/Lib/test/test_ctypes/test_byteswap.py
@@ -1,5 +1,4 @@
import binascii
-import ctypes
import math
import struct
import sys
@@ -232,7 +231,6 @@ class Test(unittest.TestCase, StructCheckMixin):
self.assertEqual(len(data), sizeof(TestStructure))
ptr = POINTER(TestStructure)
s = cast(data, ptr)[0]
- del ctypes._pointer_type_cache[TestStructure]
self.assertEqual(s.point.x, 1)
self.assertEqual(s.point.y, 2)
@@ -270,6 +268,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(base):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
@@ -297,6 +296,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
@@ -371,7 +371,6 @@ class Test(unittest.TestCase, StructCheckMixin):
self.assertEqual(len(data), sizeof(TestUnion))
ptr = POINTER(TestUnion)
s = cast(data, ptr)[0]
- del ctypes._pointer_type_cache[TestUnion]
self.assertEqual(s.point.x, 1)
self.assertEqual(s.point.y, 2)
diff --git a/Lib/test/test_ctypes/test_c_simple_type_meta.py b/Lib/test/test_ctypes/test_c_simple_type_meta.py
index 2328611856a..fd261acf497 100644
--- a/Lib/test/test_ctypes/test_c_simple_type_meta.py
+++ b/Lib/test/test_ctypes/test_c_simple_type_meta.py
@@ -1,16 +1,15 @@
import unittest
from test.support import MS_WINDOWS
import ctypes
-from ctypes import POINTER, c_void_p
+from ctypes import POINTER, Structure, c_void_p
-from ._support import PyCSimpleType
+from ._support import PyCSimpleType, PyCPointerType, PyCStructType
-class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
- def tearDown(self):
- # to not leak references, we must clean _pointer_type_cache
- ctypes._reset_cache()
+def set_non_ctypes_pointer_type(cls, pointer_type):
+ cls.__pointer_type__ = pointer_type
+class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
def test_creating_pointer_in_dunder_new_1(self):
# Test metaclass whose instances are C types; when the type is
# created it automatically creates a pointer type for itself.
@@ -36,7 +35,7 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
else:
ptr_bases = (self, POINTER(bases[0]))
p = p_meta(f"POINTER({self.__name__})", ptr_bases, {})
- ctypes._pointer_type_cache[self] = p
+ set_non_ctypes_pointer_type(self, p)
return self
class p_meta(PyCSimpleType, ct_meta):
@@ -45,20 +44,36 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
class PtrBase(c_void_p, metaclass=p_meta):
pass
+ ptr_base_pointer = POINTER(PtrBase)
+
class CtBase(object, metaclass=ct_meta):
pass
+ ct_base_pointer = POINTER(CtBase)
+
class Sub(CtBase):
pass
+ sub_pointer = POINTER(Sub)
+
class Sub2(Sub):
pass
+ sub2_pointer = POINTER(Sub2)
+
+ self.assertIsNot(ptr_base_pointer, ct_base_pointer)
+ self.assertIsNot(ct_base_pointer, sub_pointer)
+ self.assertIsNot(sub_pointer, sub2_pointer)
+
self.assertIsInstance(POINTER(Sub2), p_meta)
self.assertIsSubclass(POINTER(Sub2), Sub2)
self.assertIsSubclass(POINTER(Sub2), POINTER(Sub))
self.assertIsSubclass(POINTER(Sub), POINTER(CtBase))
+ self.assertIs(POINTER(Sub2), sub2_pointer)
+ self.assertIs(POINTER(Sub), sub_pointer)
+ self.assertIs(POINTER(CtBase), ct_base_pointer)
+
def test_creating_pointer_in_dunder_new_2(self):
# A simpler variant of the above, used in `CoClass` of the `comtypes`
# project.
@@ -69,7 +84,7 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
if isinstance(self, p_meta):
return self
p = p_meta(f"POINTER({self.__name__})", (self, c_void_p), {})
- ctypes._pointer_type_cache[self] = p
+ set_non_ctypes_pointer_type(self, p)
return self
class p_meta(PyCSimpleType, ct_meta):
@@ -78,15 +93,27 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
class Core(object):
pass
+ with self.assertRaisesRegex(TypeError, "must have storage info"):
+ POINTER(Core)
+
class CtBase(Core, metaclass=ct_meta):
pass
+ ct_base_pointer = POINTER(CtBase)
+
class Sub(CtBase):
pass
+ sub_pointer = POINTER(Sub)
+
+ self.assertIsNot(ct_base_pointer, sub_pointer)
+
self.assertIsInstance(POINTER(Sub), p_meta)
self.assertIsSubclass(POINTER(Sub), Sub)
+ self.assertIs(POINTER(Sub), sub_pointer)
+ self.assertIs(POINTER(CtBase), ct_base_pointer)
+
def test_creating_pointer_in_dunder_init_1(self):
class ct_meta(type):
def __init__(self, name, bases, namespace):
@@ -103,7 +130,7 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
else:
ptr_bases = (self, POINTER(bases[0]))
p = p_meta(f"POINTER({self.__name__})", ptr_bases, {})
- ctypes._pointer_type_cache[self] = p
+ set_non_ctypes_pointer_type(self, p)
class p_meta(PyCSimpleType, ct_meta):
pass
@@ -111,20 +138,37 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
class PtrBase(c_void_p, metaclass=p_meta):
pass
+ ptr_base_pointer = POINTER(PtrBase)
+
class CtBase(object, metaclass=ct_meta):
pass
+ ct_base_pointer = POINTER(CtBase)
+
class Sub(CtBase):
pass
+ sub_pointer = POINTER(Sub)
+
class Sub2(Sub):
pass
+ sub2_pointer = POINTER(Sub2)
+
+ self.assertIsNot(ptr_base_pointer, ct_base_pointer)
+ self.assertIsNot(ct_base_pointer, sub_pointer)
+ self.assertIsNot(sub_pointer, sub2_pointer)
+
self.assertIsInstance(POINTER(Sub2), p_meta)
self.assertIsSubclass(POINTER(Sub2), Sub2)
self.assertIsSubclass(POINTER(Sub2), POINTER(Sub))
self.assertIsSubclass(POINTER(Sub), POINTER(CtBase))
+ self.assertIs(POINTER(PtrBase), ptr_base_pointer)
+ self.assertIs(POINTER(CtBase), ct_base_pointer)
+ self.assertIs(POINTER(Sub), sub_pointer)
+ self.assertIs(POINTER(Sub2), sub2_pointer)
+
def test_creating_pointer_in_dunder_init_2(self):
class ct_meta(type):
def __init__(self, name, bases, namespace):
@@ -135,7 +179,7 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
if isinstance(self, p_meta):
return
p = p_meta(f"POINTER({self.__name__})", (self, c_void_p), {})
- ctypes._pointer_type_cache[self] = p
+ set_non_ctypes_pointer_type(self, p)
class p_meta(PyCSimpleType, ct_meta):
pass
@@ -146,12 +190,21 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
class CtBase(Core, metaclass=ct_meta):
pass
+ ct_base_pointer = POINTER(CtBase)
+
class Sub(CtBase):
pass
+ sub_pointer = POINTER(Sub)
+
+ self.assertIsNot(ct_base_pointer, sub_pointer)
+
self.assertIsInstance(POINTER(Sub), p_meta)
self.assertIsSubclass(POINTER(Sub), Sub)
+ self.assertIs(POINTER(CtBase), ct_base_pointer)
+ self.assertIs(POINTER(Sub), sub_pointer)
+
def test_bad_type_message(self):
"""Verify the error message that lists all available type codes"""
# (The string is generated at runtime, so this checks the underlying
@@ -168,3 +221,164 @@ class PyCSimpleTypeAsMetaclassTest(unittest.TestCase):
if not MS_WINDOWS:
expected_type_chars.remove('X')
self.assertIn("'" + ''.join(expected_type_chars) + "'", message)
+
+ def test_creating_pointer_in_dunder_init_3(self):
+ """Check if interfcase subclasses properly creates according internal
+ pointer types. But not the same as external pointer types.
+ """
+
+ class StructureMeta(PyCStructType):
+ def __new__(cls, name, bases, dct, /, create_pointer_type=True):
+ assert len(bases) == 1, bases
+ return super().__new__(cls, name, bases, dct)
+
+ def __init__(self, name, bases, dct, /, create_pointer_type=True):
+
+ super().__init__(name, bases, dct)
+ if create_pointer_type:
+ p_bases = (POINTER(bases[0]),)
+ ns = {'_type_': self}
+ internal_pointer_type = PointerMeta(f"p{name}", p_bases, ns)
+ assert isinstance(internal_pointer_type, PyCPointerType)
+ assert self.__pointer_type__ is internal_pointer_type
+
+ class PointerMeta(PyCPointerType):
+ def __new__(cls, name, bases, dct):
+ target = dct.get('_type_', None)
+ if target is None:
+
+ # Create corresponding interface type and then set it as target
+ target = StructureMeta(
+ f"_{name}_",
+ (bases[0]._type_,),
+ {},
+ create_pointer_type=False
+ )
+ dct['_type_'] = target
+
+ pointer_type = super().__new__(cls, name, bases, dct)
+ assert not hasattr(target, '__pointer_type__')
+
+ return pointer_type
+
+ def __init__(self, name, bases, dct, /, create_pointer_type=True):
+ target = dct.get('_type_', None)
+ assert not hasattr(target, '__pointer_type__')
+ super().__init__(name, bases, dct)
+ assert target.__pointer_type__ is self
+
+
+ class Interface(Structure, metaclass=StructureMeta, create_pointer_type=False):
+ pass
+
+ class pInterface(POINTER(c_void_p), metaclass=PointerMeta):
+ _type_ = Interface
+
+ class IUnknown(Interface):
+ pass
+
+ class pIUnknown(pInterface):
+ pass
+
+ self.assertTrue(issubclass(POINTER(IUnknown), pInterface))
+
+ self.assertIs(POINTER(Interface), pInterface)
+ self.assertIsNot(POINTER(IUnknown), pIUnknown)
+
+ def test_creating_pointer_in_dunder_init_4(self):
+ """Check if interfcase subclasses properly creates according internal
+ pointer types, the same as external pointer types.
+ """
+ class StructureMeta(PyCStructType):
+ def __new__(cls, name, bases, dct, /, create_pointer_type=True):
+ assert len(bases) == 1, bases
+
+ return super().__new__(cls, name, bases, dct)
+
+ def __init__(self, name, bases, dct, /, create_pointer_type=True):
+
+ super().__init__(name, bases, dct)
+ if create_pointer_type:
+ p_bases = (POINTER(bases[0]),)
+ ns = {'_type_': self}
+ internal_pointer_type = PointerMeta(f"p{name}", p_bases, ns)
+ assert isinstance(internal_pointer_type, PyCPointerType)
+ assert self.__pointer_type__ is internal_pointer_type
+
+ class PointerMeta(PyCPointerType):
+ def __new__(cls, name, bases, dct):
+ target = dct.get('_type_', None)
+ assert target is not None
+ pointer_type = getattr(target, '__pointer_type__', None)
+
+ if pointer_type is None:
+ pointer_type = super().__new__(cls, name, bases, dct)
+
+ return pointer_type
+
+ def __init__(self, name, bases, dct, /, create_pointer_type=True):
+ target = dct.get('_type_', None)
+ if not hasattr(target, '__pointer_type__'):
+ # target.__pointer_type__ was created by super().__new__
+ super().__init__(name, bases, dct)
+
+ assert target.__pointer_type__ is self
+
+
+ class Interface(Structure, metaclass=StructureMeta, create_pointer_type=False):
+ pass
+
+ class pInterface(POINTER(c_void_p), metaclass=PointerMeta):
+ _type_ = Interface
+
+ class IUnknown(Interface):
+ pass
+
+ class pIUnknown(pInterface):
+ _type_ = IUnknown
+
+ self.assertTrue(issubclass(POINTER(IUnknown), pInterface))
+
+ self.assertIs(POINTER(Interface), pInterface)
+ self.assertIs(POINTER(IUnknown), pIUnknown)
+
+ def test_custom_pointer_cache_for_ctypes_type1(self):
+ # Check if PyCPointerType.__init__() caches a pointer type
+ # customized in the metatype's __new__().
+ class PointerMeta(PyCPointerType):
+ def __new__(cls, name, bases, namespace):
+ namespace["_type_"] = C
+ return super().__new__(cls, name, bases, namespace)
+
+ def __init__(self, name, bases, namespace):
+ assert not hasattr(C, '__pointer_type__')
+ super().__init__(name, bases, namespace)
+ assert C.__pointer_type__ is self
+
+ class C(c_void_p): # ctypes type
+ pass
+
+ class P(ctypes._Pointer, metaclass=PointerMeta):
+ pass
+
+ self.assertIs(P._type_, C)
+ self.assertIs(P, POINTER(C))
+
+ def test_custom_pointer_cache_for_ctypes_type2(self):
+ # Check if PyCPointerType.__init__() caches a pointer type
+ # customized in the metatype's __init__().
+ class PointerMeta(PyCPointerType):
+ def __init__(self, name, bases, namespace):
+ self._type_ = namespace["_type_"] = C
+ assert not hasattr(C, '__pointer_type__')
+ super().__init__(name, bases, namespace)
+ assert C.__pointer_type__ is self
+
+ class C(c_void_p): # ctypes type
+ pass
+
+ class P(ctypes._Pointer, metaclass=PointerMeta):
+ pass
+
+ self.assertIs(P._type_, C)
+ self.assertIs(P, POINTER(C))
diff --git a/Lib/test/test_ctypes/test_generated_structs.py b/Lib/test/test_ctypes/test_generated_structs.py
index 9a8102219d8..1cb46a82701 100644
--- a/Lib/test/test_ctypes/test_generated_structs.py
+++ b/Lib/test/test_ctypes/test_generated_structs.py
@@ -10,7 +10,7 @@ Run this module to regenerate the files:
"""
import unittest
-from test.support import import_helper, verbose
+from test.support import import_helper
import re
from dataclasses import dataclass
from functools import cached_property
@@ -125,18 +125,21 @@ class Nested(Structure):
class Packed1(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 1
+ _layout_ = 'ms'
@register()
class Packed2(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 2
+ _layout_ = 'ms'
@register()
class Packed3(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 4
+ _layout_ = 'ms'
@register()
@@ -155,6 +158,7 @@ class Packed4(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 8
+ _layout_ = 'ms'
@register()
class X86_32EdgeCase(Structure):
@@ -366,6 +370,7 @@ class Example_gh_95496(Structure):
@register()
class Example_gh_84039_bad(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -380,6 +385,7 @@ class Example_gh_84039_bad(Structure):
@register()
class Example_gh_84039_good_a(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -392,6 +398,7 @@ class Example_gh_84039_good_a(Structure):
@register()
class Example_gh_84039_good(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", Example_gh_84039_good_a),
("b0", c_uint16, 4),
("b1", c_uint16, 12)]
@@ -399,6 +406,7 @@ class Example_gh_84039_good(Structure):
@register()
class Example_gh_73939(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("P", c_uint16),
("L", c_uint16, 9),
("Pro", c_uint16, 1),
@@ -419,6 +427,7 @@ class Example_gh_86098(Structure):
@register()
class Example_gh_86098_pack(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", c_uint8, 8),
("b", c_uint8, 8),
("c", c_uint32, 16)]
@@ -528,7 +537,7 @@ def dump_ctype(tp, struct_or_union_tag='', variable_name='', semi=''):
pushes.append(f'#pragma pack(push, {pack})')
pops.append(f'#pragma pack(pop)')
layout = getattr(tp, '_layout_', None)
- if layout == 'ms' or pack:
+ if layout == 'ms':
# The 'ms_struct' attribute only works on x86 and PowerPC
requires.add(
'defined(MS_WIN32) || ('
diff --git a/Lib/test/test_ctypes/test_incomplete.py b/Lib/test/test_ctypes/test_incomplete.py
index 9f859793d88..3189fcd1bd1 100644
--- a/Lib/test/test_ctypes/test_incomplete.py
+++ b/Lib/test/test_ctypes/test_incomplete.py
@@ -1,24 +1,28 @@
import ctypes
import unittest
-import warnings
from ctypes import Structure, POINTER, pointer, c_char_p
+# String-based "incomplete pointers" were implemented in ctypes 0.6.3 (2003, when
+# ctypes was an external project). They made obsolete by the current
+# incomplete *types* (setting `_fields_` late) in 0.9.5 (2005).
+# ctypes was added to Python 2.5 (2006), without any mention in docs.
-# The incomplete pointer example from the tutorial
+# This tests incomplete pointer example from the old tutorial
+# (https://svn.python.org/projects/ctypes/tags/release_0_6_3/ctypes/docs/tutorial.stx)
class TestSetPointerType(unittest.TestCase):
def tearDown(self):
- # to not leak references, we must clean _pointer_type_cache
- ctypes._reset_cache()
+ ctypes._pointer_type_cache_fallback.clear()
def test_incomplete_example(self):
- lpcell = POINTER("cell")
+ with self.assertWarns(DeprecationWarning):
+ lpcell = POINTER("cell")
class cell(Structure):
_fields_ = [("name", c_char_p),
("next", lpcell)]
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- ctypes.SetPointerType(lpcell, cell)
+ lpcell.set_type(cell)
+
+ self.assertIs(POINTER(cell), lpcell)
c1 = cell()
c1.name = b"foo"
@@ -37,13 +41,14 @@ class TestSetPointerType(unittest.TestCase):
self.assertEqual(result, [b"foo", b"bar"] * 4)
def test_deprecation(self):
- lpcell = POINTER("cell")
+ with self.assertWarns(DeprecationWarning):
+ lpcell = POINTER("cell")
class cell(Structure):
_fields_ = [("name", c_char_p),
("next", lpcell)]
- with self.assertWarns(DeprecationWarning):
- ctypes.SetPointerType(lpcell, cell)
+ lpcell.set_type(cell)
+ self.assertIs(POINTER(cell), lpcell)
if __name__ == '__main__':
diff --git a/Lib/test/test_ctypes/test_keeprefs.py b/Lib/test/test_ctypes/test_keeprefs.py
index 23b03b64b4a..5602460d5ff 100644
--- a/Lib/test/test_ctypes/test_keeprefs.py
+++ b/Lib/test/test_ctypes/test_keeprefs.py
@@ -1,6 +1,5 @@
import unittest
-from ctypes import (Structure, POINTER, pointer, _pointer_type_cache,
- c_char_p, c_int)
+from ctypes import (Structure, POINTER, pointer, c_char_p, c_int)
class SimpleTestCase(unittest.TestCase):
@@ -115,10 +114,6 @@ class PointerToStructure(unittest.TestCase):
r.a[0].x = 42
r.a[0].y = 99
- # to avoid leaking when tests are run several times
- # clean up the types left in the cache.
- del _pointer_type_cache[POINT]
-
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_ctypes/test_parameters.py b/Lib/test/test_ctypes/test_parameters.py
index f89521cf8b3..46f8ff93efa 100644
--- a/Lib/test/test_ctypes/test_parameters.py
+++ b/Lib/test/test_ctypes/test_parameters.py
@@ -1,3 +1,4 @@
+import sys
import unittest
import test.support
from ctypes import (CDLL, PyDLL, ArgumentError,
@@ -240,7 +241,8 @@ class SimpleTypesTestCase(unittest.TestCase):
self.assertRegex(repr(c_ulonglong.from_param(20000)), r"^<cparam '[LIQ]' \(20000\)>$")
self.assertEqual(repr(c_float.from_param(1.5)), "<cparam 'f' (1.5)>")
self.assertEqual(repr(c_double.from_param(1.5)), "<cparam 'd' (1.5)>")
- self.assertEqual(repr(c_double.from_param(1e300)), "<cparam 'd' (1e+300)>")
+ if sys.float_repr_style == 'short':
+ self.assertEqual(repr(c_double.from_param(1e300)), "<cparam 'd' (1e+300)>")
self.assertRegex(repr(c_longdouble.from_param(1.5)), r"^<cparam ('d' \(1.5\)|'g' at 0x[A-Fa-f0-9]+)>$")
self.assertRegex(repr(c_char_p.from_param(b'hihi')), r"^<cparam 'z' \(0x[A-Fa-f0-9]+\)>$")
self.assertRegex(repr(c_wchar_p.from_param('hihi')), r"^<cparam 'Z' \(0x[A-Fa-f0-9]+\)>$")
diff --git a/Lib/test/test_ctypes/test_pep3118.py b/Lib/test/test_ctypes/test_pep3118.py
index 06b2ccecade..11a0744f5a8 100644
--- a/Lib/test/test_ctypes/test_pep3118.py
+++ b/Lib/test/test_ctypes/test_pep3118.py
@@ -81,6 +81,7 @@ class Point(Structure):
class PackedPoint(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_long), ("y", c_long)]
class PointMidPad(Structure):
@@ -88,6 +89,7 @@ class PointMidPad(Structure):
class PackedPointMidPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_byte), ("y", c_uint64)]
class PointEndPad(Structure):
@@ -95,6 +97,7 @@ class PointEndPad(Structure):
class PackedPointEndPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_uint64), ("y", c_byte)]
class Point2(Structure):
diff --git a/Lib/test/test_ctypes/test_pointers.py b/Lib/test/test_ctypes/test_pointers.py
index fc558e10ba4..a8d243a45de 100644
--- a/Lib/test/test_ctypes/test_pointers.py
+++ b/Lib/test/test_ctypes/test_pointers.py
@@ -1,15 +1,18 @@
import array
import ctypes
+import gc
import sys
import unittest
from ctypes import (CDLL, CFUNCTYPE, Structure,
- POINTER, pointer, _Pointer, _pointer_type_cache,
+ POINTER, pointer, _Pointer,
byref, sizeof,
c_void_p, c_char_p,
c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint,
c_long, c_ulong, c_longlong, c_ulonglong,
c_float, c_double)
+from ctypes import _pointer_type_cache, _pointer_type_cache_fallback
from test.support import import_helper
+from weakref import WeakSet
_ctypes_test = import_helper.import_module("_ctypes_test")
from ._support import (_CData, PyCPointerType, Py_TPFLAGS_DISALLOW_INSTANTIATION,
Py_TPFLAGS_IMMUTABLETYPE)
@@ -22,6 +25,9 @@ python_types = [int, int, int, int, int, int,
class PointersTestCase(unittest.TestCase):
+ def tearDown(self):
+ _pointer_type_cache_fallback.clear()
+
def test_inheritance_hierarchy(self):
self.assertEqual(_Pointer.mro(), [_Pointer, _CData, object])
@@ -127,6 +133,14 @@ class PointersTestCase(unittest.TestCase):
addr = a.buffer_info()[0]
p = POINTER(POINTER(c_int))
+ def test_pointer_from_pointer(self):
+ p1 = POINTER(c_int)
+ p2 = POINTER(p1)
+
+ self.assertIsNot(p1, p2)
+ self.assertIs(p1.__pointer_type__, p2)
+ self.assertIs(p2._type_, p1)
+
def test_other(self):
class Table(Structure):
_fields_ = [("a", c_int),
@@ -141,8 +155,6 @@ class PointersTestCase(unittest.TestCase):
pt.contents.c = 33
- del _pointer_type_cache[Table]
-
def test_basic(self):
p = pointer(c_int(42))
# Although a pointer can be indexed, it has no length
@@ -175,6 +187,7 @@ class PointersTestCase(unittest.TestCase):
q = pointer(y)
pp[0] = q # <==
self.assertEqual(p[0], 6)
+
def test_c_void_p(self):
# http://sourceforge.net/tracker/?func=detail&aid=1518190&group_id=5470&atid=105470
if sizeof(c_void_p) == 4:
@@ -193,6 +206,30 @@ class PointersTestCase(unittest.TestCase):
self.assertRaises(TypeError, c_void_p, 3.14) # make sure floats are NOT accepted
self.assertRaises(TypeError, c_void_p, object()) # nor other objects
+ def test_read_null_pointer(self):
+ null_ptr = POINTER(c_int)()
+ with self.assertRaisesRegex(ValueError, "NULL pointer access"):
+ null_ptr[0]
+
+ def test_write_null_pointer(self):
+ null_ptr = POINTER(c_int)()
+ with self.assertRaisesRegex(ValueError, "NULL pointer access"):
+ null_ptr[0] = 1
+
+ def test_set_pointer_to_null_and_read(self):
+ class Bar(Structure):
+ _fields_ = [("values", POINTER(c_int))]
+
+ bar = Bar()
+ bar.values = (c_int * 3)(1, 2, 3)
+
+ values = [bar.values[0], bar.values[1], bar.values[2]]
+ self.assertEqual(values, [1, 2, 3])
+
+ bar.values = None
+ with self.assertRaisesRegex(ValueError, "NULL pointer access"):
+ bar.values[0]
+
def test_pointers_bool(self):
# NULL pointers have a boolean False value, non-NULL pointers True.
self.assertEqual(bool(POINTER(c_int)()), False)
@@ -210,20 +247,231 @@ class PointersTestCase(unittest.TestCase):
LargeNamedType = type('T' * 2 ** 25, (Structure,), {})
self.assertTrue(POINTER(LargeNamedType))
- # to not leak references, we must clean _pointer_type_cache
- del _pointer_type_cache[LargeNamedType]
-
def test_pointer_type_str_name(self):
large_string = 'T' * 2 ** 25
- P = POINTER(large_string)
+ with self.assertWarns(DeprecationWarning):
+ P = POINTER(large_string)
self.assertTrue(P)
- # to not leak references, we must clean _pointer_type_cache
- del _pointer_type_cache[id(P)]
-
def test_abstract(self):
self.assertRaises(TypeError, _Pointer.set_type, 42)
+ def test_pointer_types_equal(self):
+ t1 = POINTER(c_int)
+ t2 = POINTER(c_int)
+
+ self.assertIs(t1, t2)
+
+ p1 = t1(c_int(1))
+ p2 = pointer(c_int(1))
+
+ self.assertIsInstance(p1, t1)
+ self.assertIsInstance(p2, t1)
+
+ self.assertIs(type(p1), t1)
+ self.assertIs(type(p2), t1)
+
+ def test_incomplete_pointer_types_still_equal(self):
+ with self.assertWarns(DeprecationWarning):
+ t1 = POINTER("LP_C")
+ with self.assertWarns(DeprecationWarning):
+ t2 = POINTER("LP_C")
+
+ self.assertIs(t1, t2)
+
+ def test_incomplete_pointer_types_cannot_instantiate(self):
+ with self.assertWarns(DeprecationWarning):
+ t1 = POINTER("LP_C")
+ with self.assertRaisesRegex(TypeError, "has no _type_"):
+ t1()
+
+ def test_pointer_set_type_twice(self):
+ t1 = POINTER(c_int)
+ self.assertIs(c_int.__pointer_type__, t1)
+ self.assertIs(t1._type_, c_int)
+
+ t1.set_type(c_int)
+ self.assertIs(c_int.__pointer_type__, t1)
+ self.assertIs(t1._type_, c_int)
+
+ def test_pointer_set_wrong_type(self):
+ int_ptr = POINTER(c_int)
+ float_ptr = POINTER(c_float)
+ try:
+ class C(c_int):
+ pass
+
+ t1 = POINTER(c_int)
+ t2 = POINTER(c_float)
+ t1.set_type(c_float)
+ self.assertEqual(t1(c_float(1.5))[0], 1.5)
+ self.assertIs(t1._type_, c_float)
+ self.assertIs(c_int.__pointer_type__, t1)
+ self.assertIs(c_float.__pointer_type__, float_ptr)
+
+ t1.set_type(C)
+ self.assertEqual(t1(C(123))[0].value, 123)
+ self.assertIs(c_int.__pointer_type__, t1)
+ self.assertIs(c_float.__pointer_type__, float_ptr)
+ finally:
+ POINTER(c_int).set_type(c_int)
+ self.assertIs(POINTER(c_int), int_ptr)
+ self.assertIs(POINTER(c_int)._type_, c_int)
+ self.assertIs(c_int.__pointer_type__, int_ptr)
+
+ def test_pointer_not_ctypes_type(self):
+ with self.assertRaisesRegex(TypeError, "must have storage info"):
+ POINTER(int)
+
+ with self.assertRaisesRegex(TypeError, "must have storage info"):
+ pointer(int)
+
+ with self.assertRaisesRegex(TypeError, "must have storage info"):
+ pointer(int(1))
+
+ def test_pointer_set_python_type(self):
+ p1 = POINTER(c_int)
+ with self.assertRaisesRegex(TypeError, "must have storage info"):
+ p1.set_type(int)
+
+ def test_pointer_type_attribute_is_none(self):
+ class Cls(Structure):
+ _fields_ = (
+ ('a', c_int),
+ ('b', c_float),
+ )
+
+ with self.assertRaisesRegex(AttributeError, ".Cls'> has no attribute '__pointer_type__'"):
+ Cls.__pointer_type__
+
+ p = POINTER(Cls)
+ self.assertIs(Cls.__pointer_type__, p)
+
+ def test_arbitrary_pointer_type_attribute(self):
+ class Cls(Structure):
+ _fields_ = (
+ ('a', c_int),
+ ('b', c_float),
+ )
+
+ garbage = 'garbage'
+
+ P = POINTER(Cls)
+ self.assertIs(Cls.__pointer_type__, P)
+ Cls.__pointer_type__ = garbage
+ self.assertIs(Cls.__pointer_type__, garbage)
+ self.assertIs(POINTER(Cls), garbage)
+ self.assertIs(P._type_, Cls)
+
+ instance = Cls(1, 2.0)
+ pointer = P(instance)
+ self.assertEqual(pointer[0].a, 1)
+ self.assertEqual(pointer[0].b, 2)
+
+ del Cls.__pointer_type__
+
+ NewP = POINTER(Cls)
+ self.assertIsNot(NewP, P)
+ self.assertIs(Cls.__pointer_type__, NewP)
+ self.assertIs(P._type_, Cls)
+
+ def test_pointer_types_factory(self):
+ """Shouldn't leak"""
+ def factory():
+ class Cls(Structure):
+ _fields_ = (
+ ('a', c_int),
+ ('b', c_float),
+ )
+
+ return Cls
+
+ ws_typ = WeakSet()
+ ws_ptr = WeakSet()
+ for _ in range(10):
+ typ = factory()
+ ptr = POINTER(typ)
+
+ ws_typ.add(typ)
+ ws_ptr.add(ptr)
+
+ typ = None
+ ptr = None
+
+ gc.collect()
+
+ self.assertEqual(len(ws_typ), 0, ws_typ)
+ self.assertEqual(len(ws_ptr), 0, ws_ptr)
+
+
+class PointerTypeCacheTestCase(unittest.TestCase):
+ # dummy tests to check warnings and base behavior
+ def tearDown(self):
+ _pointer_type_cache_fallback.clear()
+
+ def test_deprecated_cache_with_not_ctypes_type(self):
+ class C:
+ pass
+
+ with self.assertWarns(DeprecationWarning):
+ P = POINTER("C")
+
+ with self.assertWarns(DeprecationWarning):
+ self.assertIs(_pointer_type_cache["C"], P)
+
+ with self.assertWarns(DeprecationWarning):
+ _pointer_type_cache[C] = P
+ self.assertIs(C.__pointer_type__, P)
+ with self.assertWarns(DeprecationWarning):
+ self.assertIs(_pointer_type_cache[C], P)
+
+ def test_deprecated_cache_with_ints(self):
+ with self.assertWarns(DeprecationWarning):
+ _pointer_type_cache[123] = 456
+
+ with self.assertWarns(DeprecationWarning):
+ self.assertEqual(_pointer_type_cache[123], 456)
+
+ def test_deprecated_cache_with_ctypes_type(self):
+ class C(Structure):
+ _fields_ = [("a", c_int),
+ ("b", c_int),
+ ("c", c_int)]
+
+ P1 = POINTER(C)
+ with self.assertWarns(DeprecationWarning):
+ P2 = POINTER("C")
+
+ with self.assertWarns(DeprecationWarning):
+ _pointer_type_cache[C] = P2
+
+ self.assertIs(C.__pointer_type__, P2)
+ self.assertIsNot(C.__pointer_type__, P1)
+
+ with self.assertWarns(DeprecationWarning):
+ self.assertIs(_pointer_type_cache[C], P2)
+
+ with self.assertWarns(DeprecationWarning):
+ self.assertIs(_pointer_type_cache.get(C), P2)
+
+ def test_get_not_registered(self):
+ with self.assertWarns(DeprecationWarning):
+ self.assertIsNone(_pointer_type_cache.get(str))
+
+ with self.assertWarns(DeprecationWarning):
+ self.assertIsNone(_pointer_type_cache.get(str, None))
+
+ def test_repeated_set_type(self):
+ # Regression test for gh-133290
+ class C(Structure):
+ _fields_ = [('a', c_int)]
+ ptr = POINTER(C)
+ # Read _type_ several times to warm up cache
+ for i in range(5):
+ self.assertIs(ptr._type_, C)
+ ptr.set_type(c_int)
+ self.assertIs(ptr._type_, c_int)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_ctypes/test_structunion.py b/Lib/test/test_ctypes/test_structunion.py
index 8d8b7e5e995..5b21d48d99c 100644
--- a/Lib/test/test_ctypes/test_structunion.py
+++ b/Lib/test/test_ctypes/test_structunion.py
@@ -11,6 +11,8 @@ from ._support import (_CData, PyCStructType, UnionType,
Py_TPFLAGS_DISALLOW_INSTANTIATION,
Py_TPFLAGS_IMMUTABLETYPE)
from struct import calcsize
+import contextlib
+from test.support import MS_WINDOWS
class StructUnionTestBase:
@@ -335,6 +337,22 @@ class StructUnionTestBase:
self.assertIn("from_address", dir(type(self.cls)))
self.assertIn("in_dll", dir(type(self.cls)))
+ def test_pack_layout_switch(self):
+ # Setting _pack_ implicitly sets default layout to MSVC;
+ # this is deprecated on non-Windows platforms.
+ if MS_WINDOWS:
+ warn_context = contextlib.nullcontext()
+ else:
+ warn_context = self.assertWarns(DeprecationWarning)
+ with warn_context:
+ class X(self.cls):
+ _pack_ = 1
+ # _layout_ missing
+ _fields_ = [('a', c_int8, 1), ('b', c_int16, 2)]
+
+ # Check MSVC layout (bitfields of different types aren't combined)
+ self.check_sizeof(X, struct_size=3, union_size=2)
+
class StructureTestCase(unittest.TestCase, StructUnionTestBase):
cls = Structure
diff --git a/Lib/test/test_ctypes/test_structures.py b/Lib/test/test_ctypes/test_structures.py
index bd7aba6376d..92d4851d739 100644
--- a/Lib/test/test_ctypes/test_structures.py
+++ b/Lib/test/test_ctypes/test_structures.py
@@ -25,6 +25,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 1
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 9)
@@ -34,6 +35,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 2
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 10)
self.assertEqual(X.b.offset, 2)
@@ -45,6 +47,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 4
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(4, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(4, longlong_align))
@@ -53,27 +56,33 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 8
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(8, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(8, longlong_align))
-
- d = {"_fields_": [("a", "b"),
- ("b", "q")],
- "_pack_": -1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", "b"), ("b", "q")]
+ _pack_ = -1
+ _layout_ = "ms"
@support.cpython_only
def test_packed_c_limits(self):
# Issue 15989
import _testcapi
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.INT_MAX + 1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.UINT_MAX + 2}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.INT_MAX + 1
+ _layout_ = "ms"
+
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.UINT_MAX + 2
+ _layout_ = "ms"
def test_initializers(self):
class Person(Structure):
@@ -685,6 +694,30 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes '
'a union by value, which is unsupported.')
+ def test_do_not_share_pointer_type_cache_via_stginfo_clone(self):
+ # This test case calls PyCStgInfo_clone()
+ # for the Mid and Vector class definitions
+ # and checks that pointer_type cache not shared
+ # between subclasses.
+ class Base(Structure):
+ _fields_ = [('y', c_double),
+ ('x', c_double)]
+ base_ptr = POINTER(Base)
+
+ class Mid(Base):
+ pass
+ Mid._fields_ = []
+ mid_ptr = POINTER(Mid)
+
+ class Vector(Mid):
+ pass
+
+ vector_ptr = POINTER(Vector)
+
+ self.assertIsNot(base_ptr, mid_ptr)
+ self.assertIsNot(base_ptr, vector_ptr)
+ self.assertIsNot(mid_ptr, vector_ptr)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_ctypes/test_unaligned_structures.py b/Lib/test/test_ctypes/test_unaligned_structures.py
index 58a00597ef5..b5fb4c0df77 100644
--- a/Lib/test/test_ctypes/test_unaligned_structures.py
+++ b/Lib/test/test_ctypes/test_unaligned_structures.py
@@ -19,10 +19,12 @@ for typ in [c_short, c_int, c_long, c_longlong,
c_ushort, c_uint, c_ulong, c_ulonglong]:
class X(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
class Y(SwappedStructure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
structures.append(X)
diff --git a/Lib/test/test_ctypes/test_values.py b/Lib/test/test_ctypes/test_values.py
index 1e209797606..8d1ee25ace5 100644
--- a/Lib/test/test_ctypes/test_values.py
+++ b/Lib/test/test_ctypes/test_values.py
@@ -7,7 +7,6 @@ import importlib.util
import sys
import unittest
from ctypes import (Structure, CDLL, POINTER, pythonapi,
- _pointer_type_cache,
c_ubyte, c_char_p, c_int)
from test.support import import_helper, thread_unsafe
@@ -98,8 +97,6 @@ class PythonValuesTestCase(unittest.TestCase):
"_PyImport_FrozenBootstrap example "
"in Doc/library/ctypes.rst may be out of date")
- del _pointer_type_cache[struct_frozen]
-
def test_undefined(self):
self.assertRaises(ValueError, c_int.in_dll, pythonapi,
"Undefined_Symbol")
diff --git a/Lib/test/test_ctypes/test_win32.py b/Lib/test/test_ctypes/test_win32.py
index 54b47dc28fb..7d513322190 100644
--- a/Lib/test/test_ctypes/test_win32.py
+++ b/Lib/test/test_ctypes/test_win32.py
@@ -5,7 +5,6 @@ import errno
import sys
import unittest
from ctypes import (CDLL, Structure, POINTER, pointer, sizeof, byref,
- _pointer_type_cache,
c_void_p, c_char, c_int, c_long)
from test import support
from test.support import import_helper
@@ -145,8 +144,8 @@ class Structures(unittest.TestCase):
self.assertEqual(ret.top, top.value)
self.assertEqual(ret.bottom, bottom.value)
- # to not leak references, we must clean _pointer_type_cache
- del _pointer_type_cache[RECT]
+ self.assertIs(PointInRect.argtypes[0], ReturnRect.argtypes[2])
+ self.assertIs(PointInRect.argtypes[0], ReturnRect.argtypes[5])
if __name__ == '__main__':
diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py
index 6fe0e7fd4b7..d5ca7f2ca1a 100644
--- a/Lib/test/test_curses.py
+++ b/Lib/test/test_curses.py
@@ -8,7 +8,8 @@ import unittest
from unittest.mock import MagicMock
from test.support import (requires, verbose, SaveSignals, cpython_only,
- check_disallow_instantiation, MISSING_C_DOCSTRINGS)
+ check_disallow_instantiation, MISSING_C_DOCSTRINGS,
+ gc_collect)
from test.support.import_helper import import_module
# Optionally test curses module. This currently requires that the
@@ -51,12 +52,6 @@ def requires_colors(test):
term = os.environ.get('TERM')
SHORT_MAX = 0x7fff
-DEFAULT_PAIR_CONTENTS = [
- (curses.COLOR_WHITE, curses.COLOR_BLACK),
- (0, 0),
- (-1, -1),
- (15, 0), # for xterm-256color (15 is for BRIGHT WHITE)
-]
# If newterm was supported we could use it instead of initscr and not exit
@unittest.skipIf(not term or term == 'unknown',
@@ -135,6 +130,9 @@ class TestCurses(unittest.TestCase):
curses.use_env(False)
curses.use_env(True)
+ def test_error(self):
+ self.assertIsSubclass(curses.error, Exception)
+
def test_create_windows(self):
win = curses.newwin(5, 10)
self.assertEqual(win.getbegyx(), (0, 0))
@@ -187,6 +185,14 @@ class TestCurses(unittest.TestCase):
self.assertEqual(win3.getparyx(), (2, 1))
self.assertEqual(win3.getmaxyx(), (6, 11))
+ def test_subwindows_references(self):
+ win = curses.newwin(5, 10)
+ win2 = win.subwin(3, 7)
+ del win
+ gc_collect()
+ del win2
+ gc_collect()
+
def test_move_cursor(self):
stdscr = self.stdscr
win = stdscr.subwin(10, 15, 2, 5)
@@ -948,8 +954,6 @@ class TestCurses(unittest.TestCase):
@requires_colors
def test_pair_content(self):
- if not hasattr(curses, 'use_default_colors'):
- self.assertIn(curses.pair_content(0), DEFAULT_PAIR_CONTENTS)
curses.pair_content(0)
maxpair = self.get_pair_limit() - 1
if maxpair > 0:
@@ -994,13 +998,27 @@ class TestCurses(unittest.TestCase):
@requires_curses_func('use_default_colors')
@requires_colors
def test_use_default_colors(self):
- old = curses.pair_content(0)
try:
curses.use_default_colors()
except curses.error:
self.skipTest('cannot change color (use_default_colors() failed)')
self.assertEqual(curses.pair_content(0), (-1, -1))
- self.assertIn(old, DEFAULT_PAIR_CONTENTS)
+
+ @requires_curses_func('assume_default_colors')
+ @requires_colors
+ def test_assume_default_colors(self):
+ try:
+ curses.assume_default_colors(-1, -1)
+ except curses.error:
+ self.skipTest('cannot change color (assume_default_colors() failed)')
+ self.assertEqual(curses.pair_content(0), (-1, -1))
+ curses.assume_default_colors(curses.COLOR_YELLOW, curses.COLOR_BLUE)
+ self.assertEqual(curses.pair_content(0), (curses.COLOR_YELLOW, curses.COLOR_BLUE))
+ curses.assume_default_colors(curses.COLOR_RED, -1)
+ self.assertEqual(curses.pair_content(0), (curses.COLOR_RED, -1))
+ curses.assume_default_colors(-1, curses.COLOR_GREEN)
+ self.assertEqual(curses.pair_content(0), (-1, curses.COLOR_GREEN))
+ curses.assume_default_colors(-1, -1)
def test_keyname(self):
# TODO: key_name()
@@ -1242,7 +1260,7 @@ class TestAscii(unittest.TestCase):
def test_controlnames(self):
for name in curses.ascii.controlnames:
- self.assertTrue(hasattr(curses.ascii, name), name)
+ self.assertHasAttr(curses.ascii, name)
def test_ctypes(self):
def check(func, expected):
diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py
index 99fefb57fd0..e98a8f284ce 100644
--- a/Lib/test/test_dataclasses/__init__.py
+++ b/Lib/test/test_dataclasses/__init__.py
@@ -5,6 +5,7 @@
from dataclasses import *
import abc
+import annotationlib
import io
import pickle
import inspect
@@ -12,6 +13,7 @@ import builtins
import types
import weakref
import traceback
+import sys
import textwrap
import unittest
from unittest.mock import Mock
@@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
from test import support
+from test.support import import_helper
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -117,7 +120,7 @@ class TestCase(unittest.TestCase):
for param in inspect.signature(dataclass).parameters:
if param == 'cls':
continue
- self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param)
+ self.assertHasAttr(Some.__dataclass_params__, param)
def test_named_init_params(self):
@dataclass
@@ -668,7 +671,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(the_fields[0].name, 'x')
self.assertEqual(the_fields[0].type, int)
- self.assertFalse(hasattr(C, 'x'))
+ self.assertNotHasAttr(C, 'x')
self.assertTrue (the_fields[0].init)
self.assertTrue (the_fields[0].repr)
self.assertEqual(the_fields[1].name, 'y')
@@ -678,7 +681,7 @@ class TestCase(unittest.TestCase):
self.assertTrue (the_fields[1].repr)
self.assertEqual(the_fields[2].name, 'z')
self.assertEqual(the_fields[2].type, str)
- self.assertFalse(hasattr(C, 'z'))
+ self.assertNotHasAttr(C, 'z')
self.assertTrue (the_fields[2].init)
self.assertFalse(the_fields[2].repr)
@@ -729,8 +732,8 @@ class TestCase(unittest.TestCase):
z: object = default
t: int = field(default=100)
- self.assertFalse(hasattr(C, 'x'))
- self.assertFalse(hasattr(C, 'y'))
+ self.assertNotHasAttr(C, 'x')
+ self.assertNotHasAttr(C, 'y')
self.assertIs (C.z, default)
self.assertEqual(C.t, 100)
@@ -2909,10 +2912,10 @@ class TestFrozen(unittest.TestCase):
pass
c = C()
- self.assertFalse(hasattr(c, 'i'))
+ self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
c.i = 5
- self.assertFalse(hasattr(c, 'i'))
+ self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
del c.i
@@ -3141,7 +3144,7 @@ class TestFrozen(unittest.TestCase):
del s.y
self.assertEqual(s.y, 10)
del s.cached
- self.assertFalse(hasattr(s, 'cached'))
+ self.assertNotHasAttr(s, 'cached')
with self.assertRaises(AttributeError) as cm:
del s.cached
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3155,12 +3158,12 @@ class TestFrozen(unittest.TestCase):
pass
s = S()
- self.assertFalse(hasattr(s, 'x'))
+ self.assertNotHasAttr(s, 'x')
s.x = 5
self.assertEqual(s.x, 5)
del s.x
- self.assertFalse(hasattr(s, 'x'))
+ self.assertNotHasAttr(s, 'x')
with self.assertRaises(AttributeError) as cm:
del s.x
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3390,8 +3393,8 @@ class TestSlots(unittest.TestCase):
B = dataclass(A, slots=True)
self.assertIsNot(A, B)
- self.assertFalse(hasattr(A, "__slots__"))
- self.assertTrue(hasattr(B, "__slots__"))
+ self.assertNotHasAttr(A, "__slots__")
+ self.assertHasAttr(B, "__slots__")
# Can't be local to test_frozen_pickle.
@dataclass(frozen=True, slots=True)
@@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase):
@support.cpython_only
def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935
- from test.support import import_helper
# Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi')
@@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
- 'y': 'typing.Any',
- 'z': 'typing.Any'})
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
+ 'y': typing.Any,
+ 'z': typing.Any})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': int,
- 'z': 'typing.Any'})
+ 'z': typing.Any})
+
+ def test_no_types_get_annotations(self):
+ C = make_dataclass('C', ['x', ('y', int), 'z'])
+
+ self.assertEqual(
+ annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.FORWARDREF),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.STRING),
+ {'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
+ )
+
+ def test_no_types_no_typing_import(self):
+ with import_helper.CleanImport('typing'):
+ self.assertNotIn('typing', sys.modules)
+ C = make_dataclass('C', ['x', ('y', int)])
+
+ self.assertNotIn('typing', sys.modules)
+ self.assertEqual(
+ C.__annotate__(annotationlib.Format.FORWARDREF),
+ {
+ 'x': annotationlib.ForwardRef('Any', module='typing'),
+ 'y': int,
+ },
+ )
+ self.assertNotIn('typing', sys.modules)
+
+ for field in fields(C):
+ if field.name == "x":
+ self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
+ else:
+ self.assertEqual(field.name, "y")
+ self.assertIs(field.type, int)
def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py
index 4be7c5649da..7e8d78b8940 100644
--- a/Lib/test/test_dbm.py
+++ b/Lib/test/test_dbm.py
@@ -66,7 +66,7 @@ class AnyDBMTestCase:
return keys
def test_error(self):
- self.assertTrue(issubclass(self.module.error, OSError))
+ self.assertIsSubclass(self.module.error, OSError)
def test_anydbm_not_existing(self):
self.assertRaises(dbm.error, dbm.open, _fname)
@@ -135,6 +135,67 @@ class AnyDBMTestCase:
assert(f[key] == b"Python:")
f.close()
+ def test_anydbm_readonly_reorganize(self):
+ self.init_db()
+ with dbm.open(_fname, 'r') as d:
+ # Early stopping.
+ if not hasattr(d, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ self.assertRaises(dbm.error, lambda: d.reorganize())
+
+ def test_anydbm_reorganize_not_changed_content(self):
+ self.init_db()
+ with dbm.open(_fname, 'c') as d:
+ # Early stopping.
+ if not hasattr(d, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ keys_before = sorted(d.keys())
+ values_before = [d[k] for k in keys_before]
+ d.reorganize()
+ keys_after = sorted(d.keys())
+ values_after = [d[k] for k in keys_before]
+ self.assertEqual(keys_before, keys_after)
+ self.assertEqual(values_before, values_after)
+
+ def test_anydbm_reorganize_decreased_size(self):
+
+ def _calculate_db_size(db_path):
+ if os.path.isfile(db_path):
+ return os.path.getsize(db_path)
+ total_size = 0
+ for root, _, filenames in os.walk(db_path):
+ for filename in filenames:
+ file_path = os.path.join(root, filename)
+ total_size += os.path.getsize(file_path)
+ return total_size
+
+ # This test requires relatively large databases to reliably show difference in size before and after reorganizing.
+ with dbm.open(_fname, 'n') as f:
+ # Early stopping.
+ if not hasattr(f, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ for k in self._dict:
+ f[k.encode('ascii')] = self._dict[k] * 100000
+ db_keys = list(f.keys())
+
+ # Make sure to calculate size of database only after file is closed to ensure file content are flushed to disk.
+ size_before = _calculate_db_size(os.path.dirname(_fname))
+
+ # Delete some elements from the start of the database.
+ keys_to_delete = db_keys[:len(db_keys) // 2]
+ with dbm.open(_fname, 'c') as f:
+ for k in keys_to_delete:
+ del f[k]
+ f.reorganize()
+
+ # Make sure to calculate size of database only after file is closed to ensure file content are flushed to disk.
+ size_after = _calculate_db_size(os.path.dirname(_fname))
+
+ self.assertLess(size_after, size_before)
+
def test_open_with_bytes(self):
dbm.open(os.fsencode(_fname), "c").close()
diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py
index 66268c42a30..e0b988b7b95 100644
--- a/Lib/test/test_dbm_gnu.py
+++ b/Lib/test/test_dbm_gnu.py
@@ -74,12 +74,12 @@ class TestGdbm(unittest.TestCase):
# Test the flag parameter open() by trying all supported flag modes.
all = set(gdbm.open_flags)
# Test standard flags (presumably "crwn").
- modes = all - set('fsu')
+ modes = all - set('fsum')
for mode in sorted(modes): # put "c" mode first
self.g = gdbm.open(filename, mode)
self.g.close()
- # Test additional flags (presumably "fsu").
+ # Test additional flags (presumably "fsum").
flags = all - set('crwn')
for mode in modes:
for flag in flags:
@@ -217,6 +217,29 @@ class TestGdbm(unittest.TestCase):
create_empty_file(os.path.join(d, 'test'))
self.assertRaises(gdbm.error, gdbm.open, filename, 'r')
+ @unittest.skipUnless('m' in gdbm.open_flags, "requires 'm' in open_flags")
+ def test_nommap_no_crash(self):
+ self.g = g = gdbm.open(filename, 'nm')
+ os.truncate(filename, 0)
+
+ g.get(b'a', b'c')
+ g.keys()
+ g.firstkey()
+ g.nextkey(b'a')
+ with self.assertRaises(KeyError):
+ g[b'a']
+ with self.assertRaises(gdbm.error):
+ len(g)
+
+ with self.assertRaises(gdbm.error):
+ g[b'a'] = b'c'
+ with self.assertRaises(gdbm.error):
+ del g[b'a']
+ with self.assertRaises(gdbm.error):
+ g.setdefault(b'a', b'c')
+ with self.assertRaises(gdbm.error):
+ g.reorganize()
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_dbm_sqlite3.py b/Lib/test/test_dbm_sqlite3.py
index 2e1f2d32924..9216da8a63f 100644
--- a/Lib/test/test_dbm_sqlite3.py
+++ b/Lib/test/test_dbm_sqlite3.py
@@ -36,7 +36,7 @@ class URI(unittest.TestCase):
)
for path, normalized in dataset:
with self.subTest(path=path, normalized=normalized):
- self.assertTrue(_normalize_uri(path).endswith(normalized))
+ self.assertEndsWith(_normalize_uri(path), normalized)
@unittest.skipUnless(sys.platform == "win32", "requires Windows")
def test_uri_windows(self):
@@ -55,7 +55,7 @@ class URI(unittest.TestCase):
with self.subTest(path=path, normalized=normalized):
if not Path(path).is_absolute():
self.skipTest(f"skipping relative path: {path!r}")
- self.assertTrue(_normalize_uri(path).endswith(normalized))
+ self.assertEndsWith(_normalize_uri(path), normalized)
class ReadOnly(_SQLiteDbmTests):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 9e298401dc3..ef64b878805 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -28,7 +28,6 @@ import logging
import math
import os, sys
import operator
-import warnings
import pickle, copy
import unittest
import numbers
@@ -982,6 +981,7 @@ class FormatTest:
('.0f', '0e-2', '0'),
('.0f', '3.14159265', '3'),
('.1f', '3.14159265', '3.1'),
+ ('.01f', '3.14159265', '3.1'), # leading zero in precision
('.4f', '3.14159265', '3.1416'),
('.6f', '3.14159265', '3.141593'),
('.7f', '3.14159265', '3.1415926'), # round-half-even!
@@ -1067,6 +1067,7 @@ class FormatTest:
('8,', '123456', ' 123,456'),
('08,', '123456', '0,123,456'), # special case: extra 0 needed
('+08,', '123456', '+123,456'), # but not if there's a sign
+ ('008,', '123456', '0,123,456'), # leading zero in width
(' 08,', '123456', ' 123,456'),
('08,', '-123456', '-123,456'),
('+09,', '123456', '+0,123,456'),
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
index 4679f297fd7..4e1a489205a 100644
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -838,7 +838,7 @@ class TestSubclass(unittest.TestCase):
self.assertEqual(list(d), list(e))
self.assertEqual(e.x, d.x)
self.assertEqual(e.z, d.z)
- self.assertFalse(hasattr(e, 'y'))
+ self.assertNotHasAttr(e, 'y')
def test_pickle_recursive(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index 8e9d44a583c..f6ec2cf5ce8 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -409,7 +409,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
def test_python_dicts(self):
# Testing Python subclass of dict...
- self.assertTrue(issubclass(dict, dict))
+ self.assertIsSubclass(dict, dict)
self.assertIsInstance({}, dict)
d = dict()
self.assertEqual(d, {})
@@ -433,7 +433,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
self.state = state
def getstate(self):
return self.state
- self.assertTrue(issubclass(C, dict))
+ self.assertIsSubclass(C, dict)
a1 = C(12)
self.assertEqual(a1.state, 12)
a2 = C(foo=1, bar=2)
@@ -1048,15 +1048,15 @@ class ClassPropertiesAndMethods(unittest.TestCase):
m = types.ModuleType("m")
self.assertTrue(m.__class__ is types.ModuleType)
- self.assertFalse(hasattr(m, "a"))
+ self.assertNotHasAttr(m, "a")
m.__class__ = SubType
self.assertTrue(m.__class__ is SubType)
- self.assertTrue(hasattr(m, "a"))
+ self.assertHasAttr(m, "a")
m.__class__ = types.ModuleType
self.assertTrue(m.__class__ is types.ModuleType)
- self.assertFalse(hasattr(m, "a"))
+ self.assertNotHasAttr(m, "a")
# Make sure that builtin immutable objects don't support __class__
# assignment, because the object instances may be interned.
@@ -1589,7 +1589,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
cm = classmethod(f)
cm_dict = {'__doc__': (
"f docstring"
- if support.HAVE_DOCSTRINGS
+ if support.HAVE_PY_DOCSTRINGS
else None
),
'__module__': __name__,
@@ -1780,7 +1780,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class E: # *not* subclassing from C
foo = C.foo
self.assertEqual(E().foo.__func__, C.foo) # i.e., unbound
- self.assertTrue(repr(C.foo.__get__(C())).startswith("<bound method "))
+ self.assertStartsWith(repr(C.foo.__get__(C())), "<bound method ")
def test_compattr(self):
# Testing computed attributes...
@@ -2058,7 +2058,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class E(object):
foo = C.foo
self.assertEqual(E().foo.__func__, C.foo) # i.e., unbound
- self.assertTrue(repr(C.foo.__get__(C(1))).startswith("<bound method "))
+ self.assertStartsWith(repr(C.foo.__get__(C(1))), "<bound method ")
@support.impl_detail("testing error message from implementation")
def test_methods_in_c(self):
@@ -3943,6 +3943,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
del C.__del__
@unittest.skipIf(support.is_emscripten, "Seems to works in Pyodide?")
+ @support.skip_wasi_stack_overflow()
def test_slots_trash(self):
# Testing slot trash...
# Deallocating deeply nested slotted trash caused stack overflows
@@ -4113,6 +4114,34 @@ class ClassPropertiesAndMethods(unittest.TestCase):
else:
self.fail("shouldn't be able to create inheritance cycles")
+ def test_assign_bases_many_subclasses(self):
+ # This is intended to check that typeobject.c:queue_slot_update() can
+ # handle updating many subclasses when a slot method is re-assigned.
+ class A:
+ x = 'hello'
+ def __call__(self):
+ return 123
+ def __getitem__(self, index):
+ return None
+
+ class X:
+ x = 'bye'
+
+ class B(A):
+ pass
+
+ subclasses = []
+ for i in range(1000):
+ sc = type(f'Sub{i}', (B,), {})
+ subclasses.append(sc)
+
+ self.assertEqual(subclasses[0]()(), 123)
+ self.assertEqual(subclasses[0]().x, 'hello')
+ B.__bases__ = (X,)
+ with self.assertRaises(TypeError):
+ subclasses[0]()()
+ self.assertEqual(subclasses[0]().x, 'bye')
+
def test_builtin_bases(self):
# Make sure all the builtin types can have their base queried without
# segfaulting. See issue #5787.
@@ -4523,6 +4552,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
del o
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@support.requires_resource('cpu')
def test_wrapper_segfault(self):
# SF 927248: deeply nested wrappers could cause stack overflow
@@ -4867,6 +4897,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
deque.append(thing, thing)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_repr_as_str(self):
# Issue #11603: crash or infinite loop when rebinding __str__ as
# __repr__.
@@ -5194,8 +5225,8 @@ class DictProxyTests(unittest.TestCase):
# We can't blindly compare with the repr of another dict as ordering
# of keys and values is arbitrary and may differ.
r = repr(self.C.__dict__)
- self.assertTrue(r.startswith('mappingproxy('), r)
- self.assertTrue(r.endswith(')'), r)
+ self.assertStartsWith(r, 'mappingproxy(')
+ self.assertEndsWith(r, ')')
for k, v in self.C.__dict__.items():
self.assertIn('{!r}: {!r}'.format(k, v), r)
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index 9485ef2889f..60c62430370 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -266,6 +266,63 @@ class DictTest(unittest.TestCase):
self.assertRaises(ValueError, {}.update, [(1, 2, 3)])
+ def test_update_type_error(self):
+ with self.assertRaises(TypeError) as cm:
+ {}.update([object() for _ in range(3)])
+
+ self.assertEqual(str(cm.exception), "object is not iterable")
+ self.assertEqual(
+ cm.exception.__notes__,
+ ['Cannot convert dictionary update sequence element #0 to a sequence'],
+ )
+
+ def badgen():
+ yield "key"
+ raise TypeError("oops")
+ yield "value"
+
+ with self.assertRaises(TypeError) as cm:
+ dict([badgen() for _ in range(3)])
+
+ self.assertEqual(str(cm.exception), "oops")
+ self.assertEqual(
+ cm.exception.__notes__,
+ ['Cannot convert dictionary update sequence element #0 to a sequence'],
+ )
+
+ def test_update_shared_keys(self):
+ class MyClass: pass
+
+ # Subclass str to enable us to create an object during the
+ # dict.update() call.
+ class MyStr(str):
+ def __hash__(self):
+ return super().__hash__()
+
+ def __eq__(self, other):
+ # Create an object that shares the same PyDictKeysObject as
+ # obj.__dict__.
+ obj2 = MyClass()
+ obj2.a = "a"
+ obj2.b = "b"
+ obj2.c = "c"
+ return super().__eq__(other)
+
+ obj = MyClass()
+ obj.a = "a"
+ obj.b = "b"
+
+ x = {}
+ x[MyStr("a")] = MyStr("a")
+
+ # gh-132617: this previously raised "dict mutated during update" error
+ x.update(obj.__dict__)
+
+ self.assertEqual(x, {
+ MyStr("a"): "a",
+ "b": "b",
+ })
+
def test_fromkeys(self):
self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
d = {}
@@ -313,17 +370,34 @@ class DictTest(unittest.TestCase):
self.assertRaises(Exc, baddict2.fromkeys, [1])
# test fast path for dictionary inputs
+ res = dict(zip(range(6), [0]*6))
d = dict(zip(range(6), range(6)))
- self.assertEqual(dict.fromkeys(d, 0), dict(zip(range(6), [0]*6)))
-
+ self.assertEqual(dict.fromkeys(d, 0), res)
+ # test fast path for set inputs
+ d = set(range(6))
+ self.assertEqual(dict.fromkeys(d, 0), res)
+ # test slow path for other iterable inputs
+ d = list(range(6))
+ self.assertEqual(dict.fromkeys(d, 0), res)
+
+ # test fast path when object's constructor returns large non-empty dict
class baddict3(dict):
def __new__(cls):
return d
- d = {i : i for i in range(10)}
+ d = {i : i for i in range(1000)}
res = d.copy()
res.update(a=None, b=None, c=None)
self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
+ # test slow path when object is a proper subclass of dict
+ class baddict4(dict):
+ def __init__(self):
+ dict.__init__(self, d)
+ d = {i : i for i in range(1000)}
+ res = d.copy()
+ res.update(a=None, b=None, c=None)
+ self.assertEqual(baddict4.fromkeys({"a", "b", "c"}), res)
+
def test_copy(self):
d = {1: 1, 2: 2, 3: 3}
self.assertIsNot(d.copy(), d)
@@ -745,8 +819,8 @@ class DictTest(unittest.TestCase):
def test_missing(self):
# Make sure dict doesn't have a __missing__ method
- self.assertFalse(hasattr(dict, "__missing__"))
- self.assertFalse(hasattr({}, "__missing__"))
+ self.assertNotHasAttr(dict, "__missing__")
+ self.assertNotHasAttr({}, "__missing__")
# Test several cases:
# (D) subclass defines __missing__ method returning a value
# (E) subclass defines __missing__ method raising RuntimeError
@@ -997,10 +1071,8 @@ class DictTest(unittest.TestCase):
a = C()
a.x = 1
d = a.__dict__
- before_resize = sys.getsizeof(d)
d[2] = 2 # split table is resized to a generic combined table
- self.assertGreater(sys.getsizeof(d), before_resize)
self.assertEqual(list(d), ['x', 2])
def test_iterator_pickling(self):
diff --git a/Lib/test/test_difflib.py b/Lib/test/test_difflib.py
index 9e217249be7..6ac584a08d1 100644
--- a/Lib/test/test_difflib.py
+++ b/Lib/test/test_difflib.py
@@ -255,21 +255,21 @@ class TestSFpatches(unittest.TestCase):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_from1.splitlines(),
patch914575_to1.splitlines())
- self.assertIn('content="text/html; charset=utf-8"', output)
+ self.assertIn('charset="utf-8"', output)
def test_make_file_iso88591_charset(self):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_from1.splitlines(),
patch914575_to1.splitlines(),
charset='iso-8859-1')
- self.assertIn('content="text/html; charset=iso-8859-1"', output)
+ self.assertIn('charset="iso-8859-1"', output)
def test_make_file_usascii_charset_with_nonascii_input(self):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_nonascii_from1.splitlines(),
patch914575_nonascii_to1.splitlines(),
charset='us-ascii')
- self.assertIn('content="text/html; charset=us-ascii"', output)
+ self.assertIn('charset="us-ascii"', output)
self.assertIn('&#305;mpl&#305;c&#305;t', output)
class TestDiffer(unittest.TestCase):
diff --git a/Lib/test/test_difflib_expect.html b/Lib/test/test_difflib_expect.html
index 9f33a9e9c9c..2346a6f9f8d 100644
--- a/Lib/test/test_difflib_expect.html
+++ b/Lib/test/test_difflib_expect.html
@@ -1,22 +1,42 @@
-<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
- "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
-
-<html>
-
+<!DOCTYPE html>
+<html lang="en">
<head>
- <meta http-equiv="Content-Type"
- content="text/html; charset=utf-8" />
- <title></title>
- <style type="text/css">
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ <title>Diff comparison</title>
+ <style>
:root {color-scheme: light dark}
- table.diff {font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace; border:medium}
- .diff_header {background-color:#e0e0e0}
- td.diff_header {text-align:right}
- .diff_next {background-color:#c0c0c0}
+ table.diff {
+ font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace;
+ border: medium;
+ }
+ .diff_header {
+ background-color: #e0e0e0;
+ font-weight: bold;
+ }
+ td.diff_header {
+ text-align: right;
+ padding: 0 8px;
+ }
+ .diff_next {
+ background-color: #c0c0c0;
+ padding: 4px 0;
+ }
.diff_add {background-color:palegreen}
.diff_chg {background-color:#ffff77}
.diff_sub {background-color:#ffaaaa}
+ table.diff[summary="Legends"] {
+ margin-top: 20px;
+ border: 1px solid #ccc;
+ }
+ table.diff[summary="Legends"] th {
+ background-color: #e0e0e0;
+ padding: 4px 8px;
+ }
+ table.diff[summary="Legends"] td {
+ padding: 4px 8px;
+ }
@media (prefers-color-scheme: dark) {
.diff_header {background-color:#666}
@@ -24,6 +44,8 @@
.diff_add {background-color:darkgreen}
.diff_chg {background-color:#847415}
.diff_sub {background-color:darkred}
+ table.diff[summary="Legends"] {border-color:#555}
+ table.diff[summary="Legends"] th{background-color:#666}
}
</style>
</head>
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index f2586fcee57..355990ed58e 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -606,7 +606,7 @@ dis_asyncwith = """\
POP_TOP
L1: RESUME 0
-%4d LOAD_FAST_BORROW 0 (c)
+%4d LOAD_FAST 0 (c)
COPY 1
LOAD_SPECIAL 3 (__aexit__)
SWAP 2
@@ -851,7 +851,7 @@ Disassembly of <code object <genexpr> at 0x..., file "%s", line %d>:
%4d RETURN_GENERATOR
POP_TOP
L1: RESUME 0
- LOAD_FAST_BORROW 0 (.0)
+ LOAD_FAST 0 (.0)
GET_ITER
L2: FOR_ITER 14 (to L3)
STORE_FAST 1 (z)
@@ -902,7 +902,7 @@ dis_loop_test_quickened_code = """\
%3d RESUME_CHECK 0
%3d BUILD_LIST 0
- LOAD_CONST_MORTAL 2 ((1, 2, 3))
+ LOAD_CONST 2 ((1, 2, 3))
LIST_EXTEND 1
LOAD_SMALL_INT 3
BINARY_OP 5 (*)
@@ -918,7 +918,7 @@ dis_loop_test_quickened_code = """\
%3d L2: END_FOR
POP_ITER
- LOAD_CONST_IMMORTAL 1 (None)
+ LOAD_CONST 1 (None)
RETURN_VALUE
""" % (loop_test.__code__.co_firstlineno,
loop_test.__code__.co_firstlineno + 1,
@@ -1304,7 +1304,7 @@ class DisTests(DisTestBase):
load_attr_quicken = """\
0 RESUME_CHECK 0
- 1 LOAD_CONST_IMMORTAL 0 ('a')
+ 1 LOAD_CONST 0 ('a')
LOAD_ATTR_SLOT 0 (__class__)
RETURN_VALUE
"""
@@ -1336,7 +1336,7 @@ class DisTests(DisTestBase):
# Loop can trigger a quicken where the loop is located
self.code_quicken(loop_test)
got = self.get_disassembly(loop_test, adaptive=True)
- jit = import_helper.import_module("_testinternalcapi").jit_enabled()
+ jit = sys._jit.is_enabled()
expected = dis_loop_test_quickened_code.format("JIT" if jit else "NO_JIT")
self.do_disassembly_compare(got, expected)
@@ -1821,7 +1821,7 @@ expected_opinfo_jumpy = [
make_inst(opname='LOAD_SMALL_INT', arg=10, argval=10, argrepr='', offset=12, start_offset=12, starts_line=False, line_number=3),
make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=14, start_offset=14, starts_line=False, line_number=3, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
make_inst(opname='GET_ITER', arg=None, argval=None, argrepr='', offset=22, start_offset=22, starts_line=False, line_number=3),
- make_inst(opname='FOR_ITER', arg=32, argval=92, argrepr='to L4', offset=24, start_offset=24, starts_line=False, line_number=3, label=1, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='FOR_ITER', arg=33, argval=94, argrepr='to L4', offset=24, start_offset=24, starts_line=False, line_number=3, label=1, cache_info=[('counter', 1, b'\x00\x00')]),
make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=28, start_offset=28, starts_line=False, line_number=3),
make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=30, start_offset=30, starts_line=True, line_number=4, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=40, start_offset=40, starts_line=False, line_number=4),
@@ -1840,110 +1840,111 @@ expected_opinfo_jumpy = [
make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=82, start_offset=82, starts_line=False, line_number=7),
make_inst(opname='JUMP_BACKWARD', arg=32, argval=24, argrepr='to L1', offset=84, start_offset=84, starts_line=False, line_number=7, cache_info=[('counter', 1, b'\x00\x00')]),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=88, start_offset=88, starts_line=True, line_number=8, label=3),
- make_inst(opname='JUMP_FORWARD', arg=13, argval=118, argrepr='to L5', offset=90, start_offset=90, starts_line=False, line_number=8),
- make_inst(opname='END_FOR', arg=None, argval=None, argrepr='', offset=92, start_offset=92, starts_line=True, line_number=3, label=4),
- make_inst(opname='POP_ITER', arg=None, argval=None, argrepr='', offset=94, start_offset=94, starts_line=False, line_number=3),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=96, start_offset=96, starts_line=True, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=1, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=106, start_offset=106, starts_line=False, line_number=10),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=108, start_offset=108, starts_line=False, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=116, start_offset=116, starts_line=False, line_number=10),
- make_inst(opname='LOAD_FAST_CHECK', arg=0, argval='i', argrepr='i', offset=118, start_offset=118, starts_line=True, line_number=11, label=5),
- make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=120, start_offset=120, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=40, argval=212, argrepr='to L8', offset=128, start_offset=128, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=132, start_offset=132, starts_line=False, line_number=11),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=134, start_offset=134, starts_line=True, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=144, start_offset=144, starts_line=False, line_number=12),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=146, start_offset=146, starts_line=False, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=154, start_offset=154, starts_line=False, line_number=12),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=156, start_offset=156, starts_line=True, line_number=13),
- make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=158, start_offset=158, starts_line=False, line_number=13),
- make_inst(opname='BINARY_OP', arg=23, argval=23, argrepr='-=', offset=160, start_offset=160, starts_line=False, line_number=13, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
- make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=172, start_offset=172, starts_line=False, line_number=13),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=174, start_offset=174, starts_line=True, line_number=14),
- make_inst(opname='LOAD_SMALL_INT', arg=6, argval=6, argrepr='', offset=176, start_offset=176, starts_line=False, line_number=14),
- make_inst(opname='COMPARE_OP', arg=148, argval='>', argrepr='bool(>)', offset=178, start_offset=178, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=3, argval=192, argrepr='to L6', offset=182, start_offset=182, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=186, start_offset=186, starts_line=False, line_number=14),
- make_inst(opname='JUMP_BACKWARD', arg=37, argval=118, argrepr='to L5', offset=188, start_offset=188, starts_line=True, line_number=15, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=192, start_offset=192, starts_line=True, line_number=16, label=6),
- make_inst(opname='LOAD_SMALL_INT', arg=4, argval=4, argrepr='', offset=194, start_offset=194, starts_line=False, line_number=16),
- make_inst(opname='COMPARE_OP', arg=18, argval='<', argrepr='bool(<)', offset=196, start_offset=196, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_TRUE', arg=3, argval=210, argrepr='to L7', offset=200, start_offset=200, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=204, start_offset=204, starts_line=False, line_number=16),
- make_inst(opname='JUMP_BACKWARD', arg=46, argval=118, argrepr='to L5', offset=206, start_offset=206, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='JUMP_FORWARD', arg=11, argval=234, argrepr='to L9', offset=210, start_offset=210, starts_line=True, line_number=17, label=7),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=212, start_offset=212, starts_line=True, line_number=19, label=8, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=2, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=222, start_offset=222, starts_line=False, line_number=19),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=224, start_offset=224, starts_line=False, line_number=19, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=232, start_offset=232, starts_line=False, line_number=19),
- make_inst(opname='NOP', arg=None, argval=None, argrepr='', offset=234, start_offset=234, starts_line=True, line_number=20, label=9),
- make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=236, start_offset=236, starts_line=True, line_number=21),
- make_inst(opname='LOAD_SMALL_INT', arg=0, argval=0, argrepr='', offset=238, start_offset=238, starts_line=False, line_number=21),
- make_inst(opname='BINARY_OP', arg=11, argval=11, argrepr='/', offset=240, start_offset=240, starts_line=False, line_number=21, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=252, start_offset=252, starts_line=False, line_number=21),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=254, start_offset=254, starts_line=True, line_number=25),
- make_inst(opname='COPY', arg=1, argval=1, argrepr='', offset=256, start_offset=256, starts_line=False, line_number=25),
- make_inst(opname='LOAD_SPECIAL', arg=1, argval=1, argrepr='__exit__', offset=258, start_offset=258, starts_line=False, line_number=25),
- make_inst(opname='SWAP', arg=2, argval=2, argrepr='', offset=260, start_offset=260, starts_line=False, line_number=25),
- make_inst(opname='SWAP', arg=3, argval=3, argrepr='', offset=262, start_offset=262, starts_line=False, line_number=25),
- make_inst(opname='LOAD_SPECIAL', arg=0, argval=0, argrepr='__enter__', offset=264, start_offset=264, starts_line=False, line_number=25),
- make_inst(opname='CALL', arg=0, argval=0, argrepr='', offset=266, start_offset=266, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='STORE_FAST', arg=1, argval='dodgy', argrepr='dodgy', offset=274, start_offset=274, starts_line=False, line_number=25),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=276, start_offset=276, starts_line=True, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=3, argval='Never reach this', argrepr="'Never reach this'", offset=286, start_offset=286, starts_line=False, line_number=26),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=288, start_offset=288, starts_line=False, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=296, start_offset=296, starts_line=False, line_number=26),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=298, start_offset=298, starts_line=True, line_number=25),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=300, start_offset=300, starts_line=False, line_number=25),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=90, start_offset=90, starts_line=False, line_number=8),
+ make_inst(opname='JUMP_FORWARD', arg=13, argval=120, argrepr='to L5', offset=92, start_offset=92, starts_line=False, line_number=8),
+ make_inst(opname='END_FOR', arg=None, argval=None, argrepr='', offset=94, start_offset=94, starts_line=True, line_number=3, label=4),
+ make_inst(opname='POP_ITER', arg=None, argval=None, argrepr='', offset=96, start_offset=96, starts_line=False, line_number=3),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=98, start_offset=98, starts_line=True, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=1, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=108, start_offset=108, starts_line=False, line_number=10),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=110, start_offset=110, starts_line=False, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=118, start_offset=118, starts_line=False, line_number=10),
+ make_inst(opname='LOAD_FAST_CHECK', arg=0, argval='i', argrepr='i', offset=120, start_offset=120, starts_line=True, line_number=11, label=5),
+ make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=122, start_offset=122, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=40, argval=214, argrepr='to L8', offset=130, start_offset=130, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=134, start_offset=134, starts_line=False, line_number=11),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=136, start_offset=136, starts_line=True, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=146, start_offset=146, starts_line=False, line_number=12),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=148, start_offset=148, starts_line=False, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=156, start_offset=156, starts_line=False, line_number=12),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=158, start_offset=158, starts_line=True, line_number=13),
+ make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=160, start_offset=160, starts_line=False, line_number=13),
+ make_inst(opname='BINARY_OP', arg=23, argval=23, argrepr='-=', offset=162, start_offset=162, starts_line=False, line_number=13, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
+ make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=174, start_offset=174, starts_line=False, line_number=13),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=176, start_offset=176, starts_line=True, line_number=14),
+ make_inst(opname='LOAD_SMALL_INT', arg=6, argval=6, argrepr='', offset=178, start_offset=178, starts_line=False, line_number=14),
+ make_inst(opname='COMPARE_OP', arg=148, argval='>', argrepr='bool(>)', offset=180, start_offset=180, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=3, argval=194, argrepr='to L6', offset=184, start_offset=184, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=188, start_offset=188, starts_line=False, line_number=14),
+ make_inst(opname='JUMP_BACKWARD', arg=37, argval=120, argrepr='to L5', offset=190, start_offset=190, starts_line=True, line_number=15, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=194, start_offset=194, starts_line=True, line_number=16, label=6),
+ make_inst(opname='LOAD_SMALL_INT', arg=4, argval=4, argrepr='', offset=196, start_offset=196, starts_line=False, line_number=16),
+ make_inst(opname='COMPARE_OP', arg=18, argval='<', argrepr='bool(<)', offset=198, start_offset=198, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_TRUE', arg=3, argval=212, argrepr='to L7', offset=202, start_offset=202, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=206, start_offset=206, starts_line=False, line_number=16),
+ make_inst(opname='JUMP_BACKWARD', arg=46, argval=120, argrepr='to L5', offset=208, start_offset=208, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='JUMP_FORWARD', arg=11, argval=236, argrepr='to L9', offset=212, start_offset=212, starts_line=True, line_number=17, label=7),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=214, start_offset=214, starts_line=True, line_number=19, label=8, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=2, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=224, start_offset=224, starts_line=False, line_number=19),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=226, start_offset=226, starts_line=False, line_number=19, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=234, start_offset=234, starts_line=False, line_number=19),
+ make_inst(opname='NOP', arg=None, argval=None, argrepr='', offset=236, start_offset=236, starts_line=True, line_number=20, label=9),
+ make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=238, start_offset=238, starts_line=True, line_number=21),
+ make_inst(opname='LOAD_SMALL_INT', arg=0, argval=0, argrepr='', offset=240, start_offset=240, starts_line=False, line_number=21),
+ make_inst(opname='BINARY_OP', arg=11, argval=11, argrepr='/', offset=242, start_offset=242, starts_line=False, line_number=21, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=254, start_offset=254, starts_line=False, line_number=21),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=256, start_offset=256, starts_line=True, line_number=25),
+ make_inst(opname='COPY', arg=1, argval=1, argrepr='', offset=258, start_offset=258, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_SPECIAL', arg=1, argval=1, argrepr='__exit__', offset=260, start_offset=260, starts_line=False, line_number=25),
+ make_inst(opname='SWAP', arg=2, argval=2, argrepr='', offset=262, start_offset=262, starts_line=False, line_number=25),
+ make_inst(opname='SWAP', arg=3, argval=3, argrepr='', offset=264, start_offset=264, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_SPECIAL', arg=0, argval=0, argrepr='__enter__', offset=266, start_offset=266, starts_line=False, line_number=25),
+ make_inst(opname='CALL', arg=0, argval=0, argrepr='', offset=268, start_offset=268, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='STORE_FAST', arg=1, argval='dodgy', argrepr='dodgy', offset=276, start_offset=276, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=278, start_offset=278, starts_line=True, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=3, argval='Never reach this', argrepr="'Never reach this'", offset=288, start_offset=288, starts_line=False, line_number=26),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=290, start_offset=290, starts_line=False, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=298, start_offset=298, starts_line=False, line_number=26),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=300, start_offset=300, starts_line=True, line_number=25),
make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=302, start_offset=302, starts_line=False, line_number=25),
- make_inst(opname='CALL', arg=3, argval=3, argrepr='', offset=304, start_offset=304, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=312, start_offset=312, starts_line=False, line_number=25),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=314, start_offset=314, starts_line=True, line_number=28, label=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=324, start_offset=324, starts_line=False, line_number=28),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=326, start_offset=326, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=334, start_offset=334, starts_line=False, line_number=28),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=336, start_offset=336, starts_line=False, line_number=28),
- make_inst(opname='RETURN_VALUE', arg=None, argval=None, argrepr='', offset=338, start_offset=338, starts_line=False, line_number=28),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=340, start_offset=340, starts_line=True, line_number=25),
- make_inst(opname='WITH_EXCEPT_START', arg=None, argval=None, argrepr='', offset=342, start_offset=342, starts_line=False, line_number=25),
- make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=344, start_offset=344, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_TRUE', arg=2, argval=360, argrepr='to L11', offset=352, start_offset=352, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=356, start_offset=356, starts_line=False, line_number=25),
- make_inst(opname='RERAISE', arg=2, argval=2, argrepr='', offset=358, start_offset=358, starts_line=False, line_number=25),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=360, start_offset=360, starts_line=False, line_number=25, label=11),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=362, start_offset=362, starts_line=False, line_number=25),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=364, start_offset=364, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=304, start_offset=304, starts_line=False, line_number=25),
+ make_inst(opname='CALL', arg=3, argval=3, argrepr='', offset=306, start_offset=306, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=314, start_offset=314, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=316, start_offset=316, starts_line=True, line_number=28, label=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=326, start_offset=326, starts_line=False, line_number=28),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=328, start_offset=328, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=336, start_offset=336, starts_line=False, line_number=28),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=338, start_offset=338, starts_line=False, line_number=28),
+ make_inst(opname='RETURN_VALUE', arg=None, argval=None, argrepr='', offset=340, start_offset=340, starts_line=False, line_number=28),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=342, start_offset=342, starts_line=True, line_number=25),
+ make_inst(opname='WITH_EXCEPT_START', arg=None, argval=None, argrepr='', offset=344, start_offset=344, starts_line=False, line_number=25),
+ make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=346, start_offset=346, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_TRUE', arg=2, argval=362, argrepr='to L11', offset=354, start_offset=354, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=358, start_offset=358, starts_line=False, line_number=25),
+ make_inst(opname='RERAISE', arg=2, argval=2, argrepr='', offset=360, start_offset=360, starts_line=False, line_number=25),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=362, start_offset=362, starts_line=False, line_number=25, label=11),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=364, start_offset=364, starts_line=False, line_number=25),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=366, start_offset=366, starts_line=False, line_number=25),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=368, start_offset=368, starts_line=False, line_number=25),
- make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=29, argval=314, argrepr='to L10', offset=370, start_offset=370, starts_line=False, line_number=25),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=372, start_offset=372, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=374, start_offset=374, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=376, start_offset=376, starts_line=False, line_number=None),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=378, start_offset=378, starts_line=False, line_number=None),
- make_inst(opname='LOAD_GLOBAL', arg=4, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=380, start_offset=380, starts_line=True, line_number=22, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='CHECK_EXC_MATCH', arg=None, argval=None, argrepr='', offset=390, start_offset=390, starts_line=False, line_number=22),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=15, argval=426, argrepr='to L12', offset=392, start_offset=392, starts_line=False, line_number=22, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=396, start_offset=396, starts_line=False, line_number=22),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=398, start_offset=398, starts_line=False, line_number=22),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=400, start_offset=400, starts_line=True, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=5, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=410, start_offset=410, starts_line=False, line_number=23),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=412, start_offset=412, starts_line=False, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=420, start_offset=420, starts_line=False, line_number=23),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=422, start_offset=422, starts_line=False, line_number=23),
- make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=56, argval=314, argrepr='to L10', offset=424, start_offset=424, starts_line=False, line_number=23),
- make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=426, start_offset=426, starts_line=True, line_number=22, label=12),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=428, start_offset=428, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=430, start_offset=430, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=432, start_offset=432, starts_line=False, line_number=None),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=434, start_offset=434, starts_line=False, line_number=None),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=436, start_offset=436, starts_line=True, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=446, start_offset=446, starts_line=False, line_number=28),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=448, start_offset=448, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=456, start_offset=456, starts_line=False, line_number=28),
- make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=458, start_offset=458, starts_line=False, line_number=28),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=460, start_offset=460, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=462, start_offset=462, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=464, start_offset=464, starts_line=False, line_number=None),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=370, start_offset=370, starts_line=False, line_number=25),
+ make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=29, argval=316, argrepr='to L10', offset=372, start_offset=372, starts_line=False, line_number=25),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=374, start_offset=374, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=376, start_offset=376, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=378, start_offset=378, starts_line=False, line_number=None),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=380, start_offset=380, starts_line=False, line_number=None),
+ make_inst(opname='LOAD_GLOBAL', arg=4, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=382, start_offset=382, starts_line=True, line_number=22, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='CHECK_EXC_MATCH', arg=None, argval=None, argrepr='', offset=392, start_offset=392, starts_line=False, line_number=22),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=15, argval=428, argrepr='to L12', offset=394, start_offset=394, starts_line=False, line_number=22, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=398, start_offset=398, starts_line=False, line_number=22),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=400, start_offset=400, starts_line=False, line_number=22),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=402, start_offset=402, starts_line=True, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=5, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=412, start_offset=412, starts_line=False, line_number=23),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=414, start_offset=414, starts_line=False, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=422, start_offset=422, starts_line=False, line_number=23),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=424, start_offset=424, starts_line=False, line_number=23),
+ make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=56, argval=316, argrepr='to L10', offset=426, start_offset=426, starts_line=False, line_number=23),
+ make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=428, start_offset=428, starts_line=True, line_number=22, label=12),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=430, start_offset=430, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=432, start_offset=432, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=434, start_offset=434, starts_line=False, line_number=None),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=436, start_offset=436, starts_line=False, line_number=None),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=438, start_offset=438, starts_line=True, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=448, start_offset=448, starts_line=False, line_number=28),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=450, start_offset=450, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=458, start_offset=458, starts_line=False, line_number=28),
+ make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=460, start_offset=460, starts_line=False, line_number=28),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=462, start_offset=462, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=464, start_offset=464, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=466, start_offset=466, starts_line=False, line_number=None),
]
# One last piece of inspect fodder to check the default line number handling
diff --git a/Lib/test/test_doctest/sample_doctest_errors.py b/Lib/test/test_doctest/sample_doctest_errors.py
new file mode 100644
index 00000000000..4a6f07af2d4
--- /dev/null
+++ b/Lib/test/test_doctest/sample_doctest_errors.py
@@ -0,0 +1,46 @@
+"""This is a sample module used for testing doctest.
+
+This module includes various scenarios involving errors.
+
+>>> 2 + 2
+5
+>>> 1/0
+1
+"""
+
+def g():
+ [][0] # line 12
+
+def errors():
+ """
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ >>> def f():
+ ... 2 + '2'
+ ...
+ >>> f()
+ 1
+ >>> g()
+ 1
+ """
+
+def syntax_error():
+ """
+ >>> 2+*3
+ 5
+ """
+
+__test__ = {
+ 'bad': """
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ """,
+}
+
+def test_suite():
+ import doctest
+ return doctest.DocTestSuite()
diff --git a/Lib/test/test_doctest/test_doctest.py b/Lib/test/test_doctest/test_doctest.py
index a4a49298bab..72763d4a013 100644
--- a/Lib/test/test_doctest/test_doctest.py
+++ b/Lib/test/test_doctest/test_doctest.py
@@ -2267,14 +2267,24 @@ def test_DocTestSuite():
>>> import unittest
>>> import test.test_doctest.sample_doctest
>>> suite = doctest.DocTestSuite(test.test_doctest.sample_doctest)
- >>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=9 errors=2 failures=2>
+ >>> for tst, _ in result.failures:
+ ... print(tst)
+ bad (test.test_doctest.sample_doctest.__test__) [0]
+ foo (test.test_doctest.sample_doctest) [0]
+ >>> for tst, _ in result.errors:
+ ... print(tst)
+ test_silly_setup (test.test_doctest.sample_doctest) [1]
+ y_is_one (test.test_doctest.sample_doctest) [0]
We can also supply the module by name:
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest')
- >>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=9 errors=2 failures=2>
The module need not contain any doctest examples:
@@ -2296,13 +2306,26 @@ def test_DocTestSuite():
>>> result
<unittest.result.TestResult run=6 errors=0 failures=2>
>>> len(result.skipped)
- 2
+ 7
+ >>> for tst, _ in result.skipped:
+ ... print(tst)
+ double_skip (test.test_doctest.sample_doctest_skip) [0]
+ double_skip (test.test_doctest.sample_doctest_skip) [1]
+ double_skip (test.test_doctest.sample_doctest_skip)
+ partial_skip_fail (test.test_doctest.sample_doctest_skip) [0]
+ partial_skip_pass (test.test_doctest.sample_doctest_skip) [0]
+ single_skip (test.test_doctest.sample_doctest_skip) [0]
+ single_skip (test.test_doctest.sample_doctest_skip)
+ >>> for tst, _ in result.failures:
+ ... print(tst)
+ no_skip_fail (test.test_doctest.sample_doctest_skip) [0]
+ partial_skip_fail (test.test_doctest.sample_doctest_skip) [1]
We can use the current module:
>>> suite = test.test_doctest.sample_doctest.test_suite()
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
We can also provide a DocTestFinder:
@@ -2310,7 +2333,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... test_finder=finder)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
The DocTestFinder need not return any tests:
@@ -2326,7 +2349,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest', globs={})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=5>
+ <unittest.result.TestResult run=9 errors=3 failures=2>
Alternatively, we can provide extra globals. Here we'll make an
error go away by providing an extra global variable:
@@ -2334,7 +2357,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... extraglobs={'y': 1})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
You can pass option flags. Here we'll cause an extra error
by disabling the blank-line feature:
@@ -2342,7 +2365,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... optionflags=doctest.DONT_ACCEPT_BLANKLINE)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=5>
+ <unittest.result.TestResult run=9 errors=2 failures=3>
You can supply setUp and tearDown functions:
@@ -2359,7 +2382,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... setUp=setUp, tearDown=tearDown)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
But the tearDown restores sanity:
@@ -2377,13 +2400,115 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest', setUp=setUp)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
Here, we didn't need to use a tearDown function because we
modified the test globals, which are a copy of the
sample_doctest module dictionary. The test globals are
automatically cleared for us after a test.
- """
+ """
+
+def test_DocTestSuite_errors():
+ """Tests for error reporting in DocTestSuite.
+
+ >>> import unittest
+ >>> import test.test_doctest.sample_doctest_errors as mod
+ >>> suite = doctest.DocTestSuite(mod)
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=4 errors=6 failures=3>
+ >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 5, in test.test_doctest.sample_doctest_errors
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.failures[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line None, in test.test_doctest.sample_doctest_errors.__test__.bad
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.failures[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 16, in test.test_doctest.sample_doctest_errors.errors
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.errors[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 7, in test.test_doctest.sample_doctest_errors
+ >...>> 1/0
+ File "<doctest test.test_doctest.sample_doctest_errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line None, in test.test_doctest.sample_doctest_errors.__test__.bad
+ File "<doctest test.test_doctest.sample_doctest_errors.__test__.bad[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 18, in test.test_doctest.sample_doctest_errors.errors
+ >...>> 1/0
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[3][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 23, in test.test_doctest.sample_doctest_errors.errors
+ >...>> f()
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ <BLANKLINE>
+ >>> print(result.errors[4][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 25, in test.test_doctest.sample_doctest_errors.errors
+ >...>> g()
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[4]>", line 1, in <module>
+ g()
+ ~^^
+ File "...sample_doctest_errors.py", line 12, in g
+ [][0] # line 12
+ ~~^^^
+ IndexError: list index out of range
+ <BLANKLINE>
+ >>> print(result.errors[5][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 31, in test.test_doctest.sample_doctest_errors.syntax_error
+ >...>> 2+*3
+ File "<doctest test.test_doctest.sample_doctest_errors.syntax_error[0]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ <BLANKLINE>
+ """
def test_DocFileSuite():
"""We can test tests found in text files using a DocFileSuite.
@@ -2396,7 +2521,7 @@ def test_DocFileSuite():
... 'test_doctest2.txt',
... 'test_doctest4.txt')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
The test files are looked for in the directory containing the
calling module. A package keyword argument can be provided to
@@ -2408,14 +2533,14 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... package='test.test_doctest')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
'/' should be used as a path separator. It will be converted
to a native separator at run time:
>>> suite = doctest.DocFileSuite('../test_doctest/test_doctest.txt')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=1 errors=0 failures=1>
+ <unittest.result.TestResult run=1 errors=1 failures=0>
If DocFileSuite is used from an interactive session, then files
are resolved relative to the directory of sys.argv[0]:
@@ -2441,7 +2566,7 @@ def test_DocFileSuite():
>>> suite = doctest.DocFileSuite(test_file, module_relative=False)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=1 errors=0 failures=1>
+ <unittest.result.TestResult run=1 errors=1 failures=0>
It is an error to specify `package` when `module_relative=False`:
@@ -2455,12 +2580,19 @@ def test_DocFileSuite():
>>> suite = doctest.DocFileSuite('test_doctest.txt',
... 'test_doctest4.txt',
- ... 'test_doctest_skip.txt')
+ ... 'test_doctest_skip.txt',
+ ... 'test_doctest_skip2.txt')
>>> result = suite.run(unittest.TestResult())
>>> result
- <unittest.result.TestResult run=3 errors=0 failures=1>
- >>> len(result.skipped)
- 1
+ <unittest.result.TestResult run=4 errors=1 failures=0>
+ >>> len(result.skipped)
+ 4
+ >>> for tst, _ in result.skipped: # doctest: +ELLIPSIS
+ ... print('=', tst)
+ = ...test_doctest_skip.txt [0]
+ = ...test_doctest_skip.txt [1]
+ = ...test_doctest_skip.txt
+ = ...test_doctest_skip2.txt [0]
You can specify initial global variables:
@@ -2469,7 +2601,7 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... globs={'favorite_color': 'blue'})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=1>
+ <unittest.result.TestResult run=3 errors=1 failures=0>
In this case, we supplied a missing favorite color. You can
provide doctest options:
@@ -2480,7 +2612,7 @@ def test_DocFileSuite():
... optionflags=doctest.DONT_ACCEPT_BLANKLINE,
... globs={'favorite_color': 'blue'})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=1 failures=1>
And, you can provide setUp and tearDown functions:
@@ -2499,7 +2631,7 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... setUp=setUp, tearDown=tearDown)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=1>
+ <unittest.result.TestResult run=3 errors=1 failures=0>
But the tearDown restores sanity:
@@ -2541,9 +2673,60 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... encoding='utf-8')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
+ """
- """
+def test_DocFileSuite_errors():
+ """Tests for error reporting in DocTestSuite.
+
+ >>> import unittest
+ >>> suite = doctest.DocFileSuite('test_doctest_errors.txt')
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=3 failures=1>
+ >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 4, in test_doctest_errors.txt
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.errors[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 6, in test_doctest_errors.txt
+ >...>> 1/0
+ File "<doctest test_doctest_errors.txt[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 11, in test_doctest_errors.txt
+ >...>> f()
+ File "<doctest test_doctest_errors.txt[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test_doctest_errors.txt[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ <BLANKLINE>
+ >>> print(result.errors[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 13, in test_doctest_errors.txt
+ >...>> 2+*3
+ File "<doctest test_doctest_errors.txt[4]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ <BLANKLINE>
+
+ """
def test_trailing_space_in_test():
"""
@@ -2612,14 +2795,26 @@ def test_unittest_reportflags():
... optionflags=doctest.DONT_ACCEPT_BLANKLINE)
>>> import unittest
>>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=1 failures=1>
>>> print(result.failures[0][1]) # doctest: +ELLIPSIS
- Traceback ...
- Failed example:
- favorite_color
- ...
- Failed example:
+ Traceback (most recent call last):
+ File ...
+ >...>> if 1:
+ AssertionError: Failed example:
if 1:
- ...
+ print('a')
+ print()
+ print('b')
+ Expected:
+ a
+ <BLANKLINE>
+ b
+ Got:
+ a
+ <BLANKLINE>
+ b
+ <BLANKLINE>
Note that we see both failures displayed.
@@ -2628,16 +2823,8 @@ def test_unittest_reportflags():
Now, when we run the test:
- >>> result = suite.run(unittest.TestResult())
- >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
- Traceback ...
- Failed example:
- favorite_color
- Exception raised:
- ...
- NameError: name 'favorite_color' is not defined
- <BLANKLINE>
- <BLANKLINE>
+ >>> suite.run(unittest.TestResult())
+ <unittest.result.TestResult run=1 errors=1 failures=0>
We get only the first failure.
@@ -2647,19 +2834,20 @@ def test_unittest_reportflags():
>>> suite = doctest.DocFileSuite('test_doctest.txt',
... optionflags=doctest.DONT_ACCEPT_BLANKLINE | doctest.REPORT_NDIFF)
- Then the default eporting options are ignored:
+ Then the default reporting options are ignored:
>>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=1 failures=1>
*NOTE*: These doctest are intentionally not placed in raw string to depict
the trailing whitespace using `\x20` in the diff below.
>>> print(result.failures[0][1]) # doctest: +ELLIPSIS
Traceback ...
- Failed example:
- favorite_color
- ...
- Failed example:
+ File ...
+ >...>> if 1:
+ AssertionError: Failed example:
if 1:
print('a')
print()
@@ -2670,7 +2858,6 @@ def test_unittest_reportflags():
+\x20
b
<BLANKLINE>
- <BLANKLINE>
Test runners can restore the formatting flags after they run:
@@ -2860,6 +3047,57 @@ Test the verbose output:
>>> _colorize.COLORIZE = save_colorize
"""
+def test_testfile_errors(): r"""
+Tests for error reporting in the testfile() function.
+
+ >>> doctest.testfile('test_doctest_errors.txt', verbose=False) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 4, in test_doctest_errors.txt
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 6, in test_doctest_errors.txt
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test_doctest_errors.txt[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 11, in test_doctest_errors.txt
+ Failed example:
+ f()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test_doctest_errors.txt[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test_doctest_errors.txt[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 13, in test_doctest_errors.txt
+ Failed example:
+ 2+*3
+ Exception raised:
+ File "<doctest test_doctest_errors.txt[4]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ **********************************************************************
+ 1 item had failures:
+ 4 of 5 in test_doctest_errors.txt
+ ***Test Failed*** 4 failures.
+ TestResults(failed=4, attempted=5)
+"""
+
class TestImporter(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
@@ -2990,6 +3228,110 @@ out of the binary module.
TestResults(failed=0, attempted=0)
"""
+def test_testmod_errors(): r"""
+Tests for error reporting in the testmod() function.
+
+ >>> import test.test_doctest.sample_doctest_errors as mod
+ >>> doctest.testmod(mod, verbose=False) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 5, in test.test_doctest.sample_doctest_errors
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 7, in test.test_doctest.sample_doctest_errors
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line ?, in test.test_doctest.sample_doctest_errors.__test__.bad
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line ?, in test.test_doctest.sample_doctest_errors.__test__.bad
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.__test__.bad[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 16, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 18, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 23, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ f()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 25, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ g()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[4]>", line 1, in <module>
+ g()
+ ~^^
+ File "...sample_doctest_errors.py", line 12, in g
+ [][0] # line 12
+ ~~^^^
+ IndexError: list index out of range
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 31, in test.test_doctest.sample_doctest_errors.syntax_error
+ Failed example:
+ 2+*3
+ Exception raised:
+ File "<doctest test.test_doctest.sample_doctest_errors.syntax_error[0]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ **********************************************************************
+ 4 items had failures:
+ 2 of 2 in test.test_doctest.sample_doctest_errors
+ 2 of 2 in test.test_doctest.sample_doctest_errors.__test__.bad
+ 4 of 5 in test.test_doctest.sample_doctest_errors.errors
+ 1 of 1 in test.test_doctest.sample_doctest_errors.syntax_error
+ ***Test Failed*** 9 failures.
+ TestResults(failed=9, attempted=10)
+"""
+
try:
os.fsencode("foo-bär@baz.py")
supports_unicode = True
@@ -3021,11 +3363,6 @@ Check doctest with a non-ascii filename:
raise Exception('clé')
Exception raised:
Traceback (most recent call last):
- File ...
- exec(compile(example.source, filename, "single",
- ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- compileflags, True), test.globs)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<doctest foo-bär@baz[0]>", line 1, in <module>
raise Exception('clé')
Exception: clé
@@ -3318,9 +3655,9 @@ def test_run_doctestsuite_multiple_times():
>>> import test.test_doctest.sample_doctest
>>> suite = doctest.DocTestSuite(test.test_doctest.sample_doctest)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
"""
diff --git a/Lib/test/test_doctest/test_doctest_errors.txt b/Lib/test/test_doctest/test_doctest_errors.txt
new file mode 100644
index 00000000000..93c3c106e60
--- /dev/null
+++ b/Lib/test/test_doctest/test_doctest_errors.txt
@@ -0,0 +1,14 @@
+This is a sample doctest in a text file, in which all examples fail
+or raise an exception.
+
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ >>> def f():
+ ... 2 + '2'
+ ...
+ >>> f()
+ 1
+ >>> 2+*3
+ 5
diff --git a/Lib/test/test_doctest/test_doctest_skip.txt b/Lib/test/test_doctest/test_doctest_skip.txt
index f340e2b8141..06c23d06e60 100644
--- a/Lib/test/test_doctest/test_doctest_skip.txt
+++ b/Lib/test/test_doctest/test_doctest_skip.txt
@@ -2,3 +2,5 @@ This is a sample doctest in a text file, in which all examples are skipped.
>>> 2 + 2 # doctest: +SKIP
5
+ >>> 2 + 2 # doctest: +SKIP
+ 4
diff --git a/Lib/test/test_doctest/test_doctest_skip2.txt b/Lib/test/test_doctest/test_doctest_skip2.txt
new file mode 100644
index 00000000000..85e4938c346
--- /dev/null
+++ b/Lib/test/test_doctest/test_doctest_skip2.txt
@@ -0,0 +1,6 @@
+This is a sample doctest in a text file, in which some examples are skipped.
+
+ >>> 2 + 2 # doctest: +SKIP
+ 5
+ >>> 2 + 2
+ 4
diff --git a/Lib/test/test_dynamicclassattribute.py b/Lib/test/test_dynamicclassattribute.py
index 9f694d9eb46..b19be33c72f 100644
--- a/Lib/test/test_dynamicclassattribute.py
+++ b/Lib/test/test_dynamicclassattribute.py
@@ -104,8 +104,8 @@ class PropertyTests(unittest.TestCase):
self.assertEqual(base.spam, 10)
self.assertEqual(base._spam, 10)
delattr(base, "spam")
- self.assertTrue(not hasattr(base, "spam"))
- self.assertTrue(not hasattr(base, "_spam"))
+ self.assertNotHasAttr(base, "spam")
+ self.assertNotHasAttr(base, "_spam")
base.spam = 20
self.assertEqual(base.spam, 20)
self.assertEqual(base._spam, 20)
diff --git a/Lib/test/test_email/test__header_value_parser.py b/Lib/test/test_email/test__header_value_parser.py
index ac12c3b2306..179e236ecdf 100644
--- a/Lib/test/test_email/test__header_value_parser.py
+++ b/Lib/test/test_email/test__header_value_parser.py
@@ -463,6 +463,19 @@ class TestParser(TestParserMixin, TestEmailBase):
[errors.NonPrintableDefect], ')')
self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+ def test_get_qp_ctext_close_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ ')', '', ' ', [], ')')
+
+ def test_get_qp_ctext_open_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ '(', '', ' ', [], '(')
+
+ def test_get_qp_ctext_no_end_char(self):
+ self._test_get_x(parser.get_qp_ctext,
+ '', '', ' ', [], '')
+
+
# get_qcontent
def test_get_qcontent_only(self):
@@ -503,6 +516,14 @@ class TestParser(TestParserMixin, TestEmailBase):
[errors.NonPrintableDefect], '"')
self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+ def test_get_qcontent_empty(self):
+ self._test_get_x(parser.get_qcontent,
+ '"', '', '', [], '"')
+
+ def test_get_qcontent_no_end_char(self):
+ self._test_get_x(parser.get_qcontent,
+ '', '', '', [], '')
+
# get_atext
def test_get_atext_only(self):
@@ -1283,6 +1304,18 @@ class TestParser(TestParserMixin, TestEmailBase):
self._test_get_x(parser.get_dtext,
'foo[bar', 'foo', 'foo', [], '[bar')
+ def test_get_dtext_open_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ '[', '', '', [], '[')
+
+ def test_get_dtext_close_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ ']', '', '', [], ']')
+
+ def test_get_dtext_empty(self):
+ self._test_get_x(parser.get_dtext,
+ '', '', '', [], '')
+
# get_domain_literal
def test_get_domain_literal_only(self):
@@ -2458,6 +2491,38 @@ class TestParser(TestParserMixin, TestEmailBase):
self.assertEqual(address.all_mailboxes[0].domain, 'example.com')
self.assertEqual(address.all_mailboxes[0].addr_spec, '"example example"@example.com')
+ def test_get_address_with_invalid_domain(self):
+ address = self._test_get_x(parser.get_address,
+ '<T@[',
+ '<T@[]>',
+ '<T@[]>',
+ [errors.InvalidHeaderDefect, # missing trailing '>' on angle-addr
+ errors.InvalidHeaderDefect, # end of input inside domain-literal
+ ],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(len(address.all_mailboxes), 1)
+ self.assertEqual(address.all_mailboxes[0].domain, '[]')
+ self.assertEqual(address.all_mailboxes[0].local_part, 'T')
+ self.assertEqual(address.all_mailboxes[0].token_type, 'invalid-mailbox')
+ self.assertEqual(address[0].token_type, 'invalid-mailbox')
+
+ address = self._test_get_x(parser.get_address,
+ '!an??:=m==fr2@[C',
+ '!an??:=m==fr2@[C];',
+ '!an??:=m==fr2@[C];',
+ [errors.InvalidHeaderDefect, # end of header in group
+ errors.InvalidHeaderDefect, # end of input inside domain-literal
+ ],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(len(address.all_mailboxes), 1)
+ self.assertEqual(address.all_mailboxes[0].domain, '[C]')
+ self.assertEqual(address.all_mailboxes[0].local_part, '=m==fr2')
+ self.assertEqual(address.all_mailboxes[0].token_type, 'invalid-mailbox')
+ self.assertEqual(address[0].token_type, 'group')
# get_address_list
@@ -2732,6 +2797,19 @@ class TestParser(TestParserMixin, TestEmailBase):
)
self.assertEqual(message_id.token_type, 'message-id')
+ def test_parse_message_id_with_invalid_domain(self):
+ message_id = self._test_parse_x(
+ parser.parse_message_id,
+ "<T@[",
+ "<T@[]>",
+ "<T@[]>",
+ [errors.ObsoleteHeaderDefect] + [errors.InvalidHeaderDefect] * 2,
+ [],
+ )
+ self.assertEqual(message_id.token_type, 'message-id')
+ self.assertEqual(str(message_id.all_defects[-1]),
+ "end of input inside domain-literal")
+
def test_parse_message_id_with_remaining(self):
message_id = self._test_parse_x(
parser.parse_message_id,
diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py
index 7b14305f997..b8116d073a2 100644
--- a/Lib/test/test_email/test_email.py
+++ b/Lib/test/test_email/test_email.py
@@ -389,6 +389,24 @@ class TestMessageAPI(TestEmailBase):
msg = email.message_from_string("Content-Type: blarg; baz; boo\n")
self.assertEqual(msg.get_param('baz'), '')
+ def test_continuation_sorting_part_order(self):
+ msg = email.message_from_string(
+ "Content-Disposition: attachment; "
+ "filename*=\"ignored\"; "
+ "filename*0*=\"utf-8''foo%20\"; "
+ "filename*1*=\"bar.txt\"\n"
+ )
+ filename = msg.get_filename()
+ self.assertEqual(filename, 'foo bar.txt')
+
+ def test_sorting_no_continuations(self):
+ msg = email.message_from_string(
+ "Content-Disposition: attachment; "
+ "filename*=\"bar.txt\"; "
+ )
+ filename = msg.get_filename()
+ self.assertEqual(filename, 'bar.txt')
+
def test_missing_filename(self):
msg = email.message_from_string("From: foo\n")
self.assertEqual(msg.get_filename(), None)
@@ -2550,6 +2568,18 @@ Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar =?mac-iceland?q?r=8Aksm?=
self.assertEqual(str(make_header(decode_header(s))),
'"Müller T" <T.Mueller@xxx.com>')
+ def test_unencoded_ascii(self):
+ # bpo-22833/gh-67022: returns [(str, None)] rather than [(bytes, None)]
+ s = 'header without encoded words'
+ self.assertEqual(decode_header(s),
+ [('header without encoded words', None)])
+
+ def test_unencoded_utf8(self):
+ # bpo-22833/gh-67022: returns [(str, None)] rather than [(bytes, None)]
+ s = 'header with unexpected non ASCII caract\xe8res'
+ self.assertEqual(decode_header(s),
+ [('header with unexpected non ASCII caract\xe8res', None)])
+
# Test the MIMEMessage class
class TestMIMEMessage(TestEmailBase):
diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py
index 4e6201e13c8..c9d09098b50 100644
--- a/Lib/test/test_email/test_utils.py
+++ b/Lib/test/test_email/test_utils.py
@@ -4,6 +4,16 @@ import test.support
import time
import unittest
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("email.utils", {"random", "socket"})
+
class DateTimeTests(unittest.TestCase):
diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py
index e06e684408c..89f4aebe28f 100644
--- a/Lib/test/test_embed.py
+++ b/Lib/test/test_embed.py
@@ -296,7 +296,7 @@ class EmbeddingTests(EmbeddingTestsMixin, unittest.TestCase):
if MS_WINDOWS:
expected_path = self.test_exe
else:
- expected_path = os.path.join(os.getcwd(), "spam")
+ expected_path = os.path.join(os.getcwd(), "_testembed")
expected_output = f"sys.executable: {expected_path}\n"
self.assertIn(expected_output, out)
self.assertEqual(err, '')
@@ -585,7 +585,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'faulthandler': False,
'tracemalloc': 0,
'perf_profiling': 0,
- 'import_time': False,
+ 'import_time': 0,
'thread_inherit_context': DEFAULT_THREAD_INHERIT_CONTEXT,
'context_aware_warnings': DEFAULT_CONTEXT_AWARE_WARNINGS,
'code_debug_ranges': True,
@@ -969,7 +969,6 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'utf8_mode': True,
}
config = {
- 'program_name': './globalvar',
'site_import': False,
'bytes_warning': True,
'warnoptions': ['default::BytesWarning'],
@@ -998,7 +997,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'hash_seed': 123,
'tracemalloc': 2,
'perf_profiling': 0,
- 'import_time': True,
+ 'import_time': 2,
'code_debug_ranges': False,
'show_ref_count': True,
'malloc_stats': True,
@@ -1064,7 +1063,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
@@ -1100,7 +1099,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
@@ -1916,6 +1915,10 @@ class AuditingTests(EmbeddingTestsMixin, unittest.TestCase):
self.run_embedded_interpreter("test_get_incomplete_frame")
+ def test_gilstate_after_finalization(self):
+ self.run_embedded_interpreter("test_gilstate_after_finalization")
+
+
class MiscTests(EmbeddingTestsMixin, unittest.TestCase):
def test_unicode_id_init(self):
# bpo-42882: Test that _PyUnicode_FromId() works
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 68cedc666a5..bbc7630fa83 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -19,7 +19,8 @@ from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support
from test.support import ALWAYS_EQ, REPO_ROOT
-from test.support import threading_helper
+from test.support import threading_helper, cpython_only
+from test.support.import_helper import ensure_lazy_imports
from datetime import timedelta
python_version = sys.version_info[:2]
@@ -35,7 +36,7 @@ def load_tests(loader, tests, ignore):
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
howto_tests = os.path.join(REPO_ROOT, 'Doc/howto/enum.rst')
- if os.path.exists(howto_tests):
+ if os.path.exists(howto_tests) and sys.float_repr_style == 'short':
tests.addTests(doctest.DocFileSuite(
howto_tests,
module_relative=False,
@@ -433,9 +434,9 @@ class _EnumTests:
def spam(cls):
pass
#
- self.assertTrue(hasattr(Season, 'spam'))
+ self.assertHasAttr(Season, 'spam')
del Season.spam
- self.assertFalse(hasattr(Season, 'spam'))
+ self.assertNotHasAttr(Season, 'spam')
#
with self.assertRaises(AttributeError):
del Season.SPRING
@@ -2651,12 +2652,12 @@ class TestSpecial(unittest.TestCase):
OneDay = day_1
OneWeek = week_1
OneMonth = month_1
- self.assertFalse(hasattr(Period, '_ignore_'))
- self.assertFalse(hasattr(Period, 'Period'))
- self.assertFalse(hasattr(Period, 'i'))
- self.assertTrue(isinstance(Period.day_1, timedelta))
- self.assertTrue(Period.month_1 is Period.day_30)
- self.assertTrue(Period.week_4 is Period.day_28)
+ self.assertNotHasAttr(Period, '_ignore_')
+ self.assertNotHasAttr(Period, 'Period')
+ self.assertNotHasAttr(Period, 'i')
+ self.assertIsInstance(Period.day_1, timedelta)
+ self.assertIs(Period.month_1, Period.day_30)
+ self.assertIs(Period.week_4, Period.day_28)
def test_nonhash_value(self):
class AutoNumberInAList(Enum):
@@ -2876,7 +2877,7 @@ class TestSpecial(unittest.TestCase):
self.assertEqual(str(ReformedColor.BLUE), 'blue')
self.assertEqual(ReformedColor.RED.behavior(), 'booyah')
self.assertEqual(ConfusedColor.RED.social(), "what's up?")
- self.assertTrue(issubclass(ReformedColor, int))
+ self.assertIsSubclass(ReformedColor, int)
def test_multiple_inherited_mixin(self):
@unique
@@ -5288,6 +5289,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("enum", {"functools", "warnings", "inspect", "re"})
+
def test_doc_1(self):
class Single(Enum):
ONE = 1
diff --git a/Lib/test/test_errno.py b/Lib/test/test_errno.py
index 5c437e9ccea..e7f185c6b1a 100644
--- a/Lib/test/test_errno.py
+++ b/Lib/test/test_errno.py
@@ -12,14 +12,12 @@ class ErrnoAttributeTests(unittest.TestCase):
def test_for_improper_attributes(self):
# No unexpected attributes should be on the module.
for error_code in std_c_errors:
- self.assertTrue(hasattr(errno, error_code),
- "errno is missing %s" % error_code)
+ self.assertHasAttr(errno, error_code)
def test_using_errorcode(self):
# Every key value in errno.errorcode should be on the module.
for value in errno.errorcode.values():
- self.assertTrue(hasattr(errno, value),
- 'no %s attr in errno' % value)
+ self.assertHasAttr(errno, value)
class ErrorcodeTests(unittest.TestCase):
diff --git a/Lib/test/test_exception_group.py b/Lib/test/test_exception_group.py
index 92bbf791764..5df2c41c6b5 100644
--- a/Lib/test/test_exception_group.py
+++ b/Lib/test/test_exception_group.py
@@ -1,13 +1,13 @@
import collections.abc
import types
import unittest
-from test.support import skip_emscripten_stack_overflow, exceeds_recursion_limit
+from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow, exceeds_recursion_limit
class TestExceptionGroupTypeHierarchy(unittest.TestCase):
def test_exception_group_types(self):
- self.assertTrue(issubclass(ExceptionGroup, Exception))
- self.assertTrue(issubclass(ExceptionGroup, BaseExceptionGroup))
- self.assertTrue(issubclass(BaseExceptionGroup, BaseException))
+ self.assertIsSubclass(ExceptionGroup, Exception)
+ self.assertIsSubclass(ExceptionGroup, BaseExceptionGroup)
+ self.assertIsSubclass(BaseExceptionGroup, BaseException)
def test_exception_is_not_generic_type(self):
with self.assertRaisesRegex(TypeError, 'Exception'):
@@ -465,12 +465,14 @@ class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
return e
@skip_emscripten_stack_overflow()
+ @skip_wasi_stack_overflow()
def test_deep_split(self):
e = self.make_deep_eg()
with self.assertRaises(RecursionError):
e.split(TypeError)
@skip_emscripten_stack_overflow()
+ @skip_wasi_stack_overflow()
def test_deep_subgroup(self):
e = self.make_deep_eg()
with self.assertRaises(RecursionError):
@@ -812,8 +814,8 @@ class NestedExceptionGroupSplitTest(ExceptionGroupSplitTestBase):
eg = ExceptionGroup("eg", [ValueError(1), TypeError(2)])
eg.__notes__ = 123
match, rest = eg.split(TypeError)
- self.assertFalse(hasattr(match, '__notes__'))
- self.assertFalse(hasattr(rest, '__notes__'))
+ self.assertNotHasAttr(match, '__notes__')
+ self.assertNotHasAttr(rest, '__notes__')
def test_drive_invalid_return_value(self):
class MyEg(ExceptionGroup):
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index d177e3dc0f5..57d0656487d 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -357,7 +357,7 @@ class ExceptionTests(unittest.TestCase):
except TypeError as err:
co = err.__traceback__.tb_frame.f_code
self.assertEqual(co.co_name, "test_capi1")
- self.assertTrue(co.co_filename.endswith('test_exceptions.py'))
+ self.assertEndsWith(co.co_filename, 'test_exceptions.py')
else:
self.fail("Expected exception")
@@ -369,7 +369,7 @@ class ExceptionTests(unittest.TestCase):
tb = err.__traceback__.tb_next
co = tb.tb_frame.f_code
self.assertEqual(co.co_name, "__init__")
- self.assertTrue(co.co_filename.endswith('test_exceptions.py'))
+ self.assertEndsWith(co.co_filename, 'test_exceptions.py')
co2 = tb.tb_frame.f_back.f_code
self.assertEqual(co2.co_name, "test_capi2")
else:
@@ -598,7 +598,7 @@ class ExceptionTests(unittest.TestCase):
def test_notes(self):
for e in [BaseException(1), Exception(2), ValueError(3)]:
with self.subTest(e=e):
- self.assertFalse(hasattr(e, '__notes__'))
+ self.assertNotHasAttr(e, '__notes__')
e.add_note("My Note")
self.assertEqual(e.__notes__, ["My Note"])
@@ -610,7 +610,7 @@ class ExceptionTests(unittest.TestCase):
self.assertEqual(e.__notes__, ["My Note", "Your Note"])
del e.__notes__
- self.assertFalse(hasattr(e, '__notes__'))
+ self.assertNotHasAttr(e, '__notes__')
e.add_note("Our Note")
self.assertEqual(e.__notes__, ["Our Note"])
@@ -1429,6 +1429,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn("maximum recursion depth exceeded", str(exc))
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@cpython_only
@support.requires_resource('cpu')
def test_trashcan_recursion(self):
@@ -1444,6 +1445,7 @@ class ExceptionTests(unittest.TestCase):
foo()
support.gc_collect()
+ @support.skip_emscripten_stack_overflow()
@cpython_only
def test_recursion_normalizing_exception(self):
import_module("_testinternalcapi")
@@ -1521,6 +1523,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn(b'Done.', out)
+ @support.skip_emscripten_stack_overflow()
def test_recursion_in_except_handler(self):
def set_relative_recursion_limit(n):
@@ -1626,7 +1629,7 @@ class ExceptionTests(unittest.TestCase):
# test basic usage of PyErr_NewException
error1 = _testcapi.make_exception_with_doc("_testcapi.error1")
self.assertIs(type(error1), type)
- self.assertTrue(issubclass(error1, Exception))
+ self.assertIsSubclass(error1, Exception)
self.assertIsNone(error1.__doc__)
# test with given docstring
@@ -1636,21 +1639,21 @@ class ExceptionTests(unittest.TestCase):
# test with explicit base (without docstring)
error3 = _testcapi.make_exception_with_doc("_testcapi.error3",
base=error2)
- self.assertTrue(issubclass(error3, error2))
+ self.assertIsSubclass(error3, error2)
# test with explicit base tuple
class C(object):
pass
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
(error3, C))
- self.assertTrue(issubclass(error4, error3))
- self.assertTrue(issubclass(error4, C))
+ self.assertIsSubclass(error4, error3)
+ self.assertIsSubclass(error4, C)
self.assertEqual(error4.__doc__, doc4)
# test with explicit dictionary
error5 = _testcapi.make_exception_with_doc("_testcapi.error5", "",
error4, {'a': 1})
- self.assertTrue(issubclass(error5, error4))
+ self.assertIsSubclass(error5, error4)
self.assertEqual(error5.a, 1)
self.assertEqual(error5.__doc__, "")
@@ -1743,7 +1746,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn("<exception str() failed>", report)
else:
self.assertIn("test message", report)
- self.assertTrue(report.endswith("\n"))
+ self.assertEndsWith(report, "\n")
@cpython_only
# Python built with Py_TRACE_REFS fail with a fatal error in
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py
index aa05db972f0..0f31c225e68 100644
--- a/Lib/test/test_external_inspection.py
+++ b/Lib/test/test_external_inspection.py
@@ -4,7 +4,10 @@ import textwrap
import importlib
import sys
import socket
-from test.support import os_helper, SHORT_TIMEOUT, busy_retry
+import threading
+from asyncio import staggered, taskgroups, base_events, tasks
+from unittest.mock import ANY
+from test.support import os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled
from test.support.script_helper import make_script
from test.support.socket_helper import find_unused_port
@@ -13,33 +16,60 @@ import subprocess
PROCESS_VM_READV_SUPPORTED = False
try:
- from _testexternalinspection import PROCESS_VM_READV_SUPPORTED
- from _testexternalinspection import get_stack_trace
- from _testexternalinspection import get_async_stack_trace
- from _testexternalinspection import get_all_awaited_by
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ from _remote_debugging import RemoteUnwinder
+ from _remote_debugging import FrameInfo, CoroInfo, TaskInfo
except ImportError:
raise unittest.SkipTest(
- "Test only runs when _testexternalinspection is available")
+ "Test only runs when _remote_debugging is available"
+ )
+
def _make_test_script(script_dir, script_basename, source):
to_return = make_script(script_dir, script_basename, source)
importlib.invalidate_caches()
return to_return
-skip_if_not_supported = unittest.skipIf((sys.platform != "darwin"
- and sys.platform != "linux"
- and sys.platform != "win32"),
- "Test only runs on Linux, Windows and MacOS")
+
+skip_if_not_supported = unittest.skipIf(
+ (
+ sys.platform != "darwin"
+ and sys.platform != "linux"
+ and sys.platform != "win32"
+ ),
+ "Test only runs on Linux, Windows and MacOS",
+)
+
+
+def get_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
+ return unwinder.get_stack_trace()
+
+
+def get_async_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_async_stack_trace()
+
+
+def get_all_awaited_by(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_all_awaited_by()
+
+
class TestGetStackTrace(unittest.TestCase):
+ maxDiff = None
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
- import time, sys, socket
+ script = textwrap.dedent(
+ f"""\
+ import time, sys, socket, threading
# Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
@@ -48,15 +78,18 @@ class TestGetStackTrace(unittest.TestCase):
for x in range(100):
if x == 50:
baz()
+
def baz():
foo()
def foo():
- sock.sendall(b"ready")
- time.sleep(1000)
+ sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number
- bar()
- """)
+ t = threading.Thread(target=bar)
+ t.start()
+ sock.sendall(b"ready:main\\n"); t.join() # same line number
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -65,21 +98,27 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
+ response = b""
+ while (
+ b"ready:main" not in response
+ or b"ready:thread" not in response
+ ):
+ response += client_socket.recv(1024)
stack_trace = get_stack_trace(p.pid)
except PermissionError:
- self.skipTest("Insufficient permissions to read the stack trace")
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -87,22 +126,34 @@ class TestGetStackTrace(unittest.TestCase):
p.terminate()
p.wait(timeout=SHORT_TIMEOUT)
-
- expected_stack_trace = [
- 'foo',
- 'baz',
- 'bar',
- '<module>'
+ thread_expected_stack_trace = [
+ FrameInfo([script_name, 15, "foo"]),
+ FrameInfo([script_name, 12, "baz"]),
+ FrameInfo([script_name, 9, "bar"]),
+ FrameInfo([threading.__file__, ANY, "Thread.run"]),
]
- self.assertEqual(stack_trace, expected_stack_trace)
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ self.assertIn((ANY, thread_expected_stack_trace), stack_trace)
+
+ # Check that the main thread stack trace is in the result
+ frame = FrameInfo([script_name, 19, "<module>"])
+ for _, stack in stack_trace:
+ if frame in stack:
+ break
+ else:
+ self.fail("Main thread stack trace not found in result")
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -112,8 +163,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
def c5():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c4():
await asyncio.sleep(0)
@@ -142,7 +192,8 @@ class TestGetStackTrace(unittest.TestCase):
return loop
asyncio.run(main(), loop_factory={{TASK_FACTORY}})
- """)
+ """
+ )
stack_trace = None
for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop":
with (
@@ -151,19 +202,23 @@ class TestGetStackTrace(unittest.TestCase):
):
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket = socket.socket(
+ socket.AF_INET, socket.SOCK_STREAM
+ )
+ server_socket.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, 1
+ )
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
script_name = _make_test_script(
- script_dir, 'script',
- script.format(TASK_FACTORY=task_factory_variant))
+ script_dir,
+ "script",
+ script.format(TASK_FACTORY=task_factory_variant),
+ )
client_socket = None
try:
- p = subprocess.Popen(
- [sys.executable, script_name]
- )
+ p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
response = client_socket.recv(1024)
@@ -171,7 +226,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -182,25 +238,63 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
- root_task = "Task-1"
expected_stack_trace = [
- ["c5", "c4", "c3", "c2"],
+ [
+ FrameInfo([script_name, 10, "c5"]),
+ FrameInfo([script_name, 14, "c4"]),
+ FrameInfo([script_name, 17, "c3"]),
+ FrameInfo([script_name, 20, "c2"]),
+ ],
"c2_root",
[
- [["main"], root_task, []],
- [["c1"], "sub_main_1", [[["main"], root_task, []]]],
- [["c1"], "sub_main_2", [[["main"], root_task, []]]],
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo([script_name, 26, "main"]),
+ ],
+ "Task-1",
+ ]
+ ),
+ CoroInfo(
+ [
+ [FrameInfo([script_name, 23, "c1"])],
+ "sub_main_1",
+ ]
+ ),
+ CoroInfo(
+ [
+ [FrameInfo([script_name, 23, "c1"])],
+ "sub_main_2",
+ ]
+ ),
],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_asyncgen_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -210,8 +304,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
async def gen_nested_call():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def gen():
for num in range(2):
@@ -224,7 +317,8 @@ class TestGetStackTrace(unittest.TestCase):
pass
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -232,10 +326,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -245,7 +339,9 @@ class TestGetStackTrace(unittest.TestCase):
self.assertEqual(response, b"ready")
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
- self.skipTest("Insufficient permissions to read the stack trace")
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -257,17 +353,26 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace[2].sort(key=lambda x: x[1])
expected_stack_trace = [
- ['gen_nested_call', 'gen', 'main'], 'Task-1', []
+ [
+ FrameInfo([script_name, 10, "gen_nested_call"]),
+ FrameInfo([script_name, 16, "gen"]),
+ FrameInfo([script_name, 19, "main"]),
+ ],
+ "Task-1",
+ [],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_gather_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -278,8 +383,7 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
@@ -292,7 +396,8 @@ class TestGetStackTrace(unittest.TestCase):
await asyncio.gather(c1(), c2())
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -300,10 +405,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -314,7 +419,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -325,18 +431,26 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
- expected_stack_trace = [
- ['deep', 'c1'], 'Task-2', [[['main'], 'Task-1', []]]
+ expected_stack_trace = [
+ [
+ FrameInfo([script_name, 11, "deep"]),
+ FrameInfo([script_name, 15, "c1"]),
+ ],
+ "Task-2",
+ [CoroInfo([[FrameInfo([script_name, 21, "main"])], "Task-1"])],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_staggered_race_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio.staggered
import time
import sys
@@ -347,15 +461,14 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
await deep()
async def c2():
- await asyncio.sleep(10000)
+ await asyncio.sleep(10_000)
async def main():
await asyncio.staggered.staggered_race(
@@ -364,7 +477,8 @@ class TestGetStackTrace(unittest.TestCase):
)
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -372,10 +486,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -386,7 +500,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -396,18 +511,44 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
-
- expected_stack_trace = [
- ['deep', 'c1', 'run_one_coro'], 'Task-2', [[['main'], 'Task-1', []]]
+ expected_stack_trace = [
+ [
+ FrameInfo([script_name, 11, "deep"]),
+ FrameInfo([script_name, 15, "c1"]),
+ FrameInfo(
+ [
+ staggered.__file__,
+ ANY,
+ "staggered_race.<locals>.run_one_coro",
+ ]
+ ),
+ ],
+ "Task-2",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [staggered.__file__, ANY, "staggered_race"]
+ ),
+ FrameInfo([script_name, 21, "main"]),
+ ],
+ "Task-1",
+ ]
+ )
+ ],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_global_awaited_by(self):
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import os
import random
@@ -443,6 +584,8 @@ class TestGetStackTrace(unittest.TestCase):
assert message == data.decode()
writer.close()
await writer.wait_closed()
+ # Signal we are ready to sleep
+ sock.sendall(b"ready")
await asyncio.sleep(SHORT_TIMEOUT)
async def echo_client_spam(server):
@@ -452,8 +595,10 @@ class TestGetStackTrace(unittest.TestCase):
random.shuffle(msg)
tg.create_task(echo_client("".join(msg)))
await asyncio.sleep(0)
- # at least a 1000 tasks created
- sock.sendall(b"ready")
+ # at least a 1000 tasks created. Each task will signal
+ # when is ready to avoid the race caused by the fact that
+ # tasks are waited on tg.__exit__ and we cannot signal when
+ # that happens otherwise
# at this point all client tasks completed without assertion errors
# let's wrap up the test
server.close()
@@ -468,7 +613,8 @@ class TestGetStackTrace(unittest.TestCase):
tg.create_task(echo_client_spam(server), name="echo client spam")
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -476,17 +622,19 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
+ for _ in range(1000):
+ expected_response = b"ready"
+ response = client_socket.recv(len(expected_response))
+ self.assertEqual(response, expected_response)
for _ in busy_retry(SHORT_TIMEOUT):
try:
all_awaited_by = get_all_awaited_by(p.pid)
@@ -497,7 +645,9 @@ class TestGetStackTrace(unittest.TestCase):
msg = str(re)
if msg.startswith("Task list appears corrupted"):
continue
- elif msg.startswith("Invalid linked list structure reading remote memory"):
+ elif msg.startswith(
+ "Invalid linked list structure reading remote memory"
+ ):
continue
elif msg.startswith("Unknown error reading memory"):
continue
@@ -516,22 +666,174 @@ class TestGetStackTrace(unittest.TestCase):
# expected: at least 1000 pending tasks
self.assertGreaterEqual(len(entries), 1000)
# the first three tasks stem from the code structure
- self.assertIn(('Task-1', []), entries)
- self.assertIn(('server task', [[['main'], 'Task-1', []]]), entries)
- self.assertIn(('echo client spam', [[['main'], 'Task-1', []]]), entries)
+ main_stack = [
+ FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]),
+ FrameInfo(
+ [taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
+ ),
+ FrameInfo([script_name, 60, "main"]),
+ ]
+ self.assertIn(
+ TaskInfo(
+ [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
+ [
+ ANY,
+ "server task",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ base_events.__file__,
+ ANY,
+ "Server.serve_forever",
+ ]
+ )
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, ANY, "main"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
+ [
+ ANY,
+ "Task-4",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [tasks.__file__, ANY, "sleep"]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 38,
+ "echo_client",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 41,
+ "echo_client_spam",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
- expected_stack = [[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]]
- tasks_with_stack = [task for task in entries if task[1] == expected_stack]
- self.assertGreaterEqual(len(tasks_with_stack), 1000)
+ expected_awaited_by = [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, 41, "echo_client_spam"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ]
+ tasks_with_awaited = [
+ task
+ for task in entries
+ if task.awaited_by == expected_awaited_by
+ ]
+ self.assertGreaterEqual(len(tasks_with_awaited), 1000)
# the final task will have some random number, but it should for
# sure be one of the echo client spam horde (In windows this is not true
# for some reason)
if sys.platform != "win32":
- self.assertEqual([[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]], entries[-1][1])
+ self.assertEqual(
+ tasks_with_awaited[-1].awaited_by,
+ entries[-1].awaited_by,
+ )
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -540,12 +842,160 @@ class TestGetStackTrace(unittest.TestCase):
p.wait(timeout=SHORT_TIMEOUT)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_self_trace(self):
stack_trace = get_stack_trace(os.getpid())
- print(stack_trace)
- self.assertEqual(stack_trace[0], "test_self_trace")
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ this_tread_stack = None
+ for thread_id, stack in stack_trace:
+ if thread_id == threading.get_native_id():
+ this_tread_stack = stack
+ break
+ self.assertIsNotNone(this_tread_stack)
+ self.assertEqual(
+ stack[:2],
+ [
+ FrameInfo(
+ [
+ __file__,
+ get_stack_trace.__code__.co_firstlineno + 2,
+ "get_stack_trace",
+ ]
+ ),
+ FrameInfo(
+ [
+ __file__,
+ self.test_self_trace.__code__.co_firstlineno + 6,
+ "TestGetStackTrace.test_self_trace",
+ ]
+ ),
+ ],
+ )
+
+ @skip_if_not_supported
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
+ @requires_gil_enabled("Free threaded builds don't have an 'active thread'")
+ def test_only_active_thread(self):
+ # Test that only_active_thread parameter works correctly
+ port = find_unused_port()
+ script = textwrap.dedent(
+ f"""\
+ import time, sys, socket, threading
+
+ # Connect to the test process
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(('localhost', {port}))
+
+ def worker_thread(name, barrier, ready_event):
+ barrier.wait() # Synchronize thread start
+ ready_event.wait() # Wait for main thread signal
+ # Sleep to keep thread alive
+ time.sleep(10_000)
+
+ def main_work():
+ # Do busy work to hold the GIL
+ sock.sendall(b"working\\n")
+ count = 0
+ while count < 100000000:
+ count += 1
+ if count % 10000000 == 0:
+ pass # Keep main thread busy
+ sock.sendall(b"done\\n")
+
+ # Create synchronization primitives
+ num_threads = 3
+ barrier = threading.Barrier(num_threads + 1) # +1 for main thread
+ ready_event = threading.Event()
+
+ # Start worker threads
+ threads = []
+ for i in range(num_threads):
+ t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event))
+ t.start()
+ threads.append(t)
+
+ # Wait for all threads to be ready
+ barrier.wait()
+
+ # Signal ready to parent process
+ sock.sendall(b"ready\\n")
+
+ # Signal threads to start waiting
+ ready_event.set()
+
+ # Give threads time to start sleeping
+ time.sleep(0.1)
+
+ # Now do busy work to hold the GIL
+ main_work()
+ """
+ )
+
+ with os_helper.temp_dir() as work_dir:
+ script_dir = os.path.join(work_dir, "script_pkg")
+ os.mkdir(script_dir)
+
+ # Create a socket server to communicate with the target process
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server_socket.bind(("localhost", port))
+ server_socket.settimeout(SHORT_TIMEOUT)
+ server_socket.listen(1)
+
+ script_name = _make_test_script(script_dir, "script", script)
+ client_socket = None
+ try:
+ p = subprocess.Popen([sys.executable, script_name])
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+
+ # Wait for ready signal
+ response = b""
+ while b"ready" not in response:
+ response += client_socket.recv(1024)
+
+ # Wait for the main thread to start its busy work
+ while b"working" not in response:
+ response += client_socket.recv(1024)
+
+ # Get stack trace with all threads
+ unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
+ all_traces = unwinder_all.get_stack_trace()
+
+ # Get stack trace with only GIL holder
+ unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True)
+ gil_traces = unwinder_gil.get_stack_trace()
+
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ finally:
+ if client_socket is not None:
+ client_socket.close()
+ p.kill()
+ p.terminate()
+ p.wait(timeout=SHORT_TIMEOUT)
+
+ # Verify we got multiple threads in all_traces
+ self.assertGreater(len(all_traces), 1, "Should have multiple threads")
+
+ # Verify we got exactly one thread in gil_traces
+ self.assertEqual(len(gil_traces), 1, "Should have exactly one GIL holder")
+
+ # The GIL holder should be in the all_traces list
+ gil_thread_id = gil_traces[0][0]
+ all_thread_ids = [trace[0] for trace in all_traces]
+ self.assertIn(gil_thread_id, all_thread_ids,
+ "GIL holder should be among all threads")
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py
index 371c63adce9..2fb963f52e5 100644
--- a/Lib/test/test_faulthandler.py
+++ b/Lib/test/test_faulthandler.py
@@ -166,29 +166,6 @@ class FaultHandlerTests(unittest.TestCase):
fatal_error = 'Windows fatal exception: %s' % name_regex
self.check_error(code, line_number, fatal_error, **kw)
- @unittest.skipIf(sys.platform.startswith('aix'),
- "the first page of memory is a mapped read-only on AIX")
- def test_read_null(self):
- if not MS_WINDOWS:
- self.check_fatal_error("""
- import faulthandler
- faulthandler.enable()
- faulthandler._read_null()
- """,
- 3,
- # Issue #12700: Read NULL raises SIGILL on Mac OS X Lion
- '(?:Segmentation fault'
- '|Bus error'
- '|Illegal instruction)')
- else:
- self.check_windows_exception("""
- import faulthandler
- faulthandler.enable()
- faulthandler._read_null()
- """,
- 3,
- 'access violation')
-
@skip_segfault_on_android
def test_sigsegv(self):
self.check_fatal_error("""
diff --git a/Lib/test/test_fcntl.py b/Lib/test/test_fcntl.py
index b84c98ef3a2..7140a7b4f29 100644
--- a/Lib/test/test_fcntl.py
+++ b/Lib/test/test_fcntl.py
@@ -11,7 +11,7 @@ from test.support import (
cpython_only, get_pagesize, is_apple, requires_subprocess, verbose
)
from test.support.import_helper import import_module
-from test.support.os_helper import TESTFN, unlink
+from test.support.os_helper import TESTFN, unlink, make_bad_fd
# Skip test if no fcntl module.
@@ -228,6 +228,63 @@ class TestFcntl(unittest.TestCase):
os.close(test_pipe_r)
os.close(test_pipe_w)
+ def _check_fcntl_not_mutate_len(self, nbytes=None):
+ self.f = open(TESTFN, 'wb')
+ buf = struct.pack('ii', fcntl.F_OWNER_PID, os.getpid())
+ if nbytes is not None:
+ buf += b' ' * (nbytes - len(buf))
+ else:
+ nbytes = len(buf)
+ save_buf = bytes(buf)
+ r = fcntl.fcntl(self.f, fcntl.F_SETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(buf, save_buf)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ buf = b' ' * nbytes
+ r = fcntl.fcntl(self.f, fcntl.F_GETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(buf, b' ' * nbytes)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ buf = memoryview(b' ' * nbytes)
+ r = fcntl.fcntl(self.f, fcntl.F_GETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(bytes(buf), b' ' * nbytes)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ @unittest.skipUnless(
+ hasattr(fcntl, "F_SETOWN_EX") and hasattr(fcntl, "F_GETOWN_EX"),
+ "requires F_SETOWN_EX and F_GETOWN_EX")
+ def test_fcntl_small_buffer(self):
+ self._check_fcntl_not_mutate_len()
+
+ @unittest.skipUnless(
+ hasattr(fcntl, "F_SETOWN_EX") and hasattr(fcntl, "F_GETOWN_EX"),
+ "requires F_SETOWN_EX and F_GETOWN_EX")
+ def test_fcntl_large_buffer(self):
+ self._check_fcntl_not_mutate_len(2024)
+
+ @unittest.skipUnless(hasattr(fcntl, 'F_DUPFD'), 'need fcntl.F_DUPFD')
+ def test_bad_fd(self):
+ # gh-134744: Test error handling
+ fd = make_bad_fd()
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, 0)
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, b'\0' * 10)
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, b'\0' * 2048)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py
index b340ef7ed16..6524baabe7f 100644
--- a/Lib/test/test_fileinput.py
+++ b/Lib/test/test_fileinput.py
@@ -245,7 +245,7 @@ class FileInputTests(BaseTests, unittest.TestCase):
orig_stdin = sys.stdin
try:
sys.stdin = BytesIO(b'spam, bacon, sausage, and spam')
- self.assertFalse(hasattr(sys.stdin, 'buffer'))
+ self.assertNotHasAttr(sys.stdin, 'buffer')
fi = FileInput(files=['-'], mode='rb')
lines = list(fi)
self.assertEqual(lines, [b'spam, bacon, sausage, and spam'])
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
index 5a0f033ebb8..e3d54f6315a 100644
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -591,7 +591,7 @@ class OtherFileTests:
try:
f.write(b"abc")
f.close()
- with open(TESTFN_ASCII, "rb") as f:
+ with self.open(TESTFN_ASCII, "rb") as f:
self.assertEqual(f.read(), b"abc")
finally:
os.unlink(TESTFN_ASCII)
@@ -608,7 +608,7 @@ class OtherFileTests:
try:
f.write(b"abc")
f.close()
- with open(TESTFN_UNICODE, "rb") as f:
+ with self.open(TESTFN_UNICODE, "rb") as f:
self.assertEqual(f.read(), b"abc")
finally:
os.unlink(TESTFN_UNICODE)
@@ -692,13 +692,13 @@ class OtherFileTests:
def testAppend(self):
try:
- f = open(TESTFN, 'wb')
+ f = self.FileIO(TESTFN, 'wb')
f.write(b'spam')
f.close()
- f = open(TESTFN, 'ab')
+ f = self.FileIO(TESTFN, 'ab')
f.write(b'eggs')
f.close()
- f = open(TESTFN, 'rb')
+ f = self.FileIO(TESTFN, 'rb')
d = f.read()
f.close()
self.assertEqual(d, b'spameggs')
@@ -734,6 +734,7 @@ class OtherFileTests:
class COtherFileTests(OtherFileTests, unittest.TestCase):
FileIO = _io.FileIO
modulename = '_io'
+ open = _io.open
@cpython_only
def testInvalidFd_overflow(self):
@@ -755,6 +756,7 @@ class COtherFileTests(OtherFileTests, unittest.TestCase):
class PyOtherFileTests(OtherFileTests, unittest.TestCase):
FileIO = _pyio.FileIO
modulename = '_pyio'
+ open = _pyio.open
def test_open_code(self):
# Check that the default behaviour of open_code matches
diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py
index 237d7b5d35e..00518abcb11 100644
--- a/Lib/test/test_float.py
+++ b/Lib/test/test_float.py
@@ -795,6 +795,8 @@ class FormatTestCase(unittest.TestCase):
self.assertRaises(ValueError, format, x, '.6,n')
@support.requires_IEEE_754
+ @unittest.skipUnless(sys.float_repr_style == 'short',
+ "applies only when using short float repr style")
def test_format_testfile(self):
with open(format_testfile, encoding="utf-8") as testfile:
for line in testfile:
diff --git a/Lib/test/test_fnmatch.py b/Lib/test/test_fnmatch.py
index d4163cfe782..5daaf3b3fdd 100644
--- a/Lib/test/test_fnmatch.py
+++ b/Lib/test/test_fnmatch.py
@@ -218,24 +218,24 @@ class TranslateTestCase(unittest.TestCase):
def test_translate(self):
import re
- self.assertEqual(translate('*'), r'(?s:.*)\Z')
- self.assertEqual(translate('?'), r'(?s:.)\Z')
- self.assertEqual(translate('a?b*'), r'(?s:a.b.*)\Z')
- self.assertEqual(translate('[abc]'), r'(?s:[abc])\Z')
- self.assertEqual(translate('[]]'), r'(?s:[]])\Z')
- self.assertEqual(translate('[!x]'), r'(?s:[^x])\Z')
- self.assertEqual(translate('[^x]'), r'(?s:[\^x])\Z')
- self.assertEqual(translate('[x'), r'(?s:\[x)\Z')
+ self.assertEqual(translate('*'), r'(?s:.*)\z')
+ self.assertEqual(translate('?'), r'(?s:.)\z')
+ self.assertEqual(translate('a?b*'), r'(?s:a.b.*)\z')
+ self.assertEqual(translate('[abc]'), r'(?s:[abc])\z')
+ self.assertEqual(translate('[]]'), r'(?s:[]])\z')
+ self.assertEqual(translate('[!x]'), r'(?s:[^x])\z')
+ self.assertEqual(translate('[^x]'), r'(?s:[\^x])\z')
+ self.assertEqual(translate('[x'), r'(?s:\[x)\z')
# from the docs
- self.assertEqual(translate('*.txt'), r'(?s:.*\.txt)\Z')
+ self.assertEqual(translate('*.txt'), r'(?s:.*\.txt)\z')
# squash consecutive stars
- self.assertEqual(translate('*********'), r'(?s:.*)\Z')
- self.assertEqual(translate('A*********'), r'(?s:A.*)\Z')
- self.assertEqual(translate('*********A'), r'(?s:.*A)\Z')
- self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\Z')
+ self.assertEqual(translate('*********'), r'(?s:.*)\z')
+ self.assertEqual(translate('A*********'), r'(?s:A.*)\z')
+ self.assertEqual(translate('*********A'), r'(?s:.*A)\z')
+ self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\z')
# fancy translation to prevent exponential-time match failure
t = translate('**a*a****a')
- self.assertEqual(t, r'(?s:(?>.*?a)(?>.*?a).*a)\Z')
+ self.assertEqual(t, r'(?s:(?>.*?a)(?>.*?a).*a)\z')
# and try pasting multiple translate results - it's an undocumented
# feature that this works
r1 = translate('**a**a**a*')
@@ -249,27 +249,27 @@ class TranslateTestCase(unittest.TestCase):
def test_translate_wildcards(self):
for pattern, expect in [
- ('ab*', r'(?s:ab.*)\Z'),
- ('ab*cd', r'(?s:ab.*cd)\Z'),
- ('ab*cd*', r'(?s:ab(?>.*?cd).*)\Z'),
- ('ab*cd*12', r'(?s:ab(?>.*?cd).*12)\Z'),
- ('ab*cd*12*', r'(?s:ab(?>.*?cd)(?>.*?12).*)\Z'),
- ('ab*cd*12*34', r'(?s:ab(?>.*?cd)(?>.*?12).*34)\Z'),
- ('ab*cd*12*34*', r'(?s:ab(?>.*?cd)(?>.*?12)(?>.*?34).*)\Z'),
+ ('ab*', r'(?s:ab.*)\z'),
+ ('ab*cd', r'(?s:ab.*cd)\z'),
+ ('ab*cd*', r'(?s:ab(?>.*?cd).*)\z'),
+ ('ab*cd*12', r'(?s:ab(?>.*?cd).*12)\z'),
+ ('ab*cd*12*', r'(?s:ab(?>.*?cd)(?>.*?12).*)\z'),
+ ('ab*cd*12*34', r'(?s:ab(?>.*?cd)(?>.*?12).*34)\z'),
+ ('ab*cd*12*34*', r'(?s:ab(?>.*?cd)(?>.*?12)(?>.*?34).*)\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
self.assertEqual(translated, expect, pattern)
for pattern, expect in [
- ('*ab', r'(?s:.*ab)\Z'),
- ('*ab*', r'(?s:(?>.*?ab).*)\Z'),
- ('*ab*cd', r'(?s:(?>.*?ab).*cd)\Z'),
- ('*ab*cd*', r'(?s:(?>.*?ab)(?>.*?cd).*)\Z'),
- ('*ab*cd*12', r'(?s:(?>.*?ab)(?>.*?cd).*12)\Z'),
- ('*ab*cd*12*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*)\Z'),
- ('*ab*cd*12*34', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*34)\Z'),
- ('*ab*cd*12*34*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12)(?>.*?34).*)\Z'),
+ ('*ab', r'(?s:.*ab)\z'),
+ ('*ab*', r'(?s:(?>.*?ab).*)\z'),
+ ('*ab*cd', r'(?s:(?>.*?ab).*cd)\z'),
+ ('*ab*cd*', r'(?s:(?>.*?ab)(?>.*?cd).*)\z'),
+ ('*ab*cd*12', r'(?s:(?>.*?ab)(?>.*?cd).*12)\z'),
+ ('*ab*cd*12*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*)\z'),
+ ('*ab*cd*12*34', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*34)\z'),
+ ('*ab*cd*12*34*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12)(?>.*?34).*)\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
@@ -277,28 +277,28 @@ class TranslateTestCase(unittest.TestCase):
def test_translate_expressions(self):
for pattern, expect in [
- ('[', r'(?s:\[)\Z'),
- ('[!', r'(?s:\[!)\Z'),
- ('[]', r'(?s:\[\])\Z'),
- ('[abc', r'(?s:\[abc)\Z'),
- ('[!abc', r'(?s:\[!abc)\Z'),
- ('[abc]', r'(?s:[abc])\Z'),
- ('[!abc]', r'(?s:[^abc])\Z'),
- ('[!abc][!def]', r'(?s:[^abc][^def])\Z'),
+ ('[', r'(?s:\[)\z'),
+ ('[!', r'(?s:\[!)\z'),
+ ('[]', r'(?s:\[\])\z'),
+ ('[abc', r'(?s:\[abc)\z'),
+ ('[!abc', r'(?s:\[!abc)\z'),
+ ('[abc]', r'(?s:[abc])\z'),
+ ('[!abc]', r'(?s:[^abc])\z'),
+ ('[!abc][!def]', r'(?s:[^abc][^def])\z'),
# with [[
- ('[[', r'(?s:\[\[)\Z'),
- ('[[a', r'(?s:\[\[a)\Z'),
- ('[[]', r'(?s:[\[])\Z'),
- ('[[]a', r'(?s:[\[]a)\Z'),
- ('[[]]', r'(?s:[\[]\])\Z'),
- ('[[]a]', r'(?s:[\[]a\])\Z'),
- ('[[a]', r'(?s:[\[a])\Z'),
- ('[[a]]', r'(?s:[\[a]\])\Z'),
- ('[[a]b', r'(?s:[\[a]b)\Z'),
+ ('[[', r'(?s:\[\[)\z'),
+ ('[[a', r'(?s:\[\[a)\z'),
+ ('[[]', r'(?s:[\[])\z'),
+ ('[[]a', r'(?s:[\[]a)\z'),
+ ('[[]]', r'(?s:[\[]\])\z'),
+ ('[[]a]', r'(?s:[\[]a\])\z'),
+ ('[[a]', r'(?s:[\[a])\z'),
+ ('[[a]]', r'(?s:[\[a]\])\z'),
+ ('[[a]b', r'(?s:[\[a]b)\z'),
# backslashes
- ('[\\', r'(?s:\[\\)\Z'),
- (r'[\]', r'(?s:[\\])\Z'),
- (r'[\\]', r'(?s:[\\\\])\Z'),
+ ('[\\', r'(?s:\[\\)\z'),
+ (r'[\]', r'(?s:[\\])\z'),
+ (r'[\\]', r'(?s:[\\\\])\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py
index c7cc32e0949..1f626d87fa6 100644
--- a/Lib/test/test_format.py
+++ b/Lib/test/test_format.py
@@ -346,12 +346,12 @@ class FormatTest(unittest.TestCase):
testcommon(b"%s", memoryview(b"abc"), b"abc")
# %a will give the equivalent of
# repr(some_obj).encode('ascii', 'backslashreplace')
- testcommon(b"%a", 3.14, b"3.14")
+ testcommon(b"%a", 3.25, b"3.25")
testcommon(b"%a", b"ghi", b"b'ghi'")
testcommon(b"%a", "jkl", b"'jkl'")
testcommon(b"%a", "\u0544", b"'\\u0544'")
# %r is an alias for %a
- testcommon(b"%r", 3.14, b"3.14")
+ testcommon(b"%r", 3.25, b"3.25")
testcommon(b"%r", b"ghi", b"b'ghi'")
testcommon(b"%r", "jkl", b"'jkl'")
testcommon(b"%r", "\u0544", b"'\\u0544'")
@@ -407,19 +407,19 @@ class FormatTest(unittest.TestCase):
self.assertEqual(format("abc", "\u2007<5"), "abc\u2007\u2007")
self.assertEqual(format(123, "\u2007<5"), "123\u2007\u2007")
- self.assertEqual(format(12.3, "\u2007<6"), "12.3\u2007\u2007")
+ self.assertEqual(format(12.5, "\u2007<6"), "12.5\u2007\u2007")
self.assertEqual(format(0j, "\u2007<4"), "0j\u2007\u2007")
self.assertEqual(format(1+2j, "\u2007<8"), "(1+2j)\u2007\u2007")
self.assertEqual(format("abc", "\u2007>5"), "\u2007\u2007abc")
self.assertEqual(format(123, "\u2007>5"), "\u2007\u2007123")
- self.assertEqual(format(12.3, "\u2007>6"), "\u2007\u200712.3")
+ self.assertEqual(format(12.5, "\u2007>6"), "\u2007\u200712.5")
self.assertEqual(format(1+2j, "\u2007>8"), "\u2007\u2007(1+2j)")
self.assertEqual(format(0j, "\u2007>4"), "\u2007\u20070j")
self.assertEqual(format("abc", "\u2007^5"), "\u2007abc\u2007")
self.assertEqual(format(123, "\u2007^5"), "\u2007123\u2007")
- self.assertEqual(format(12.3, "\u2007^6"), "\u200712.3\u2007")
+ self.assertEqual(format(12.5, "\u2007^6"), "\u200712.5\u2007")
self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007")
self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007")
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index 84faa636064..d1d2739856c 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -1,7 +1,7 @@
"""Tests for Lib/fractions.py."""
from decimal import Decimal
-from test.support import requires_IEEE_754
+from test.support import requires_IEEE_754, adjust_int_max_str_digits
import math
import numbers
import operator
@@ -395,12 +395,14 @@ class FractionTest(unittest.TestCase):
def testFromString(self):
self.assertEqual((5, 1), _components(F("5")))
+ self.assertEqual((5, 1), _components(F("005")))
self.assertEqual((3, 2), _components(F("3/2")))
self.assertEqual((3, 2), _components(F("3 / 2")))
self.assertEqual((3, 2), _components(F(" \n +3/2")))
self.assertEqual((-3, 2), _components(F("-3/2 ")))
- self.assertEqual((13, 2), _components(F(" 013/02 \n ")))
+ self.assertEqual((13, 2), _components(F(" 0013/002 \n ")))
self.assertEqual((16, 5), _components(F(" 3.2 ")))
+ self.assertEqual((16, 5), _components(F("003.2")))
self.assertEqual((-16, 5), _components(F(" -3.2 ")))
self.assertEqual((-3, 1), _components(F(" -3. ")))
self.assertEqual((3, 5), _components(F(" .6 ")))
@@ -419,116 +421,102 @@ class FractionTest(unittest.TestCase):
self.assertRaisesMessage(
ZeroDivisionError, "Fraction(3, 0)",
F, "3/0")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '3/'",
- F, "3/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '/2'",
- F, "/2")
- self.assertRaisesMessage(
- # Denominators don't need a sign.
- ValueError, "Invalid literal for Fraction: '3/+2'",
- F, "3/+2")
- self.assertRaisesMessage(
- # Imitate float's parsing.
- ValueError, "Invalid literal for Fraction: '+ 3/2'",
- F, "+ 3/2")
- self.assertRaisesMessage(
- # Avoid treating '.' as a regex special character.
- ValueError, "Invalid literal for Fraction: '3a2'",
- F, "3a2")
- self.assertRaisesMessage(
- # Don't accept combinations of decimals and rationals.
- ValueError, "Invalid literal for Fraction: '3/7.2'",
- F, "3/7.2")
- self.assertRaisesMessage(
- # Don't accept combinations of decimals and rationals.
- ValueError, "Invalid literal for Fraction: '3.2/7'",
- F, "3.2/7")
- self.assertRaisesMessage(
- # Allow 3. and .3, but not .
- ValueError, "Invalid literal for Fraction: '.'",
- F, ".")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_'",
- F, "_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_1'",
- F, "_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1__2'",
- F, "1__2")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '/_'",
- F, "/_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1_/'",
- F, "1_/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_1/'",
- F, "_1/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1__2/'",
- F, "1__2/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/_'",
- F, "1/_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/_1'",
- F, "1/_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/1__2'",
- F, "1/1__2")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1._111'",
- F, "1._111")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1__1'",
- F, "1.1__1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1e+_1'",
- F, "1.1e+_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1e+1__1'",
- F, "1.1e+1__1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '123.dd'",
- F, "123.dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '123.5_dd'",
- F, "123.5_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: 'dd.5'",
- F, "dd.5")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '7_dd'",
- F, "7_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/dd'",
- F, "1/dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/123_dd'",
- F, "1/123_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '789edd'",
- F, "789edd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '789e2_dd'",
- F, "789e2_dd")
+
+ def check_invalid(s):
+ msg = "Invalid literal for Fraction: " + repr(s)
+ self.assertRaisesMessage(ValueError, msg, F, s)
+
+ check_invalid("3/")
+ check_invalid("/2")
+ # Denominators don't need a sign.
+ check_invalid("3/+2")
+ check_invalid("3/-2")
+ # Imitate float's parsing.
+ check_invalid("+ 3/2")
+ check_invalid("- 3/2")
+ # Avoid treating '.' as a regex special character.
+ check_invalid("3a2")
+ # Don't accept combinations of decimals and rationals.
+ check_invalid("3/7.2")
+ check_invalid("3.2/7")
+ # No space around dot.
+ check_invalid("3 .2")
+ check_invalid("3. 2")
+ # No space around e.
+ check_invalid("3.2 e1")
+ check_invalid("3.2e 1")
+ # Fractional part don't need a sign.
+ check_invalid("3.+2")
+ check_invalid("3.-2")
+ # Only accept base 10.
+ check_invalid("0x10")
+ check_invalid("0x10/1")
+ check_invalid("1/0x10")
+ check_invalid("0x10.")
+ check_invalid("0x10.1")
+ check_invalid("1.0x10")
+ check_invalid("1.0e0x10")
+ # Only accept decimal digits.
+ check_invalid("³")
+ check_invalid("³/2")
+ check_invalid("3/²")
+ check_invalid("³.2")
+ check_invalid("3.²")
+ check_invalid("3.2e²")
+ check_invalid("¼")
+ # Allow 3. and .3, but not .
+ check_invalid(".")
+ check_invalid("_")
+ check_invalid("_1")
+ check_invalid("1__2")
+ check_invalid("/_")
+ check_invalid("1_/")
+ check_invalid("_1/")
+ check_invalid("1__2/")
+ check_invalid("1/_")
+ check_invalid("1/_1")
+ check_invalid("1/1__2")
+ check_invalid("1._111")
+ check_invalid("1.1__1")
+ check_invalid("1.1e+_1")
+ check_invalid("1.1e+1__1")
+ check_invalid("123.dd")
+ check_invalid("123.5_dd")
+ check_invalid("dd.5")
+ check_invalid("7_dd")
+ check_invalid("1/dd")
+ check_invalid("1/123_dd")
+ check_invalid("789edd")
+ check_invalid("789e2_dd")
# Test catastrophic backtracking.
val = "9"*50 + "_"
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '" + val + "'",
- F, val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/" + val + "'",
- F, "1/" + val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1." + val + "'",
- F, "1." + val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1+e" + val + "'",
- F, "1.1+e" + val)
+ check_invalid(val)
+ check_invalid("1/" + val)
+ check_invalid("1." + val)
+ check_invalid("." + val)
+ check_invalid("1.1+e" + val)
+ check_invalid("1.1e" + val)
+
+ def test_limit_int(self):
+ maxdigits = 5000
+ with adjust_int_max_str_digits(maxdigits):
+ msg = 'Exceeds the limit'
+ val = '1' * maxdigits
+ num = (10**maxdigits - 1)//9
+ self.assertEqual((num, 1), _components(F(val)))
+ self.assertRaisesRegex(ValueError, msg, F, val + '1')
+ self.assertEqual((num, 2), _components(F(val + '/2')))
+ self.assertRaisesRegex(ValueError, msg, F, val + '1/2')
+ self.assertEqual((1, num), _components(F('1/' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '1/1' + val)
+ self.assertEqual(((10**(maxdigits+1) - 1)//9, 10**maxdigits),
+ _components(F('1.' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '1.1' + val)
+ self.assertEqual((num, 10**maxdigits), _components(F('.' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '.1' + val)
+ self.assertRaisesRegex(ValueError, msg, F, '1.1e1' + val)
+ self.assertEqual((11, 10), _components(F('1.1e' + '0' * maxdigits)))
+ self.assertRaisesRegex(ValueError, msg, F, '1.1e' + '0' * (maxdigits+1))
def testImmutable(self):
r = F(7, 3)
@@ -1530,6 +1518,8 @@ class FractionTest(unittest.TestCase):
(F(51, 1000), '.1f', '0.1'),
(F(149, 1000), '.1f', '0.1'),
(F(151, 1000), '.1f', '0.2'),
+ (F(22, 7), '.02f', '3.14'), # issue gh-130662
+ (F(22, 7), '005.02f', '03.14'),
]
for fraction, spec, expected in testcases:
with self.subTest(fraction=fraction, spec=spec):
@@ -1628,12 +1618,6 @@ class FractionTest(unittest.TestCase):
'=010%',
'>00.2f',
'>00f',
- # Too many zeros - minimum width should not have leading zeros
- '006f',
- # Leading zeros in precision
- '.010f',
- '.02f',
- '.000f',
# Missing precision
'.e',
'.f',
diff --git a/Lib/test/test_free_threading/test_dict.py b/Lib/test/test_free_threading/test_dict.py
index 476cc3178d8..5d5d4e226ca 100644
--- a/Lib/test/test_free_threading/test_dict.py
+++ b/Lib/test/test_free_threading/test_dict.py
@@ -228,6 +228,22 @@ class TestDict(TestCase):
self.assertEqual(count, 0)
+ def test_racing_object_get_set_dict(self):
+ e = Exception()
+
+ def writer():
+ for i in range(10000):
+ e.__dict__ = {1:2}
+
+ def reader():
+ for i in range(10000):
+ e.__dict__
+
+ t1 = Thread(target=writer)
+ t2 = Thread(target=reader)
+
+ with threading_helper.start_threads([t1, t2]):
+ pass
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_free_threading/test_functools.py b/Lib/test/test_free_threading/test_functools.py
new file mode 100644
index 00000000000..a442fe056ce
--- /dev/null
+++ b/Lib/test/test_free_threading/test_functools.py
@@ -0,0 +1,75 @@
+import random
+import unittest
+
+from functools import lru_cache
+from threading import Barrier, Thread
+
+from test.support import threading_helper
+
+@threading_helper.requires_working_threading()
+class TestLRUCache(unittest.TestCase):
+
+ def _test_concurrent_operations(self, maxsize):
+ num_threads = 10
+ b = Barrier(num_threads)
+ @lru_cache(maxsize=maxsize)
+ def func(arg=0):
+ return object()
+
+
+ def thread_func():
+ b.wait()
+ for i in range(1000):
+ r = random.randint(0, 1000)
+ if i < 800:
+ func(i)
+ elif i < 900:
+ func.cache_info()
+ else:
+ func.cache_clear()
+
+ threads = []
+ for i in range(num_threads):
+ t = Thread(target=thread_func)
+ threads.append(t)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ def test_concurrent_operations_unbounded(self):
+ self._test_concurrent_operations(maxsize=None)
+
+ def test_concurrent_operations_bounded(self):
+ self._test_concurrent_operations(maxsize=128)
+
+ def _test_reentrant_cache_clear(self, maxsize):
+ num_threads = 10
+ b = Barrier(num_threads)
+ @lru_cache(maxsize=maxsize)
+ def func(arg=0):
+ func.cache_clear()
+ return object()
+
+
+ def thread_func():
+ b.wait()
+ for i in range(1000):
+ func(random.randint(0, 10000))
+
+ threads = []
+ for i in range(num_threads):
+ t = Thread(target=thread_func)
+ threads.append(t)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ def test_reentrant_cache_clear_unbounded(self):
+ self._test_reentrant_cache_clear(maxsize=None)
+
+ def test_reentrant_cache_clear_bounded(self):
+ self._test_reentrant_cache_clear(maxsize=128)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_generators.py b/Lib/test/test_free_threading/test_generators.py
new file mode 100644
index 00000000000..d01675eb38b
--- /dev/null
+++ b/Lib/test/test_free_threading/test_generators.py
@@ -0,0 +1,51 @@
+import concurrent.futures
+import unittest
+from threading import Barrier
+from unittest import TestCase
+import random
+import time
+
+from test.support import threading_helper, Py_GIL_DISABLED
+
+threading_helper.requires_working_threading(module=True)
+
+
+def random_sleep():
+ delay_us = random.randint(50, 100)
+ time.sleep(delay_us * 1e-6)
+
+def random_string():
+ return ''.join(random.choice('0123456789ABCDEF') for _ in range(10))
+
+def set_gen_name(g, b):
+ b.wait()
+ random_sleep()
+ g.__name__ = random_string()
+ return g.__name__
+
+def set_gen_qualname(g, b):
+ b.wait()
+ random_sleep()
+ g.__qualname__ = random_string()
+ return g.__qualname__
+
+
+@unittest.skipUnless(Py_GIL_DISABLED, "Enable only in FT build")
+class TestFTGenerators(TestCase):
+ NUM_THREADS = 4
+
+ def concurrent_write_with_func(self, func):
+ gen = (x for x in range(42))
+ for j in range(1000):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.NUM_THREADS) as executor:
+ b = Barrier(self.NUM_THREADS)
+ futures = {executor.submit(func, gen, b): i for i in range(self.NUM_THREADS)}
+ for fut in concurrent.futures.as_completed(futures):
+ gen_name = fut.result()
+ self.assertEqual(len(gen_name), 10)
+
+ def test_concurrent_write(self):
+ with self.subTest(func=set_gen_name):
+ self.concurrent_write_with_func(func=set_gen_name)
+ with self.subTest(func=set_gen_qualname):
+ self.concurrent_write_with_func(func=set_gen_qualname)
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py
new file mode 100644
index 00000000000..ee7adfb2b78
--- /dev/null
+++ b/Lib/test/test_free_threading/test_heapq.py
@@ -0,0 +1,267 @@
+import unittest
+
+import heapq
+
+from enum import Enum
+from threading import Thread, Barrier, Lock
+from random import shuffle, randint
+
+from test.support import threading_helper
+from test import test_heapq
+
+
+NTHREADS = 10
+OBJECT_COUNT = 5_000
+
+
+class Heap(Enum):
+ MIN = 1
+ MAX = 2
+
+
+@threading_helper.requires_working_threading()
+class TestHeapq(unittest.TestCase):
+ def setUp(self):
+ self.test_heapq = test_heapq.TestHeapPython()
+
+ def test_racing_heapify(self):
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappush(self):
+ heap = []
+
+ def heappush_func(heap):
+ for item in reversed(range(OBJECT_COUNT)):
+ heapq.heappush(heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_func(heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop(heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_ascending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_func,
+ args=(heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), 0)
+
+ def test_racing_heappushpop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_func(heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop(heap, item)
+ self.assertTrue(popped_item <= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_func,
+ args=(heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapreplace(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_func(heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace(heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_func,
+ args=(heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapify_max(self):
+ max_heap = list(range(OBJECT_COUNT))
+ shuffle(max_heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappush_max(self):
+ max_heap = []
+
+ def heappush_max_func(max_heap):
+ for item in range(OBJECT_COUNT):
+ heapq.heappush_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_max_func(max_heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop_max(max_heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_descending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_max_func,
+ args=(max_heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), 0)
+
+ def test_racing_heappushpop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_max_func(max_heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop_max(max_heap, item)
+ self.assertTrue(popped_item >= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_max_func,
+ args=(max_heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heapreplace_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_max_func(max_heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_max_func,
+ args=(max_heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_lock_free_list_read(self):
+ n, n_threads = 1_000, 10
+ l = []
+ barrier = Barrier(n_threads * 2)
+
+ count = 0
+ lock = Lock()
+
+ def worker():
+ with lock:
+ nonlocal count
+ x = count
+ count += 1
+
+ barrier.wait()
+ for i in range(n):
+ if x % 2:
+ heapq.heappush(l, 1)
+ heapq.heappop(l)
+ else:
+ try:
+ l[0]
+ except IndexError:
+ pass
+
+ self.run_concurrently(worker, (), n_threads * 2)
+
+ @staticmethod
+ def is_sorted_ascending(lst):
+ """
+ Check if the list is sorted in ascending order (non-decreasing).
+ """
+ return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def is_sorted_descending(lst):
+ """
+ Check if the list is sorted in descending order (non-increasing).
+ """
+ return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def create_heap(size, heap_kind):
+ """
+ Create a min/max heap where elements are in the range (0, size - 1) and
+ shuffled before heapify.
+ """
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+ if heap_kind == Heap.MIN:
+ heapq.heapify(heap)
+ else:
+ heapq.heapify_max(heap)
+
+ return heap
+
+ @staticmethod
+ def create_random_list(a, b, size):
+ """
+ Create a list of random numbers between a and b (inclusive).
+ """
+ return [randint(-a, b) for _ in range(size)]
+
+ def run_concurrently(self, worker_func, args, nthreads):
+ """
+ Run the worker function concurrently in multiple threads.
+ """
+ barrier = Barrier(nthreads)
+
+ def wrapper_func(*args):
+ # Wait for all threads to reach this point before proceeding.
+ barrier.wait()
+ worker_func(*args)
+
+ with threading_helper.catch_threading_exception() as cm:
+ workers = (
+ Thread(target=wrapper_func, args=args) for _ in range(nthreads)
+ )
+ with threading_helper.start_threads(workers):
+ pass
+
+ # Worker threads should not raise any exceptions
+ self.assertIsNone(cm.exc_value)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_io.py b/Lib/test/test_free_threading/test_io.py
new file mode 100644
index 00000000000..f9bec740ddf
--- /dev/null
+++ b/Lib/test/test_free_threading/test_io.py
@@ -0,0 +1,109 @@
+import threading
+from unittest import TestCase
+from test.support import threading_helper
+from random import randint
+from io import BytesIO
+from sys import getsizeof
+
+
+class TestBytesIO(TestCase):
+ # Test pretty much everything that can break under free-threading.
+ # Non-deterministic, but at least one of these things will fail if
+ # BytesIO object is not free-thread safe.
+
+ def check(self, funcs, *args):
+ barrier = threading.Barrier(len(funcs))
+ threads = []
+
+ for func in funcs:
+ thread = threading.Thread(target=func, args=(barrier, *args))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ @threading_helper.requires_working_threading()
+ @threading_helper.reap_threads
+ def test_free_threading(self):
+ """Test for segfaults and aborts."""
+
+ def write(barrier, b, *ignore):
+ barrier.wait()
+ try: b.write(b'0' * randint(100, 1000))
+ except ValueError: pass # ignore write fail to closed file
+
+ def writelines(barrier, b, *ignore):
+ barrier.wait()
+ b.write(b'0\n' * randint(100, 1000))
+
+ def truncate(barrier, b, *ignore):
+ barrier.wait()
+ try: b.truncate(0)
+ except: BufferError # ignore exported buffer
+
+ def read(barrier, b, *ignore):
+ barrier.wait()
+ b.read()
+
+ def read1(barrier, b, *ignore):
+ barrier.wait()
+ b.read1()
+
+ def readline(barrier, b, *ignore):
+ barrier.wait()
+ b.readline()
+
+ def readlines(barrier, b, *ignore):
+ barrier.wait()
+ b.readlines()
+
+ def readinto(barrier, b, into, *ignore):
+ barrier.wait()
+ b.readinto(into)
+
+ def close(barrier, b, *ignore):
+ barrier.wait()
+ b.close()
+
+ def getvalue(barrier, b, *ignore):
+ barrier.wait()
+ b.getvalue()
+
+ def getbuffer(barrier, b, *ignore):
+ barrier.wait()
+ b.getbuffer()
+
+ def iter(barrier, b, *ignore):
+ barrier.wait()
+ list(b)
+
+ def getstate(barrier, b, *ignore):
+ barrier.wait()
+ b.__getstate__()
+
+ def setstate(barrier, b, st, *ignore):
+ barrier.wait()
+ b.__setstate__(st)
+
+ def sizeof(barrier, b, *ignore):
+ barrier.wait()
+ getsizeof(b)
+
+ self.check([write] * 10, BytesIO())
+ self.check([writelines] * 10, BytesIO())
+ self.check([write] * 10 + [truncate] * 10, BytesIO())
+ self.check([truncate] + [read] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [read1] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [readline] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [readlines] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [readinto] * 10, BytesIO(b'0\n'*204800), bytearray(b'0\n'*204800))
+ self.check([close] + [write] * 10, BytesIO())
+ self.check([truncate] + [getvalue] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [getbuffer] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [iter] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [getstate] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [setstate] * 10, BytesIO(b'0\n'*204800), (b'123', 0, None))
+ self.check([truncate] + [sizeof] * 10, BytesIO(b'0\n'*204800))
+
+ # no tests for seek or tell because they don't break anything
diff --git a/Lib/test/test_free_threading/test_itertools.py b/Lib/test/test_free_threading/test_itertools.py
new file mode 100644
index 00000000000..9d366041917
--- /dev/null
+++ b/Lib/test/test_free_threading/test_itertools.py
@@ -0,0 +1,95 @@
+import unittest
+from threading import Thread, Barrier
+from itertools import batched, chain, cycle
+from test.support import threading_helper
+
+
+threading_helper.requires_working_threading(module=True)
+
+class ItertoolsThreading(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ def test_batched(self):
+ number_of_threads = 10
+ number_of_iterations = 20
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ while True:
+ try:
+ next(it)
+ except StopIteration:
+ break
+
+ data = tuple(range(1000))
+ for it in range(number_of_iterations):
+ batch_iterator = batched(data, 2)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[batch_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+ @threading_helper.reap_threads
+ def test_cycle(self):
+ number_of_threads = 6
+ number_of_iterations = 10
+ number_of_cycles = 400
+
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ for _ in range(number_of_cycles):
+ try:
+ next(it)
+ except StopIteration:
+ pass
+
+ data = (1, 2, 3, 4)
+ for it in range(number_of_iterations):
+ cycle_iterator = cycle(data)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[cycle_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+ @threading_helper.reap_threads
+ def test_chain(self):
+ number_of_threads = 6
+ number_of_iterations = 20
+
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ while True:
+ try:
+ next(it)
+ except StopIteration:
+ break
+
+ data = [(1, )] * 200
+ for it in range(number_of_iterations):
+ chain_iterator = chain(*data)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[chain_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_itertools_batched.py b/Lib/test/test_free_threading/test_itertools_batched.py
deleted file mode 100644
index a754b4f9ea9..00000000000
--- a/Lib/test/test_free_threading/test_itertools_batched.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import unittest
-from threading import Thread, Barrier
-from itertools import batched
-from test.support import threading_helper
-
-
-threading_helper.requires_working_threading(module=True)
-
-class EnumerateThreading(unittest.TestCase):
-
- @threading_helper.reap_threads
- def test_threading(self):
- number_of_threads = 10
- number_of_iterations = 20
- barrier = Barrier(number_of_threads)
- def work(it):
- barrier.wait()
- while True:
- try:
- _ = next(it)
- except StopIteration:
- break
-
- data = tuple(range(1000))
- for it in range(number_of_iterations):
- batch_iterator = batched(data, 2)
- worker_threads = []
- for ii in range(number_of_threads):
- worker_threads.append(
- Thread(target=work, args=[batch_iterator]))
-
- with threading_helper.start_threads(worker_threads):
- pass
-
- barrier.reset()
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/Lib/test/test_free_threading/test_itertools_combinatoric.py b/Lib/test/test_free_threading/test_itertools_combinatoric.py
new file mode 100644
index 00000000000..5b3b88deedd
--- /dev/null
+++ b/Lib/test/test_free_threading/test_itertools_combinatoric.py
@@ -0,0 +1,51 @@
+import unittest
+from threading import Thread, Barrier
+from itertools import combinations, product
+from test.support import threading_helper
+
+
+threading_helper.requires_working_threading(module=True)
+
+def test_concurrent_iteration(iterator, number_of_threads):
+ barrier = Barrier(number_of_threads)
+ def iterator_worker(it):
+ barrier.wait()
+ while True:
+ try:
+ _ = next(it)
+ except StopIteration:
+ return
+
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=iterator_worker, args=[iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+class ItertoolsThreading(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ def test_combinations(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = combinations((1, 2, 3, 4, 5), 2)
+ test_concurrent_iteration(iterator, number_of_threads)
+
+ @threading_helper.reap_threads
+ def test_product(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = product((1, 2, 3, 4, 5), (10, 20, 30))
+ test_concurrent_iteration(iterator, number_of_threads)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_fstring.py b/Lib/test/test_fstring.py
index e75e7db378c..58a30c8e6ac 100644
--- a/Lib/test/test_fstring.py
+++ b/Lib/test/test_fstring.py
@@ -1304,7 +1304,7 @@ x = (
"Bf''",
"BF''",]
double_quote_cases = [case.replace("'", '"') for case in single_quote_cases]
- self.assertAllRaise(SyntaxError, 'invalid syntax',
+ self.assertAllRaise(SyntaxError, 'prefixes are incompatible',
single_quote_cases + double_quote_cases)
def test_leading_trailing_spaces(self):
@@ -1336,9 +1336,9 @@ x = (
def test_conversions(self):
self.assertEqual(f'{3.14:10.10}', ' 3.14')
- self.assertEqual(f'{3.14!s:10.10}', '3.14 ')
- self.assertEqual(f'{3.14!r:10.10}', '3.14 ')
- self.assertEqual(f'{3.14!a:10.10}', '3.14 ')
+ self.assertEqual(f'{1.25!s:10.10}', '1.25 ')
+ self.assertEqual(f'{1.25!r:10.10}', '1.25 ')
+ self.assertEqual(f'{1.25!a:10.10}', '1.25 ')
self.assertEqual(f'{"a"}', 'a')
self.assertEqual(f'{"a"!r}', "'a'")
@@ -1347,7 +1347,7 @@ x = (
# Conversions can have trailing whitespace after them since it
# does not provide any significance
self.assertEqual(f"{3!s }", "3")
- self.assertEqual(f'{3.14!s :10.10}', '3.14 ')
+ self.assertEqual(f'{1.25!s :10.10}', '1.25 ')
# Not a conversion.
self.assertEqual(f'{"a!r"}', "a!r")
@@ -1358,7 +1358,6 @@ x = (
self.assertAllRaise(SyntaxError, "f-string: expecting '}'",
["f'{3!'",
"f'{3!s'",
- "f'{3!g'",
])
self.assertAllRaise(SyntaxError, 'f-string: missing conversion character',
@@ -1381,7 +1380,7 @@ x = (
for conv in ' s', ' s ':
self.assertAllRaise(SyntaxError,
"f-string: conversion type must come right after the"
- " exclamanation mark",
+ " exclamation mark",
["f'{3!" + conv + "}'"])
self.assertAllRaise(SyntaxError,
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 4794a7465f0..f7e09fd771e 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -21,8 +21,10 @@ from weakref import proxy
import contextlib
from inspect import Signature
+from test.support import ALWAYS_EQ
from test.support import import_helper
from test.support import threading_helper
+from test.support import cpython_only
from test.support import EqualToForwardRef
import functools
@@ -63,6 +65,14 @@ class BadTuple(tuple):
class MyDict(dict):
pass
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports(
+ "functools", {"os", "weakref", "typing", "annotationlib", "warnings"}
+ )
+
class TestPartial:
@@ -235,6 +245,13 @@ class TestPartial:
actual_args, actual_kwds = p('x', 'y')
self.assertEqual(actual_args, ('x', 0, 'y', 1))
self.assertEqual(actual_kwds, {})
+ # Checks via `is` and not `eq`
+ # thus ALWAYS_EQ isn't treated as Placeholder
+ p = self.partial(capture, ALWAYS_EQ)
+ actual_args, actual_kwds = p()
+ self.assertEqual(len(actual_args), 1)
+ self.assertIs(actual_args[0], ALWAYS_EQ)
+ self.assertEqual(actual_kwds, {})
def test_placeholders_optimization(self):
PH = self.module.Placeholder
@@ -251,6 +268,17 @@ class TestPartial:
self.assertEqual(p2.args, (PH, 0))
self.assertEqual(p2(1), ((1, 0), {}))
+ def test_placeholders_kw_restriction(self):
+ PH = self.module.Placeholder
+ with self.assertRaisesRegex(TypeError, "Placeholder"):
+ self.partial(capture, a=PH)
+ # Passes, as checks via `is` and not `eq`
+ p = self.partial(capture, a=ALWAYS_EQ)
+ actual_args, actual_kwds = p()
+ self.assertEqual(actual_args, ())
+ self.assertEqual(len(actual_kwds), 1)
+ self.assertIs(actual_kwds['a'], ALWAYS_EQ)
+
def test_construct_placeholder_singleton(self):
PH = self.module.Placeholder
tp = type(PH)
@@ -2952,7 +2980,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(meth.__qualname__, prefix + meth.__name__)
self.assertEqual(meth.__doc__,
('My function docstring'
- if support.HAVE_DOCSTRINGS
+ if support.HAVE_PY_DOCSTRINGS
else None))
self.assertEqual(meth.__annotations__['arg'], int)
@@ -3107,7 +3135,7 @@ class TestSingleDispatch(unittest.TestCase):
with self.subTest(meth=meth):
self.assertEqual(meth.__doc__,
('My function docstring'
- if support.HAVE_DOCSTRINGS
+ if support.HAVE_PY_DOCSTRINGS
else None))
self.assertEqual(meth.__annotations__['arg'], int)
@@ -3584,7 +3612,7 @@ class TestCachedProperty(unittest.TestCase):
def test_doc(self):
self.assertEqual(CachedCostItem.cost.__doc__,
("The cost of the item."
- if support.HAVE_DOCSTRINGS
+ if support.HAVE_PY_DOCSTRINGS
else None))
def test_module(self):
diff --git a/Lib/test/test_future_stmt/test_future.py b/Lib/test/test_future_stmt/test_future.py
index 42c6cb3fefa..71f1e616116 100644
--- a/Lib/test/test_future_stmt/test_future.py
+++ b/Lib/test/test_future_stmt/test_future.py
@@ -422,6 +422,11 @@ class AnnotationsFutureTestCase(unittest.TestCase):
eq('(((a)))', 'a')
eq('(((a, b)))', '(a, b)')
eq("1 + 2 + 3")
+ eq("t''")
+ eq("t'{a + b}'")
+ eq("t'{a!s}'")
+ eq("t'{a:b}'")
+ eq("t'{a:b=}'")
def test_fstring_debug_annotations(self):
# f-strings with '=' don't round trip very well, so set the expected
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index b5140057a69..b4cbfb6d774 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -7,7 +7,7 @@ from test.support import (verbose, refcount_test,
Py_GIL_DISABLED)
from test.support.import_helper import import_module
from test.support.os_helper import temp_dir, TESTFN, unlink
-from test.support.script_helper import assert_python_ok, make_script
+from test.support.script_helper import assert_python_ok, make_script, run_test_script
from test.support import threading_helper, gc_threshold
import gc
@@ -300,7 +300,7 @@ class GCTests(unittest.TestCase):
# We're mostly just checking that this doesn't crash.
rc, stdout, stderr = assert_python_ok("-c", code)
self.assertEqual(rc, 0)
- self.assertRegex(stdout, rb"""\A\s*func=<function at \S+>\s*\Z""")
+ self.assertRegex(stdout, rb"""\A\s*func=<function at \S+>\s*\z""")
self.assertFalse(stderr)
@refcount_test
@@ -914,7 +914,7 @@ class GCTests(unittest.TestCase):
gc.collect()
self.assertEqual(len(Lazarus.resurrected_instances), 1)
instance = Lazarus.resurrected_instances.pop()
- self.assertTrue(hasattr(instance, "cargo"))
+ self.assertHasAttr(instance, "cargo")
self.assertEqual(id(instance.cargo), cargo_id)
gc.collect()
@@ -1127,64 +1127,14 @@ class GCTests(unittest.TestCase):
class IncrementalGCTests(unittest.TestCase):
-
- def setUp(self):
- # Reenable GC as it is disabled module-wide
- gc.enable()
-
- def tearDown(self):
- gc.disable()
-
@unittest.skipIf(_testinternalcapi is None, "requires _testinternalcapi")
@requires_gil_enabled("Free threading does not support incremental GC")
- # Use small increments to emulate longer running process in a shorter time
- @gc_threshold(200, 10)
def test_incremental_gc_handles_fast_cycle_creation(self):
-
- class LinkedList:
-
- #Use slots to reduce number of implicit objects
- __slots__ = "next", "prev", "surprise"
-
- def __init__(self, next=None, prev=None):
- self.next = next
- if next is not None:
- next.prev = self
- self.prev = prev
- if prev is not None:
- prev.next = self
-
- def make_ll(depth):
- head = LinkedList()
- for i in range(depth):
- head = LinkedList(head, head.prev)
- return head
-
- head = make_ll(1000)
- count = 1000
-
- # There will be some objects we aren't counting,
- # e.g. the gc stats dicts. This test checks
- # that the counts don't grow, so we try to
- # correct for the uncounted objects
- # This is just an estimate.
- CORRECTION = 20
-
- enabled = gc.isenabled()
- gc.enable()
- olds = []
- initial_heap_size = _testinternalcapi.get_tracked_heap_size()
- for i in range(20_000):
- newhead = make_ll(20)
- count += 20
- newhead.surprise = head
- olds.append(newhead)
- if len(olds) == 20:
- new_objects = _testinternalcapi.get_tracked_heap_size() - initial_heap_size
- self.assertLess(new_objects, 27_000, f"Heap growing. Reached limit after {i} iterations")
- del olds[:]
- if not enabled:
- gc.disable()
+ # Run this test in a fresh process. The number of alive objects (which can
+ # be from unit tests run before this one) can influence how quickly cyclic
+ # garbage is found.
+ script = support.findfile("_test_gc_fast_cycles.py")
+ run_test_script(script)
class GCCallbackTests(unittest.TestCase):
diff --git a/Lib/test/test_generated_cases.py b/Lib/test/test_generated_cases.py
index 5b120f28131..eb01328b6ea 100644
--- a/Lib/test/test_generated_cases.py
+++ b/Lib/test/test_generated_cases.py
@@ -1,11 +1,9 @@
import contextlib
import os
-import re
import sys
import tempfile
import unittest
-from io import StringIO
from test import support
from test import test_tools
@@ -31,12 +29,11 @@ skip_if_different_mount_drives()
test_tools.skip_if_missing("cases_generator")
with test_tools.imports_under_tool("cases_generator"):
- from analyzer import analyze_forest, StackItem
+ from analyzer import StackItem
from cwriter import CWriter
import parser
from stack import Local, Stack
import tier1_generator
- import opcode_metadata_generator
import optimizer_generator
@@ -59,14 +56,14 @@ class TestEffects(unittest.TestCase):
def test_effect_sizes(self):
stack = Stack()
inputs = [
- x := StackItem("x", None, "1"),
- y := StackItem("y", None, "oparg"),
- z := StackItem("z", None, "oparg*2"),
+ x := StackItem("x", "1"),
+ y := StackItem("y", "oparg"),
+ z := StackItem("z", "oparg*2"),
]
outputs = [
- StackItem("x", None, "1"),
- StackItem("b", None, "oparg*4"),
- StackItem("c", None, "1"),
+ StackItem("x", "1"),
+ StackItem("b", "oparg*4"),
+ StackItem("c", "1"),
]
null = CWriter.null()
stack.pop(z, null)
@@ -419,7 +416,7 @@ class TestGeneratedCases(unittest.TestCase):
def test_error_if_plain(self):
input = """
inst(OP, (--)) {
- ERROR_IF(cond, label);
+ ERROR_IF(cond);
}
"""
output = """
@@ -432,7 +429,7 @@ class TestGeneratedCases(unittest.TestCase):
next_instr += 1;
INSTRUCTION_STATS(OP);
if (cond) {
- JUMP_TO_LABEL(label);
+ JUMP_TO_LABEL(error);
}
DISPATCH();
}
@@ -442,7 +439,7 @@ class TestGeneratedCases(unittest.TestCase):
def test_error_if_plain_with_comment(self):
input = """
inst(OP, (--)) {
- ERROR_IF(cond, label); // Comment is ok
+ ERROR_IF(cond); // Comment is ok
}
"""
output = """
@@ -455,7 +452,7 @@ class TestGeneratedCases(unittest.TestCase):
next_instr += 1;
INSTRUCTION_STATS(OP);
if (cond) {
- JUMP_TO_LABEL(label);
+ JUMP_TO_LABEL(error);
}
DISPATCH();
}
@@ -467,7 +464,7 @@ class TestGeneratedCases(unittest.TestCase):
inst(OP, (left, right -- res)) {
SPAM(left, right);
INPUTS_DEAD();
- ERROR_IF(cond, label);
+ ERROR_IF(cond);
res = 0;
}
"""
@@ -487,7 +484,7 @@ class TestGeneratedCases(unittest.TestCase):
left = stack_pointer[-2];
SPAM(left, right);
if (cond) {
- JUMP_TO_LABEL(pop_2_label);
+ JUMP_TO_LABEL(pop_2_error);
}
res = 0;
stack_pointer[-2] = res;
@@ -503,7 +500,7 @@ class TestGeneratedCases(unittest.TestCase):
inst(OP, (left, right -- res)) {
res = SPAM(left, right);
INPUTS_DEAD();
- ERROR_IF(cond, label);
+ ERROR_IF(cond);
}
"""
output = """
@@ -522,7 +519,7 @@ class TestGeneratedCases(unittest.TestCase):
left = stack_pointer[-2];
res = SPAM(left, right);
if (cond) {
- JUMP_TO_LABEL(pop_2_label);
+ JUMP_TO_LABEL(pop_2_error);
}
stack_pointer[-2] = res;
stack_pointer += -1;
@@ -903,7 +900,7 @@ class TestGeneratedCases(unittest.TestCase):
inst(OP, (extra, values[oparg] --)) {
DEAD(extra);
DEAD(values);
- ERROR_IF(oparg == 0, somewhere);
+ ERROR_IF(oparg == 0);
}
"""
output = """
@@ -922,7 +919,7 @@ class TestGeneratedCases(unittest.TestCase):
if (oparg == 0) {
stack_pointer += -1 - oparg;
assert(WITHIN_STACK_BOUNDS());
- JUMP_TO_LABEL(somewhere);
+ JUMP_TO_LABEL(error);
}
stack_pointer += -1 - oparg;
assert(WITHIN_STACK_BOUNDS());
@@ -1106,32 +1103,6 @@ class TestGeneratedCases(unittest.TestCase):
"""
self.run_cases_test(input, output)
- def test_pointer_to_stackref(self):
- input = """
- inst(OP, (arg: _PyStackRef * -- out)) {
- out = *arg;
- DEAD(arg);
- }
- """
- output = """
- TARGET(OP) {
- #if Py_TAIL_CALL_INTERP
- int opcode = OP;
- (void)(opcode);
- #endif
- frame->instr_ptr = next_instr;
- next_instr += 1;
- INSTRUCTION_STATS(OP);
- _PyStackRef *arg;
- _PyStackRef out;
- arg = (_PyStackRef *)stack_pointer[-1].bits;
- out = *arg;
- stack_pointer[-1] = out;
- DISPATCH();
- }
- """
- self.run_cases_test(input, output)
-
def test_unused_cached_value(self):
input = """
op(FIRST, (arg1 -- out)) {
@@ -1319,7 +1290,7 @@ class TestGeneratedCases(unittest.TestCase):
op(THIRD, (j, k --)) {
INPUTS_DEAD(); // Mark j and k as used
- ERROR_IF(cond, error);
+ ERROR_IF(cond);
}
macro(TEST) = FIRST + SECOND + THIRD;
@@ -1369,7 +1340,7 @@ class TestGeneratedCases(unittest.TestCase):
op(SECOND, (a -- a, b)) {
b = 1;
- ERROR_IF(cond, error);
+ ERROR_IF(cond);
}
macro(TEST) = FIRST + SECOND;
@@ -1414,10 +1385,10 @@ class TestGeneratedCases(unittest.TestCase):
input = """
inst(OP1, ( --)) {
- ERROR_IF(true, here);
+ ERROR_IF(true);
}
inst(OP2, ( --)) {
- ERROR_IF(1, there);
+ ERROR_IF(1);
}
"""
output = """
@@ -1429,7 +1400,7 @@ class TestGeneratedCases(unittest.TestCase):
frame->instr_ptr = next_instr;
next_instr += 1;
INSTRUCTION_STATS(OP1);
- JUMP_TO_LABEL(here);
+ JUMP_TO_LABEL(error);
}
TARGET(OP2) {
@@ -1440,7 +1411,7 @@ class TestGeneratedCases(unittest.TestCase):
frame->instr_ptr = next_instr;
next_instr += 1;
INSTRUCTION_STATS(OP2);
- JUMP_TO_LABEL(there);
+ JUMP_TO_LABEL(error);
}
"""
self.run_cases_test(input, output)
@@ -1716,7 +1687,7 @@ class TestGeneratedCases(unittest.TestCase):
input = """
inst(OP, ( -- )) {
- ERROR_IF(escaping_call(), error);
+ ERROR_IF(escaping_call());
}
"""
with self.assertRaises(SyntaxError):
@@ -2005,8 +1976,8 @@ class TestGeneratedAbstractCases(unittest.TestCase):
"""
output = """
case OP: {
- JitOptSymbol *arg1;
- JitOptSymbol *out;
+ JitOptRef arg1;
+ JitOptRef out;
arg1 = stack_pointer[-1];
out = EGGS(arg1);
stack_pointer[-1] = out;
@@ -2014,7 +1985,7 @@ class TestGeneratedAbstractCases(unittest.TestCase):
}
case OP2: {
- JitOptSymbol *out;
+ JitOptRef out;
out = sym_new_not_null(ctx);
stack_pointer[-1] = out;
break;
@@ -2039,14 +2010,14 @@ class TestGeneratedAbstractCases(unittest.TestCase):
"""
output = """
case OP: {
- JitOptSymbol *out;
+ JitOptRef out;
out = sym_new_not_null(ctx);
stack_pointer[-1] = out;
break;
}
case OP2: {
- JitOptSymbol *out;
+ JitOptRef out;
out = NULL;
stack_pointer[-1] = out;
break;
@@ -2069,6 +2040,386 @@ class TestGeneratedAbstractCases(unittest.TestCase):
with self.assertRaisesRegex(AssertionError, "All abstract uops"):
self.run_cases_test(input, input2, output)
+ def test_validate_uop_input_length_mismatch(self):
+ input = """
+ op(OP, (arg1 -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1, arg2 -- out)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Must have the same number of inputs"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_length_mismatch(self):
+ input = """
+ op(OP, (arg1 -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1 -- out1, out2)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Must have the same number of outputs"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_input_name_mismatch(self):
+ input = """
+ op(OP, (foo -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (bar -- out)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal names"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_name_mismatch(self):
+ input = """
+ op(OP, (arg1 -- foo)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1 -- bar)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Outputs must have equal names"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_input(self):
+ input = """
+ op(OP, (unused -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (foo -- )) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += -1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ input = """
+ op(OP, (foo -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (unused -- )) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += -1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_output(self):
+ input = """
+ op(OP, ( -- unused)) {
+ }
+ """
+ input2 = """
+ op(OP, ( -- foo)) {
+ foo = NULL;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ foo = NULL;
+ stack_pointer[0] = foo;
+ stack_pointer += 1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ input = """
+ op(OP, ( -- foo)) {
+ foo = NULL;
+ }
+ """
+ input2 = """
+ op(OP, ( -- unused)) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += 1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_input_size_mismatch(self):
+ input = """
+ op(OP, (arg1[2] -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (arg1[4] -- )) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_size_mismatch(self):
+ input = """
+ op(OP, ( -- out[2])) {
+ }
+ """
+ input2 = """
+ op(OP, ( -- out[4])) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Outputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_size_mismatch(self):
+ input = """
+ op(OP, (foo[2] -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (unused[4] -- )) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ res = PyStackRef_IsNone(foo);
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_known(ctx, foo);
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ res_stackref = PyStackRef_IsNone(foo);
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = sym_new_known(ctx, foo);
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in_deopt(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ DEOPT_IF(PyStackRef_IsNull(foo));
+ res = foo;
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = foo;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (PyStackRef_IsNull(foo)) {
+ ctx->done = true;
+ break;
+ }
+ res_stackref = foo;
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = foo;
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in_error_if(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ ERROR_IF(PyStackRef_IsNull(foo));
+ res = foo;
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = foo;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (PyStackRef_IsNull(foo)) {
+ goto error;
+ }
+ res_stackref = foo;
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = foo;
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+
+ def test_replace_opcode_uop_body_copied_in_complex(self):
+ input = """
+ pure op(OP, (foo -- res)) {
+ if (foo) {
+ res = PyStackRef_IsNone(foo);
+ }
+ else {
+ res = 1;
+ }
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_known(ctx, foo);
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (foo) {
+ res_stackref = PyStackRef_IsNone(foo);
+ }
+ else {
+ res_stackref = 1;
+ }
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = sym_new_known(ctx, foo);
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_replace_opocode_uop_reject_array_effects(self):
+ input = """
+ pure op(OP, (foo[2] -- res)) {
+ if (foo) {
+ res = PyStackRef_IsNone(foo);
+ }
+ else {
+ res = 1;
+ }
+ }
+ """
+ input2 = """
+ op(OP, (foo[2] -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_unknown(ctx);
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Pure evaluation cannot take array-like inputs"):
+ self.run_cases_test(input, input2, output)
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_genericalias.py b/Lib/test/test_genericalias.py
index 5c13897b8d9..7601cb00ff6 100644
--- a/Lib/test/test_genericalias.py
+++ b/Lib/test/test_genericalias.py
@@ -61,6 +61,7 @@ try:
from tkinter import Event
except ImportError:
Event = None
+from string.templatelib import Template, Interpolation
from typing import TypeVar
T = TypeVar('T')
@@ -137,7 +138,12 @@ class BaseTest(unittest.TestCase):
Future, _WorkItem,
Morsel,
DictReader, DictWriter,
- array]
+ array,
+ staticmethod,
+ classmethod,
+ Template,
+ Interpolation,
+ ]
if ctypes is not None:
generic_types.extend((ctypes.Array, ctypes.LibraryLoader, ctypes.py_object))
if ValueProxy is not None:
@@ -230,13 +236,13 @@ class BaseTest(unittest.TestCase):
self.assertEqual(repr(x2), 'tuple[*tuple[int, str]]')
x3 = tuple[*tuple[int, ...]]
self.assertEqual(repr(x3), 'tuple[*tuple[int, ...]]')
- self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr.<locals>.MyList[int]'))
+ self.assertEndsWith(repr(MyList[int]), '.BaseTest.test_repr.<locals>.MyList[int]')
self.assertEqual(repr(list[str]()), '[]') # instances should keep their normal repr
# gh-105488
- self.assertTrue(repr(MyGeneric[int]).endswith('MyGeneric[int]'))
- self.assertTrue(repr(MyGeneric[[]]).endswith('MyGeneric[[]]'))
- self.assertTrue(repr(MyGeneric[[int, str]]).endswith('MyGeneric[[int, str]]'))
+ self.assertEndsWith(repr(MyGeneric[int]), 'MyGeneric[int]')
+ self.assertEndsWith(repr(MyGeneric[[]]), 'MyGeneric[[]]')
+ self.assertEndsWith(repr(MyGeneric[[int, str]]), 'MyGeneric[[int, str]]')
def test_exposed_type(self):
import types
@@ -356,7 +362,7 @@ class BaseTest(unittest.TestCase):
def test_issubclass(self):
class L(list): ...
- self.assertTrue(issubclass(L, list))
+ self.assertIsSubclass(L, list)
with self.assertRaises(TypeError):
issubclass(L, list[str])
diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py
index 6c3abe602f5..16c3268fefb 100644
--- a/Lib/test/test_genericpath.py
+++ b/Lib/test/test_genericpath.py
@@ -8,7 +8,7 @@ import sys
import unittest
import warnings
from test.support import (
- is_apple, is_emscripten, os_helper, warnings_helper
+ is_apple, os_helper, warnings_helper
)
from test.support.script_helper import assert_python_ok
from test.support.os_helper import FakePath
@@ -92,8 +92,8 @@ class GenericTest:
for s1 in testlist:
for s2 in testlist:
p = commonprefix([s1, s2])
- self.assertTrue(s1.startswith(p))
- self.assertTrue(s2.startswith(p))
+ self.assertStartsWith(s1, p)
+ self.assertStartsWith(s2, p)
if s1 != s2:
n = len(p)
self.assertNotEqual(s1[n:n+1], s2[n:n+1])
diff --git a/Lib/test/test_getpass.py b/Lib/test/test_getpass.py
index 80dda2caaa3..ab36535a1cf 100644
--- a/Lib/test/test_getpass.py
+++ b/Lib/test/test_getpass.py
@@ -161,6 +161,45 @@ class UnixGetpassTest(unittest.TestCase):
self.assertIn('Warning', stderr.getvalue())
self.assertIn('Password:', stderr.getvalue())
+ def test_echo_char_replaces_input_with_asterisks(self):
+ mock_result = '*************'
+ with mock.patch('os.open') as os_open, \
+ mock.patch('io.FileIO'), \
+ mock.patch('io.TextIOWrapper') as textio, \
+ mock.patch('termios.tcgetattr'), \
+ mock.patch('termios.tcsetattr'), \
+ mock.patch('getpass._raw_input') as mock_input:
+ os_open.return_value = 3
+ mock_input.return_value = mock_result
+
+ result = getpass.unix_getpass(echo_char='*')
+ mock_input.assert_called_once_with('Password: ', textio(),
+ input=textio(), echo_char='*')
+ self.assertEqual(result, mock_result)
+
+ def test_raw_input_with_echo_char(self):
+ passwd = 'my1pa$$word!'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, passwd)
+ self.assertEqual('Password: ************', mock_output.getvalue())
+
+ def test_control_chars_with_echo_char(self):
+ passwd = 'pass\twd\b'
+ expect_result = 'pass\tw'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, expect_result)
+ self.assertEqual('Password: *******\x08 \x08', mock_output.getvalue())
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 61bbd0dba43..33b7d75e3ff 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -6,7 +6,8 @@ import unittest.mock
from functools import partial
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
+from test.support.import_helper import ensure_lazy_imports
# TODO:
@@ -115,6 +116,23 @@ GNU_MO_DATA_CORRUPT = base64.b64encode(bytes([
0x62, 0x61, 0x72, 0x00, # Message data
]))
+
+GNU_MO_DATA_BIG_ENDIAN = base64.b64encode(bytes([
+ 0x95, 0x04, 0x12, 0xDE, # Magic
+ 0x00, 0x00, 0x00, 0x00, # Version
+ 0x00, 0x00, 0x00, 0x01, # Message count
+ 0x00, 0x00, 0x00, 0x1C, # Message offset
+ 0x00, 0x00, 0x00, 0x24, # Translation offset
+ 0x00, 0x00, 0x00, 0x00, # Hash table size
+ 0x00, 0x00, 0x00, 0x2C, # Hash table offset
+ 0x00, 0x00, 0x00, 0x03, # 1st message length
+ 0x00, 0x00, 0x00, 0x2C, # 1st message offset
+ 0x00, 0x00, 0x00, 0x03, # 1st trans length
+ 0x00, 0x00, 0x00, 0x30, # 1st trans offset
+ 0x66, 0x6F, 0x6F, 0x00, # Message data
+ 0x62, 0x61, 0x72, 0x00, # Message data
+]))
+
UMO_DATA = b'''\
3hIElQAAAAADAAAAHAAAADQAAAAAAAAAAAAAAAAAAABMAAAABAAAAE0AAAAQAAAAUgAAAA8BAABj
AAAABAAAAHMBAAAWAAAAeAEAAABhYsOeAG15Y29udGV4dMOeBGFiw54AUHJvamVjdC1JZC1WZXJz
@@ -142,6 +160,7 @@ MOFILE_BAD_MAGIC_NUMBER = os.path.join(LOCALEDIR, 'gettext_bad_magic_number.mo')
MOFILE_BAD_MAJOR_VERSION = os.path.join(LOCALEDIR, 'gettext_bad_major_version.mo')
MOFILE_BAD_MINOR_VERSION = os.path.join(LOCALEDIR, 'gettext_bad_minor_version.mo')
MOFILE_CORRUPT = os.path.join(LOCALEDIR, 'gettext_corrupt.mo')
+MOFILE_BIG_ENDIAN = os.path.join(LOCALEDIR, 'gettext_big_endian.mo')
UMOFILE = os.path.join(LOCALEDIR, 'ugettext.mo')
MMOFILE = os.path.join(LOCALEDIR, 'metadata.mo')
@@ -168,6 +187,8 @@ class GettextBaseTest(unittest.TestCase):
fp.write(base64.decodebytes(GNU_MO_DATA_BAD_MINOR_VERSION))
with open(MOFILE_CORRUPT, 'wb') as fp:
fp.write(base64.decodebytes(GNU_MO_DATA_CORRUPT))
+ with open(MOFILE_BIG_ENDIAN, 'wb') as fp:
+ fp.write(base64.decodebytes(GNU_MO_DATA_BIG_ENDIAN))
with open(UMOFILE, 'wb') as fp:
fp.write(base64.decodebytes(UMO_DATA))
with open(MMOFILE, 'wb') as fp:
@@ -293,6 +314,12 @@ class GettextTestCase2(GettextBaseTest):
self.assertEqual(exception.strerror, "File is corrupt")
self.assertEqual(exception.filename, MOFILE_CORRUPT)
+ def test_big_endian_file(self):
+ with open(MOFILE_BIG_ENDIAN, 'rb') as fp:
+ t = gettext.GNUTranslations(fp)
+
+ self.assertEqual(t.gettext('foo'), 'bar')
+
def test_some_translations(self):
eq = self.assertEqual
# test some translations
@@ -905,6 +932,10 @@ class MiscTestCase(unittest.TestCase):
support.check__all__(self, gettext,
not_exported={'c2py', 'ENOENT'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("gettext", {"re", "warnings", "locale"})
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py
index 6e5fc2939c6..d0ed5129253 100644
--- a/Lib/test/test_glob.py
+++ b/Lib/test/test_glob.py
@@ -459,59 +459,59 @@ class GlobTests(unittest.TestCase):
def test_translate(self):
def fn(pat):
return glob.translate(pat, seps='/')
- self.assertEqual(fn('foo'), r'(?s:foo)\Z')
- self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z')
- self.assertEqual(fn('*'), r'(?s:[^/.][^/]*)\Z')
- self.assertEqual(fn('?'), r'(?s:(?!\.)[^/])\Z')
- self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('*a'), r'(?s:(?!\.)[^/]*a)\Z')
- self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z')
- self.assertEqual(fn('?aa'), r'(?s:(?!\.)[^/]aa)\Z')
- self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z')
- self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z')
- self.assertEqual(fn('**'), r'(?s:(?!\.)[^/]*)\Z')
- self.assertEqual(fn('***'), r'(?s:(?!\.)[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:(?!\.)[^/]*b)\Z')
+ self.assertEqual(fn('foo'), r'(?s:foo)\z')
+ self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\z')
+ self.assertEqual(fn('*'), r'(?s:[^/.][^/]*)\z')
+ self.assertEqual(fn('?'), r'(?s:(?!\.)[^/])\z')
+ self.assertEqual(fn('a*'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('*a'), r'(?s:(?!\.)[^/]*a)\z')
+ self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\z')
+ self.assertEqual(fn('?aa'), r'(?s:(?!\.)[^/]aa)\z')
+ self.assertEqual(fn('aa?'), r'(?s:aa[^/])\z')
+ self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\z')
+ self.assertEqual(fn('**'), r'(?s:(?!\.)[^/]*)\z')
+ self.assertEqual(fn('***'), r'(?s:(?!\.)[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:(?!\.)[^/]*b)\z')
self.assertEqual(fn('/**/*/*.*/**'),
- r'(?s:/(?!\.)[^/]*/[^/.][^/]*/(?!\.)[^/]*\.[^/]*/(?!\.)[^/]*)\Z')
+ r'(?s:/(?!\.)[^/]*/[^/.][^/]*/(?!\.)[^/]*\.[^/]*/(?!\.)[^/]*)\z')
def test_translate_include_hidden(self):
def fn(pat):
return glob.translate(pat, include_hidden=True, seps='/')
- self.assertEqual(fn('foo'), r'(?s:foo)\Z')
- self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z')
- self.assertEqual(fn('*'), r'(?s:[^/]+)\Z')
- self.assertEqual(fn('?'), r'(?s:[^/])\Z')
- self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('*a'), r'(?s:[^/]*a)\Z')
- self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z')
- self.assertEqual(fn('?aa'), r'(?s:[^/]aa)\Z')
- self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z')
- self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z')
- self.assertEqual(fn('**'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('***'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z')
- self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/[^/]*/[^/]+/[^/]*\.[^/]*/[^/]*)\Z')
+ self.assertEqual(fn('foo'), r'(?s:foo)\z')
+ self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\z')
+ self.assertEqual(fn('*'), r'(?s:[^/]+)\z')
+ self.assertEqual(fn('?'), r'(?s:[^/])\z')
+ self.assertEqual(fn('a*'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('*a'), r'(?s:[^/]*a)\z')
+ self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\z')
+ self.assertEqual(fn('?aa'), r'(?s:[^/]aa)\z')
+ self.assertEqual(fn('aa?'), r'(?s:aa[^/])\z')
+ self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\z')
+ self.assertEqual(fn('**'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('***'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:[^/]*b)\z')
+ self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/[^/]*/[^/]+/[^/]*\.[^/]*/[^/]*)\z')
def test_translate_recursive(self):
def fn(pat):
return glob.translate(pat, recursive=True, include_hidden=True, seps='/')
- self.assertEqual(fn('*'), r'(?s:[^/]+)\Z')
- self.assertEqual(fn('?'), r'(?s:[^/])\Z')
- self.assertEqual(fn('**'), r'(?s:.*)\Z')
- self.assertEqual(fn('**/**'), r'(?s:.*)\Z')
- self.assertEqual(fn('***'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z')
- self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/(?:.+/)?[^/]+/[^/]*\.[^/]*/.*)\Z')
+ self.assertEqual(fn('*'), r'(?s:[^/]+)\z')
+ self.assertEqual(fn('?'), r'(?s:[^/])\z')
+ self.assertEqual(fn('**'), r'(?s:.*)\z')
+ self.assertEqual(fn('**/**'), r'(?s:.*)\z')
+ self.assertEqual(fn('***'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:[^/]*b)\z')
+ self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/(?:.+/)?[^/]+/[^/]*\.[^/]*/.*)\z')
def test_translate_seps(self):
def fn(pat):
return glob.translate(pat, recursive=True, include_hidden=True, seps=['/', '\\'])
- self.assertEqual(fn('foo/bar\\baz'), r'(?s:foo[/\\]bar[/\\]baz)\Z')
- self.assertEqual(fn('**/*'), r'(?s:(?:.+[/\\])?[^/\\]+)\Z')
+ self.assertEqual(fn('foo/bar\\baz'), r'(?s:foo[/\\]bar[/\\]baz)\z')
+ self.assertEqual(fn('**/*'), r'(?s:(?:.+[/\\])?[^/\\]+)\z')
if __name__ == "__main__":
diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py
index 35cd6984267..7f5d48b9c63 100644
--- a/Lib/test/test_grammar.py
+++ b/Lib/test/test_grammar.py
@@ -1,7 +1,7 @@
# Python test set -- part 1, grammar.
# This just tests whether the parser accepts them all.
-from test.support import check_syntax_error
+from test.support import check_syntax_error, skip_wasi_stack_overflow
from test.support import import_helper
import annotationlib
import inspect
@@ -216,6 +216,27 @@ the \'lazy\' dog.\n\
'
self.assertEqual(x, y)
+ def test_string_prefixes(self):
+ def check(s):
+ parsed = eval(s)
+ self.assertIs(type(parsed), str)
+ self.assertGreater(len(parsed), 0)
+
+ check("u'abc'")
+ check("r'abc\t'")
+ check("rf'abc\a {1 + 1}'")
+ check("fr'abc\a {1 + 1}'")
+
+ def test_bytes_prefixes(self):
+ def check(s):
+ parsed = eval(s)
+ self.assertIs(type(parsed), bytes)
+ self.assertGreater(len(parsed), 0)
+
+ check("b'abc'")
+ check("br'abc\t'")
+ check("rb'abc\a'")
+
def test_ellipsis(self):
x = ...
self.assertTrue(x is Ellipsis)
@@ -228,6 +249,18 @@ the \'lazy\' dog.\n\
compile(s, "<test>", "exec")
self.assertIn("was never closed", str(cm.exception))
+ @skip_wasi_stack_overflow()
+ def test_max_level(self):
+ # Macro defined in Parser/lexer/state.h
+ MAXLEVEL = 200
+
+ result = eval("(" * MAXLEVEL + ")" * MAXLEVEL)
+ self.assertEqual(result, ())
+
+ with self.assertRaises(SyntaxError) as cm:
+ eval("(" * (MAXLEVEL + 1) + ")" * (MAXLEVEL + 1))
+ self.assertStartsWith(str(cm.exception), 'too many nested parentheses')
+
var_annot_global: int # a global annotated is necessary for test_var_annot
@@ -1507,6 +1540,8 @@ class GrammarTests(unittest.TestCase):
check('[None (3, 4)]')
check('[True (3, 4)]')
check('[... (3, 4)]')
+ check('[t"{x}" (3, 4)]')
+ check('[t"x={x}" (3, 4)]')
msg=r'is not subscriptable; perhaps you missed a comma\?'
check('[{1, 2} [i, j]]')
@@ -1529,6 +1564,8 @@ class GrammarTests(unittest.TestCase):
check('[f"x={x}" [i, j]]')
check('["abc" [i, j]]')
check('[b"abc" [i, j]]')
+ check('[t"{x}" [i, j]]')
+ check('[t"x={x}" [i, j]]')
msg=r'indices must be integers or slices, not tuple;'
check('[[1, 2] [3, 4]]')
@@ -1549,6 +1586,8 @@ class GrammarTests(unittest.TestCase):
check('[[1, 2] [f"{x}"]]')
check('[[1, 2] [f"x={x}"]]')
check('[[1, 2] ["abc"]]')
+ check('[[1, 2] [t"{x}"]]')
+ check('[[1, 2] [t"x={x}"]]')
msg=r'indices must be integers or slices, not'
check('[[1, 2] [b"abc"]]')
check('[[1, 2] [12.3]]')
diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py
index fa5de7c190e..a12ff5662a7 100644
--- a/Lib/test/test_gzip.py
+++ b/Lib/test/test_gzip.py
@@ -9,7 +9,6 @@ import os
import struct
import sys
import unittest
-import warnings
from subprocess import PIPE, Popen
from test.support import catch_unraisable_exception
from test.support import import_helper
@@ -331,13 +330,13 @@ class TestGzip(BaseTest):
def test_1647484(self):
for mode in ('wb', 'rb'):
with gzip.GzipFile(self.filename, mode) as f:
- self.assertTrue(hasattr(f, "name"))
+ self.assertHasAttr(f, "name")
self.assertEqual(f.name, self.filename)
def test_paddedfile_getattr(self):
self.test_write()
with gzip.GzipFile(self.filename, 'rb') as f:
- self.assertTrue(hasattr(f.fileobj, "name"))
+ self.assertHasAttr(f.fileobj, "name")
self.assertEqual(f.fileobj.name, self.filename)
def test_mtime(self):
@@ -345,7 +344,7 @@ class TestGzip(BaseTest):
with gzip.GzipFile(self.filename, 'w', mtime = mtime) as fWrite:
fWrite.write(data1)
with gzip.GzipFile(self.filename) as fRead:
- self.assertTrue(hasattr(fRead, 'mtime'))
+ self.assertHasAttr(fRead, 'mtime')
self.assertIsNone(fRead.mtime)
dataRead = fRead.read()
self.assertEqual(dataRead, data1)
@@ -460,7 +459,7 @@ class TestGzip(BaseTest):
self.assertEqual(d, data1 * 50, "Incorrect data in file")
def test_gzip_BadGzipFile_exception(self):
- self.assertTrue(issubclass(gzip.BadGzipFile, OSError))
+ self.assertIsSubclass(gzip.BadGzipFile, OSError)
def test_bad_gzip_file(self):
with open(self.filename, 'wb') as file:
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index 5e3356a02f3..5bad483ae9d 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -12,12 +12,12 @@ import io
import itertools
import logging
import os
+import re
import sys
import sysconfig
import tempfile
import threading
import unittest
-import warnings
from test import support
from test.support import _4G, bigmemtest
from test.support import hashlib_helper
@@ -98,6 +98,14 @@ def read_vectors(hash_name):
yield parts
+DEPRECATED_STRING_PARAMETER = re.escape(
+ "the 'string' keyword parameter is deprecated since "
+ "Python 3.15 and slated for removal in Python 3.19; "
+ "use the 'data' keyword parameter or pass the data "
+ "to hash as a positional argument instead"
+)
+
+
class HashLibTestCase(unittest.TestCase):
supported_hash_names = ( 'md5', 'MD5', 'sha1', 'SHA1',
'sha224', 'SHA224', 'sha256', 'SHA256',
@@ -141,19 +149,18 @@ class HashLibTestCase(unittest.TestCase):
# of hashlib.new given the algorithm name.
for algorithm, constructors in self.constructors_to_test.items():
constructors.add(getattr(hashlib, algorithm))
- def _test_algorithm_via_hashlib_new(data=None, _alg=algorithm, **kwargs):
- if data is None:
- return hashlib.new(_alg, **kwargs)
- return hashlib.new(_alg, data, **kwargs)
- constructors.add(_test_algorithm_via_hashlib_new)
+ def c(*args, __algorithm_name=algorithm, **kwargs):
+ return hashlib.new(__algorithm_name, *args, **kwargs)
+ c.__name__ = f'do_test_algorithm_via_hashlib_new_{algorithm}'
+ constructors.add(c)
_hashlib = self._conditional_import_module('_hashlib')
self._hashlib = _hashlib
if _hashlib:
# These algorithms should always be present when this module
# is compiled. If not, something was compiled wrong.
- self.assertTrue(hasattr(_hashlib, 'openssl_md5'))
- self.assertTrue(hasattr(_hashlib, 'openssl_sha1'))
+ self.assertHasAttr(_hashlib, 'openssl_md5')
+ self.assertHasAttr(_hashlib, 'openssl_sha1')
for algorithm, constructors in self.constructors_to_test.items():
constructor = getattr(_hashlib, 'openssl_'+algorithm, None)
if constructor:
@@ -201,6 +208,11 @@ class HashLibTestCase(unittest.TestCase):
return itertools.chain.from_iterable(constructors)
@property
+ def shake_constructors(self):
+ for shake_name in self.shakes:
+ yield from self.constructors_to_test.get(shake_name, ())
+
+ @property
def is_fips_mode(self):
return get_fips_mode()
@@ -250,6 +262,85 @@ class HashLibTestCase(unittest.TestCase):
self._hashlib.new("md5", usedforsecurity=False)
self._hashlib.openssl_md5(usedforsecurity=False)
+ @unittest.skipIf(get_fips_mode(), "skip in FIPS mode")
+ def test_clinic_signature(self):
+ for constructor in self.hash_constructors:
+ with self.subTest(constructor.__name__):
+ constructor(b'')
+ constructor(data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ constructor(string=b'')
+
+ digest_name = constructor(b'').name
+ with self.subTest(digest_name):
+ hashlib.new(digest_name, b'')
+ hashlib.new(digest_name, data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ hashlib.new(digest_name, string=b'')
+ # Make sure that _hashlib contains the constructor
+ # to test when using a combination of libcrypto and
+ # interned hash implementations.
+ if self._hashlib and digest_name in self._hashlib._constructors:
+ self._hashlib.new(digest_name, b'')
+ self._hashlib.new(digest_name, data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ self._hashlib.new(digest_name, string=b'')
+
+ @unittest.skipIf(get_fips_mode(), "skip in FIPS mode")
+ def test_clinic_signature_errors(self):
+ nomsg = b''
+ mymsg = b'msg'
+ conflicting_call = re.escape(
+ "'data' and 'string' are mutually exclusive "
+ "and support for 'string' keyword parameter "
+ "is slated for removal in a future version."
+ )
+ duplicated_param = re.escape("given by name ('data') and position")
+ unexpected_param = re.escape("got an unexpected keyword argument '_'")
+ for args, kwds, errmsg in [
+ # Reject duplicated arguments before unknown keyword arguments.
+ ((nomsg,), dict(data=nomsg, _=nomsg), duplicated_param),
+ ((mymsg,), dict(data=nomsg, _=nomsg), duplicated_param),
+ # Reject duplicated arguments before conflicting ones.
+ *itertools.product(
+ [[nomsg], [mymsg]],
+ [dict(data=nomsg), dict(data=nomsg, string=nomsg)],
+ [duplicated_param]
+ ),
+ # Reject unknown keyword arguments before conflicting ones.
+ *itertools.product(
+ [()],
+ [
+ dict(_=None),
+ dict(data=nomsg, _=None),
+ dict(string=nomsg, _=None),
+ dict(string=nomsg, data=nomsg, _=None),
+ ],
+ [unexpected_param]
+ ),
+ ((nomsg,), dict(_=None), unexpected_param),
+ ((mymsg,), dict(_=None), unexpected_param),
+ # Reject conflicting arguments.
+ [(nomsg,), dict(string=nomsg), conflicting_call],
+ [(mymsg,), dict(string=nomsg), conflicting_call],
+ [(), dict(data=nomsg, string=nomsg), conflicting_call],
+ ]:
+ for constructor in self.hash_constructors:
+ digest_name = constructor(b'').name
+ with self.subTest(constructor.__name__, args=args, kwds=kwds):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ constructor(*args, **kwds)
+ with self.subTest(digest_name, args=args, kwds=kwds):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ hashlib.new(digest_name, *args, **kwds)
+ if (self._hashlib and
+ digest_name in self._hashlib._constructors):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ self._hashlib.new(digest_name, *args, **kwds)
+
def test_unknown_hash(self):
self.assertRaises(ValueError, hashlib.new, 'spam spam spam spam spam')
self.assertRaises(TypeError, hashlib.new, 1)
@@ -284,6 +375,16 @@ class HashLibTestCase(unittest.TestCase):
self.assertIs(constructor, _md5.md5)
self.assertEqual(sorted(builtin_constructor_cache), ['MD5', 'md5'])
+ def test_copy(self):
+ for cons in self.hash_constructors:
+ h1 = cons(os.urandom(16), usedforsecurity=False)
+ h2 = h1.copy()
+ self.assertIs(type(h1), type(h2))
+ self.assertEqual(h1.name, h2.name)
+ size = (16,) if h1.name in self.shakes else ()
+ self.assertEqual(h1.digest(*size), h2.digest(*size))
+ self.assertEqual(h1.hexdigest(*size), h2.hexdigest(*size))
+
def test_hexdigest(self):
for cons in self.hash_constructors:
h = cons(usedforsecurity=False)
@@ -294,21 +395,50 @@ class HashLibTestCase(unittest.TestCase):
self.assertIsInstance(h.digest(), bytes)
self.assertEqual(hexstr(h.digest()), h.hexdigest())
- def test_digest_length_overflow(self):
- # See issue #34922
- large_sizes = (2**29, 2**32-10, 2**32+10, 2**61, 2**64-10, 2**64+10)
- for cons in self.hash_constructors:
- h = cons(usedforsecurity=False)
- if h.name not in self.shakes:
- continue
- if HASH is not None and isinstance(h, HASH):
- # _hashopenssl's take a size_t
- continue
- for digest in h.digest, h.hexdigest:
- self.assertRaises(ValueError, digest, -10)
- for length in large_sizes:
- with self.assertRaises((ValueError, OverflowError)):
- digest(length)
+ def test_shakes_zero_digest_length(self):
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(b'abcdef', usedforsecurity=False)
+ self.assertEqual(h.digest(0), b'')
+ self.assertEqual(h.hexdigest(0), '')
+
+ def test_shakes_invalid_digest_length(self):
+ # See https://github.com/python/cpython/issues/79103.
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(usedforsecurity=False)
+ # Note: digest() and hexdigest() take a signed input and
+ # raise if it is negative; the rationale is that we use
+ # internally PyBytes_FromStringAndSize() and _Py_strhex()
+ # which both take a Py_ssize_t.
+ for negative_size in (-1, -10, -(1 << 31), -sys.maxsize):
+ self.assertRaises(ValueError, h.digest, negative_size)
+ self.assertRaises(ValueError, h.hexdigest, negative_size)
+
+ def test_shakes_overflow_digest_length(self):
+ # See https://github.com/python/cpython/issues/135759.
+
+ exc_types = (OverflowError, ValueError)
+ # HACL* accepts an 'uint32_t' while OpenSSL accepts a 'size_t'.
+ openssl_overflown_sizes = (sys.maxsize + 1, 2 * sys.maxsize)
+ # https://github.com/python/cpython/issues/79103 restricts
+ # the accepted built-in lengths to 2 ** 29, even if OpenSSL
+ # accepts such lengths.
+ builtin_overflown_sizes = openssl_overflown_sizes + (
+ 2 ** 29, 2 ** 32 - 10, 2 ** 32, 2 ** 32 + 10,
+ 2 ** 61, 2 ** 64 - 10, 2 ** 64, 2 ** 64 + 10,
+ )
+
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(usedforsecurity=False)
+ if HASH is not None and isinstance(h, HASH):
+ overflown_sizes = openssl_overflown_sizes
+ else:
+ overflown_sizes = builtin_overflown_sizes
+ for invalid_size in overflown_sizes:
+ self.assertRaises(exc_types, h.digest, invalid_size)
+ self.assertRaises(exc_types, h.hexdigest, invalid_size)
def test_name_attribute(self):
for cons in self.hash_constructors:
@@ -719,8 +849,6 @@ class HashLibTestCase(unittest.TestCase):
self.assertRaises(ValueError, constructor, node_offset=-1)
self.assertRaises(OverflowError, constructor, node_offset=max_offset+1)
- self.assertRaises(TypeError, constructor, data=b'')
- self.assertRaises(TypeError, constructor, string=b'')
self.assertRaises(TypeError, constructor, '')
constructor(
@@ -929,49 +1057,67 @@ class HashLibTestCase(unittest.TestCase):
def test_sha256_gil(self):
gil_minsize = hashlib_helper.find_gil_minsize(['_sha2', '_hashlib'])
+ data = b'1' + b'#' * gil_minsize + b'1'
+ expected = hashlib.sha256(data).hexdigest()
+
m = hashlib.sha256()
m.update(b'1')
m.update(b'#' * gil_minsize)
m.update(b'1')
- self.assertEqual(
- m.hexdigest(),
- '1cfceca95989f51f658e3f3ffe7f1cd43726c9e088c13ee10b46f57cef135b94'
- )
+ self.assertEqual(m.hexdigest(), expected)
- m = hashlib.sha256(b'1' + b'#' * gil_minsize + b'1')
- self.assertEqual(
- m.hexdigest(),
- '1cfceca95989f51f658e3f3ffe7f1cd43726c9e088c13ee10b46f57cef135b94'
- )
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_threaded_hashing_fast(self):
+ # Same as test_threaded_hashing_slow() but only tests some functions
+ # since otherwise test_hashlib.py becomes too slow during development.
+ for name in ['md5', 'sha1', 'sha256', 'sha3_256', 'blake2s']:
+ if constructor := getattr(hashlib, name, None):
+ with self.subTest(name):
+ self.do_test_threaded_hashing(constructor, is_shake=False)
+ if shake_128 := getattr(hashlib, 'shake_128', None):
+ self.do_test_threaded_hashing(shake_128, is_shake=True)
+ @requires_resource('cpu')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
- def test_threaded_hashing(self):
+ def test_threaded_hashing_slow(self):
+ for algorithm, constructors in self.constructors_to_test.items():
+ is_shake = algorithm in self.shakes
+ for constructor in constructors:
+ with self.subTest(constructor.__name__, is_shake=is_shake):
+ self.do_test_threaded_hashing(constructor, is_shake)
+
+ def do_test_threaded_hashing(self, constructor, is_shake):
# Updating the same hash object from several threads at once
# using data chunk sizes containing the same byte sequences.
#
# If the internal locks are working to prevent multiple
# updates on the same object from running at once, the resulting
# hash will be the same as doing it single threaded upfront.
- hasher = hashlib.sha1()
- num_threads = 5
- smallest_data = b'swineflu'
- data = smallest_data * 200000
- expected_hash = hashlib.sha1(data*num_threads).hexdigest()
-
- def hash_in_chunks(chunk_size):
- index = 0
- while index < len(data):
- hasher.update(data[index:index + chunk_size])
- index += chunk_size
+
+ # The data to hash has length s|M|q^N and the chunk size for the i-th
+ # thread is s|M|q^(N-i), where N is the number of threads, M is a fixed
+ # message of small length, and s >= 1 and q >= 2 are small integers.
+ smallest_size, num_threads, s, q = 8, 5, 2, 10
+
+ smallest_data = os.urandom(smallest_size)
+ data = s * smallest_data * (q ** num_threads)
+
+ h1 = constructor(usedforsecurity=False)
+ h2 = constructor(data * num_threads, usedforsecurity=False)
+
+ def update(chunk_size):
+ for index in range(0, len(data), chunk_size):
+ h1.update(data[index:index + chunk_size])
threads = []
- for threadnum in range(num_threads):
- chunk_size = len(data) // (10 ** threadnum)
+ for thread_num in range(num_threads):
+ # chunk_size = len(data) // (q ** thread_num)
+ chunk_size = s * smallest_size * q ** (num_threads - thread_num)
self.assertGreater(chunk_size, 0)
- self.assertEqual(chunk_size % len(smallest_data), 0)
- thread = threading.Thread(target=hash_in_chunks,
- args=(chunk_size,))
+ self.assertEqual(chunk_size % smallest_size, 0)
+ thread = threading.Thread(target=update, args=(chunk_size,))
threads.append(thread)
for thread in threads:
@@ -979,7 +1125,10 @@ class HashLibTestCase(unittest.TestCase):
for thread in threads:
thread.join()
- self.assertEqual(expected_hash, hasher.hexdigest())
+ if is_shake:
+ self.assertEqual(h1.hexdigest(16), h2.hexdigest(16))
+ else:
+ self.assertEqual(h1.hexdigest(), h2.hexdigest())
def test_get_fips_mode(self):
fips_mode = self.is_fips_mode
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 1aa8e4e2897..d6623fee9bb 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -13,8 +13,9 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
# _heapq is imported, so check them there
-func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
- '_heappop_max', '_heapreplace_max', '_heapify_max']
+func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace']
+# Add max-heap variants
+func_names += [func + '_max' for func in func_names]
class TestModules(TestCase):
def test_py_functions(self):
@@ -24,7 +25,7 @@ class TestModules(TestCase):
@skipUnless(c_heapq, 'requires _heapq')
def test_c_functions(self):
for fname in func_names:
- self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
+ self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq', fname)
def load_tests(loader, tests, ignore):
@@ -74,6 +75,34 @@ class TestHeap:
except AttributeError:
pass
+ def test_max_push_pop(self):
+ # 1) Push 256 random numbers and pop them off, verifying all's OK.
+ heap = []
+ data = []
+ self.check_max_invariant(heap)
+ for i in range(256):
+ item = random.random()
+ data.append(item)
+ self.module.heappush_max(heap, item)
+ self.check_max_invariant(heap)
+ results = []
+ while heap:
+ item = self.module.heappop_max(heap)
+ self.check_max_invariant(heap)
+ results.append(item)
+ data_sorted = data[:]
+ data_sorted.sort(reverse=True)
+
+ self.assertEqual(data_sorted, results)
+ # 2) Check that the invariant holds for a sorted array
+ self.check_max_invariant(results)
+
+ self.assertRaises(TypeError, self.module.heappush_max, [])
+
+ exc_types = (AttributeError, TypeError)
+ self.assertRaises(exc_types, self.module.heappush_max, None, None)
+ self.assertRaises(exc_types, self.module.heappop_max, None)
+
def check_invariant(self, heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
@@ -81,6 +110,11 @@ class TestHeap:
parentpos = (pos-1) >> 1
self.assertTrue(heap[parentpos] <= item)
+ def check_max_invariant(self, heap):
+ for pos, item in enumerate(heap[1:], start=1):
+ parentpos = (pos - 1) >> 1
+ self.assertGreaterEqual(heap[parentpos], item)
+
def test_heapify(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
@@ -89,6 +123,14 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapify, None)
+ def test_heapify_max(self):
+ for size in list(range(30)) + [20000]:
+ heap = [random.random() for dummy in range(size)]
+ self.module.heapify_max(heap)
+ self.check_max_invariant(heap)
+
+ self.assertRaises(TypeError, self.module.heapify_max, None)
+
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
heap = []
@@ -109,10 +151,7 @@ class TestHeap:
def test_nbest(self):
# Less-naive "N-best" algorithm, much faster (if len(data) is big
- # enough <wink>) than sorting all of data. However, if we had a max
- # heap instead of a min heap, it could go faster still via
- # heapify'ing all of data (linear time), then doing 10 heappops
- # (10 log-time steps).
+ # enough <wink>) than sorting all of data.
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
@@ -125,6 +164,17 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapreplace, None, None)
self.assertRaises(IndexError, self.module.heapreplace, [], None)
+ def test_nbest_maxheap(self):
+ # With a max heap instead of a min heap, the "N-best" algorithm can
+ # go even faster still via heapify'ing all of data (linear time), then
+ # doing 10 heappops (10 log-time steps).
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify_max(heap)
+ result = [self.module.heappop_max(heap) for _ in range(10)]
+ result.reverse()
+ self.assertEqual(result, sorted(data)[-10:])
+
def test_nbest_with_pushpop(self):
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
@@ -134,6 +184,62 @@ class TestHeap:
self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
self.assertEqual(self.module.heappushpop([], 'x'), 'x')
+ def test_naive_nworst(self):
+ # Max-heap variant of "test_naive_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ if len(heap) > 10:
+ self.module.heappop_max(heap)
+ heap.sort()
+ expected = sorted(data)[:10]
+ self.assertEqual(heap, expected)
+
+ def heapiter_max(self, heap):
+ # An iterator returning a max-heap's elements, largest-first.
+ try:
+ while 1:
+ yield self.module.heappop_max(heap)
+ except IndexError:
+ pass
+
+ def test_nworst(self):
+ # Max-heap variant of "test_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ if item < heap[0]: # this gets rarer the longer we run
+ self.module.heapreplace_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+
+ self.assertRaises(TypeError, self.module.heapreplace_max, None)
+ self.assertRaises(TypeError, self.module.heapreplace_max, None, None)
+ self.assertRaises(IndexError, self.module.heapreplace_max, [], None)
+
+ def test_nworst_minheap(self):
+ # Min-heap variant of "test_nbest_maxheap"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify(heap)
+ result = [self.module.heappop(heap) for _ in range(10)]
+ result.reverse()
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(result, expected)
+
+ def test_nworst_with_pushpop(self):
+ # Max-heap variant of "test_nbest_with_pushpop"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ self.module.heappushpop_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+ self.assertEqual(self.module.heappushpop_max([], 'x'), 'x')
+
def test_heappushpop(self):
h = []
x = self.module.heappushpop(h, 10)
@@ -153,12 +259,31 @@ class TestHeap:
x = self.module.heappushpop(h, 11)
self.assertEqual((h, x), ([11], 10))
+ def test_heappushpop_max(self):
+ h = []
+ x = self.module.heappushpop_max(h, 10)
+ self.assertTupleEqual((h, x), ([], 10))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 10.0)
+ self.assertTupleEqual((h, x), ([10], 10.0))
+ self.assertIsInstance(h[0], int)
+ self.assertIsInstance(x, float)
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 11)
+ self.assertTupleEqual((h, x), ([10], 11))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 9)
+ self.assertTupleEqual((h, x), ([9], 10))
+
def test_heappop_max(self):
- # _heapop_max has an optimization for one-item lists which isn't
+ # heapop_max has an optimization for one-item lists which isn't
# covered in other tests, so test that case explicitly here
h = [3, 2]
- self.assertEqual(self.module._heappop_max(h), 3)
- self.assertEqual(self.module._heappop_max(h), 2)
+ self.assertEqual(self.module.heappop_max(h), 3)
+ self.assertEqual(self.module.heappop_max(h), 2)
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
@@ -175,6 +300,20 @@ class TestHeap:
heap_sorted = [self.module.heappop(heap) for i in range(size)]
self.assertEqual(heap_sorted, sorted(data))
+ def test_heapsort_max(self):
+ for trial in range(100):
+ size = random.randrange(50)
+ data = [random.randrange(25) for i in range(size)]
+ if trial & 1: # Half of the time, use heapify_max
+ heap = data[:]
+ self.module.heapify_max(heap)
+ else: # The rest of the time, use heappush_max
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ heap_sorted = [self.module.heappop_max(heap) for i in range(size)]
+ self.assertEqual(heap_sorted, sorted(data, reverse=True))
+
def test_merge(self):
inputs = []
for i in range(random.randrange(25)):
@@ -377,16 +516,20 @@ class SideEffectLT:
class TestErrorHandling:
def test_non_sequence(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, 10)
for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10, 10)
def test_len_only(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly())
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(TypeError, f, 2, LenOnly())
@@ -395,7 +538,8 @@ class TestErrorHandling:
seq = [CmpErr(), CmpErr(), CmpErr()]
for f in (self.module.heapify, self.module.heappop):
self.assertRaises(ZeroDivisionError, f, seq)
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises(ZeroDivisionError, f, seq, 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(ZeroDivisionError, f, 2, seq)
@@ -403,6 +547,8 @@ class TestErrorHandling:
def test_arg_parsing(self):
for f in (self.module.heapify, self.module.heappop,
self.module.heappush, self.module.heapreplace,
+ self.module.heapify_max, self.module.heappop_max,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10)
@@ -424,6 +570,10 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappush(heap, SideEffectLT(5, heap))
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappush_max(heap, SideEffectLT(5, heap))
def test_heappop_mutating_heap(self):
heap = []
@@ -431,8 +581,12 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappop_max(heap)
- def test_comparison_operator_modifiying_heap(self):
+ def test_comparison_operator_modifying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
@@ -444,7 +598,7 @@ class TestErrorHandling:
self.module.heappush(heap, EvilClass(0))
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
- def test_comparison_operator_modifiying_heap_two_heaps(self):
+ def test_comparison_operator_modifying_heap_two_heaps(self):
class h(int):
def __lt__(self, o):
@@ -464,6 +618,17 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
+ list1, list2 = [], []
+
+ self.module.heappush_max(list1, h(0))
+ self.module.heappush_max(list2, g(0))
+ self.module.heappush_max(list1, g(1))
+ self.module.heappush_max(list2, h(1))
+
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list1, g(1))
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list2, h(1))
+
+
class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 70c79437722..ff6e1bce0ef 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -1,3 +1,21 @@
+"""Test suite for HMAC.
+
+Python provides three different implementations of HMAC:
+
+- OpenSSL HMAC using OpenSSL hash functions.
+- HACL* HMAC using HACL* hash functions.
+- Generic Python HMAC using user-defined hash functions.
+
+The generic Python HMAC implementation is able to use OpenSSL
+callables or names, HACL* named hash functions or arbitrary
+objects implementing PEP 247 interface.
+
+In the two first cases, Python HMAC wraps a C HMAC object (either OpenSSL
+or HACL*-based). As a last resort, HMAC is re-implemented in pure Python.
+It is however interesting to test the pure Python implementation against
+the OpenSSL and HACL* hash functions.
+"""
+
import binascii
import functools
import hmac
@@ -10,6 +28,12 @@ import unittest.mock as mock
import warnings
from _operator import _compare_digest as operator_compare_digest
from test.support import check_disallow_instantiation
+from test.support.hashlib_helper import (
+ BuiltinHashFunctionsTrait,
+ HashFunctionsTrait,
+ NamedHashFunctionsTrait,
+ OpenSSLHashFunctionsTrait,
+)
from test.support.import_helper import import_fresh_module, import_module
try:
@@ -382,50 +406,7 @@ class BuiltinAssertersMixin(ThroughBuiltinAPIMixin, AssertersMixin):
pass
-class HashFunctionsTrait:
- """Trait class for 'hashfunc' in hmac_new() and hmac_digest()."""
-
- ALGORITHMS = [
- 'md5', 'sha1',
- 'sha224', 'sha256', 'sha384', 'sha512',
- 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512',
- ]
-
- # By default, a missing algorithm skips the test that uses it.
- _ = property(lambda self: self.skipTest("missing hash function"))
- md5 = sha1 = _
- sha224 = sha256 = sha384 = sha512 = _
- sha3_224 = sha3_256 = sha3_384 = sha3_512 = _
- del _
-
-
-class WithOpenSSLHashFunctions(HashFunctionsTrait):
- """Test a HMAC implementation with an OpenSSL-based callable 'hashfunc'."""
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- for name in cls.ALGORITHMS:
- @property
- @hashlib_helper.requires_openssl_hashdigest(name)
- def func(self, *, __name=name): # __name needed to bind 'name'
- return getattr(_hashlib, f'openssl_{__name}')
- setattr(cls, name, func)
-
-
-class WithNamedHashFunctions(HashFunctionsTrait):
- """Test a HMAC implementation with a named 'hashfunc'."""
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- for name in cls.ALGORITHMS:
- setattr(cls, name, name)
-
-
-class RFCTestCaseMixin(AssertersMixin, HashFunctionsTrait):
+class RFCTestCaseMixin(HashFunctionsTrait, AssertersMixin):
"""Test HMAC implementations against RFC 2202/4231 and NIST test vectors.
- Test vectors for MD5 and SHA-1 are taken from RFC 2202.
@@ -739,26 +720,83 @@ class RFCTestCaseMixin(AssertersMixin, HashFunctionsTrait):
)
-class PyRFCTestCase(ThroughObjectMixin, PyAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
- unittest.TestCase):
+class PurePythonInitHMAC(PyModuleMixin, HashFunctionsTrait):
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ for meth in ['_init_openssl_hmac', '_init_builtin_hmac']:
+ fn = getattr(cls.hmac.HMAC, meth)
+ cm = mock.patch.object(cls.hmac.HMAC, meth, autospec=True, wraps=fn)
+ cls.enterClassContext(cm)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.hmac.HMAC._init_openssl_hmac.assert_not_called()
+ cls.hmac.HMAC._init_builtin_hmac.assert_not_called()
+ # Do not assert that HMAC._init_old() has been called as it's tricky
+ # to determine whether a test for a specific hash function has been
+ # executed or not. On regular builds, it will be called but if a
+ # hash function is not available, it's hard to detect for which
+ # test we should checj HMAC._init_old() or not.
+ super().tearDownClass()
+
+
+class PyRFCOpenSSLTestCase(ThroughObjectMixin,
+ PyAssertersMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
"""Python implementation of HMAC using hmac.HMAC().
- The underlying hash functions are OpenSSL-based.
+ The underlying hash functions are OpenSSL-based but
+ _init_old() is used instead of _init_openssl_hmac().
"""
-class PyDotNewRFCTestCase(ThroughModuleAPIMixin, PyAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
- unittest.TestCase):
+class PyRFCBuiltinTestCase(ThroughObjectMixin,
+ PyAssertersMixin,
+ BuiltinHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
+ """Python implementation of HMAC using hmac.HMAC().
+
+ The underlying hash functions are HACL*-based but
+ _init_old() is used instead of _init_builtin_hmac().
+ """
+
+
+class PyDotNewOpenSSLRFCTestCase(ThroughModuleAPIMixin,
+ PyAssertersMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
+ """Python implementation of HMAC using hmac.new().
+
+ The underlying hash functions are OpenSSL-based but
+ _init_old() is used instead of _init_openssl_hmac().
+ """
+
+
+class PyDotNewBuiltinRFCTestCase(ThroughModuleAPIMixin,
+ PyAssertersMixin,
+ BuiltinHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
"""Python implementation of HMAC using hmac.new().
- The underlying hash functions are OpenSSL-based.
+ The underlying hash functions are HACL-based but
+ _init_old() is used instead of _init_openssl_hmac().
"""
class OpenSSLRFCTestCase(OpenSSLAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
unittest.TestCase):
"""OpenSSL implementation of HMAC.
@@ -767,7 +805,8 @@ class OpenSSLRFCTestCase(OpenSSLAssertersMixin,
class BuiltinRFCTestCase(BuiltinAssertersMixin,
- WithNamedHashFunctions, RFCTestCaseMixin,
+ NamedHashFunctionsTrait,
+ RFCTestCaseMixin,
unittest.TestCase):
"""Built-in HACL* implementation of HMAC.
@@ -784,12 +823,6 @@ class BuiltinRFCTestCase(BuiltinAssertersMixin,
self.check_hmac_hexdigest(key, msg, hexdigest, digest_size, func)
-# TODO(picnixz): once we have a HACL* HMAC, we should also test the Python
-# implementation of HMAC with a HACL*-based hash function. For now, we only
-# test it partially via the '_sha2' module, but for completeness we could
-# also test the RFC test vectors against all possible implementations.
-
-
class DigestModTestCaseMixin(CreatorMixin, DigestMixin):
"""Tests for the 'digestmod' parameter for hmac_new() and hmac_digest()."""
diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py
index b42a611c62c..65a4bee72b9 100644
--- a/Lib/test/test_htmlparser.py
+++ b/Lib/test/test_htmlparser.py
@@ -5,6 +5,7 @@ import pprint
import unittest
from unittest.mock import patch
+from test import support
class EventCollector(html.parser.HTMLParser):
@@ -317,6 +318,16 @@ text
("endtag", element_lower)],
collector=Collector(convert_charrefs=False))
+ def test_EOF_in_cdata(self):
+ content = """<!-- not a comment --> &not-an-entity-ref;
+ <a href="" /> </p><p> <span></span></style>
+ '</script' + '>'"""
+ s = f'<script>{content}'
+ self._run_check(s, [
+ ("starttag", 'script', []),
+ ("data", content)
+ ])
+
def test_comments(self):
html = ("<!-- I'm a valid comment -->"
'<!--me too!-->'
@@ -348,18 +359,16 @@ text
collector = lambda: EventCollectorCharrefs()
self.assertTrue(collector().convert_charrefs)
charrefs = ['&quot;', '&#34;', '&#x22;', '&quot', '&#34', '&#x22']
- # check charrefs in the middle of the text/attributes
- expected = [('starttag', 'a', [('href', 'foo"zar')]),
- ('data', 'a"z'), ('endtag', 'a')]
+ # check charrefs in the middle of the text
+ expected = [('starttag', 'a', []), ('data', 'a"z'), ('endtag', 'a')]
for charref in charrefs:
- self._run_check('<a href="foo{0}zar">a{0}z</a>'.format(charref),
+ self._run_check('<a>a{0}z</a>'.format(charref),
expected, collector=collector())
- # check charrefs at the beginning/end of the text/attributes
- expected = [('data', '"'),
- ('starttag', 'a', [('x', '"'), ('y', '"X'), ('z', 'X"')]),
+ # check charrefs at the beginning/end of the text
+ expected = [('data', '"'), ('starttag', 'a', []),
('data', '"'), ('endtag', 'a'), ('data', '"')]
for charref in charrefs:
- self._run_check('{0}<a x="{0}" y="{0}X" z="X{0}">'
+ self._run_check('{0}<a>'
'{0}</a>{0}'.format(charref),
expected, collector=collector())
# check charrefs in <script>/<style> elements
@@ -382,6 +391,35 @@ text
self._run_check('no charrefs here', [('data', 'no charrefs here')],
collector=collector())
+ def test_convert_charrefs_in_attribute_values(self):
+ # default value for convert_charrefs is now True
+ collector = lambda: EventCollectorCharrefs()
+ self.assertTrue(collector().convert_charrefs)
+
+ # always unescape terminated entity refs, numeric and hex char refs:
+ # - regardless whether they are at start, middle, end of attribute
+ # - or followed by alphanumeric, non-alphanumeric, or equals char
+ charrefs = ['&cent;', '&#xa2;', '&#xa2', '&#162;', '&#162']
+ expected = [('starttag', 'a',
+ [('x', '¢'), ('x', 'z¢'), ('x', '¢z'),
+ ('x', 'z¢z'), ('x', '¢ z'), ('x', '¢=z')]),
+ ('endtag', 'a')]
+ for charref in charrefs:
+ self._run_check('<a x="{0}" x="z{0}" x="{0}z" '
+ ' x="z{0}z" x="{0} z" x="{0}=z"></a>'
+ .format(charref), expected, collector=collector())
+
+ # only unescape unterminated entity matches if they are not followed by
+ # an alphanumeric or an equals sign
+ charref = '&cent'
+ expected = [('starttag', 'a',
+ [('x', '¢'), ('x', 'z¢'), ('x', '&centz'),
+ ('x', 'z&centz'), ('x', '¢ z'), ('x', '&cent=z')]),
+ ('endtag', 'a')]
+ self._run_check('<a x="{0}" x="z{0}" x="{0}z" '
+ ' x="z{0}z" x="{0} z" x="{0}=z"></a>'
+ .format(charref), expected, collector=collector())
+
# the remaining tests were for the "tolerant" parser (which is now
# the default), and check various kind of broken markup
def test_tolerant_parsing(self):
@@ -393,28 +431,34 @@ text
('data', '<'),
('starttag', 'bc<', [('a', None)]),
('endtag', 'html'),
- ('data', '\n<img src="URL>'),
- ('comment', '/img'),
- ('endtag', 'html<')])
+ ('data', '\n')])
def test_starttag_junk_chars(self):
+ self._run_check("<", [('data', '<')])
+ self._run_check("<>", [('data', '<>')])
+ self._run_check("< >", [('data', '< >')])
+ self._run_check("< ", [('data', '< ')])
self._run_check("</>", [])
+ self._run_check("<$>", [('data', '<$>')])
self._run_check("</$>", [('comment', '$')])
self._run_check("</", [('data', '</')])
- self._run_check("</a", [('data', '</a')])
+ self._run_check("</a", [])
+ self._run_check("</ a>", [('endtag', 'a')])
+ self._run_check("</ a", [('comment', ' a')])
self._run_check("<a<a>", [('starttag', 'a<a', [])])
self._run_check("</a<a>", [('endtag', 'a<a')])
- self._run_check("<!", [('data', '<!')])
- self._run_check("<a", [('data', '<a')])
- self._run_check("<a foo='bar'", [('data', "<a foo='bar'")])
- self._run_check("<a foo='bar", [('data', "<a foo='bar")])
- self._run_check("<a foo='>'", [('data', "<a foo='>'")])
- self._run_check("<a foo='>", [('data', "<a foo='>")])
+ self._run_check("<!", [('comment', '')])
+ self._run_check("<a", [])
+ self._run_check("<a foo='bar'", [])
+ self._run_check("<a foo='bar", [])
+ self._run_check("<a foo='>'", [])
+ self._run_check("<a foo='>", [])
self._run_check("<a$>", [('starttag', 'a$', [])])
self._run_check("<a$b>", [('starttag', 'a$b', [])])
self._run_check("<a$b/>", [('startendtag', 'a$b', [])])
self._run_check("<a$b >", [('starttag', 'a$b', [])])
self._run_check("<a$b />", [('startendtag', 'a$b', [])])
+ self._run_check("</a$b>", [('endtag', 'a$b')])
def test_slashes_in_starttag(self):
self._run_check('<a foo="var"/>', [('startendtag', 'a', [('foo', 'var')])])
@@ -539,52 +583,129 @@ text
for html, expected in data:
self._run_check(html, expected)
- def test_broken_comments(self):
- html = ('<! not really a comment >'
+ def test_eof_in_comments(self):
+ data = [
+ ('<!--', [('comment', '')]),
+ ('<!---', [('comment', '')]),
+ ('<!----', [('comment', '')]),
+ ('<!-----', [('comment', '-')]),
+ ('<!------', [('comment', '--')]),
+ ('<!----!', [('comment', '')]),
+ ('<!---!', [('comment', '-!')]),
+ ('<!---!>', [('comment', '-!>')]),
+ ('<!--foo', [('comment', 'foo')]),
+ ('<!--foo-', [('comment', 'foo')]),
+ ('<!--foo--', [('comment', 'foo')]),
+ ('<!--foo--!', [('comment', 'foo')]),
+ ('<!--<!--', [('comment', '<!')]),
+ ('<!--<!--!', [('comment', '<!')]),
+ ]
+ for html, expected in data:
+ self._run_check(html, expected)
+
+ def test_eof_in_declarations(self):
+ data = [
+ ('<!', [('comment', '')]),
+ ('<!-', [('comment', '-')]),
+ ('<![', [('comment', '[')]),
+ ('<![CDATA[', [('unknown decl', 'CDATA[')]),
+ ('<![CDATA[x', [('unknown decl', 'CDATA[x')]),
+ ('<![CDATA[x]', [('unknown decl', 'CDATA[x]')]),
+ ('<![CDATA[x]]', [('unknown decl', 'CDATA[x]]')]),
+ ('<!DOCTYPE', [('decl', 'DOCTYPE')]),
+ ('<!DOCTYPE ', [('decl', 'DOCTYPE ')]),
+ ('<!DOCTYPE html', [('decl', 'DOCTYPE html')]),
+ ('<!DOCTYPE html ', [('decl', 'DOCTYPE html ')]),
+ ('<!DOCTYPE html PUBLIC', [('decl', 'DOCTYPE html PUBLIC')]),
+ ('<!DOCTYPE html PUBLIC "foo', [('decl', 'DOCTYPE html PUBLIC "foo')]),
+ ('<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN" "foo',
+ [('decl', 'DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN" "foo')]),
+ ]
+ for html, expected in data:
+ self._run_check(html, expected)
+
+ def test_bogus_comments(self):
+ html = ('<!ELEMENT br EMPTY>'
+ '<! not really a comment >'
'<! not a comment either -->'
'<! -- close enough -->'
'<!><!<-- this was an empty comment>'
- '<!!! another bogus comment !!!>')
+ '<!!! another bogus comment !!!>'
+ # see #32876
+ '<![with square brackets]!>'
+ '<![\nmultiline\nbogusness\n]!>'
+ '<![more brackets]-[and a hyphen]!>'
+ '<![cdata[should be uppercase]]>'
+ '<![CDATA [whitespaces are not ignored]]>'
+ '<![CDATA]]>' # required '[' after CDATA
+ )
expected = [
+ ('comment', 'ELEMENT br EMPTY'),
('comment', ' not really a comment '),
('comment', ' not a comment either --'),
('comment', ' -- close enough --'),
('comment', ''),
('comment', '<-- this was an empty comment'),
('comment', '!! another bogus comment !!!'),
+ ('comment', '[with square brackets]!'),
+ ('comment', '[\nmultiline\nbogusness\n]!'),
+ ('comment', '[more brackets]-[and a hyphen]!'),
+ ('comment', '[cdata[should be uppercase]]'),
+ ('comment', '[CDATA [whitespaces are not ignored]]'),
+ ('comment', '[CDATA]]'),
]
self._run_check(html, expected)
def test_broken_condcoms(self):
# these condcoms are missing the '--' after '<!' and before the '>'
+ # and they are considered bogus comments according to
+ # "8.2.4.42. Markup declaration open state"
html = ('<![if !(IE)]>broken condcom<![endif]>'
'<![if ! IE]><link href="favicon.tiff"/><![endif]>'
'<![if !IE 6]><img src="firefox.png" /><![endif]>'
'<![if !ie 6]><b>foo</b><![endif]>'
'<![if (!IE)|(lt IE 9)]><img src="mammoth.bmp" /><![endif]>')
- # According to the HTML5 specs sections "8.2.4.44 Bogus comment state"
- # and "8.2.4.45 Markup declaration open state", comment tokens should
- # be emitted instead of 'unknown decl', but calling unknown_decl
- # provides more flexibility.
- # See also Lib/_markupbase.py:parse_declaration
expected = [
- ('unknown decl', 'if !(IE)'),
+ ('comment', '[if !(IE)]'),
('data', 'broken condcom'),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if ! IE'),
+ ('comment', '[endif]'),
+ ('comment', '[if ! IE]'),
('startendtag', 'link', [('href', 'favicon.tiff')]),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if !IE 6'),
+ ('comment', '[endif]'),
+ ('comment', '[if !IE 6]'),
('startendtag', 'img', [('src', 'firefox.png')]),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if !ie 6'),
+ ('comment', '[endif]'),
+ ('comment', '[if !ie 6]'),
('starttag', 'b', []),
('data', 'foo'),
('endtag', 'b'),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if (!IE)|(lt IE 9)'),
+ ('comment', '[endif]'),
+ ('comment', '[if (!IE)|(lt IE 9)]'),
('startendtag', 'img', [('src', 'mammoth.bmp')]),
- ('unknown decl', 'endif')
+ ('comment', '[endif]')
+ ]
+ self._run_check(html, expected)
+
+ def test_cdata_declarations(self):
+ # More tests should be added. See also "8.2.4.42. Markup
+ # declaration open state", "8.2.4.69. CDATA section state",
+ # and issue 32876
+ html = ('<![CDATA[just some plain text]]>')
+ expected = [('unknown decl', 'CDATA[just some plain text')]
+ self._run_check(html, expected)
+
+ def test_cdata_declarations_multiline(self):
+ html = ('<code><![CDATA['
+ ' if (a < b && a > b) {'
+ ' printf("[<marquee>How?</marquee>]");'
+ ' }'
+ ']]></code>')
+ expected = [
+ ('starttag', 'code', []),
+ ('unknown decl',
+ 'CDATA[ if (a < b && a > b) { '
+ 'printf("[<marquee>How?</marquee>]"); }'),
+ ('endtag', 'code')
]
self._run_check(html, expected)
@@ -600,6 +721,26 @@ text
('endtag', 'a'), ('data', ' bar & baz')]
)
+ @support.requires_resource('cpu')
+ def test_eof_no_quadratic_complexity(self):
+ # Each of these examples used to take about an hour.
+ # Now they take a fraction of a second.
+ def check(source):
+ parser = html.parser.HTMLParser()
+ parser.feed(source)
+ parser.close()
+ n = 120_000
+ check("<a " * n)
+ check("<a a=" * n)
+ check("</a " * 14 * n)
+ check("</a a=" * 11 * n)
+ check("<!--" * 4 * n)
+ check("<!" * 60 * n)
+ check("<?" * 19 * n)
+ check("</$" * 15 * n)
+ check("<![CDATA[" * 9 * n)
+ check("<!doctype" * 35 * n)
+
class AttributesTestCase(TestCaseBase):
diff --git a/Lib/test/test_http_cookiejar.py b/Lib/test/test_http_cookiejar.py
index 6bc33b15ec3..04cb440cd4c 100644
--- a/Lib/test/test_http_cookiejar.py
+++ b/Lib/test/test_http_cookiejar.py
@@ -4,6 +4,7 @@ import os
import stat
import sys
import re
+from test import support
from test.support import os_helper
from test.support import warnings_helper
import time
@@ -105,8 +106,7 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(http2time(s.lower()), test_t, s.lower())
self.assertEqual(http2time(s.upper()), test_t, s.upper())
- def test_http2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Mandag 16. September 1996',
@@ -121,10 +121,9 @@ class DateTimeTests(unittest.TestCase):
'08-01-3697739',
'09 Feb 19942632 22:23:32 GMT',
'Wed, 09 Feb 1994834 22:23:32 GMT',
- ]:
- self.assertIsNone(http2time(test),
- "http2time(%s) is not None\n"
- "http2time(test) %s" % (test, http2time(test)))
+ ])
+ def test_http2time_garbage(self, test):
+ self.assertIsNone(http2time(test))
def test_http2time_redos_regression_actually_completes(self):
# LOOSE_HTTP_DATE_RE was vulnerable to malicious input which caused catastrophic backtracking (REDoS).
@@ -149,9 +148,7 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(parse_date("1994-02-03 19:45:29 +0530"),
(1994, 2, 3, 14, 15, 29))
- def test_iso2time_formats(self):
- # test iso2time for supported dates.
- tests = [
+ @support.subTests('s', [
'1994-02-03 00:00:00 -0000', # ISO 8601 format
'1994-02-03 00:00:00 +0000', # ISO 8601 format
'1994-02-03 00:00:00', # zone is optional
@@ -164,16 +161,15 @@ class DateTimeTests(unittest.TestCase):
# A few tests with extra space at various places
' 1994-02-03 ',
' 1994-02-03T00:00:00 ',
- ]
-
+ ])
+ def test_iso2time_formats(self, s):
+ # test iso2time for supported dates.
test_t = 760233600 # assume broken POSIX counting of seconds
- for s in tests:
- self.assertEqual(iso2time(s), test_t, s)
- self.assertEqual(iso2time(s.lower()), test_t, s.lower())
- self.assertEqual(iso2time(s.upper()), test_t, s.upper())
+ self.assertEqual(iso2time(s), test_t, s)
+ self.assertEqual(iso2time(s.lower()), test_t, s.lower())
+ self.assertEqual(iso2time(s.upper()), test_t, s.upper())
- def test_iso2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Thursday, 03-Feb-94 00:00:00 GMT',
@@ -186,9 +182,9 @@ class DateTimeTests(unittest.TestCase):
'01-01-1980 00:00:62',
'01-01-1980T00:00:62',
'19800101T250000Z',
- ]:
- self.assertIsNone(iso2time(test),
- "iso2time(%r)" % test)
+ ])
+ def test_iso2time_garbage(self, test):
+ self.assertIsNone(iso2time(test))
def test_iso2time_performance_regression(self):
# If ISO_DATE_RE regresses to quadratic complexity, this test will take a very long time to succeed.
@@ -199,24 +195,23 @@ class DateTimeTests(unittest.TestCase):
class HeaderTests(unittest.TestCase):
- def test_parse_ns_headers(self):
- # quotes should be stripped
- expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
- for hdr in [
+ @support.subTests('hdr', [
'foo=bar; expires=01 Jan 2040 22:23:32 GMT',
'foo=bar; expires="01 Jan 2040 22:23:32 GMT"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
-
- def test_parse_ns_headers_version(self):
-
+ ])
+ def test_parse_ns_headers(self, hdr):
# quotes should be stripped
- expected = [[('foo', 'bar'), ('version', '1')]]
- for hdr in [
+ expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
+
+ @support.subTests('hdr', [
'foo=bar; version="1"',
'foo=bar; Version="1"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
+ ])
+ def test_parse_ns_headers_version(self, hdr):
+ # quotes should be stripped
+ expected = [[('foo', 'bar'), ('version', '1')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
def test_parse_ns_headers_special_names(self):
# names such as 'expires' are not special in first name=value pair
@@ -226,8 +221,7 @@ class HeaderTests(unittest.TestCase):
expected = [[("expires", "01 Jan 2040 22:23:32 GMT"), ("version", "0")]]
self.assertEqual(parse_ns_headers([hdr]), expected)
- def test_join_header_words(self):
- for src, expected in [
+ @support.subTests('src,expected', [
([[("foo", None), ("bar", "baz")]], "foo; bar=baz"),
(([]), ""),
(([[]]), ""),
@@ -237,12 +231,11 @@ class HeaderTests(unittest.TestCase):
'n; foo="foo;_", bar=foo_bar'),
([[("n", "m"), ("foo", None)], [("bar", "foo_bar")]],
'n=m; foo, bar=foo_bar'),
- ]:
- with self.subTest(src=src):
- self.assertEqual(join_header_words(src), expected)
+ ])
+ def test_join_header_words(self, src, expected):
+ self.assertEqual(join_header_words(src), expected)
- def test_split_header_words(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", [[("foo", None)]]),
("foo=bar", [[("foo", "bar")]]),
(" foo ", [[("foo", None)]]),
@@ -259,24 +252,22 @@ class HeaderTests(unittest.TestCase):
(r'foo; bar=baz, spam=, foo="\,\;\"", bar= ',
[[("foo", None), ("bar", "baz")],
[("spam", "")], [("foo", ',;"')], [("bar", "")]]),
- ]
-
- for arg, expect in tests:
- try:
- result = split_header_words([arg])
- except:
- import traceback, io
- f = io.StringIO()
- traceback.print_exc(None, f)
- result = "(error -- traceback follows)\n\n%s" % f.getvalue()
- self.assertEqual(result, expect, """
+ ])
+ def test_split_header_words(self, arg, expect):
+ try:
+ result = split_header_words([arg])
+ except:
+ import traceback, io
+ f = io.StringIO()
+ traceback.print_exc(None, f)
+ result = "(error -- traceback follows)\n\n%s" % f.getvalue()
+ self.assertEqual(result, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
""" % (arg, expect, result))
- def test_roundtrip(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", "foo"),
("foo=bar", "foo=bar"),
(" foo ", "foo"),
@@ -309,12 +300,11 @@ Got: '%s'
('n; foo="foo;_", bar="foo,_"',
'n; foo="foo;_", bar="foo,_"'),
- ]
-
- for arg, expect in tests:
- input = split_header_words([arg])
- res = join_header_words(input)
- self.assertEqual(res, expect, """
+ ])
+ def test_roundtrip(self, arg, expect):
+ input = split_header_words([arg])
+ res = join_header_words(input)
+ self.assertEqual(res, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
@@ -516,14 +506,7 @@ class CookieTests(unittest.TestCase):
## just the 7 special TLD's listed in their spec. And folks rely on
## that...
- def test_domain_return_ok(self):
- # test optimization: .domain_return_ok() should filter out most
- # domains in the CookieJar before we try to access them (because that
- # may require disk access -- in particular, with MSIECookieJar)
- # This is only a rough check for performance reasons, so it's not too
- # critical as long as it's sufficiently liberal.
- pol = DefaultCookiePolicy()
- for url, domain, ok in [
+ @support.subTests('url,domain,ok', [
("http://foo.bar.com/", "blah.com", False),
("http://foo.bar.com/", "rhubarb.blah.com", False),
("http://foo.bar.com/", "rhubarb.foo.bar.com", False),
@@ -543,11 +526,18 @@ class CookieTests(unittest.TestCase):
("http://foo/", ".local", True),
("http://barfoo.com", ".foo.com", False),
("http://barfoo.com", "foo.com", False),
- ]:
- request = urllib.request.Request(url)
- r = pol.domain_return_ok(domain, request)
- if ok: self.assertTrue(r)
- else: self.assertFalse(r)
+ ])
+ def test_domain_return_ok(self, url, domain, ok):
+ # test optimization: .domain_return_ok() should filter out most
+ # domains in the CookieJar before we try to access them (because that
+ # may require disk access -- in particular, with MSIECookieJar)
+ # This is only a rough check for performance reasons, so it's not too
+ # critical as long as it's sufficiently liberal.
+ pol = DefaultCookiePolicy()
+ request = urllib.request.Request(url)
+ r = pol.domain_return_ok(domain, request)
+ if ok: self.assertTrue(r)
+ else: self.assertFalse(r)
def test_missing_value(self):
# missing = sign in Cookie: header is regarded by Mozilla as a missing
@@ -581,10 +571,7 @@ class CookieTests(unittest.TestCase):
self.assertEqual(interact_netscape(c, "http://www.acme.com/foo/"),
'"spam"; eggs')
- def test_rfc2109_handling(self):
- # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
- # dependent on policy settings
- for rfc2109_as_netscape, rfc2965, version in [
+ @support.subTests('rfc2109_as_netscape,rfc2965,version', [
# default according to rfc2965 if not explicitly specified
(None, False, 0),
(None, True, 1),
@@ -593,24 +580,27 @@ class CookieTests(unittest.TestCase):
(False, True, 1),
(True, False, 0),
(True, True, 0),
- ]:
- policy = DefaultCookiePolicy(
- rfc2109_as_netscape=rfc2109_as_netscape,
- rfc2965=rfc2965)
- c = CookieJar(policy)
- interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
- try:
- cookie = c._cookies["www.example.com"]["/"]["ni"]
- except KeyError:
- self.assertIsNone(version) # didn't expect a stored cookie
- else:
- self.assertEqual(cookie.version, version)
- # 2965 cookies are unaffected
- interact_2965(c, "http://www.example.com/",
- "foo=bar; Version=1")
- if rfc2965:
- cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
- self.assertEqual(cookie2965.version, 1)
+ ])
+ def test_rfc2109_handling(self, rfc2109_as_netscape, rfc2965, version):
+ # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
+ # dependent on policy settings
+ policy = DefaultCookiePolicy(
+ rfc2109_as_netscape=rfc2109_as_netscape,
+ rfc2965=rfc2965)
+ c = CookieJar(policy)
+ interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
+ try:
+ cookie = c._cookies["www.example.com"]["/"]["ni"]
+ except KeyError:
+ self.assertIsNone(version) # didn't expect a stored cookie
+ else:
+ self.assertEqual(cookie.version, version)
+ # 2965 cookies are unaffected
+ interact_2965(c, "http://www.example.com/",
+ "foo=bar; Version=1")
+ if rfc2965:
+ cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
+ self.assertEqual(cookie2965.version, 1)
def test_ns_parser(self):
c = CookieJar()
@@ -778,8 +768,7 @@ class CookieTests(unittest.TestCase):
# Cookie is sent back to the same URI.
self.assertEqual(interact_netscape(cj, uri), value)
- def test_escape_path(self):
- cases = [
+ @support.subTests('arg,result', [
# quoted safe
("/foo%2f/bar", "/foo%2F/bar"),
("/foo%2F/bar", "/foo%2F/bar"),
@@ -799,9 +788,9 @@ class CookieTests(unittest.TestCase):
("/foo/bar\u00fc", "/foo/bar%C3%BC"), # UTF-8 encoded
# unicode
("/foo/bar\uabcd", "/foo/bar%EA%AF%8D"), # UTF-8 encoded
- ]
- for arg, result in cases:
- self.assertEqual(escape_path(arg), result)
+ ])
+ def test_escape_path(self, arg, result):
+ self.assertEqual(escape_path(arg), result)
def test_request_path(self):
# with parameters
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 2cafa4e45a1..2548a7c5f29 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -3,16 +3,16 @@
Written by Cody A.W. Somerville <cody-somerville@ubuntu.com>,
Josip Dzolonga, and Michael Otteneder for the 2007/08 GHOP contest.
"""
-from collections import OrderedDict
+
from http.server import BaseHTTPRequestHandler, HTTPServer, HTTPSServer, \
- SimpleHTTPRequestHandler, CGIHTTPRequestHandler
+ SimpleHTTPRequestHandler
from http import server, HTTPStatus
+import contextlib
import os
import socket
import sys
import re
-import base64
import ntpath
import pathlib
import shutil
@@ -21,6 +21,7 @@ import email.utils
import html
import http, http.client
import urllib.parse
+import urllib.request
import tempfile
import time
import datetime
@@ -31,8 +32,10 @@ from io import BytesIO, StringIO
import unittest
from test import support
from test.support import (
- is_apple, import_helper, os_helper, requires_subprocess, threading_helper
+ is_apple, import_helper, os_helper, threading_helper
)
+from test.support.script_helper import kill_python, spawn_python
+from test.support.socket_helper import find_unused_port
try:
import ssl
@@ -522,42 +525,120 @@ class SimpleHTTPServerTestCase(BaseTestCase):
reader.close()
return body
+ def check_list_dir_dirname(self, dirname, quotedname=None):
+ fullpath = os.path.join(self.tempdir, dirname)
+ try:
+ os.mkdir(os.path.join(self.tempdir, dirname))
+ except (OSError, UnicodeEncodeError):
+ self.skipTest(f'Can not create directory {dirname!a} '
+ f'on current file system')
+
+ if quotedname is None:
+ quotedname = urllib.parse.quote(dirname, errors='surrogatepass')
+ response = self.request(self.base_url + '/' + quotedname + '/')
+ body = self.check_status_and_reason(response, HTTPStatus.OK)
+ displaypath = html.escape(f'{self.base_url}/{dirname}/', quote=False)
+ enc = sys.getfilesystemencoding()
+ prefix = f'listing for {displaypath}</'.encode(enc, 'surrogateescape')
+ self.assertIn(prefix + b'title>', body)
+ self.assertIn(prefix + b'h1>', body)
+
+ def check_list_dir_filename(self, filename):
+ fullpath = os.path.join(self.tempdir, filename)
+ content = ascii(fullpath).encode() + (os_helper.TESTFN_UNDECODABLE or b'\xff')
+ try:
+ with open(fullpath, 'wb') as f:
+ f.write(content)
+ except OSError:
+ self.skipTest(f'Can not create file {filename!a} '
+ f'on current file system')
+
+ response = self.request(self.base_url + '/')
+ body = self.check_status_and_reason(response, HTTPStatus.OK)
+ quotedname = urllib.parse.quote(filename, errors='surrogatepass')
+ enc = response.headers.get_content_charset()
+ self.assertIsNotNone(enc)
+ self.assertIn((f'href="{quotedname}"').encode('ascii'), body)
+ displayname = html.escape(filename, quote=False)
+ self.assertIn(f'>{displayname}<'.encode(enc, 'surrogateescape'), body)
+
+ response = self.request(self.base_url + '/' + quotedname)
+ self.check_status_and_reason(response, HTTPStatus.OK, data=content)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ def test_list_dir_nonascii_dirname(self):
+ dirname = os_helper.TESTFN_NONASCII + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ def test_list_dir_nonascii_filename(self):
+ filename = os_helper.TESTFN_NONASCII + '.txt'
+ self.check_list_dir_filename(filename)
+
@unittest.skipIf(is_apple,
'undecodable name cannot always be decoded on Apple platforms')
@unittest.skipIf(sys.platform == 'win32',
'undecodable name cannot be decoded on win32')
@unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
'need os_helper.TESTFN_UNDECODABLE')
- def test_undecodable_filename(self):
- enc = sys.getfilesystemencoding()
- filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
- with open(os.path.join(self.tempdir, filename), 'wb') as f:
- f.write(os_helper.TESTFN_UNDECODABLE)
- response = self.request(self.base_url + '/')
- if is_apple:
- # On Apple platforms the HFS+ filesystem replaces bytes that
- # aren't valid UTF-8 into a percent-encoded value.
- for name in os.listdir(self.tempdir):
- if name != 'test': # Ignore a filename created in setUp().
- filename = name
- break
- body = self.check_status_and_reason(response, HTTPStatus.OK)
- quotedname = urllib.parse.quote(filename, errors='surrogatepass')
- self.assertIn(('href="%s"' % quotedname)
- .encode(enc, 'surrogateescape'), body)
- self.assertIn(('>%s<' % html.escape(filename, quote=False))
- .encode(enc, 'surrogateescape'), body)
- response = self.request(self.base_url + '/' + quotedname)
- self.check_status_and_reason(response, HTTPStatus.OK,
- data=os_helper.TESTFN_UNDECODABLE)
+ def test_list_dir_undecodable_dirname(self):
+ dirname = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.dir'
+ self.check_list_dir_dirname(dirname)
- def test_undecodable_parameter(self):
- # sanity check using a valid parameter
+ @unittest.skipIf(is_apple,
+ 'undecodable name cannot always be decoded on Apple platforms')
+ @unittest.skipIf(sys.platform == 'win32',
+ 'undecodable name cannot be decoded on win32')
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
+ 'need os_helper.TESTFN_UNDECODABLE')
+ def test_list_dir_undecodable_filename(self):
+ filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_undecodable_dirname2(self):
+ dirname = '\ufffd.dir'
+ self.check_list_dir_dirname(dirname, quotedname='%ff.dir')
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ def test_list_dir_unencodable_dirname(self):
+ dirname = os_helper.TESTFN_UNENCODABLE + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ def test_list_dir_unencodable_filename(self):
+ filename = os_helper.TESTFN_UNENCODABLE + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_escape_dirname(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&amp;', '&amp', '<i>', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ dirname = name + '.dir'
+ self.check_list_dir_dirname(dirname,
+ quotedname=urllib.parse.quote(dirname, safe='&<>\'"'))
+
+ def test_list_dir_escape_filename(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&amp;', '&amp', '<i>', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ filename = name + '.txt'
+ self.check_list_dir_filename(filename)
+ os_helper.unlink(os.path.join(self.tempdir, filename))
+
+ def test_list_dir_with_query_and_fragment(self):
+ prefix = f'listing for {self.base_url}/</'.encode('latin1')
+ response = self.request(self.base_url + '/#123').read()
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
response = self.request(self.base_url + '/?x=123').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=123'.encode('latin1'))
- # now the bogus encoding
- response = self.request(self.base_url + '/?x=%bb').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=\xef\xbf\xbd'.encode('latin1'))
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
def test_get_dir_redirect_location_domain_injection_bug(self):
"""Ensure //evil.co/..%2f../../X does not put //evil.co/ in Location.
@@ -615,10 +696,19 @@ class SimpleHTTPServerTestCase(BaseTestCase):
# check for trailing "/" which should return 404. See Issue17324
response = self.request(self.base_url + '/test/')
self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2f')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2F')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
response = self.request(self.base_url + '/')
self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2f')
+ self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2F')
+ self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.base_url)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"), self.base_url + "/")
self.assertEqual(response.getheader("Content-Length"), "0")
response = self.request(self.base_url + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
@@ -724,6 +814,8 @@ class SimpleHTTPServerTestCase(BaseTestCase):
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"),
+ self.tempdir_name + "/")
response = self.request(self.tempdir_name + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name + '?hi=1')
@@ -731,350 +823,6 @@ class SimpleHTTPServerTestCase(BaseTestCase):
self.assertEqual(response.getheader("Location"),
self.tempdir_name + "/?hi=1")
- def test_html_escape_filename(self):
- filename = '<test&>.txt'
- fullpath = os.path.join(self.tempdir, filename)
-
- try:
- open(fullpath, 'wb').close()
- except OSError:
- raise unittest.SkipTest('Can not create file %s on current file '
- 'system' % filename)
-
- try:
- response = self.request(self.base_url + '/')
- body = self.check_status_and_reason(response, HTTPStatus.OK)
- enc = response.headers.get_content_charset()
- finally:
- os.unlink(fullpath) # avoid affecting test_undecodable_filename
-
- self.assertIsNotNone(enc)
- html_text = '>%s<' % html.escape(filename, quote=False)
- self.assertIn(html_text.encode(enc), body)
-
-
-cgi_file1 = """\
-#!%s
-
-print("Content-type: text/html")
-print()
-print("Hello World")
-"""
-
-cgi_file2 = """\
-#!%s
-import os
-import sys
-import urllib.parse
-
-print("Content-type: text/html")
-print()
-
-content_length = int(os.environ["CONTENT_LENGTH"])
-query_string = sys.stdin.buffer.read(content_length)
-params = {key.decode("utf-8"): val.decode("utf-8")
- for key, val in urllib.parse.parse_qsl(query_string)}
-
-print("%%s, %%s, %%s" %% (params["spam"], params["eggs"], params["bacon"]))
-"""
-
-cgi_file4 = """\
-#!%s
-import os
-
-print("Content-type: text/html")
-print()
-
-print(os.environ["%s"])
-"""
-
-cgi_file6 = """\
-#!%s
-import os
-
-print("X-ambv: was here")
-print("Content-type: text/html")
-print()
-print("<pre>")
-for k, v in os.environ.items():
- try:
- k.encode('ascii')
- v.encode('ascii')
- except UnicodeEncodeError:
- continue # see: BPO-44647
- print(f"{k}={v}")
-print("</pre>")
-"""
-
-
-@unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
- "This test can't be run reliably as root (issue #13308).")
-@requires_subprocess()
-class CGIHTTPServerTestCase(BaseTestCase):
- class request_handler(NoLogRequestHandler, CGIHTTPRequestHandler):
- _test_case_self = None # populated by each setUp() method call.
-
- def __init__(self, *args, **kwargs):
- with self._test_case_self.assertWarnsRegex(
- DeprecationWarning,
- r'http\.server\.CGIHTTPRequestHandler'):
- # This context also happens to catch and silence the
- # threading DeprecationWarning from os.fork().
- super().__init__(*args, **kwargs)
-
- linesep = os.linesep.encode('ascii')
-
- def setUp(self):
- self.request_handler._test_case_self = self # practical, but yuck.
- BaseTestCase.setUp(self)
- self.cwd = os.getcwd()
- self.parent_dir = tempfile.mkdtemp()
- self.cgi_dir = os.path.join(self.parent_dir, 'cgi-bin')
- self.cgi_child_dir = os.path.join(self.cgi_dir, 'child-dir')
- self.sub_dir_1 = os.path.join(self.parent_dir, 'sub')
- self.sub_dir_2 = os.path.join(self.sub_dir_1, 'dir')
- self.cgi_dir_in_sub_dir = os.path.join(self.sub_dir_2, 'cgi-bin')
- os.mkdir(self.cgi_dir)
- os.mkdir(self.cgi_child_dir)
- os.mkdir(self.sub_dir_1)
- os.mkdir(self.sub_dir_2)
- os.mkdir(self.cgi_dir_in_sub_dir)
- self.nocgi_path = None
- self.file1_path = None
- self.file2_path = None
- self.file3_path = None
- self.file4_path = None
- self.file5_path = None
-
- # The shebang line should be pure ASCII: use symlink if possible.
- # See issue #7668.
- self._pythonexe_symlink = None
- if os_helper.can_symlink():
- self.pythonexe = os.path.join(self.parent_dir, 'python')
- self._pythonexe_symlink = support.PythonSymlink(self.pythonexe).__enter__()
- else:
- self.pythonexe = sys.executable
-
- try:
- # The python executable path is written as the first line of the
- # CGI Python script. The encoding cookie cannot be used, and so the
- # path should be encodable to the default script encoding (utf-8)
- self.pythonexe.encode('utf-8')
- except UnicodeEncodeError:
- self.tearDown()
- self.skipTest("Python executable path is not encodable to utf-8")
-
- self.nocgi_path = os.path.join(self.parent_dir, 'nocgi.py')
- with open(self.nocgi_path, 'w', encoding='utf-8') as fp:
- fp.write(cgi_file1 % self.pythonexe)
- os.chmod(self.nocgi_path, 0o777)
-
- self.file1_path = os.path.join(self.cgi_dir, 'file1.py')
- with open(self.file1_path, 'w', encoding='utf-8') as file1:
- file1.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file1_path, 0o777)
-
- self.file2_path = os.path.join(self.cgi_dir, 'file2.py')
- with open(self.file2_path, 'w', encoding='utf-8') as file2:
- file2.write(cgi_file2 % self.pythonexe)
- os.chmod(self.file2_path, 0o777)
-
- self.file3_path = os.path.join(self.cgi_child_dir, 'file3.py')
- with open(self.file3_path, 'w', encoding='utf-8') as file3:
- file3.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file3_path, 0o777)
-
- self.file4_path = os.path.join(self.cgi_dir, 'file4.py')
- with open(self.file4_path, 'w', encoding='utf-8') as file4:
- file4.write(cgi_file4 % (self.pythonexe, 'QUERY_STRING'))
- os.chmod(self.file4_path, 0o777)
-
- self.file5_path = os.path.join(self.cgi_dir_in_sub_dir, 'file5.py')
- with open(self.file5_path, 'w', encoding='utf-8') as file5:
- file5.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file5_path, 0o777)
-
- self.file6_path = os.path.join(self.cgi_dir, 'file6.py')
- with open(self.file6_path, 'w', encoding='utf-8') as file6:
- file6.write(cgi_file6 % self.pythonexe)
- os.chmod(self.file6_path, 0o777)
-
- os.chdir(self.parent_dir)
-
- def tearDown(self):
- self.request_handler._test_case_self = None
- try:
- os.chdir(self.cwd)
- if self._pythonexe_symlink:
- self._pythonexe_symlink.__exit__(None, None, None)
- if self.nocgi_path:
- os.remove(self.nocgi_path)
- if self.file1_path:
- os.remove(self.file1_path)
- if self.file2_path:
- os.remove(self.file2_path)
- if self.file3_path:
- os.remove(self.file3_path)
- if self.file4_path:
- os.remove(self.file4_path)
- if self.file5_path:
- os.remove(self.file5_path)
- if self.file6_path:
- os.remove(self.file6_path)
- os.rmdir(self.cgi_child_dir)
- os.rmdir(self.cgi_dir)
- os.rmdir(self.cgi_dir_in_sub_dir)
- os.rmdir(self.sub_dir_2)
- os.rmdir(self.sub_dir_1)
- # The 'gmon.out' file can be written in the current working
- # directory if C-level code profiling with gprof is enabled.
- os_helper.unlink(os.path.join(self.parent_dir, 'gmon.out'))
- os.rmdir(self.parent_dir)
- finally:
- BaseTestCase.tearDown(self)
-
- def test_url_collapse_path(self):
- # verify tail is the last portion and head is the rest on proper urls
- test_vectors = {
- '': '//',
- '..': IndexError,
- '/.//..': IndexError,
- '/': '//',
- '//': '//',
- '/\\': '//\\',
- '/.//': '//',
- 'cgi-bin/file1.py': '/cgi-bin/file1.py',
- '/cgi-bin/file1.py': '/cgi-bin/file1.py',
- 'a': '//a',
- '/a': '//a',
- '//a': '//a',
- './a': '//a',
- './C:/': '/C:/',
- '/a/b': '/a/b',
- '/a/b/': '/a/b/',
- '/a/b/.': '/a/b/',
- '/a/b/c/..': '/a/b/',
- '/a/b/c/../d': '/a/b/d',
- '/a/b/c/../d/e/../f': '/a/b/d/f',
- '/a/b/c/../d/e/../../f': '/a/b/f',
- '/a/b/c/../d/e/.././././..//f': '/a/b/f',
- '../a/b/c/../d/e/.././././..//f': IndexError,
- '/a/b/c/../d/e/../../../f': '/a/f',
- '/a/b/c/../d/e/../../../../f': '//f',
- '/a/b/c/../d/e/../../../../../f': IndexError,
- '/a/b/c/../d/e/../../../../f/..': '//',
- '/a/b/c/../d/e/../../../../f/../.': '//',
- }
- for path, expected in test_vectors.items():
- if isinstance(expected, type) and issubclass(expected, Exception):
- self.assertRaises(expected,
- server._url_collapse_path, path)
- else:
- actual = server._url_collapse_path(path)
- self.assertEqual(expected, actual,
- msg='path = %r\nGot: %r\nWanted: %r' %
- (path, actual, expected))
-
- def test_headers_and_content(self):
- res = self.request('/cgi-bin/file1.py')
- self.assertEqual(
- (res.read(), res.getheader('Content-type'), res.status),
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK))
-
- def test_issue19435(self):
- res = self.request('///////////nocgi.py/../cgi-bin/nothere.sh')
- self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
-
- def test_post(self):
- params = urllib.parse.urlencode(
- {'spam' : 1, 'eggs' : 'python', 'bacon' : 123456})
- headers = {'Content-type' : 'application/x-www-form-urlencoded'}
- res = self.request('/cgi-bin/file2.py', 'POST', params, headers)
-
- self.assertEqual(res.read(), b'1, python, 123456' + self.linesep)
-
- def test_invaliduri(self):
- res = self.request('/cgi-bin/invalid')
- res.read()
- self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
-
- def test_authorization(self):
- headers = {b'Authorization' : b'Basic ' +
- base64.b64encode(b'username:pass')}
- res = self.request('/cgi-bin/file1.py', 'GET', headers=headers)
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_no_leading_slash(self):
- # http://bugs.python.org/issue2254
- res = self.request('cgi-bin/file1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_os_environ_is_not_altered(self):
- signature = "Test CGI Server"
- os.environ['SERVER_SOFTWARE'] = signature
- res = self.request('/cgi-bin/file1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
- self.assertEqual(os.environ['SERVER_SOFTWARE'], signature)
-
- def test_urlquote_decoding_in_cgi_check(self):
- res = self.request('/cgi-bin%2ffile1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_nested_cgi_path_issue21323(self):
- res = self.request('/cgi-bin/child-dir/file3.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_query_with_multiple_question_mark(self):
- res = self.request('/cgi-bin/file4.py?a=b?c=d')
- self.assertEqual(
- (b'a=b?c=d' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_query_with_continuous_slashes(self):
- res = self.request('/cgi-bin/file4.py?k=aa%2F%2Fbb&//q//p//=//a//b//')
- self.assertEqual(
- (b'k=aa%2F%2Fbb&//q//p//=//a//b//' + self.linesep,
- 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_cgi_path_in_sub_directories(self):
- try:
- CGIHTTPRequestHandler.cgi_directories.append('/sub/dir/cgi-bin')
- res = self.request('/sub/dir/cgi-bin/file5.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
- finally:
- CGIHTTPRequestHandler.cgi_directories.remove('/sub/dir/cgi-bin')
-
- def test_accept(self):
- browser_accept = \
- 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8'
- tests = (
- ((('Accept', browser_accept),), browser_accept),
- ((), ''),
- # Hack case to get two values for the one header
- ((('Accept', 'text/html'), ('ACCEPT', 'text/plain')),
- 'text/html,text/plain'),
- )
- for headers, expected in tests:
- headers = OrderedDict(headers)
- with self.subTest(headers):
- res = self.request('/cgi-bin/file6.py', 'GET', headers=headers)
- self.assertEqual(http.HTTPStatus.OK, res.status)
- expected = f"HTTP_ACCEPT={expected}".encode('ascii')
- self.assertIn(expected, res.read())
-
class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, directory=None):
@@ -1095,6 +843,7 @@ class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def log_message(self, format, *args):
pass
+
class RejectingSocketlessRequestHandler(SocketlessRequestHandler):
def handle_expect_100(self):
self.send_error(HTTPStatus.EXPECTATION_FAILED)
@@ -1536,6 +1285,243 @@ class ScriptTestCase(unittest.TestCase):
self.assertEqual(mock_server.address_family, socket.AF_INET)
+class CommandLineTestCase(unittest.TestCase):
+ default_port = 8000
+ default_bind = None
+ default_protocol = 'HTTP/1.0'
+ default_handler = SimpleHTTPRequestHandler
+ default_server = unittest.mock.ANY
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = 'somepass'
+ tls_cert_options = ['--tls-cert']
+ tls_key_options = ['--tls-key']
+ tls_password_options = ['--tls-password-file']
+ args = {
+ 'HandlerClass': default_handler,
+ 'ServerClass': default_server,
+ 'protocol': default_protocol,
+ 'port': default_port,
+ 'bind': default_bind,
+ 'tls_cert': None,
+ 'tls_key': None,
+ 'tls_password': None,
+ }
+
+ def setUp(self):
+ super().setUp()
+ self.tls_password_file = tempfile.mktemp()
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password.encode())
+ self.addCleanup(os_helper.unlink, self.tls_password_file)
+
+ def invoke_httpd(self, *args, stdout=None, stderr=None):
+ stdout = StringIO() if stdout is None else stdout
+ stderr = StringIO() if stderr is None else stderr
+ with contextlib.redirect_stdout(stdout), \
+ contextlib.redirect_stderr(stderr):
+ server._main(args)
+ return stdout.getvalue(), stderr.getvalue()
+
+ @mock.patch('http.server.test')
+ def test_port_flag(self, mock_func):
+ ports = [8000, 65535]
+ for port in ports:
+ with self.subTest(port=port):
+ self.invoke_httpd(str(port))
+ call_args = self.args | dict(port=port)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_directory_flag(self, mock_func):
+ options = ['-d', '--directory']
+ directories = ['.', '/foo', '\\bar', '/',
+ 'C:\\', 'C:\\foo', 'C:\\bar',
+ '/home/user', './foo/foo2', 'D:\\foo\\bar']
+ for flag in options:
+ for directory in directories:
+ with self.subTest(flag=flag, directory=directory):
+ self.invoke_httpd(flag, directory)
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_bind_flag(self, mock_func):
+ options = ['-b', '--bind']
+ bind_addresses = ['localhost', '127.0.0.1', '::1',
+ '0.0.0.0', '8.8.8.8']
+ for flag in options:
+ for bind_address in bind_addresses:
+ with self.subTest(flag=flag, bind_address=bind_address):
+ self.invoke_httpd(flag, bind_address)
+ call_args = self.args | dict(bind=bind_address)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_protocol_flag(self, mock_func):
+ options = ['-p', '--protocol']
+ protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0']
+ for flag in options:
+ for protocol in protocols:
+ with self.subTest(flag=flag, protocol=protocol):
+ self.invoke_httpd(flag, protocol)
+ call_args = self.args | dict(protocol=protocol)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ self.invoke_httpd(tls_cert_option, self.tls_cert,
+ tls_key_option, self.tls_key)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_and_password_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ for tls_password_option in self.tls_password_options:
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_key_option,
+ self.tls_key,
+ tls_password_option,
+ self.tls_password_file)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ 'tls_password': self.tls_password,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_missing_tls_cert_flag(self, mock_func):
+ for tls_key_option in self.tls_key_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_key_option, self.tls_key)
+ mock_func.reset_mock()
+
+ for tls_password_option in self.tls_password_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_password_option, self.tls_password)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_invalid_password_file(self, mock_func):
+ non_existent_file = 'non_existent_file'
+ for tls_password_option in self.tls_password_options:
+ for tls_cert_option in self.tls_cert_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_password_option,
+ non_existent_file)
+
+ @mock.patch('http.server.test')
+ def test_no_arguments(self, mock_func):
+ self.invoke_httpd()
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_help_flag(self, _):
+ options = ['-h', '--help']
+ for option in options:
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(option, stdout=stdout, stderr=stderr)
+ self.assertIn('usage', stdout.getvalue())
+ self.assertEqual(stderr.getvalue(), '')
+
+ @mock.patch('http.server.test')
+ def test_unknown_flag(self, _):
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr)
+ self.assertEqual(stdout.getvalue(), '')
+ self.assertIn('error', stderr.getvalue())
+
+
+class CommandLineRunTimeTestCase(unittest.TestCase):
+ served_data = os.urandom(32)
+ served_filename = 'served_filename'
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = b'somepass'
+ tls_password_file = 'ssl_key_password'
+
+ def setUp(self):
+ super().setUp()
+ server_dir_context = os_helper.temp_cwd()
+ server_dir = self.enterContext(server_dir_context)
+ with open(self.served_filename, 'wb') as f:
+ f.write(self.served_data)
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password)
+
+ def fetch_file(self, path, context=None):
+ req = urllib.request.Request(path, method='GET')
+ with urllib.request.urlopen(req, context=context) as res:
+ return res.read()
+
+ def parse_cli_output(self, output):
+ match = re.search(r'Serving (HTTP|HTTPS) on (.+) port (\d+)', output)
+ if match is None:
+ return None, None, None
+ return match.group(1).lower(), match.group(2), int(match.group(3))
+
+ def wait_for_server(self, proc, protocol, bind, port):
+ """Check that the server has been successfully started."""
+ line = proc.stdout.readline().strip()
+ if support.verbose:
+ print()
+ print('python -m http.server: ', line)
+ return self.parse_cli_output(line) == (protocol, bind, port)
+
+ def test_http_client(self):
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'http', bind, port))
+ res = self.fetch_file(f'http://{bind}:{port}/{self.served_filename}')
+ self.assertEqual(res, self.served_data)
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ def test_https_client(self):
+ context = ssl.create_default_context()
+ # allow self-signed certificates
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ '--tls-cert', self.tls_cert,
+ '--tls-key', self.tls_key,
+ '--tls-password-file', self.tls_password_file,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'https', bind, port))
+ url = f'https://{bind}:{port}/{self.served_filename}'
+ res = self.fetch_file(url, context=context)
+ self.assertEqual(res, self.served_data)
+
+
def setUpModule():
unittest.addModuleCleanup(os.chdir, os.getcwd())
diff --git a/Lib/test/test_idle.py b/Lib/test/test_idle.py
index 3d8b7ecc0ec..ebf572ac5ca 100644
--- a/Lib/test/test_idle.py
+++ b/Lib/test/test_idle.py
@@ -16,7 +16,7 @@ idlelib.testing = True
# Unittest.main and test.libregrtest.runtest.runtest_inner
# call load_tests, when present here, to discover tests to run.
-from idlelib.idle_test import load_tests
+from idlelib.idle_test import load_tests # noqa: F401
if __name__ == '__main__':
tk.NoDefaultRoot()
diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py
index b5f4645847a..6e34094c5aa 100644
--- a/Lib/test/test_import/__init__.py
+++ b/Lib/test/test_import/__init__.py
@@ -1001,7 +1001,7 @@ from not_a_module import symbol
expected_error = error + (
rb" \(consider renaming '.*numpy.py' if it has the "
- rb"same name as a library you intended to import\)\s+\Z"
+ rb"same name as a library you intended to import\)\s+\z"
)
popen = script_helper.spawn_python(os.path.join(tmp, "numpy.py"))
@@ -1022,14 +1022,14 @@ from not_a_module import symbol
f.write("this_script_does_not_attempt_to_import_numpy = True")
expected_error = (
- rb"AttributeError: module 'numpy' has no attribute 'attr'\s+\Z"
+ rb"AttributeError: module 'numpy' has no attribute 'attr'\s+\z"
)
popen = script_helper.spawn_python('-c', 'import numpy; numpy.attr', cwd=tmp)
stdout, stderr = popen.communicate()
self.assertRegex(stdout, expected_error)
expected_error = (
- rb"ImportError: cannot import name 'attr' from 'numpy' \(.*\)\s+\Z"
+ rb"ImportError: cannot import name 'attr' from 'numpy' \(.*\)\s+\z"
)
popen = script_helper.spawn_python('-c', 'from numpy import attr', cwd=tmp)
stdout, stderr = popen.communicate()
diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py
index e535d119763..1549cbe96ce 100644
--- a/Lib/test/test_importlib/import_/test_relative_imports.py
+++ b/Lib/test/test_importlib/import_/test_relative_imports.py
@@ -223,6 +223,21 @@ class RelativeImports:
self.__import__('sys', {'__package__': '', '__spec__': None},
level=1)
+ def test_malicious_relative_import(self):
+ # https://github.com/python/cpython/issues/134100
+ # Test to make sure UAF bug with error msg doesn't come back to life
+ import sys
+ loooong = "".ljust(0x23000, "b")
+ name = f"a.{loooong}.c"
+
+ with util.uncache(name):
+ sys.modules[name] = {}
+ with self.assertRaisesRegex(
+ KeyError,
+ r"'a\.b+' not in sys\.modules as expected"
+ ):
+ __import__(f"{loooong}.c", {"__package__": "a"}, level=1)
+
(Frozen_RelativeImports,
Source_RelativeImports
diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py
index befac5d62b0..655e5881a15 100644
--- a/Lib/test/test_importlib/test_locks.py
+++ b/Lib/test/test_importlib/test_locks.py
@@ -34,6 +34,7 @@ class ModuleLockAsRLockTests:
# lock status in repr unsupported
test_repr = None
test_locked_repr = None
+ test_repr_count = None
def tearDown(self):
for splitinit in init.values():
diff --git a/Lib/test/test_importlib/test_threaded_import.py b/Lib/test/test_importlib/test_threaded_import.py
index 9af1e4d505c..f78dc399720 100644
--- a/Lib/test/test_importlib/test_threaded_import.py
+++ b/Lib/test/test_importlib/test_threaded_import.py
@@ -135,10 +135,12 @@ class ThreadedImportTests(unittest.TestCase):
if verbose:
print("OK.")
- def test_parallel_module_init(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_module_init(self, size):
self.check_parallel_module_init()
- def test_parallel_meta_path(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_meta_path(self, size):
finder = Finder()
sys.meta_path.insert(0, finder)
try:
@@ -148,7 +150,8 @@ class ThreadedImportTests(unittest.TestCase):
finally:
sys.meta_path.remove(finder)
- def test_parallel_path_hooks(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_path_hooks(self, size):
# Here the Finder instance is only used to check concurrent calls
# to path_hook().
finder = Finder()
@@ -242,13 +245,15 @@ class ThreadedImportTests(unittest.TestCase):
__import__(TESTFN)
del sys.modules[TESTFN]
- def test_concurrent_futures_circular_import(self):
+ @support.bigmemtest(size=1, memuse=1.8*2**30, dry_run=False)
+ def test_concurrent_futures_circular_import(self, size):
# Regression test for bpo-43515
fn = os.path.join(os.path.dirname(__file__),
'partial', 'cfimport.py')
script_helper.assert_python_ok(fn)
- def test_multiprocessing_pool_circular_import(self):
+ @support.bigmemtest(size=1, memuse=1.8*2**30, dry_run=False)
+ def test_multiprocessing_pool_circular_import(self, size):
# Regression test for bpo-41567
fn = os.path.join(os.path.dirname(__file__),
'partial', 'pool_in_threads.py')
diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py
index 06f0ca36f97..79eb103224b 100644
--- a/Lib/test/test_inspect/test_inspect.py
+++ b/Lib/test/test_inspect/test_inspect.py
@@ -786,12 +786,12 @@ class TestRetrievingSourceCode(GetSourceBase):
def test_getfile_builtin_module(self):
with self.assertRaises(TypeError) as e:
inspect.getfile(sys)
- self.assertTrue(str(e.exception).startswith('<module'))
+ self.assertStartsWith(str(e.exception), '<module')
def test_getfile_builtin_class(self):
with self.assertRaises(TypeError) as e:
inspect.getfile(int)
- self.assertTrue(str(e.exception).startswith('<class'))
+ self.assertStartsWith(str(e.exception), '<class')
def test_getfile_builtin_function_or_method(self):
with self.assertRaises(TypeError) as e_abs:
@@ -2949,7 +2949,7 @@ class TestSignatureObject(unittest.TestCase):
pass
sig = inspect.signature(test)
- self.assertTrue(repr(sig).startswith('<Signature'))
+ self.assertStartsWith(repr(sig), '<Signature')
self.assertTrue('(po, /, pk' in repr(sig))
# We need two functions, because it is impossible to represent
@@ -2958,7 +2958,7 @@ class TestSignatureObject(unittest.TestCase):
pass
sig2 = inspect.signature(test2)
- self.assertTrue(repr(sig2).startswith('<Signature'))
+ self.assertStartsWith(repr(sig2), '<Signature')
self.assertTrue('(pod=42, /)' in repr(sig2))
po = sig.parameters['po']
@@ -3412,9 +3412,10 @@ class TestSignatureObject(unittest.TestCase):
int))
def test_signature_on_classmethod(self):
- self.assertEqual(self.signature(classmethod),
- ((('function', ..., ..., "positional_only"),),
- ...))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(classmethod),
+ ((('function', ..., ..., "positional_only"),),
+ ...))
class Test:
@classmethod
@@ -3434,9 +3435,10 @@ class TestSignatureObject(unittest.TestCase):
...))
def test_signature_on_staticmethod(self):
- self.assertEqual(self.signature(staticmethod),
- ((('function', ..., ..., "positional_only"),),
- ...))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(staticmethod),
+ ((('function', ..., ..., "positional_only"),),
+ ...))
class Test:
@staticmethod
@@ -3845,7 +3847,6 @@ class TestSignatureObject(unittest.TestCase):
('b', ..., ..., "positional_or_keyword")),
...))
-
def test_signature_on_class(self):
class C:
def __init__(self, a):
@@ -3954,9 +3955,10 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(C(3), 8)
self.assertEqual(C(3, 7), 1)
- # BUG: Returns '<Signature (b)>'
- with self.assertRaises(AssertionError):
- self.assertEqual(self.signature(C), self.signature((0).__pow__))
+ if not support.MISSING_C_DOCSTRINGS:
+ # BUG: Returns '<Signature (b)>'
+ with self.assertRaises(AssertionError):
+ self.assertEqual(self.signature(C), self.signature((0).__pow__))
class CM(type):
def __new__(mcls, name, bases, dct, *, foo=1):
@@ -4019,6 +4021,45 @@ class TestSignatureObject(unittest.TestCase):
('bar', 2, ..., "keyword_only")),
...))
+ def test_signature_on_class_with_decorated_new(self):
+ def identity(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ return func(*args, **kwargs)
+ return wrapped
+
+ class Foo:
+ @identity
+ def __new__(cls, a, b):
+ pass
+
+ self.assertEqual(self.signature(Foo),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+ self.assertEqual(self.signature(Foo.__new__),
+ ((('cls', ..., ..., "positional_or_keyword"),
+ ('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+ class Bar:
+ __new__ = identity(object.__new__)
+
+ varargs_signature = (
+ (('args', ..., ..., 'var_positional'),
+ ('kwargs', ..., ..., 'var_keyword')),
+ ...,
+ )
+
+ self.assertEqual(self.signature(Bar), ((), ...))
+ self.assertEqual(self.signature(Bar.__new__), varargs_signature)
+ self.assertEqual(self.signature(Bar, follow_wrapped=False),
+ varargs_signature)
+ self.assertEqual(self.signature(Bar.__new__, follow_wrapped=False),
+ varargs_signature)
+
def test_signature_on_class_with_init(self):
class C:
def __init__(self, b):
@@ -4352,7 +4393,8 @@ class TestSignatureObject(unittest.TestCase):
__call__ = (2).__pow__
self.assertEqual(C()(3), 8)
- self.assertEqual(self.signature(C()), self.signature((0).__pow__))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(C()), self.signature((0).__pow__))
with self.subTest('ClassMethodDescriptorType'):
class C(dict):
@@ -4361,7 +4403,8 @@ class TestSignatureObject(unittest.TestCase):
res = C()([1, 2], 3)
self.assertEqual(res, {1: 3, 2: 3})
self.assertEqual(type(res), C)
- self.assertEqual(self.signature(C()), self.signature(dict.fromkeys))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(C()), self.signature(dict.fromkeys))
with self.subTest('MethodDescriptorType'):
class C(str):
@@ -4375,7 +4418,8 @@ class TestSignatureObject(unittest.TestCase):
__call__ = int.__pow__
self.assertEqual(C(2)(3), 8)
- self.assertEqual(self.signature(C()), self.signature((0).__pow__))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(C()), self.signature((0).__pow__))
with self.subTest('MemberDescriptorType'):
class C:
@@ -4393,7 +4437,8 @@ class TestSignatureObject(unittest.TestCase):
def __call__(self, *args, **kwargs):
pass
- self.assertEqual(self.signature(C), ((), ...))
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertEqual(self.signature(C), ((), ...))
self.assertEqual(self.signature(C()),
((('a', ..., ..., "positional_only"),
('b', ..., ..., "positional_or_keyword"),
@@ -4952,6 +4997,37 @@ class TestSignatureObject(unittest.TestCase):
with self.assertRaisesRegex(NameError, "undefined"):
signature_func(ida.f)
+ def test_signature_deferred_annotations(self):
+ def f(x: undef):
+ pass
+
+ class C:
+ x: undef
+
+ def __init__(self, x: undef):
+ self.x = x
+
+ sig = inspect.signature(f, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['x'])
+ sig = inspect.signature(C, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['x'])
+
+ class CallableWrapper:
+ def __init__(self, func):
+ self.func = func
+ self.__annotate__ = func.__annotate__
+
+ def __call__(self, *args, **kwargs):
+ return self.func(*args, **kwargs)
+
+ @property
+ def __annotations__(self):
+ return self.__annotate__(Format.VALUE)
+
+ cw = CallableWrapper(f)
+ sig = inspect.signature(cw, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['args', 'kwargs'])
+
def test_signature_none_annotation(self):
class funclike:
# Has to be callable, and have correct
@@ -5057,7 +5133,7 @@ class TestParameterObject(unittest.TestCase):
with self.assertRaisesRegex(ValueError, 'cannot have default values'):
p.replace(kind=inspect.Parameter.VAR_POSITIONAL)
- self.assertTrue(repr(p).startswith('<Parameter'))
+ self.assertStartsWith(repr(p), '<Parameter')
self.assertTrue('"a=42"' in repr(p))
def test_signature_parameter_hashable(self):
@@ -5801,7 +5877,7 @@ class TestSignatureDefinitions(unittest.TestCase):
def test_os_module_has_signatures(self):
unsupported_signature = {'chmod', 'utime'}
unsupported_signature |= {name for name in
- ['get_terminal_size', 'posix_spawn', 'posix_spawnp',
+ ['get_terminal_size', 'link', 'posix_spawn', 'posix_spawnp',
'register_at_fork', 'startfile']
if hasattr(os, name)}
self._test_module_has_signatures(os, unsupported_signature=unsupported_signature)
@@ -6101,12 +6177,14 @@ class TestRepl(unittest.TestCase):
object.
"""
+ # TODO(picnixz): refactor this as it's used by test_repl.py
+
# To run the REPL without using a terminal, spawn python with the command
# line option '-i' and the process name set to '<stdin>'.
# The directory of argv[0] must match the directory of the Python
# executable for the Popen() call to python to succeed as the directory
- # path may be used by Py_GetPath() to build the default module search
- # path.
+ # path may be used by PyConfig_Get("module_search_paths") to build the
+ # default module search path.
stdin_fname = os.path.join(os.path.dirname(sys.executable), "<stdin>")
cmd_line = [stdin_fname, '-E', '-i']
cmd_line.extend(args)
diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py
index 245528ce57a..7a7cb73f673 100644
--- a/Lib/test/test_int.py
+++ b/Lib/test/test_int.py
@@ -836,7 +836,7 @@ class PyLongModuleTests(unittest.TestCase):
n = hibit | getrandbits(bits - 1)
assert n.bit_length() == bits
sn = str(n)
- self.assertFalse(sn.startswith('0'))
+ self.assertNotStartsWith(sn, '0')
self.assertEqual(n, int(sn))
bits <<= 1
diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py
index 66c7afce88f..0ee4582b5d1 100644
--- a/Lib/test/test_interpreters/test_api.py
+++ b/Lib/test/test_interpreters/test_api.py
@@ -1,18 +1,23 @@
+import contextlib
import os
import pickle
+import sys
from textwrap import dedent
import threading
import types
import unittest
from test import support
+from test.support import os_helper
+from test.support import script_helper
from test.support import import_helper
# Raise SkipTest if subinterpreters not supported.
_interpreters = import_helper.import_module('_interpreters')
+from concurrent import interpreters
from test.support import Py_GIL_DISABLED
-from test.support import interpreters
from test.support import force_not_colorized
-from test.support.interpreters import (
+import test._crossinterp_definitions as defs
+from concurrent.interpreters import (
InterpreterError, InterpreterNotFoundError, ExecutionFailed,
)
from .utils import (
@@ -29,6 +34,59 @@ WHENCE_STR_XI = 'cross-interpreter C-API'
WHENCE_STR_STDLIB = '_interpreters module'
+def is_pickleable(obj):
+ try:
+ pickle.dumps(obj)
+ except Exception:
+ return False
+ return True
+
+
+@contextlib.contextmanager
+def defined_in___main__(name, script, *, remove=False):
+ import __main__ as mainmod
+ mainns = vars(mainmod)
+ assert name not in mainns
+ exec(script, mainns, mainns)
+ if remove:
+ yield mainns.pop(name)
+ else:
+ try:
+ yield mainns[name]
+ finally:
+ mainns.pop(name, None)
+
+
+def build_excinfo(exctype, msg=None, formatted=None, errdisplay=None):
+ if isinstance(exctype, type):
+ assert issubclass(exctype, BaseException), exctype
+ exctype = types.SimpleNamespace(
+ __name__=exctype.__name__,
+ __qualname__=exctype.__qualname__,
+ __module__=exctype.__module__,
+ )
+ elif isinstance(exctype, str):
+ module, _, name = exctype.rpartition(exctype)
+ if not module and name in __builtins__:
+ module = 'builtins'
+ exctype = types.SimpleNamespace(
+ __name__=name,
+ __qualname__=exctype,
+ __module__=module or None,
+ )
+ else:
+ assert isinstance(exctype, types.SimpleNamespace)
+ assert msg is None or isinstance(msg, str), msg
+ assert formatted is None or isinstance(formatted, str), formatted
+ assert errdisplay is None or isinstance(errdisplay, str), errdisplay
+ return types.SimpleNamespace(
+ type=exctype,
+ msg=msg,
+ formatted=formatted,
+ errdisplay=errdisplay,
+ )
+
+
class ModuleTests(TestBase):
def test_queue_aliases(self):
@@ -75,7 +133,7 @@ class CreateTests(TestBase):
main, = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.create()
print(interp.id)
"""))
@@ -138,7 +196,7 @@ class GetCurrentTests(TestBase):
main = interpreters.get_main()
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(cur.id)
"""))
@@ -155,7 +213,7 @@ class GetCurrentTests(TestBase):
with self.subTest('subinterpreter'):
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(id(cur))
cur = interpreters.get_current()
@@ -167,7 +225,7 @@ class GetCurrentTests(TestBase):
with self.subTest('per-interpreter'):
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(id(cur))
"""))
@@ -524,7 +582,7 @@ class TestInterpreterClose(TestBase):
main, = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent(f"""
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.Interpreter({interp.id})
try:
interp.close()
@@ -541,7 +599,7 @@ class TestInterpreterClose(TestBase):
self.assertEqual(set(interpreters.list_all()),
{main, interp1, interp2})
interp1.exec(dedent(f"""
- from test.support import interpreters
+ from concurrent import interpreters
interp2 = interpreters.Interpreter({interp2.id})
interp2.close()
interp3 = interpreters.create()
@@ -748,7 +806,7 @@ class TestInterpreterExec(TestBase):
ham()
""")
scriptfile = self.make_script('script.py', tempdir, text="""
- from test.support import interpreters
+ from concurrent import interpreters
def script():
import spam
@@ -769,7 +827,7 @@ class TestInterpreterExec(TestBase):
~~~~~~~~~~~^^^^^^^^
{interpmod_line.strip()}
raise ExecutionFailed(excinfo)
- test.support.interpreters.ExecutionFailed: RuntimeError: uh-oh!
+ concurrent.interpreters.ExecutionFailed: RuntimeError: uh-oh!
Uncaught in the interpreter:
@@ -839,9 +897,16 @@ class TestInterpreterExec(TestBase):
interp.exec(10)
def test_bytes_for_script(self):
+ r, w = self.pipe()
+ RAN = b'R'
+ DONE = b'D'
interp = interpreters.create()
- with self.assertRaises(TypeError):
- interp.exec(b'print("spam")')
+ interp.exec(f"""if True:
+ import os
+ os.write({w}, {RAN!r})
+ """)
+ os.write(w, DONE)
+ self.assertEqual(os.read(r, 1), RAN)
def test_with_background_threads_still_running(self):
r_interp, w_interp = self.pipe()
@@ -879,28 +944,46 @@ class TestInterpreterExec(TestBase):
with self.assertRaisesRegex(InterpreterError, 'unrecognized'):
interp.exec('raise Exception("it worked!")')
+ def test_list_comprehension(self):
+ # gh-135450: List comprehensions caused an assertion failure
+ # in _PyCode_CheckNoExternalState()
+ import string
+ r_interp, w_interp = self.pipe()
+
+ interp = interpreters.create()
+ interp.exec(f"""if True:
+ import os
+ comp = [str(i) for i in range(10)]
+ os.write({w_interp}, ''.join(comp).encode())
+ """)
+ self.assertEqual(os.read(r_interp, 10).decode(), string.digits)
+ interp.close()
+
+
# test__interpreters covers the remaining
# Interpreter.exec() behavior.
-def call_func_noop():
- pass
+call_func_noop = defs.spam_minimal
+call_func_ident = defs.spam_returns_arg
+call_func_failure = defs.spam_raises
def call_func_return_shareable():
return (1, None)
-def call_func_return_not_shareable():
- return [1, 2, 3]
+def call_func_return_stateless_func():
+ return (lambda x: x)
-def call_func_failure():
- raise Exception('spam!')
+def call_func_return_pickleable():
+ return [1, 2, 3]
-def call_func_ident(value):
- return value
+def call_func_return_unpickleable():
+ x = 42
+ return (lambda: x)
def get_call_func_closure(value):
@@ -909,6 +992,11 @@ def get_call_func_closure(value):
return call_func_closure
+def call_func_exec_wrapper(script, ns):
+ res = exec(script, ns, ns)
+ return res, ns, id(ns)
+
+
class Spam:
@staticmethod
@@ -1005,86 +1093,663 @@ class TestInterpreterCall(TestBase):
# - preserves info (e.g. SyntaxError)
# - matching error display
- def test_call(self):
+ @contextlib.contextmanager
+ def assert_fails(self, expected):
+ with self.assertRaises(ExecutionFailed) as cm:
+ yield cm
+ uncaught = cm.exception.excinfo
+ self.assertEqual(uncaught.type.__name__, expected.__name__)
+
+ def assert_fails_not_shareable(self):
+ return self.assert_fails(interpreters.NotShareableError)
+
+ def assert_code_equal(self, code1, code2):
+ if code1 == code2:
+ return
+ self.assertEqual(code1.co_name, code2.co_name)
+ self.assertEqual(code1.co_flags, code2.co_flags)
+ self.assertEqual(code1.co_consts, code2.co_consts)
+ self.assertEqual(code1.co_varnames, code2.co_varnames)
+ self.assertEqual(code1.co_cellvars, code2.co_cellvars)
+ self.assertEqual(code1.co_freevars, code2.co_freevars)
+ self.assertEqual(code1.co_names, code2.co_names)
+ self.assertEqual(
+ _testinternalcapi.get_code_var_counts(code1),
+ _testinternalcapi.get_code_var_counts(code2),
+ )
+ self.assertEqual(code1.co_code, code2.co_code)
+
+ def assert_funcs_equal(self, func1, func2):
+ if func1 == func2:
+ return
+ self.assertIs(type(func1), type(func2))
+ self.assertEqual(func1.__name__, func2.__name__)
+ self.assertEqual(func1.__defaults__, func2.__defaults__)
+ self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__)
+ self.assertEqual(func1.__closure__, func2.__closure__)
+ self.assert_code_equal(func1.__code__, func2.__code__)
+ self.assertEqual(
+ _testinternalcapi.get_code_var_counts(func1),
+ _testinternalcapi.get_code_var_counts(func2),
+ )
+
+ def assert_exceptions_equal(self, exc1, exc2):
+ assert isinstance(exc1, Exception)
+ assert isinstance(exc2, Exception)
+ if exc1 == exc2:
+ return
+ self.assertIs(type(exc1), type(exc2))
+ self.assertEqual(exc1.args, exc2.args)
+
+ def test_stateless_funcs(self):
interp = interpreters.create()
- for i, (callable, args, kwargs) in enumerate([
- (call_func_noop, (), {}),
- (call_func_return_shareable, (), {}),
- (call_func_return_not_shareable, (), {}),
- (Spam.noop, (), {}),
+ func = call_func_noop
+ with self.subTest('no args, no return'):
+ res = interp.call(func)
+ self.assertIsNone(res)
+
+ func = call_func_return_shareable
+ with self.subTest('no args, returns shareable'):
+ res = interp.call(func)
+ self.assertEqual(res, (1, None))
+
+ func = call_func_return_stateless_func
+ expected = (lambda x: x)
+ with self.subTest('no args, returns stateless func'):
+ res = interp.call(func)
+ self.assert_funcs_equal(res, expected)
+
+ func = call_func_return_pickleable
+ with self.subTest('no args, returns pickleable'):
+ res = interp.call(func)
+ self.assertEqual(res, [1, 2, 3])
+
+ func = call_func_return_unpickleable
+ with self.subTest('no args, returns unpickleable'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+
+ def test_stateless_func_returns_arg(self):
+ interp = interpreters.create()
+
+ for arg in [
+ None,
+ 10,
+ 'spam!',
+ b'spam!',
+ (1, 2, 'spam!'),
+ memoryview(b'spam!'),
+ ]:
+ with self.subTest(f'shareable {arg!r}'):
+ assert _interpreters.is_shareable(arg)
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assertEqual(res, arg)
+
+ for arg in defs.STATELESS_FUNCTIONS:
+ with self.subTest(f'stateless func {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assert_funcs_equal(res, arg)
+
+ for arg in defs.TOP_FUNCTIONS:
+ if arg in defs.STATELESS_FUNCTIONS:
+ continue
+ with self.subTest(f'stateful func {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assert_funcs_equal(res, arg)
+ assert is_pickleable(arg)
+
+ for arg in [
+ Ellipsis,
+ NotImplemented,
+ object(),
+ 2**1000,
+ [1, 2, 3],
+ {'a': 1, 'b': 2},
+ types.SimpleNamespace(x=42),
+ # builtin types
+ object,
+ type,
+ Exception,
+ ModuleNotFoundError,
+ # builtin exceptions
+ Exception('uh-oh!'),
+ ModuleNotFoundError('mymodule'),
+ # builtin fnctions
+ len,
+ sys.exit,
+ # user classes
+ *defs.TOP_CLASSES,
+ *(c(*a) for c, a in defs.TOP_CLASSES.items()
+ if c not in defs.CLASSES_WITHOUT_EQUALITY),
+ ]:
+ with self.subTest(f'pickleable {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ if type(arg) is object:
+ self.assertIs(type(res), object)
+ elif isinstance(arg, BaseException):
+ self.assert_exceptions_equal(res, arg)
+ else:
+ self.assertEqual(res, arg)
+ assert is_pickleable(arg)
+
+ for arg in [
+ types.MappingProxyType({}),
+ *(f for f in defs.NESTED_FUNCTIONS
+ if f not in defs.STATELESS_FUNCTIONS),
+ ]:
+ with self.subTest(f'unpickleable {arg!r}'):
+ assert not _interpreters.is_shareable(arg)
+ assert not is_pickleable(arg)
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ def test_full_args(self):
+ interp = interpreters.create()
+ expected = (1, 2, 3, 4, 5, 6, ('?',), {'g': 7, 'h': 8})
+ func = defs.spam_full_args
+ res = interp.call(func, 1, 2, 3, 4, '?', e=5, f=6, g=7, h=8)
+ self.assertEqual(res, expected)
+
+ def test_full_defaults(self):
+ # pickleable, but not stateless
+ interp = interpreters.create()
+ expected = (-1, -2, -3, -4, -5, -6, (), {'g': 8, 'h': 9})
+ res = interp.call(defs.spam_full_args_with_defaults, g=8, h=9)
+ self.assertEqual(res, expected)
+
+ def test_modified_arg(self):
+ interp = interpreters.create()
+ script = dedent("""
+ a = 7
+ b = 2
+ c = a ** b
+ """)
+ ns = {}
+ expected = {'a': 7, 'b': 2, 'c': 49}
+ res = interp.call(call_func_exec_wrapper, script, ns)
+ obj, resns, resid = res
+ del resns['__builtins__']
+ self.assertIsNone(obj)
+ self.assertEqual(ns, {})
+ self.assertEqual(resns, expected)
+ self.assertNotEqual(resid, id(ns))
+ self.assertNotEqual(resid, id(resns))
+
+ def test_func_in___main___valid(self):
+ # pickleable, already there'
+
+ with os_helper.temp_dir() as tempdir:
+ def new_mod(name, text):
+ script_helper.make_script(tempdir, name, dedent(text))
+
+ def run(text):
+ name = 'myscript'
+ text = dedent(f"""
+ import sys
+ sys.path.insert(0, {tempdir!r})
+
+ """) + dedent(text)
+ filename = script_helper.make_script(tempdir, name, text)
+ res = script_helper.assert_python_ok(filename)
+ return res.out.decode('utf-8').strip()
+
+ # no module indirection
+ with self.subTest('no indirection'):
+ text = run(f"""
+ from concurrent import interpreters
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ interp = interpreters.create()
+ res = interp.call(spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ # indirect as func, direct interp
+ new_mod('mymod', f"""
+ def run(interp, func):
+ return interp.call(func)
+ """)
+ with self.subTest('indirect as func, direct interp'):
+ text = run(f"""
+ from concurrent import interpreters
+ import mymod
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ interp = interpreters.create()
+ res = mymod.run(interp, spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ # indirect as func, indirect interp
+ new_mod('mymod', f"""
+ from concurrent import interpreters
+ def run(func):
+ interp = interpreters.create()
+ return interp.call(func)
+ """)
+ with self.subTest('indirect as func, indirect interp'):
+ text = run(f"""
+ import mymod
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ res = mymod.run(spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ def test_func_in___main___invalid(self):
+ interp = interpreters.create()
+
+ funcname = f'{__name__.replace(".", "_")}_spam_okay'
+ script = dedent(f"""
+ def {funcname}():
+ # This a global var...
+ return __name__
+ """)
+
+ with self.subTest('pickleable, added dynamically'):
+ with defined_in___main__(funcname, script) as arg:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ with self.subTest('lying about __main__'):
+ with defined_in___main__(funcname, script, remove=True) as arg:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ def test_func_in___main___hidden(self):
+ # When a top-level function that uses global variables is called
+ # through Interpreter.call(), it will be pickled, sent over,
+ # and unpickled. That requires that it be found in the other
+ # interpreter's __main__ module. However, the original script
+ # that defined the function is only run in the main interpreter,
+ # so pickle.loads() would normally fail.
+ #
+ # We work around this by running the script in the other
+ # interpreter. However, this is a one-off solution for the sake
+ # of unpickling, so we avoid modifying that interpreter's
+ # __main__ module by running the script in a hidden module.
+ #
+ # In this test we verify that the function runs with the hidden
+ # module as its __globals__ when called in the other interpreter,
+ # and that the interpreter's __main__ module is unaffected.
+ text = dedent("""
+ eggs = True
+
+ def spam(*, explicit=False):
+ if explicit:
+ import __main__
+ ns = __main__.__dict__
+ else:
+ # For now we have to have a LOAD_GLOBAL in the
+ # function in order for globals() to actually return
+ # spam.__globals__. Maybe it doesn't go through pickle?
+ # XXX We will fix this later.
+ spam
+ ns = globals()
+
+ func = ns.get('spam')
+ return [
+ id(ns),
+ ns.get('__name__'),
+ ns.get('__file__'),
+ id(func),
+ None if func is None else repr(func),
+ ns.get('eggs'),
+ ns.get('ham'),
+ ]
+
+ if __name__ == "__main__":
+ from concurrent import interpreters
+ interp = interpreters.create()
+
+ ham = True
+ print([
+ [
+ spam(explicit=True),
+ spam(),
+ ],
+ [
+ interp.call(spam, explicit=True),
+ interp.call(spam),
+ ],
+ ])
+ """)
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ local, remote = eval(stdout)
+
+ # In the main interpreter.
+ main, unpickled = local
+ nsid, _, _, funcid, func, _, _ = main
+ self.assertEqual(main, [
+ nsid,
+ '__main__',
+ filename,
+ funcid,
+ func,
+ True,
+ True,
+ ])
+ self.assertIsNot(func, None)
+ self.assertRegex(func, '^<function spam at 0x.*>$')
+ self.assertEqual(unpickled, main)
+
+ # In the subinterpreter.
+ main, unpickled = remote
+ nsid1, _, _, funcid1, _, _, _ = main
+ self.assertEqual(main, [
+ nsid1,
+ '__main__',
+ None,
+ funcid1,
+ None,
+ None,
+ None,
+ ])
+ nsid2, _, _, funcid2, func, _, _ = unpickled
+ self.assertEqual(unpickled, [
+ nsid2,
+ '<fake __main__>',
+ filename,
+ funcid2,
+ func,
+ True,
+ None,
+ ])
+ self.assertIsNot(func, None)
+ self.assertRegex(func, '^<function spam at 0x.*>$')
+ self.assertNotEqual(nsid2, nsid1)
+ self.assertNotEqual(funcid2, funcid1)
+
+ def test_func_in___main___uses_globals(self):
+ # See the note in test_func_in___main___hidden about pickle
+ # and the __main__ module.
+ #
+ # Additionally, the solution to that problem must provide
+ # for global variables on which a pickled function might rely.
+ #
+ # To check that, we run a script that has two global functions
+ # and a global variable in the __main__ module. One of the
+ # functions sets the global variable and the other returns
+ # the value.
+ #
+ # The script calls those functions multiple times in another
+ # interpreter, to verify the following:
+ #
+ # * the global variable is properly initialized
+ # * the global variable retains state between calls
+ # * the setter modifies that persistent variable
+ # * the getter uses the variable
+ # * the calls in the other interpreter do not modify
+ # the main interpreter
+ # * those calls don't modify the interpreter's __main__ module
+ # * the functions and variable do not actually show up in the
+ # other interpreter's __main__ module
+ text = dedent("""
+ count = 0
+
+ def inc(x=1):
+ global count
+ count += x
+
+ def get_count():
+ return count
+
+ if __name__ == "__main__":
+ counts = []
+ results = [count, counts]
+
+ from concurrent import interpreters
+ interp = interpreters.create()
+
+ val = interp.call(get_count)
+ counts.append(val)
+
+ interp.call(inc)
+ val = interp.call(get_count)
+ counts.append(val)
+
+ interp.call(inc, 3)
+ val = interp.call(get_count)
+ counts.append(val)
+
+ results.append(count)
+
+ modified = {name: interp.call(eval, f'{name!r} in vars()')
+ for name in ('count', 'inc', 'get_count')}
+ results.append(modified)
+
+ print(results)
+ """)
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ before, counts, after, modified = eval(stdout)
+ self.assertEqual(modified, {
+ 'count': False,
+ 'inc': False,
+ 'get_count': False,
+ })
+ self.assertEqual(before, 0)
+ self.assertEqual(after, 0)
+ self.assertEqual(counts, [0, 1, 4])
+
+ def test_raises(self):
+ interp = interpreters.create()
+ with self.assertRaises(ExecutionFailed):
+ interp.call(call_func_failure)
+
+ with self.assert_fails(ValueError):
+ interp.call(call_func_complex, '???', exc=ValueError('spam'))
+
+ def test_call_valid(self):
+ interp = interpreters.create()
+
+ for i, (callable, args, kwargs, expected) in enumerate([
+ (call_func_noop, (), {}, None),
+ (call_func_ident, ('spamspamspam',), {}, 'spamspamspam'),
+ (call_func_return_shareable, (), {}, (1, None)),
+ (call_func_return_pickleable, (), {}, [1, 2, 3]),
+ (Spam.noop, (), {}, None),
+ (Spam.from_values, (), {}, Spam(())),
+ (Spam.from_values, (1, 2, 3), {}, Spam((1, 2, 3))),
+ (Spam, ('???',), {}, Spam('???')),
+ (Spam(101), (), {}, (101, (), {})),
+ (Spam(10101).run, (), {}, (10101, (), {})),
+ (call_func_complex, ('ident', 'spam'), {}, 'spam'),
+ (call_func_complex, ('full-ident', 'spam'), {}, ('spam', (), {})),
+ (call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'},
+ ('spam', ('ham',), {'eggs': '!!!'})),
+ (call_func_complex, ('globals',), {}, __name__),
+ (call_func_complex, ('interpid',), {}, interp.id),
+ (call_func_complex, ('custom', 'spam!'), {}, Spam('spam!')),
]):
with self.subTest(f'success case #{i+1}'):
- res = interp.call(callable)
- self.assertIs(res, None)
+ res = interp.call(callable, *args, **kwargs)
+ self.assertEqual(res, expected)
+
+ def test_call_invalid(self):
+ interp = interpreters.create()
+
+ func = get_call_func_closure
+ with self.subTest(func):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, 42)
+
+ func = get_call_func_closure(42)
+ with self.subTest(func):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+
+ func = call_func_complex
+ op = 'closure'
+ with self.subTest(f'{func} ({op})'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, op, value='~~~')
+
+ op = 'custom-inner'
+ with self.subTest(f'{func} ({op})'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, op, 'eggs!')
+
+ def test_callable_requires_frame(self):
+ # There are various functions that require a current frame.
+ interp = interpreters.create()
+ for call, expected in [
+ ((eval, '[1, 2, 3]'),
+ [1, 2, 3]),
+ ((eval, 'sum([1, 2, 3])'),
+ 6),
+ ((exec, '...'),
+ None),
+ ]:
+ with self.subTest(str(call)):
+ res = interp.call(*call)
+ self.assertEqual(res, expected)
+
+ result_not_pickleable = [
+ globals,
+ locals,
+ vars,
+ ]
+ for func, expectedtype in {
+ globals: dict,
+ locals: dict,
+ vars: dict,
+ dir: list,
+ }.items():
+ with self.subTest(str(func)):
+ if func in result_not_pickleable:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+ else:
+ res = interp.call(func)
+ self.assertIsInstance(res, expectedtype)
+ self.assertIn('__builtins__', res)
+
+ def test_globals_from_builtins(self):
+ # The builtins exec(), eval(), globals(), locals(), vars(),
+ # and dir() each runs relative to the target interpreter's
+ # __main__ module, when called directly. However,
+ # globals(), locals(), and vars() don't work when called
+ # directly so we don't check them.
+ from _frozen_importlib import BuiltinImporter
+ interp = interpreters.create()
+
+ names = interp.call(dir)
+ self.assertEqual(names, [
+ '__builtins__',
+ '__doc__',
+ '__loader__',
+ '__name__',
+ '__package__',
+ '__spec__',
+ ])
+
+ values = {name: interp.call(eval, name)
+ for name in names if name != '__builtins__'}
+ self.assertEqual(values, {
+ '__name__': '__main__',
+ '__doc__': None,
+ '__spec__': None, # It wasn't imported, so no module spec?
+ '__package__': None,
+ '__loader__': BuiltinImporter,
+ })
+ with self.assertRaises(ExecutionFailed):
+ interp.call(eval, 'spam'),
+
+ interp.call(exec, f'assert dir() == {names}')
+
+ # Update the interpreter's __main__.
+ interp.prepare_main(spam=42)
+ expected = names + ['spam']
+
+ names = interp.call(dir)
+ self.assertEqual(names, expected)
+
+ value = interp.call(eval, 'spam')
+ self.assertEqual(value, 42)
+
+ interp.call(exec, f'assert dir() == {expected}, dir()')
+
+ def test_globals_from_stateless_func(self):
+ # A stateless func, which doesn't depend on any globals,
+ # doesn't go through pickle, so it runs in __main__.
+ def set_global(name, value):
+ globals()[name] = value
+
+ def get_global(name):
+ return globals().get(name)
+
+ interp = interpreters.create()
+
+ modname = interp.call(get_global, '__name__')
+ self.assertEqual(modname, '__main__')
+
+ res = interp.call(get_global, 'spam')
+ self.assertIsNone(res)
+
+ interp.exec('spam = True')
+ res = interp.call(get_global, 'spam')
+ self.assertTrue(res)
+
+ interp.call(set_global, 'spam', 42)
+ res = interp.call(get_global, 'spam')
+ self.assertEqual(res, 42)
+
+ interp.exec('assert spam == 42, repr(spam)')
+
+ def test_call_in_thread(self):
+ interp = interpreters.create()
for i, (callable, args, kwargs) in enumerate([
- (call_func_ident, ('spamspamspam',), {}),
- (get_call_func_closure, (42,), {}),
- (get_call_func_closure(42), (), {}),
+ (call_func_noop, (), {}),
+ (call_func_return_shareable, (), {}),
+ (call_func_return_pickleable, (), {}),
(Spam.from_values, (), {}),
(Spam.from_values, (1, 2, 3), {}),
- (Spam, ('???'), {}),
(Spam(101), (), {}),
(Spam(10101).run, (), {}),
+ (Spam.noop, (), {}),
(call_func_complex, ('ident', 'spam'), {}),
(call_func_complex, ('full-ident', 'spam'), {}),
(call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'}),
(call_func_complex, ('globals',), {}),
(call_func_complex, ('interpid',), {}),
- (call_func_complex, ('closure',), {'value': '~~~'}),
(call_func_complex, ('custom', 'spam!'), {}),
- (call_func_complex, ('custom-inner', 'eggs!'), {}),
- (call_func_complex, ('???',), {'exc': ValueError('spam')}),
- ]):
- with self.subTest(f'invalid case #{i+1}'):
- with self.assertRaises(Exception):
- if args or kwargs:
- raise Exception((args, kwargs))
- interp.call(callable)
-
- with self.assertRaises(ExecutionFailed):
- interp.call(call_func_failure)
-
- def test_call_in_thread(self):
- interp = interpreters.create()
-
- for i, (callable, args, kwargs) in enumerate([
- (call_func_noop, (), {}),
- (call_func_return_shareable, (), {}),
- (call_func_return_not_shareable, (), {}),
- (Spam.noop, (), {}),
]):
with self.subTest(f'success case #{i+1}'):
with self.captured_thread_exception() as ctx:
- t = interp.call_in_thread(callable)
+ t = interp.call_in_thread(callable, *args, **kwargs)
t.join()
self.assertIsNone(ctx.caught)
for i, (callable, args, kwargs) in enumerate([
- (call_func_ident, ('spamspamspam',), {}),
(get_call_func_closure, (42,), {}),
(get_call_func_closure(42), (), {}),
- (Spam.from_values, (), {}),
- (Spam.from_values, (1, 2, 3), {}),
- (Spam, ('???'), {}),
- (Spam(101), (), {}),
- (Spam(10101).run, (), {}),
- (call_func_complex, ('ident', 'spam'), {}),
- (call_func_complex, ('full-ident', 'spam'), {}),
- (call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'}),
- (call_func_complex, ('globals',), {}),
- (call_func_complex, ('interpid',), {}),
- (call_func_complex, ('closure',), {'value': '~~~'}),
- (call_func_complex, ('custom', 'spam!'), {}),
- (call_func_complex, ('custom-inner', 'eggs!'), {}),
- (call_func_complex, ('???',), {'exc': ValueError('spam')}),
]):
with self.subTest(f'invalid case #{i+1}'):
- if args or kwargs:
- continue
with self.captured_thread_exception() as ctx:
- t = interp.call_in_thread(callable)
+ t = interp.call_in_thread(callable, *args, **kwargs)
t.join()
self.assertIsNotNone(ctx.caught)
@@ -1452,6 +2117,14 @@ class LowLevelTests(TestBase):
self.assertFalse(
self.interp_exists(interpid))
+ with self.subTest('basic C-API'):
+ interpid = _testinternalcapi.create_interpreter()
+ self.assertTrue(
+ self.interp_exists(interpid))
+ _testinternalcapi.destroy_interpreter(interpid, basic=True)
+ self.assertFalse(
+ self.interp_exists(interpid))
+
def test_get_config(self):
# This test overlaps with
# test.test_capi.test_misc.InterpreterConfigTests.
@@ -1585,18 +2258,14 @@ class LowLevelTests(TestBase):
with results:
exc = _interpreters.exec(interpid, script)
out = results.stdout()
- self.assertEqual(out, '')
- self.assert_ns_equal(exc, types.SimpleNamespace(
- type=types.SimpleNamespace(
- __name__='Exception',
- __qualname__='Exception',
- __module__='builtins',
- ),
- msg='uh-oh!',
+ expected = build_excinfo(
+ Exception, 'uh-oh!',
# We check these in other tests.
formatted=exc.formatted,
errdisplay=exc.errdisplay,
- ))
+ )
+ self.assertEqual(out, '')
+ self.assert_ns_equal(exc, expected)
with self.subTest('from C-API'):
with self.interpreter_from_capi() as interpid:
@@ -1608,25 +2277,50 @@ class LowLevelTests(TestBase):
self.assertEqual(exc.msg, 'it worked!')
def test_call(self):
- with self.subTest('no args'):
- interpid = _interpreters.create()
- exc = _interpreters.call(interpid, call_func_return_shareable)
- self.assertIs(exc, None)
+ interpid = _interpreters.create()
+
+ # Here we focus on basic args and return values.
+ # See TestInterpreterCall for full operational coverage,
+ # including supported callables.
+
+ with self.subTest('no args, return None'):
+ func = defs.spam_minimal
+ res, exc = _interpreters.call(interpid, func)
+ self.assertIsNone(exc)
+ self.assertIsNone(res)
+
+ with self.subTest('empty args, return None'):
+ func = defs.spam_minimal
+ res, exc = _interpreters.call(interpid, func, (), {})
+ self.assertIsNone(exc)
+ self.assertIsNone(res)
+
+ with self.subTest('no args, return non-None'):
+ func = defs.script_with_return
+ res, exc = _interpreters.call(interpid, func)
+ self.assertIsNone(exc)
+ self.assertIs(res, True)
+
+ with self.subTest('full args, return non-None'):
+ expected = (1, 2, 3, 4, 5, 6, (7, 8), {'g': 9, 'h': 0})
+ func = defs.spam_full_args
+ args = (1, 2, 3, 4, 7, 8)
+ kwargs = dict(e=5, f=6, g=9, h=0)
+ res, exc = _interpreters.call(interpid, func, args, kwargs)
+ self.assertIsNone(exc)
+ self.assertEqual(res, expected)
with self.subTest('uncaught exception'):
- interpid = _interpreters.create()
- exc = _interpreters.call(interpid, call_func_failure)
- self.assertEqual(exc, types.SimpleNamespace(
- type=types.SimpleNamespace(
- __name__='Exception',
- __qualname__='Exception',
- __module__='builtins',
- ),
- msg='spam!',
+ func = defs.spam_raises
+ res, exc = _interpreters.call(interpid, func)
+ expected = build_excinfo(
+ Exception, 'spam!',
# We check these in other tests.
formatted=exc.formatted,
errdisplay=exc.errdisplay,
- ))
+ )
+ self.assertIsNone(res)
+ self.assertEqual(exc, expected)
@requires_test_modules
def test_set___main___attrs(self):
diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py
index eada18f99d0..109ddf34453 100644
--- a/Lib/test/test_interpreters/test_channels.py
+++ b/Lib/test/test_interpreters/test_channels.py
@@ -8,8 +8,8 @@ import time
from test.support import import_helper
# Raise SkipTest if subinterpreters not supported.
_channels = import_helper.import_module('_interpchannels')
-from test.support import interpreters
-from test.support.interpreters import channels
+from concurrent import interpreters
+from test.support import channels
from .utils import _run_output, TestBase
@@ -171,7 +171,7 @@ class TestSendRecv(TestBase):
def test_send_recv_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import channels
+ from test.support import channels
r, s = channels.create()
orig = b'spam'
s.send_nowait(orig)
@@ -244,7 +244,7 @@ class TestSendRecv(TestBase):
def test_send_recv_nowait_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import channels
+ from test.support import channels
r, s = channels.create()
orig = b'spam'
s.send_nowait(orig)
@@ -377,17 +377,17 @@ class TestSendRecv(TestBase):
if not unbound:
extraargs = ''
elif unbound is channels.UNBOUND:
- extraargs = ', unbound=channels.UNBOUND'
+ extraargs = ', unbounditems=channels.UNBOUND'
elif unbound is channels.UNBOUND_ERROR:
- extraargs = ', unbound=channels.UNBOUND_ERROR'
+ extraargs = ', unbounditems=channels.UNBOUND_ERROR'
elif unbound is channels.UNBOUND_REMOVE:
- extraargs = ', unbound=channels.UNBOUND_REMOVE'
+ extraargs = ', unbounditems=channels.UNBOUND_REMOVE'
else:
raise NotImplementedError(repr(unbound))
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
sch = channels.SendChannel({sch.id})
obj1 = b'spam'
obj2 = b'eggs'
@@ -454,11 +454,11 @@ class TestSendRecv(TestBase):
with self.assertRaises(channels.ChannelEmptyError):
rch.recv_nowait()
- sch.send_nowait(b'ham', unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(b'ham', unbounditems=channels.UNBOUND_REMOVE)
self.assertEqual(_channels.get_count(rch.id), 1)
interp = common(rch, sch, channels.UNBOUND_REMOVE, 1)
self.assertEqual(_channels.get_count(rch.id), 3)
- sch.send_nowait(42, unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(42, unbounditems=channels.UNBOUND_REMOVE)
self.assertEqual(_channels.get_count(rch.id), 4)
del interp
self.assertEqual(_channels.get_count(rch.id), 2)
@@ -482,13 +482,13 @@ class TestSendRecv(TestBase):
self.assertEqual(_channels.get_count(rch.id), 0)
_run_output(interp, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
sch = channels.SendChannel({sch.id})
- sch.send_nowait(1, unbound=channels.UNBOUND)
- sch.send_nowait(2, unbound=channels.UNBOUND_ERROR)
+ sch.send_nowait(1, unbounditems=channels.UNBOUND)
+ sch.send_nowait(2, unbounditems=channels.UNBOUND_ERROR)
sch.send_nowait(3)
- sch.send_nowait(4, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(5, unbound=channels.UNBOUND)
+ sch.send_nowait(4, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(5, unbounditems=channels.UNBOUND)
"""))
self.assertEqual(_channels.get_count(rch.id), 5)
@@ -518,15 +518,15 @@ class TestSendRecv(TestBase):
sch.send_nowait(1)
_run_output(interp1, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
rch = channels.RecvChannel({rch.id})
sch = channels.SendChannel({sch.id})
obj1 = rch.recv()
- sch.send_nowait(2, unbound=channels.UNBOUND)
- sch.send_nowait(obj1, unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(2, unbounditems=channels.UNBOUND)
+ sch.send_nowait(obj1, unbounditems=channels.UNBOUND_REMOVE)
"""))
_run_output(interp2, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
rch = channels.RecvChannel({rch.id})
sch = channels.SendChannel({sch.id})
obj2 = rch.recv()
@@ -535,21 +535,21 @@ class TestSendRecv(TestBase):
self.assertEqual(_channels.get_count(rch.id), 0)
sch.send_nowait(3)
_run_output(interp1, dedent("""
- sch.send_nowait(4, unbound=channels.UNBOUND)
+ sch.send_nowait(4, unbounditems=channels.UNBOUND)
# interp closed here
- sch.send_nowait(5, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(6, unbound=channels.UNBOUND)
+ sch.send_nowait(5, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(6, unbounditems=channels.UNBOUND)
"""))
_run_output(interp2, dedent("""
- sch.send_nowait(7, unbound=channels.UNBOUND_ERROR)
+ sch.send_nowait(7, unbounditems=channels.UNBOUND_ERROR)
# interp closed here
- sch.send_nowait(obj1, unbound=channels.UNBOUND_ERROR)
- sch.send_nowait(obj2, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(8, unbound=channels.UNBOUND)
+ sch.send_nowait(obj1, unbounditems=channels.UNBOUND_ERROR)
+ sch.send_nowait(obj2, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(8, unbounditems=channels.UNBOUND)
"""))
_run_output(interp1, dedent("""
- sch.send_nowait(9, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(10, unbound=channels.UNBOUND)
+ sch.send_nowait(9, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(10, unbounditems=channels.UNBOUND)
"""))
self.assertEqual(_channels.get_count(rch.id), 10)
diff --git a/Lib/test/test_interpreters/test_lifecycle.py b/Lib/test/test_interpreters/test_lifecycle.py
index ac24f6568ac..15537ac6cc8 100644
--- a/Lib/test/test_interpreters/test_lifecycle.py
+++ b/Lib/test/test_interpreters/test_lifecycle.py
@@ -119,7 +119,7 @@ class StartupTests(TestBase):
# The main interpreter's sys.path[0] should be used by subinterpreters.
script = '''
import sys
- from test.support import interpreters
+ from concurrent import interpreters
orig = sys.path[0]
@@ -170,7 +170,7 @@ class FinalizationTests(TestBase):
# is reported, even when subinterpreters get cleaned up at the end.
import subprocess
argv = [sys.executable, '-c', '''if True:
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.create()
raise Exception
''']
diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py
index 18f83d097eb..cb17340f581 100644
--- a/Lib/test/test_interpreters/test_queues.py
+++ b/Lib/test/test_interpreters/test_queues.py
@@ -7,8 +7,8 @@ import unittest
from test.support import import_helper, Py_DEBUG
# Raise SkipTest if subinterpreters not supported.
_queues = import_helper.import_module('_interpqueues')
-from test.support import interpreters
-from test.support.interpreters import queues, _crossinterp
+from concurrent import interpreters
+from concurrent.interpreters import _queues as queues, _crossinterp
from .utils import _run_output, TestBase as _TestBase
@@ -42,7 +42,7 @@ class LowLevelTests(TestBase):
importlib.reload(queues)
def test_create_destroy(self):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)
with self.assertRaises(queues.QueueNotFoundError):
@@ -56,7 +56,7 @@ class LowLevelTests(TestBase):
'-c',
dedent(f"""
import {_queues.__name__} as _queues
- _queues.create(2, 0, {REPLACE})
+ _queues.create(2, {REPLACE}, -1)
"""),
)
self.assertEqual(stdout, '')
@@ -67,13 +67,13 @@ class LowLevelTests(TestBase):
def test_bind_release(self):
with self.subTest('typical'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)
with self.subTest('bind too much'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
@@ -81,7 +81,7 @@ class LowLevelTests(TestBase):
self.assertEqual(get_num_queues(), 0)
with self.subTest('nested'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
@@ -89,7 +89,7 @@ class LowLevelTests(TestBase):
self.assertEqual(get_num_queues(), 0)
with self.subTest('release without binding'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
with self.assertRaises(queues.QueueError):
_queues.release(qid)
@@ -126,19 +126,19 @@ class QueueTests(TestBase):
interp = interpreters.create()
interp.exec(dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue1 = queues.Queue({queue1.id})
"""));
with self.subTest('same interpreter'):
queue2 = queues.create()
- queue1.put(queue2, syncobj=True)
+ queue1.put(queue2)
queue3 = queue1.get()
self.assertIs(queue3, queue2)
with self.subTest('from current interpreter'):
queue4 = queues.create()
- queue1.put(queue4, syncobj=True)
+ queue1.put(queue4)
out = _run_output(interp, dedent("""
queue4 = queue1.get()
print(queue4.id)
@@ -149,7 +149,7 @@ class QueueTests(TestBase):
with self.subTest('from subinterpreter'):
out = _run_output(interp, dedent("""
queue5 = queues.create()
- queue1.put(queue5, syncobj=True)
+ queue1.put(queue5)
print(queue5.id)
"""))
qid = int(out)
@@ -198,7 +198,7 @@ class TestQueueOps(TestBase):
def test_empty(self):
queue = queues.create()
before = queue.empty()
- queue.put(None, syncobj=True)
+ queue.put(None)
during = queue.empty()
queue.get()
after = queue.empty()
@@ -208,18 +208,64 @@ class TestQueueOps(TestBase):
self.assertIs(after, True)
def test_full(self):
- expected = [False, False, False, True, False, False, False]
- actual = []
- queue = queues.create(3)
- for _ in range(3):
- actual.append(queue.full())
- queue.put(None, syncobj=True)
- actual.append(queue.full())
- for _ in range(3):
- queue.get()
- actual.append(queue.full())
+ for maxsize in [1, 3, 11]:
+ with self.subTest(f'maxsize={maxsize}'):
+ num_to_add = maxsize
+ expected = [False] * (num_to_add * 2 + 3)
+ expected[maxsize] = True
+ expected[maxsize + 1] = True
+
+ queue = queues.create(maxsize)
+ actual = []
+ empty = [queue.empty()]
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.put_nowait(None)
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueFull):
+ queue.put_nowait(None)
+ empty.append(queue.empty())
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.get_nowait()
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueEmpty):
+ queue.get_nowait()
+ actual.append(queue.full())
+ empty.append(queue.empty())
- self.assertEqual(actual, expected)
+ self.assertEqual(actual, expected)
+ self.assertEqual(empty, [True, False, True])
+
+ # no max size
+ for args in [(), (0,), (-1,), (-10,)]:
+ with self.subTest(f'maxsize={args[0]}' if args else '<default>'):
+ num_to_add = 13
+ expected = [False] * (num_to_add * 2 + 3)
+
+ queue = queues.create(*args)
+ actual = []
+ empty = [queue.empty()]
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.put_nowait(None)
+ actual.append(queue.full())
+ empty.append(queue.empty())
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.get_nowait()
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueEmpty):
+ queue.get_nowait()
+ actual.append(queue.full())
+ empty.append(queue.empty())
+
+ self.assertEqual(actual, expected)
+ self.assertEqual(empty, [True, False, True])
def test_qsize(self):
expected = [0, 1, 2, 3, 2, 3, 2, 1, 0, 1, 0]
@@ -227,16 +273,16 @@ class TestQueueOps(TestBase):
queue = queues.create()
for _ in range(3):
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
queue.get()
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
for _ in range(3):
queue.get()
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
queue.get()
actual.append(queue.qsize())
@@ -245,70 +291,32 @@ class TestQueueOps(TestBase):
def test_put_get_main(self):
expected = list(range(20))
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create()
- for i in range(20):
- queue.put(i, **kwds)
- actual = [queue.get() for _ in range(20)]
+ queue = queues.create()
+ for i in range(20):
+ queue.put(i)
+ actual = [queue.get() for _ in range(20)]
- self.assertEqual(actual, expected)
+ self.assertEqual(actual, expected)
def test_put_timeout(self):
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create(2)
- queue.put(None, **kwds)
- queue.put(None, **kwds)
- with self.assertRaises(queues.QueueFull):
- queue.put(None, timeout=0.1, **kwds)
- queue.get()
- queue.put(None, **kwds)
+ queue = queues.create(2)
+ queue.put(None)
+ queue.put(None)
+ with self.assertRaises(queues.QueueFull):
+ queue.put(None, timeout=0.1)
+ queue.get()
+ queue.put(None)
def test_put_nowait(self):
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create(2)
- queue.put_nowait(None, **kwds)
- queue.put_nowait(None, **kwds)
- with self.assertRaises(queues.QueueFull):
- queue.put_nowait(None, **kwds)
- queue.get()
- queue.put_nowait(None, **kwds)
-
- def test_put_syncobj(self):
- for obj in [
- None,
- True,
- 10,
- 'spam',
- b'spam',
- (0, 'a'),
- ]:
- with self.subTest(repr(obj)):
- queue = queues.create()
-
- queue.put(obj, syncobj=True)
- obj2 = queue.get()
- self.assertEqual(obj2, obj)
-
- queue.put(obj, syncobj=True)
- obj2 = queue.get_nowait()
- self.assertEqual(obj2, obj)
-
- for obj in [
- [1, 2, 3],
- {'a': 13, 'b': 17},
- ]:
- with self.subTest(repr(obj)):
- queue = queues.create()
- with self.assertRaises(interpreters.NotShareableError):
- queue.put(obj, syncobj=True)
+ queue = queues.create(2)
+ queue.put_nowait(None)
+ queue.put_nowait(None)
+ with self.assertRaises(queues.QueueFull):
+ queue.put_nowait(None)
+ queue.get()
+ queue.put_nowait(None)
- def test_put_not_syncobj(self):
+ def test_put_full_fallback(self):
for obj in [
None,
True,
@@ -323,11 +331,11 @@ class TestQueueOps(TestBase):
with self.subTest(repr(obj)):
queue = queues.create()
- queue.put(obj, syncobj=False)
+ queue.put(obj)
obj2 = queue.get()
self.assertEqual(obj2, obj)
- queue.put(obj, syncobj=False)
+ queue.put(obj)
obj2 = queue.get_nowait()
self.assertEqual(obj2, obj)
@@ -341,24 +349,9 @@ class TestQueueOps(TestBase):
with self.assertRaises(queues.QueueEmpty):
queue.get_nowait()
- def test_put_get_default_syncobj(self):
- expected = list(range(20))
- queue = queues.create(syncobj=True)
- for methname in ('get', 'get_nowait'):
- with self.subTest(f'{methname}()'):
- get = getattr(queue, methname)
- for i in range(20):
- queue.put(i)
- actual = [get() for _ in range(20)]
- self.assertEqual(actual, expected)
-
- obj = [1, 2, 3] # lists are not shareable
- with self.assertRaises(interpreters.NotShareableError):
- queue.put(obj)
-
- def test_put_get_default_not_syncobj(self):
+ def test_put_get_full_fallback(self):
expected = list(range(20))
- queue = queues.create(syncobj=False)
+ queue = queues.create()
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
get = getattr(queue, methname)
@@ -377,14 +370,14 @@ class TestQueueOps(TestBase):
def test_put_get_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.create()
"""))
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
interp.exec(dedent(f"""
orig = b'spam'
- queue.put(orig, syncobj=True)
+ queue.put(orig)
obj = queue.{methname}()
assert obj == orig, 'expected: obj == orig'
assert obj is not orig, 'expected: obj is not orig'
@@ -399,12 +392,12 @@ class TestQueueOps(TestBase):
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
obj1 = b'spam'
- queue1.put(obj1, syncobj=True)
+ queue1.put(obj1)
out = _run_output(
interp,
dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue1 = queues.Queue({queue1.id})
queue2 = queues.Queue({queue2.id})
assert queue1.qsize() == 1, 'expected: queue1.qsize() == 1'
@@ -416,7 +409,7 @@ class TestQueueOps(TestBase):
obj2 = b'eggs'
print(id(obj2))
assert queue2.qsize() == 0, 'expected: queue2.qsize() == 0'
- queue2.put(obj2, syncobj=True)
+ queue2.put(obj2)
assert queue2.qsize() == 1, 'expected: queue2.qsize() == 1'
"""))
self.assertEqual(len(queues.list_all()), 2)
@@ -433,22 +426,22 @@ class TestQueueOps(TestBase):
if not unbound:
extraargs = ''
elif unbound is queues.UNBOUND:
- extraargs = ', unbound=queues.UNBOUND'
+ extraargs = ', unbounditems=queues.UNBOUND'
elif unbound is queues.UNBOUND_ERROR:
- extraargs = ', unbound=queues.UNBOUND_ERROR'
+ extraargs = ', unbounditems=queues.UNBOUND_ERROR'
elif unbound is queues.UNBOUND_REMOVE:
- extraargs = ', unbound=queues.UNBOUND_REMOVE'
+ extraargs = ', unbounditems=queues.UNBOUND_REMOVE'
else:
raise NotImplementedError(repr(unbound))
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj1 = b'spam'
obj2 = b'eggs'
- queue.put(obj1, syncobj=True{extraargs})
- queue.put(obj2, syncobj=True{extraargs})
+ queue.put(obj1{extraargs})
+ queue.put(obj2{extraargs})
"""))
self.assertEqual(queue.qsize(), presize + 2)
@@ -501,11 +494,11 @@ class TestQueueOps(TestBase):
with self.assertRaises(queues.QueueEmpty):
queue.get_nowait()
- queue.put(b'ham', unbound=queues.UNBOUND_REMOVE)
+ queue.put(b'ham', unbounditems=queues.UNBOUND_REMOVE)
self.assertEqual(queue.qsize(), 1)
interp = common(queue, queues.UNBOUND_REMOVE, 1)
self.assertEqual(queue.qsize(), 3)
- queue.put(42, unbound=queues.UNBOUND_REMOVE)
+ queue.put(42, unbounditems=queues.UNBOUND_REMOVE)
self.assertEqual(queue.qsize(), 4)
del interp
self.assertEqual(queue.qsize(), 2)
@@ -521,13 +514,13 @@ class TestQueueOps(TestBase):
queue = queues.create()
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
- queue.put(1, syncobj=True, unbound=queues.UNBOUND)
- queue.put(2, syncobj=True, unbound=queues.UNBOUND_ERROR)
- queue.put(3, syncobj=True)
- queue.put(4, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(5, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(1, unbounditems=queues.UNBOUND)
+ queue.put(2, unbounditems=queues.UNBOUND_ERROR)
+ queue.put(3)
+ queue.put(4, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(5, unbounditems=queues.UNBOUND)
"""))
self.assertEqual(queue.qsize(), 5)
@@ -555,16 +548,16 @@ class TestQueueOps(TestBase):
interp1 = interpreters.create()
interp2 = interpreters.create()
- queue.put(1, syncobj=True)
+ queue.put(1)
_run_output(interp1, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj1 = queue.get()
- queue.put(2, syncobj=True, unbound=queues.UNBOUND)
- queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_REMOVE)
+ queue.put(2, unbounditems=queues.UNBOUND)
+ queue.put(obj1, unbounditems=queues.UNBOUND_REMOVE)
"""))
_run_output(interp2, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj2 = queue.get()
obj1 = queue.get()
@@ -572,21 +565,21 @@ class TestQueueOps(TestBase):
self.assertEqual(queue.qsize(), 0)
queue.put(3)
_run_output(interp1, dedent("""
- queue.put(4, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(4, unbounditems=queues.UNBOUND)
# interp closed here
- queue.put(5, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(6, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(5, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(6, unbounditems=queues.UNBOUND)
"""))
_run_output(interp2, dedent("""
- queue.put(7, syncobj=True, unbound=queues.UNBOUND_ERROR)
+ queue.put(7, unbounditems=queues.UNBOUND_ERROR)
# interp closed here
- queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_ERROR)
- queue.put(obj2, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(8, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(obj1, unbounditems=queues.UNBOUND_ERROR)
+ queue.put(obj2, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(8, unbounditems=queues.UNBOUND)
"""))
_run_output(interp1, dedent("""
- queue.put(9, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(10, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(9, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(10, unbounditems=queues.UNBOUND)
"""))
self.assertEqual(queue.qsize(), 10)
@@ -642,12 +635,12 @@ class TestQueueOps(TestBase):
break
except queues.QueueEmpty:
continue
- queue2.put(obj, syncobj=True)
+ queue2.put(obj)
t = threading.Thread(target=f)
t.start()
orig = b'spam'
- queue1.put(orig, syncobj=True)
+ queue1.put(orig)
obj = queue2.get()
t.join()
diff --git a/Lib/test/test_interpreters/test_stress.py b/Lib/test/test_interpreters/test_stress.py
index 56bfc172199..e25e67a0d4f 100644
--- a/Lib/test/test_interpreters/test_stress.py
+++ b/Lib/test/test_interpreters/test_stress.py
@@ -6,7 +6,7 @@ from test.support import import_helper
from test.support import threading_helper
# Raise SkipTest if subinterpreters not supported.
import_helper.import_module('_interpreters')
-from test.support import interpreters
+from concurrent import interpreters
from .utils import TestBase
@@ -21,21 +21,29 @@ class StressTests(TestBase):
for _ in range(100):
interp = interpreters.create()
alive.append(interp)
+ del alive
+ support.gc_collect()
- @support.requires_resource('cpu')
- @threading_helper.requires_working_threading()
- def test_create_many_threaded(self):
+ @support.bigmemtest(size=200, memuse=32*2**20, dry_run=False)
+ def test_create_many_threaded(self, size):
alive = []
+ start = threading.Event()
def task():
+ # try to create all interpreters simultaneously
+ if not start.wait(support.SHORT_TIMEOUT):
+ raise TimeoutError
interp = interpreters.create()
alive.append(interp)
- threads = (threading.Thread(target=task) for _ in range(200))
+ threads = [threading.Thread(target=task) for _ in range(size)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
+ del alive
+ support.gc_collect()
- @support.requires_resource('cpu')
@threading_helper.requires_working_threading()
- def test_many_threads_running_interp_in_other_interp(self):
+ @support.bigmemtest(size=200, memuse=34*2**20, dry_run=False)
+ def test_many_threads_running_interp_in_other_interp(self, size):
+ start = threading.Event()
interp = interpreters.create()
script = f"""if True:
@@ -47,6 +55,9 @@ class StressTests(TestBase):
interp = interpreters.create()
alreadyrunning = (f'{interpreters.InterpreterError}: '
'interpreter already running')
+ # try to run all interpreters simultaneously
+ if not start.wait(support.SHORT_TIMEOUT):
+ raise TimeoutError
success = False
while not success:
try:
@@ -58,9 +69,10 @@ class StressTests(TestBase):
else:
success = True
- threads = (threading.Thread(target=run) for _ in range(200))
+ threads = [threading.Thread(target=run) for _ in range(size)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
+ support.gc_collect()
if __name__ == '__main__':
diff --git a/Lib/test/test_interpreters/utils.py b/Lib/test/test_interpreters/utils.py
index fc4ad662e03..ae09aa457b4 100644
--- a/Lib/test/test_interpreters/utils.py
+++ b/Lib/test/test_interpreters/utils.py
@@ -12,7 +12,6 @@ from textwrap import dedent
import threading
import types
import unittest
-import warnings
from test import support
@@ -22,7 +21,7 @@ try:
import _interpreters
except ImportError as exc:
raise unittest.SkipTest(str(exc))
-from test.support import interpreters
+from concurrent import interpreters
try:
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 545643aa455..0c921ffbc25 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -572,7 +572,7 @@ class IOTest(unittest.TestCase):
for [test, abilities] in tests:
with self.subTest(test):
if test == pipe_writer and not threading_helper.can_start_thread:
- skipTest()
+ self.skipTest("Need threads")
with test() as obj:
do_test(test, obj, abilities)
@@ -902,7 +902,7 @@ class IOTest(unittest.TestCase):
self.BytesIO()
)
for obj in test:
- self.assertTrue(hasattr(obj, "__dict__"))
+ self.assertHasAttr(obj, "__dict__")
def test_opener(self):
with self.open(os_helper.TESTFN, "w", encoding="utf-8") as f:
@@ -918,7 +918,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -1
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -1')
def test_bad_opener_other_negative(self):
@@ -926,7 +926,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -2
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -2')
def test_opener_invalid_fd(self):
@@ -1062,6 +1062,37 @@ class IOTest(unittest.TestCase):
# Silence destructor error
R.flush = lambda self: None
+ @threading_helper.requires_working_threading()
+ def test_write_readline_races(self):
+ # gh-134908: Concurrent iteration over a file caused races
+ thread_count = 2
+ write_count = 100
+ read_count = 100
+
+ def writer(file, barrier):
+ barrier.wait()
+ for _ in range(write_count):
+ file.write("x")
+
+ def reader(file, barrier):
+ barrier.wait()
+ for _ in range(read_count):
+ for line in file:
+ self.assertEqual(line, "")
+
+ with self.open(os_helper.TESTFN, "w+") as f:
+ barrier = threading.Barrier(thread_count + 1)
+ reader = threading.Thread(target=reader, args=(f, barrier))
+ writers = [threading.Thread(target=writer, args=(f, barrier))
+ for _ in range(thread_count)]
+ with threading_helper.catch_threading_exception() as cm:
+ with threading_helper.start_threads(writers + [reader]):
+ pass
+ self.assertIsNone(cm.exc_type)
+
+ self.assertEqual(os.stat(os_helper.TESTFN).st_size,
+ write_count * thread_count)
+
class CIOTest(IOTest):
@@ -1117,7 +1148,7 @@ class TestIOCTypes(unittest.TestCase):
def check_subs(types, base):
for tp in types:
with self.subTest(tp=tp, base=base):
- self.assertTrue(issubclass(tp, base))
+ self.assertIsSubclass(tp, base)
def recursive_check(d):
for k, v in d.items():
@@ -1373,6 +1404,28 @@ class CommonBufferedTests:
with self.assertRaises(AttributeError):
buf.raw = x
+ def test_pickling_subclass(self):
+ global MyBufferedIO
+ class MyBufferedIO(self.tp):
+ def __init__(self, raw, tag):
+ super().__init__(raw)
+ self.tag = tag
+ def __getstate__(self):
+ return self.tag, self.raw.getvalue()
+ def __setstate__(slf, state):
+ tag, value = state
+ slf.__init__(self.BytesIO(value), tag)
+
+ raw = self.BytesIO(b'data')
+ buf = MyBufferedIO(raw, tag='ham')
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(protocol=proto):
+ pickled = pickle.dumps(buf, proto)
+ newbuf = pickle.loads(pickled)
+ self.assertEqual(newbuf.raw.getvalue(), b'data')
+ self.assertEqual(newbuf.tag, 'ham')
+ del MyBufferedIO
+
class SizeofTest:
@@ -1848,7 +1901,7 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
flushed = b"".join(writer._write_stack)
# At least (total - 8) bytes were implicitly flushed, perhaps more
# depending on the implementation.
- self.assertTrue(flushed.startswith(contents[:-8]), flushed)
+ self.assertStartsWith(flushed, contents[:-8])
def check_writes(self, intermediate_func):
# Lots of writes, test the flushed output is as expected.
@@ -1918,7 +1971,7 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
self.assertEqual(bufio.write(b"ABCDEFGHI"), 9)
s = raw.pop_written()
# Previously buffered bytes were flushed
- self.assertTrue(s.startswith(b"01234567A"), s)
+ self.assertStartsWith(s, b"01234567A")
def test_write_and_rewind(self):
raw = self.BytesIO()
@@ -2214,7 +2267,7 @@ class BufferedRWPairTest(unittest.TestCase):
def test_peek(self):
pair = self.tp(self.BytesIO(b"abcdef"), self.MockRawIO())
- self.assertTrue(pair.peek(3).startswith(b"abc"))
+ self.assertStartsWith(pair.peek(3), b"abc")
self.assertEqual(pair.read(3), b"abc")
def test_readable(self):
@@ -3950,6 +4003,28 @@ class TextIOWrapperTest(unittest.TestCase):
f.write(res)
self.assertEqual(res + f.readline(), 'foo\nbar\n')
+ def test_pickling_subclass(self):
+ global MyTextIO
+ class MyTextIO(self.TextIOWrapper):
+ def __init__(self, raw, tag):
+ super().__init__(raw)
+ self.tag = tag
+ def __getstate__(self):
+ return self.tag, self.buffer.getvalue()
+ def __setstate__(slf, state):
+ tag, value = state
+ slf.__init__(self.BytesIO(value), tag)
+
+ raw = self.BytesIO(b'data')
+ txt = MyTextIO(raw, 'ham')
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(protocol=proto):
+ pickled = pickle.dumps(txt, proto)
+ newtxt = pickle.loads(pickled)
+ self.assertEqual(newtxt.buffer.getvalue(), b'data')
+ self.assertEqual(newtxt.tag, 'ham')
+ del MyTextIO
+
@unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
def test_read_non_blocking(self):
import os
@@ -4373,7 +4448,7 @@ class MiscIOTest(unittest.TestCase):
self._check_abc_inheritance(io)
def _check_warn_on_dealloc(self, *args, **kwargs):
- f = open(*args, **kwargs)
+ f = self.open(*args, **kwargs)
r = repr(f)
with self.assertWarns(ResourceWarning) as cm:
f = None
@@ -4402,7 +4477,7 @@ class MiscIOTest(unittest.TestCase):
r, w = os.pipe()
fds += r, w
with warnings_helper.check_no_resource_warning(self):
- open(r, *args, closefd=False, **kwargs)
+ self.open(r, *args, closefd=False, **kwargs)
@unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
def test_warn_on_dealloc_fd(self):
@@ -4574,10 +4649,8 @@ class MiscIOTest(unittest.TestCase):
proc = assert_python_ok('-X', 'warn_default_encoding', '-c', code)
warnings = proc.err.splitlines()
self.assertEqual(len(warnings), 2)
- self.assertTrue(
- warnings[0].startswith(b"<string>:5: EncodingWarning: "))
- self.assertTrue(
- warnings[1].startswith(b"<string>:8: EncodingWarning: "))
+ self.assertStartsWith(warnings[0], b"<string>:5: EncodingWarning: ")
+ self.assertStartsWith(warnings[1], b"<string>:8: EncodingWarning: ")
def test_text_encoding(self):
# PEP 597, bpo-47000. io.text_encoding() returns "locale" or "utf-8"
@@ -4790,7 +4863,7 @@ class SignalsTest(unittest.TestCase):
os.read(r, len(data) * 100)
exc = cm.exception
if isinstance(exc, RuntimeError):
- self.assertTrue(str(exc).startswith("reentrant call"), str(exc))
+ self.assertStartsWith(str(exc), "reentrant call")
finally:
signal.alarm(0)
wio.close()
diff --git a/Lib/test/test_ioctl.py b/Lib/test/test_ioctl.py
index 7a986048bda..277d2fc99ea 100644
--- a/Lib/test/test_ioctl.py
+++ b/Lib/test/test_ioctl.py
@@ -5,7 +5,7 @@ import sys
import threading
import unittest
from test import support
-from test.support import threading_helper
+from test.support import os_helper, threading_helper
from test.support.import_helper import import_module
fcntl = import_module('fcntl')
termios = import_module('termios')
@@ -127,9 +127,8 @@ class IoctlTestsTty(unittest.TestCase):
self._check_ioctl_not_mutate_len(1024)
def test_ioctl_mutate_2048(self):
- # Test with a larger buffer, just for the record.
self._check_ioctl_mutate_len(2048)
- self.assertRaises(ValueError, self._check_ioctl_not_mutate_len, 2048)
+ self._check_ioctl_not_mutate_len(1024)
@unittest.skipUnless(hasattr(os, 'openpty'), "need os.openpty()")
@@ -202,6 +201,17 @@ class IoctlTestsPty(unittest.TestCase):
new_winsz = struct.unpack("HHHH", result)
self.assertEqual(new_winsz[:2], (20, 40))
+ @unittest.skipUnless(hasattr(fcntl, 'FICLONE'), 'need fcntl.FICLONE')
+ def test_bad_fd(self):
+ # gh-134744: Test error handling
+ fd = os_helper.make_bad_fd()
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, fd)
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, b'\0' * 10)
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, b'\0' * 2048)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
index d04012d1afd..db1c38243e2 100644
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -397,6 +397,19 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6):
# A trailing IPv4 address is two parts
assertBadSplit("10:9:8:7:6:5:4:3:42.42.42.42%scope")
+ def test_bad_address_split_v6_too_long(self):
+ def assertBadSplit(addr):
+ msg = r"At most 45 characters expected in '%s"
+ with self.assertAddressError(msg, re.escape(addr[:45])):
+ ipaddress.IPv6Address(addr)
+
+ # Long IPv6 address
+ long_addr = ("0:" * 10000) + "0"
+ assertBadSplit(long_addr)
+ assertBadSplit(long_addr + "%zoneid")
+ assertBadSplit(long_addr + ":255.255.255.255")
+ assertBadSplit(long_addr + ":ffff:255.255.255.255")
+
def test_bad_address_split_v6_too_many_parts(self):
def assertBadSplit(addr):
msg = "Exactly 8 parts expected without '::' in %r"
@@ -2178,6 +2191,11 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(ipaddress.ip_address('FFFF::192.0.2.1'),
ipaddress.ip_address('FFFF::c000:201'))
+ self.assertEqual(ipaddress.ip_address('0000:0000:0000:0000:0000:FFFF:192.168.255.255'),
+ ipaddress.ip_address('::ffff:c0a8:ffff'))
+ self.assertEqual(ipaddress.ip_address('FFFF:0000:0000:0000:0000:0000:192.168.255.255'),
+ ipaddress.ip_address('ffff::c0a8:ffff'))
+
self.assertEqual(ipaddress.ip_address('::FFFF:192.0.2.1%scope'),
ipaddress.ip_address('::FFFF:c000:201%scope'))
self.assertEqual(ipaddress.ip_address('FFFF::192.0.2.1%scope'),
@@ -2190,6 +2208,10 @@ class IpaddrUnitTest(unittest.TestCase):
ipaddress.ip_address('::FFFF:c000:201%scope'))
self.assertNotEqual(ipaddress.ip_address('FFFF::192.0.2.1'),
ipaddress.ip_address('FFFF::c000:201%scope'))
+ self.assertEqual(ipaddress.ip_address('0000:0000:0000:0000:0000:FFFF:192.168.255.255%scope'),
+ ipaddress.ip_address('::ffff:c0a8:ffff%scope'))
+ self.assertEqual(ipaddress.ip_address('FFFF:0000:0000:0000:0000:0000:192.168.255.255%scope'),
+ ipaddress.ip_address('ffff::c0a8:ffff%scope'))
def testIPVersion(self):
self.assertEqual(ipaddress.IPv4Address.version, 4)
@@ -2599,6 +2621,10 @@ class IpaddrUnitTest(unittest.TestCase):
'::7:6:5:4:3:2:0': '0:7:6:5:4:3:2:0/128',
'7:6:5:4:3:2:1::': '7:6:5:4:3:2:1:0/128',
'0:6:5:4:3:2:1::': '0:6:5:4:3:2:1:0/128',
+ '0000:0000:0000:0000:0000:0000:255.255.255.255': '::ffff:ffff/128',
+ '0000:0000:0000:0000:0000:ffff:255.255.255.255': '::ffff:255.255.255.255/128',
+ 'ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255':
+ 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128',
}
for uncompressed, compressed in list(test_addresses.items()):
self.assertEqual(compressed, str(ipaddress.IPv6Interface(
@@ -2762,6 +2788,34 @@ class IpaddrUnitTest(unittest.TestCase):
ipv6_address2 = ipaddress.IPv6Interface("2001:658:22a:cafe:200:0:0:2")
self.assertNotEqual(ipv6_address1.__hash__(), ipv6_address2.__hash__())
+ # issue 134062 Hash collisions in IPv4Network and IPv6Network
+ def testNetworkV4HashCollisions(self):
+ self.assertNotEqual(
+ ipaddress.IPv4Network("192.168.1.255/32").__hash__(),
+ ipaddress.IPv4Network("192.168.1.0/24").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("172.24.255.0/24").__hash__(),
+ ipaddress.IPv4Network("172.24.0.0/16").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("192.168.1.87/32").__hash__(),
+ ipaddress.IPv4Network("192.168.1.86/31").__hash__()
+ )
+
+ # issue 134062 Hash collisions in IPv4Network and IPv6Network
+ def testNetworkV6HashCollisions(self):
+ self.assertNotEqual(
+ ipaddress.IPv6Network("fe80::/64").__hash__(),
+ ipaddress.IPv6Network("fe80::ffff:ffff:ffff:0/112").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("10.0.0.0/8").__hash__(),
+ ipaddress.IPv6Network(
+ "ffff:ffff:ffff:ffff:ffff:ffff:aff:0/112"
+ ).__hash__()
+ )
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_isinstance.py b/Lib/test/test_isinstance.py
index daad00e8643..f440fc28ee7 100644
--- a/Lib/test/test_isinstance.py
+++ b/Lib/test/test_isinstance.py
@@ -318,6 +318,7 @@ class TestIsInstanceIsSubclass(unittest.TestCase):
self.assertRaises(RecursionError, isinstance, 1, X())
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_infinite_recursion_via_bases_tuple(self):
"""Regression test for bpo-30570."""
class Failure(object):
@@ -328,6 +329,7 @@ class TestIsInstanceIsSubclass(unittest.TestCase):
issubclass(Failure(), int)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_infinite_cycle_in_bases(self):
"""Regression test for bpo-30570."""
class X:
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index 1b9f3cf7624..18e4b676c53 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -1147,7 +1147,7 @@ class TestCase(unittest.TestCase):
def test_exception_locations(self):
# The location of an exception raised from __init__ or
- # __next__ should should be the iterator expression
+ # __next__ should be the iterator expression
def init_raises():
try:
diff --git a/Lib/test/test_json/test_dump.py b/Lib/test/test_json/test_dump.py
index 13b40020781..39470754003 100644
--- a/Lib/test/test_json/test_dump.py
+++ b/Lib/test/test_json/test_dump.py
@@ -22,6 +22,14 @@ class TestDump:
self.assertIn('valid_key', o)
self.assertNotIn(b'invalid_key', o)
+ def test_dump_skipkeys_indent_empty(self):
+ v = {b'invalid_key': False}
+ self.assertEqual(self.json.dumps(v, skipkeys=True, indent=4), '{}')
+
+ def test_skipkeys_indent(self):
+ v = {b'invalid_key': False, 'valid_key': True}
+ self.assertEqual(self.json.dumps(v, skipkeys=True, indent=4), '{\n "valid_key": true\n}')
+
def test_encode_truefalse(self):
self.assertEqual(self.dumps(
{True: False, False: True}, sort_keys=True),
diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py
index 7c1696cc66d..79c44af2fbf 100644
--- a/Lib/test/test_json/test_fail.py
+++ b/Lib/test/test_json/test_fail.py
@@ -102,7 +102,7 @@ class TestFail:
with self.assertRaisesRegex(TypeError,
'Object of type module is not JSON serializable') as cm:
self.dumps(sys)
- self.assertFalse(hasattr(cm.exception, '__notes__'))
+ self.assertNotHasAttr(cm.exception, '__notes__')
with self.assertRaises(TypeError) as cm:
self.dumps([1, [2, 3, sys]])
diff --git a/Lib/test/test_json/test_recursion.py b/Lib/test/test_json/test_recursion.py
index d82093f3895..5d7b56ff9ad 100644
--- a/Lib/test/test_json/test_recursion.py
+++ b/Lib/test/test_json/test_recursion.py
@@ -69,6 +69,7 @@ class TestRecursion:
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_highly_nested_objects_decoding(self):
very_deep = 200000
# test that loading highly-nested objects doesn't segfault when C
@@ -85,6 +86,7 @@ class TestRecursion:
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow()
+ @support.requires_resource('cpu')
def test_highly_nested_objects_encoding(self):
# See #12051
l, d = [], {}
@@ -98,6 +100,7 @@ class TestRecursion:
self.dumps(d)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_endless_recursion(self):
# See #12051
class EndlessJSONEncoder(self.json.JSONEncoder):
diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py
index ba9c42f758e..30f9bb33316 100644
--- a/Lib/test/test_json/test_tool.py
+++ b/Lib/test/test_json/test_tool.py
@@ -6,9 +6,11 @@ import unittest
import subprocess
from test import support
-from test.support import force_not_colorized, os_helper
+from test.support import force_colorized, force_not_colorized, os_helper
from test.support.script_helper import assert_python_ok
+from _colorize import get_theme
+
@support.requires_subprocess()
class TestMain(unittest.TestCase):
@@ -158,7 +160,7 @@ class TestMain(unittest.TestCase):
rc, out, err = assert_python_ok('-m', self.module, '-h',
PYTHON_COLORS='0')
self.assertEqual(rc, 0)
- self.assertTrue(out.startswith(b'usage: '))
+ self.assertStartsWith(out, b'usage: ')
self.assertEqual(err, b'')
def test_sort_keys_flag(self):
@@ -246,34 +248,39 @@ class TestMain(unittest.TestCase):
proc.communicate(b'"{}"')
self.assertEqual(proc.returncode, errno.EPIPE)
+ @force_colorized
def test_colors(self):
infile = os_helper.TESTFN
self.addCleanup(os.remove, infile)
+ t = get_theme().syntax
+ ob = "{"
+ cb = "}"
+
cases = (
- ('{}', b'{}'),
- ('[]', b'[]'),
- ('null', b'\x1b[1;36mnull\x1b[0m'),
- ('true', b'\x1b[1;36mtrue\x1b[0m'),
- ('false', b'\x1b[1;36mfalse\x1b[0m'),
- ('NaN', b'NaN'),
- ('Infinity', b'Infinity'),
- ('-Infinity', b'-Infinity'),
- ('"foo"', b'\x1b[1;32m"foo"\x1b[0m'),
- (r'" \"foo\" "', b'\x1b[1;32m" \\"foo\\" "\x1b[0m'),
- ('"α"', b'\x1b[1;32m"\\u03b1"\x1b[0m'),
- ('123', b'123'),
- ('-1.2345e+23', b'-1.2345e+23'),
+ ('{}', '{}'),
+ ('[]', '[]'),
+ ('null', f'{t.keyword}null{t.reset}'),
+ ('true', f'{t.keyword}true{t.reset}'),
+ ('false', f'{t.keyword}false{t.reset}'),
+ ('NaN', f'{t.number}NaN{t.reset}'),
+ ('Infinity', f'{t.number}Infinity{t.reset}'),
+ ('-Infinity', f'{t.number}-Infinity{t.reset}'),
+ ('"foo"', f'{t.string}"foo"{t.reset}'),
+ (r'" \"foo\" "', f'{t.string}" \\"foo\\" "{t.reset}'),
+ ('"α"', f'{t.string}"\\u03b1"{t.reset}'),
+ ('123', f'{t.number}123{t.reset}'),
+ ('-1.25e+23', f'{t.number}-1.25e+23{t.reset}'),
(r'{"\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
(r'{"\\\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
('''\
{
"foo": "bar",
@@ -281,30 +288,32 @@ class TestMain(unittest.TestCase):
"qux": [true, false, null],
"xyz": [NaN, -Infinity, Infinity]
}''',
- b'''\
-{
- \x1b[94m"foo"\x1b[0m: \x1b[1;32m"bar"\x1b[0m,
- \x1b[94m"baz"\x1b[0m: 1234,
- \x1b[94m"qux"\x1b[0m: [
- \x1b[1;36mtrue\x1b[0m,
- \x1b[1;36mfalse\x1b[0m,
- \x1b[1;36mnull\x1b[0m
+ f'''\
+{ob}
+ {t.definition}"foo"{t.reset}: {t.string}"bar"{t.reset},
+ {t.definition}"baz"{t.reset}: {t.number}1234{t.reset},
+ {t.definition}"qux"{t.reset}: [
+ {t.keyword}true{t.reset},
+ {t.keyword}false{t.reset},
+ {t.keyword}null{t.reset}
],
- \x1b[94m"xyz"\x1b[0m: [
- NaN,
- -Infinity,
- Infinity
+ {t.definition}"xyz"{t.reset}: [
+ {t.number}NaN{t.reset},
+ {t.number}-Infinity{t.reset},
+ {t.number}Infinity{t.reset}
]
-}'''),
+{cb}'''),
)
for input_, expected in cases:
with self.subTest(input=input_):
with open(infile, "w", encoding="utf-8") as fp:
fp.write(input_)
- _, stdout, _ = assert_python_ok('-m', self.module, infile,
- PYTHON_COLORS='1')
- stdout = stdout.replace(b'\r\n', b'\n') # normalize line endings
+ _, stdout_b, _ = assert_python_ok(
+ '-m', self.module, infile, FORCE_COLOR='1', __isolated='1'
+ )
+ stdout = stdout_b.decode()
+ stdout = stdout.replace('\r\n', '\n') # normalize line endings
stdout = stdout.strip()
self.assertEqual(stdout, expected)
diff --git a/Lib/test/test_launcher.py b/Lib/test/test_launcher.py
index 173fc743cf6..caa1603c78e 100644
--- a/Lib/test/test_launcher.py
+++ b/Lib/test/test_launcher.py
@@ -443,7 +443,7 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
except subprocess.CalledProcessError:
raise unittest.SkipTest("requires at least one Python 3.x install")
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("3."), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "3.")
def test_search_major_3_32(self):
try:
@@ -453,8 +453,8 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
raise unittest.SkipTest("requires at least one 32-bit Python 3.x install")
raise
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("3."), data["env.tag"])
- self.assertTrue(data["env.tag"].endswith("-32"), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "3.")
+ self.assertEndsWith(data["env.tag"], "-32")
def test_search_major_2(self):
try:
@@ -463,7 +463,7 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
if not is_installed("2.7"):
raise unittest.SkipTest("requires at least one Python 2.x install")
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("2."), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "2.")
def test_py_default(self):
with self.py_ini(TEST_PY_DEFAULTS):
diff --git a/Lib/test/test_linecache.py b/Lib/test/test_linecache.py
index e4aa41ebb43..02f65338428 100644
--- a/Lib/test/test_linecache.py
+++ b/Lib/test/test_linecache.py
@@ -4,10 +4,12 @@ import linecache
import unittest
import os.path
import tempfile
+import threading
import tokenize
from importlib.machinery import ModuleSpec
from test import support
from test.support import os_helper
+from test.support import threading_helper
from test.support.script_helper import assert_python_ok
@@ -374,5 +376,40 @@ class LineCacheInvalidationTests(unittest.TestCase):
self.assertIn(self.unchanged_file, linecache.cache)
+class MultiThreadingTest(unittest.TestCase):
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_read_write_safety(self):
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ filenames = []
+ for i in range(10):
+ name = os.path.join(tmpdirname, f"test_{i}.py")
+ with open(name, "w") as h:
+ h.write("import time\n")
+ h.write("import system\n")
+ filenames.append(name)
+
+ def linecache_get_line(b):
+ b.wait()
+ for _ in range(100):
+ for name in filenames:
+ linecache.getline(name, 1)
+
+ def check(funcs):
+ barrier = threading.Barrier(len(funcs))
+ threads = []
+
+ for func in funcs:
+ thread = threading.Thread(target=func, args=(barrier,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ check([linecache_get_line] * 20)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py
index 6894fba2ad1..223f34fb696 100644
--- a/Lib/test/test_list.py
+++ b/Lib/test/test_list.py
@@ -365,5 +365,20 @@ class ListTest(list_tests.CommonTest):
rc, _, _ = assert_python_ok("-c", code)
self.assertEqual(rc, 0)
+ def test_list_overwrite_local(self):
+ """Test that overwriting the last reference to the
+ iterable doesn't prematurely free the iterable"""
+
+ def foo(x):
+ self.assertEqual(sys.getrefcount(x), 1)
+ r = 0
+ for i in x:
+ r += i
+ x = None
+ return r
+
+ self.assertEqual(foo(list(range(10))), 45)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_listcomps.py b/Lib/test/test_listcomps.py
index cffdeeacc5d..70148dc30fc 100644
--- a/Lib/test/test_listcomps.py
+++ b/Lib/test/test_listcomps.py
@@ -716,7 +716,7 @@ class ListComprehensionTest(unittest.TestCase):
def test_exception_locations(self):
# The location of an exception raised from __init__ or
- # __next__ should should be the iterator expression
+ # __next__ should be the iterator expression
def init_raises():
try:
diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py
index 528ceef5281..55b502e52ca 100644
--- a/Lib/test/test_locale.py
+++ b/Lib/test/test_locale.py
@@ -1,13 +1,18 @@
from decimal import Decimal
-from test.support import verbose, is_android, linked_to_musl, os_helper
+from test.support import cpython_only, verbose, is_android, linked_to_musl, os_helper
from test.support.warnings_helper import check_warnings
-from test.support.import_helper import import_fresh_module
+from test.support.import_helper import ensure_lazy_imports, import_fresh_module
from unittest import mock
import unittest
import locale
import sys
import codecs
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("locale", {"re", "warnings"})
+
class BaseLocalizedTest(unittest.TestCase):
#
@@ -382,6 +387,10 @@ class NormalizeTest(unittest.TestCase):
self.check('c', 'C')
self.check('posix', 'C')
+ def test_c_utf8(self):
+ self.check('c.utf8', 'C.UTF-8')
+ self.check('C.UTF-8', 'C.UTF-8')
+
def test_english(self):
self.check('en', 'en_US.ISO8859-1')
self.check('EN', 'en_US.ISO8859-1')
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index de9108288a7..3819965ed2c 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -61,7 +61,7 @@ import warnings
import weakref
from http.server import HTTPServer, BaseHTTPRequestHandler
-from unittest.mock import patch
+from unittest.mock import call, Mock, patch
from urllib.parse import urlparse, parse_qs
from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
ThreadingTCPServer, StreamRequestHandler)
@@ -1036,7 +1036,7 @@ class TestTCPServer(ControlMixin, ThreadingTCPServer):
"""
allow_reuse_address = True
- allow_reuse_port = True
+ allow_reuse_port = False
def __init__(self, addr, handler, poll_interval=0.5,
bind_and_activate=True):
@@ -5572,12 +5572,19 @@ class BasicConfigTest(unittest.TestCase):
assertRaises = self.assertRaises
handlers = [logging.StreamHandler()]
stream = sys.stderr
+ formatter = logging.Formatter()
assertRaises(ValueError, logging.basicConfig, filename='test.log',
stream=stream)
assertRaises(ValueError, logging.basicConfig, filename='test.log',
handlers=handlers)
assertRaises(ValueError, logging.basicConfig, stream=stream,
handlers=handlers)
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ format='%(message)s')
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ datefmt='%H:%M:%S')
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ style='%')
# Issue 23207: test for invalid kwargs
assertRaises(ValueError, logging.basicConfig, loglevel=logging.INFO)
# Should pop both filename and filemode even if filename is None
@@ -5712,6 +5719,20 @@ class BasicConfigTest(unittest.TestCase):
# didn't write anything due to the encoding error
self.assertEqual(data, r'')
+ def test_formatter_given(self):
+ mock_formatter = Mock()
+ mock_handler = Mock(formatter=None)
+ with patch("logging.Formatter") as mock_formatter_init:
+ logging.basicConfig(formatter=mock_formatter, handlers=[mock_handler])
+ self.assertEqual(mock_handler.setFormatter.call_args_list, [call(mock_formatter)])
+ self.assertEqual(mock_formatter_init.call_count, 0)
+
+ def test_formatter_not_given(self):
+ mock_handler = Mock(formatter=None)
+ with patch("logging.Formatter") as mock_formatter_init:
+ logging.basicConfig(handlers=[mock_handler])
+ self.assertEqual(mock_formatter_init.call_count, 1)
+
@support.requires_working_socket()
def test_log_taskName(self):
async def log_record():
@@ -6740,7 +6761,7 @@ class TimedRotatingFileHandlerTest(BaseFileTest):
rotator = rotators[i]
candidates = rotator.getFilesToDelete()
self.assertEqual(len(candidates), n_files - backupCount, candidates)
- matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\Z")
+ matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\z")
for c in candidates:
d, fn = os.path.split(c)
self.assertStartsWith(fn, prefix+'.')
diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py
index 9ffb93e797d..e93c3c37354 100644
--- a/Lib/test/test_lzma.py
+++ b/Lib/test/test_lzma.py
@@ -1025,12 +1025,12 @@ class FileTestCase(unittest.TestCase):
with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
result = f.peek()
self.assertGreater(len(result), 0)
- self.assertTrue(INPUT.startswith(result))
+ self.assertStartsWith(INPUT, result)
self.assertEqual(f.read(), INPUT)
with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
result = f.peek(10)
self.assertGreater(len(result), 0)
- self.assertTrue(INPUT.startswith(result))
+ self.assertStartsWith(INPUT, result)
self.assertEqual(f.read(), INPUT)
def test_peek_bad_args(self):
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 913a60bf9e0..46cb54647b1 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -475,6 +475,19 @@ class MathTests(unittest.TestCase):
# similarly, copysign(2., NAN) could be 2. or -2.
self.assertEqual(abs(math.copysign(2., NAN)), 2.)
+ def test_signbit(self):
+ self.assertRaises(TypeError, math.signbit)
+ self.assertRaises(TypeError, math.signbit, '1.0')
+
+ # C11, §7.12.3.6 requires signbit() to return a nonzero value
+ # if and only if the sign of its argument value is negative,
+ # but in practice, we are only interested in a boolean value.
+ self.assertIsInstance(math.signbit(1.0), bool)
+
+ for arg in [0., 1., INF, NAN]:
+ self.assertFalse(math.signbit(arg))
+ self.assertTrue(math.signbit(-arg))
+
def testCos(self):
self.assertRaises(TypeError, math.cos)
self.ftest('cos(-pi/2)', math.cos(-math.pi/2), 0, abs_tol=math.ulp(1))
@@ -1214,6 +1227,12 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.ldexp(NINF, n), NINF)
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
+ @requires_IEEE_754
+ def testLdexp_denormal(self):
+ # Denormal output incorrectly rounded (truncated)
+ # on some Windows.
+ self.assertEqual(math.ldexp(6993274598585239, -1126), 1e-323)
+
def testLog(self):
self.assertRaises(TypeError, math.log)
self.assertRaises(TypeError, math.log, 1, 2, 3)
@@ -1381,7 +1400,6 @@ class MathTests(unittest.TestCase):
args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
self.assertEqual(sumprod(*args), 0.0)
-
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sumprod() accuracy not guaranteed on machines with double rounding")
@@ -1967,6 +1985,28 @@ class MathTests(unittest.TestCase):
self.assertFalse(math.isfinite(float("inf")))
self.assertFalse(math.isfinite(float("-inf")))
+ def testIsnormal(self):
+ self.assertTrue(math.isnormal(1.25))
+ self.assertTrue(math.isnormal(-1.0))
+ self.assertFalse(math.isnormal(0.0))
+ self.assertFalse(math.isnormal(-0.0))
+ self.assertFalse(math.isnormal(INF))
+ self.assertFalse(math.isnormal(NINF))
+ self.assertFalse(math.isnormal(NAN))
+ self.assertFalse(math.isnormal(FLOAT_MIN/2))
+ self.assertFalse(math.isnormal(-FLOAT_MIN/2))
+
+ def testIssubnormal(self):
+ self.assertFalse(math.issubnormal(1.25))
+ self.assertFalse(math.issubnormal(-1.0))
+ self.assertFalse(math.issubnormal(0.0))
+ self.assertFalse(math.issubnormal(-0.0))
+ self.assertFalse(math.issubnormal(INF))
+ self.assertFalse(math.issubnormal(NINF))
+ self.assertFalse(math.issubnormal(NAN))
+ self.assertTrue(math.issubnormal(FLOAT_MIN/2))
+ self.assertTrue(math.issubnormal(-FLOAT_MIN/2))
+
def testIsnan(self):
self.assertTrue(math.isnan(float("nan")))
self.assertTrue(math.isnan(float("-nan")))
@@ -2458,7 +2498,6 @@ class MathTests(unittest.TestCase):
with self.assertRaises(ValueError):
math.nextafter(1.0, INF, steps=-1)
-
@requires_IEEE_754
def test_ulp(self):
self.assertEqual(math.ulp(1.0), sys.float_info.epsilon)
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index 95629ed862d..63998a86c45 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -265,8 +265,8 @@ class MemoryTestMixin:
memio = self.ioclass(buf * 10)
self.assertEqual(iter(memio), memio)
- self.assertTrue(hasattr(memio, '__iter__'))
- self.assertTrue(hasattr(memio, '__next__'))
+ self.assertHasAttr(memio, '__iter__')
+ self.assertHasAttr(memio, '__next__')
i = 0
for line in memio:
self.assertEqual(line, buf)
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index 61b068c630c..64f440f180b 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -743,19 +743,21 @@ class RacingTest(unittest.TestCase):
from multiprocessing.managers import SharedMemoryManager
except ImportError:
self.skipTest("Test requires multiprocessing")
- from threading import Thread
+ from threading import Thread, Event
- n = 100
+ start = Event()
with SharedMemoryManager() as smm:
obj = smm.ShareableList(range(100))
- threads = []
- for _ in range(n):
- # Issue gh-127085, the `ShareableList.count` is just a convenient way to mess the `exports`
- # counter of `memoryview`, this issue has no direct relation with `ShareableList`.
- threads.append(Thread(target=obj.count, args=(1,)))
-
+ def test():
+ # Issue gh-127085, the `ShareableList.count` is just a
+ # convenient way to mess the `exports` counter of `memoryview`,
+ # this issue has no direct relation with `ShareableList`.
+ start.wait(support.SHORT_TIMEOUT)
+ for i in range(10):
+ obj.count(1)
+ threads = [Thread(target=test) for _ in range(10)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
del obj
diff --git a/Lib/test/test_mimetypes.py b/Lib/test/test_mimetypes.py
index dad5dbde7cd..fb57d5e5544 100644
--- a/Lib/test/test_mimetypes.py
+++ b/Lib/test/test_mimetypes.py
@@ -6,7 +6,8 @@ import sys
import unittest.mock
from platform import win32_edition
from test import support
-from test.support import os_helper
+from test.support import cpython_only, force_not_colorized, os_helper
+from test.support.import_helper import ensure_lazy_imports
try:
import _winapi
@@ -435,8 +436,13 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, mimetypes)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("mimetypes", {"os", "posixpath", "urllib.parse", "argparse"})
+
class CommandLineTest(unittest.TestCase):
+ @force_not_colorized
def test_parse_args(self):
args, help_text = mimetypes._parse_args("-h")
self.assertTrue(help_text.startswith("usage: "))
diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py
index 6679c0a4fbe..4f25e9c2a03 100644
--- a/Lib/test/test_minidom.py
+++ b/Lib/test/test_minidom.py
@@ -102,41 +102,38 @@ class MinidomTest(unittest.TestCase):
elem = root.childNodes[0]
nelem = dom.createElement("element")
root.insertBefore(nelem, elem)
- self.confirm(len(root.childNodes) == 2
- and root.childNodes.length == 2
- and root.childNodes[0] is nelem
- and root.childNodes.item(0) is nelem
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.firstChild is nelem
- and root.lastChild is elem
- and root.toxml() == "<doc><element/><foo/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 2)
+ self.assertEqual(root.childNodes.length, 2)
+ self.assertIs(root.childNodes[0], nelem)
+ self.assertIs(root.childNodes.item(0), nelem)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.firstChild, nelem)
+ self.assertIs(root.lastChild, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/></doc>")
nelem = dom.createElement("element")
root.insertBefore(nelem, None)
- self.confirm(len(root.childNodes) == 3
- and root.childNodes.length == 3
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.childNodes[2] is nelem
- and root.childNodes.item(2) is nelem
- and root.lastChild is nelem
- and nelem.previousSibling is elem
- and root.toxml() == "<doc><element/><foo/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 3)
+ self.assertEqual(root.childNodes.length, 3)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.childNodes[2], nelem)
+ self.assertIs(root.childNodes.item(2), nelem)
+ self.assertIs(root.lastChild, nelem)
+ self.assertIs(nelem.previousSibling, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/><element/></doc>")
nelem2 = dom.createElement("bar")
root.insertBefore(nelem2, nelem)
- self.confirm(len(root.childNodes) == 4
- and root.childNodes.length == 4
- and root.childNodes[2] is nelem2
- and root.childNodes.item(2) is nelem2
- and root.childNodes[3] is nelem
- and root.childNodes.item(3) is nelem
- and nelem2.nextSibling is nelem
- and nelem.previousSibling is nelem2
- and root.toxml() ==
- "<doc><element/><foo/><bar/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 4)
+ self.assertEqual(root.childNodes.length, 4)
+ self.assertIs(root.childNodes[2], nelem2)
+ self.assertIs(root.childNodes.item(2), nelem2)
+ self.assertIs(root.childNodes[3], nelem)
+ self.assertIs(root.childNodes.item(3), nelem)
+ self.assertIs(nelem2.nextSibling, nelem)
+ self.assertIs(nelem.previousSibling, nelem2)
+ self.assertEqual(root.toxml(),
+ "<doc><element/><foo/><bar/><element/></doc>")
dom.unlink()
def _create_fragment_test_nodes(self):
@@ -342,8 +339,8 @@ class MinidomTest(unittest.TestCase):
self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode,
None)
self.assertIs(node, child.removeAttributeNode(node))
- self.confirm(len(child.attributes) == 0
- and child.getAttributeNode("spam") is None)
+ self.assertEqual(len(child.attributes), 0)
+ self.assertIsNone(child.getAttributeNode("spam"))
dom2 = Document()
child2 = dom2.appendChild(dom2.createElement("foo"))
node2 = child2.getAttributeNode("spam")
@@ -366,33 +363,34 @@ class MinidomTest(unittest.TestCase):
# Set this attribute to be an ID and make sure that doesn't change
# when changing the value:
el.setIdAttribute("spam")
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "bam"
- and el.attributes["spam"].nodeValue == "bam"
- and el.getAttribute("spam") == "bam"
- and el.getAttributeNode("spam").isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "bam")
+ self.assertEqual(el.attributes["spam"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam"), "bam")
+ self.assertTrue(el.getAttributeNode("spam").isId)
el.attributes["spam"] = "ham"
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam"].isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertTrue(el.attributes["spam"].isId)
el.setAttribute("spam2", "bam")
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam"
- and el.attributes["spam2"].nodeValue == "bam"
- and el.getAttribute("spam2") == "bam")
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam2"), "bam")
el.attributes["spam2"] = "bam2"
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam2"
- and el.attributes["spam2"].nodeValue == "bam2"
- and el.getAttribute("spam2") == "bam2")
+
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam2")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam2")
+ self.assertEqual(el.getAttribute("spam2"), "bam2")
dom.unlink()
def testGetAttrList(self):
@@ -448,12 +446,12 @@ class MinidomTest(unittest.TestCase):
dom = parseString(d)
elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom",
"myelem")
- self.confirm(len(elems) == 1
- and elems[0].namespaceURI == "http://pyxml.sf.net/minidom"
- and elems[0].localName == "myelem"
- and elems[0].prefix == "minidom"
- and elems[0].tagName == "minidom:myelem"
- and elems[0].nodeName == "minidom:myelem")
+ self.assertEqual(len(elems), 1)
+ self.assertEqual(elems[0].namespaceURI, "http://pyxml.sf.net/minidom")
+ self.assertEqual(elems[0].localName, "myelem")
+ self.assertEqual(elems[0].prefix, "minidom")
+ self.assertEqual(elems[0].tagName, "minidom:myelem")
+ self.assertEqual(elems[0].nodeName, "minidom:myelem")
dom.unlink()
def get_empty_nodelist_from_elements_by_tagName_ns_helper(self, doc, nsuri,
@@ -602,17 +600,17 @@ class MinidomTest(unittest.TestCase):
def testProcessingInstruction(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
pi = dom.documentElement.firstChild
- self.confirm(pi.target == "mypi"
- and pi.data == "data \t\n "
- and pi.nodeName == "mypi"
- and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
- and pi.attributes is None
- and not pi.hasChildNodes()
- and len(pi.childNodes) == 0
- and pi.firstChild is None
- and pi.lastChild is None
- and pi.localName is None
- and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE)
+ self.assertEqual(pi.target, "mypi")
+ self.assertEqual(pi.data, "data \t\n ")
+ self.assertEqual(pi.nodeName, "mypi")
+ self.assertEqual(pi.nodeType, Node.PROCESSING_INSTRUCTION_NODE)
+ self.assertIsNone(pi.attributes)
+ self.assertFalse(pi.hasChildNodes())
+ self.assertEqual(len(pi.childNodes), 0)
+ self.assertIsNone(pi.firstChild)
+ self.assertIsNone(pi.lastChild)
+ self.assertIsNone(pi.localName)
+ self.assertEqual(pi.namespaceURI, xml.dom.EMPTY_NAMESPACE)
def testProcessingInstructionRepr(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
@@ -718,19 +716,16 @@ class MinidomTest(unittest.TestCase):
keys2 = list(attrs2.keys())
keys1.sort()
keys2.sort()
- self.assertEqual(keys1, keys2,
- "clone of element has same attribute keys")
+ self.assertEqual(keys1, keys2)
for i in range(len(keys1)):
a1 = attrs1.item(i)
a2 = attrs2.item(i)
- self.confirm(a1 is not a2
- and a1.value == a2.value
- and a1.nodeValue == a2.nodeValue
- and a1.namespaceURI == a2.namespaceURI
- and a1.localName == a2.localName
- , "clone of attribute node has proper attribute values")
- self.assertIs(a2.ownerElement, e2,
- "clone of attribute node correctly owned")
+ self.assertIsNot(a1, a2)
+ self.assertEqual(a1.value, a2.value)
+ self.assertEqual(a1.nodeValue, a2.nodeValue)
+ self.assertEqual(a1.namespaceURI,a2.namespaceURI)
+ self.assertEqual(a1.localName, a2.localName)
+ self.assertIs(a2.ownerElement, e2)
def _setupCloneElement(self, deep):
dom = parseString("<doc attr='value'><foo/></doc>")
@@ -746,20 +741,19 @@ class MinidomTest(unittest.TestCase):
def testCloneElementShallow(self):
dom, clone = self._setupCloneElement(0)
- self.confirm(len(clone.childNodes) == 0
- and clone.childNodes.length == 0
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"/>'
- , "testCloneElementShallow")
+ self.assertEqual(len(clone.childNodes), 0)
+ self.assertEqual(clone.childNodes.length, 0)
+ self.assertIsNone(clone.parentNode)
+ self.assertEqual(clone.toxml(), '<doc attr="value"/>')
+
dom.unlink()
def testCloneElementDeep(self):
dom, clone = self._setupCloneElement(1)
- self.confirm(len(clone.childNodes) == 1
- and clone.childNodes.length == 1
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"><foo/></doc>'
- , "testCloneElementDeep")
+ self.assertEqual(len(clone.childNodes), 1)
+ self.assertEqual(clone.childNodes.length, 1)
+ self.assertIsNone(clone.parentNode)
+ self.assertTrue(clone.toxml(), '<doc attr="value"><foo/></doc>')
dom.unlink()
def testCloneDocumentShallow(self):
diff --git a/Lib/test/test_monitoring.py b/Lib/test/test_monitoring.py
index 263e4e6f394..a932ac80117 100644
--- a/Lib/test/test_monitoring.py
+++ b/Lib/test/test_monitoring.py
@@ -2157,6 +2157,21 @@ class TestRegressions(MonitoringTestBase, unittest.TestCase):
sys.monitoring.restart_events()
sys.monitoring.set_events(0, 0)
+ def test_134879(self):
+ # gh-134789
+ # Specialized FOR_ITER not incrementing index
+ def foo():
+ t = 0
+ for i in [1,2,3,4]:
+ t += i
+ self.assertEqual(t, 10)
+
+ sys.monitoring.use_tool_id(0, "test")
+ self.addCleanup(sys.monitoring.free_tool_id, 0)
+ sys.monitoring.set_local_events(0, foo.__code__, E.BRANCH_LEFT | E.BRANCH_RIGHT)
+ foo()
+ sys.monitoring.set_local_events(0, foo.__code__, 0)
+
class TestOptimizer(MonitoringTestBase, unittest.TestCase):
diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py
index 1b55f1e70b3..d7a233377bd 100644
--- a/Lib/test/test_multibytecodec.py
+++ b/Lib/test/test_multibytecodec.py
@@ -314,7 +314,8 @@ class Test_StreamReader(unittest.TestCase):
f.write(b'\xa1')
finally:
f.close()
- f = codecs.open(TESTFN, encoding='cp949')
+ with self.assertWarns(DeprecationWarning):
+ f = codecs.open(TESTFN, encoding='cp949')
try:
self.assertRaises(UnicodeDecodeError, f.read, 2)
finally:
diff --git a/Lib/test/test_netrc.py b/Lib/test/test_netrc.py
index 81e11a293cc..9d720f62710 100644
--- a/Lib/test/test_netrc.py
+++ b/Lib/test/test_netrc.py
@@ -1,11 +1,7 @@
import netrc, os, unittest, sys, textwrap
+from test import support
from test.support import os_helper
-try:
- import pwd
-except ImportError:
- pwd = None
-
temp_filename = os_helper.TESTFN
class NetrcTestCase(unittest.TestCase):
@@ -269,9 +265,14 @@ class NetrcTestCase(unittest.TestCase):
machine bar.domain.com login foo password pass
""", '#pass')
+ @unittest.skipUnless(support.is_wasi, 'WASI only test')
+ def test_security_on_WASI(self):
+ self.assertFalse(netrc._can_security_check())
+ self.assertEqual(netrc._getpwuid(0), 'uid 0')
+ self.assertEqual(netrc._getpwuid(123456), 'uid 123456')
@unittest.skipUnless(os.name == 'posix', 'POSIX only test')
- @unittest.skipIf(pwd is None, 'security check requires pwd module')
+ @unittest.skipUnless(hasattr(os, 'getuid'), "os.getuid is required")
@os_helper.skip_unless_working_chmod
def test_security(self):
# This test is incomplete since we are normally not run as root and
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index c10387b58e3..22f6403d482 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -6,8 +6,9 @@ import subprocess
import sys
import unittest
import warnings
-from test.support import cpython_only, os_helper
-from test.support import TestFailed, is_emscripten
+from ntpath import ALLOW_MISSING
+from test import support
+from test.support import TestFailed, cpython_only, os_helper
from test.support.os_helper import FakePath
from test import test_genericpath
from tempfile import TemporaryFile
@@ -77,6 +78,10 @@ def tester(fn, wantResult):
%(str(fn), str(wantResult), repr(gotResult)))
+def _parameterize(*parameters):
+ return support.subTests('kwargs', parameters, _do_cleanups=True)
+
+
class NtpathTestCase(unittest.TestCase):
def assertPathEqual(self, path1, path2):
if path1 == path2 or _norm(path1) == _norm(path2):
@@ -124,6 +129,22 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.splitdrive("//?/UNC/server/share/dir")',
("//?/UNC/server/share", "/dir"))
+ def test_splitdrive_invalid_paths(self):
+ splitdrive = ntpath.splitdrive
+ self.assertEqual(splitdrive('\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ ('\\\\ser\x00ver\\sha\x00re', '\\di\x00r'))
+ self.assertEqual(splitdrive(b'\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ (b'\\\\ser\x00ver\\sha\x00re', b'\\di\x00r'))
+ self.assertEqual(splitdrive("\\\\\udfff\\\udffe\\\udffd"),
+ ('\\\\\udfff\\\udffe', '\\\udffd'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\\xff\\share\\dir')
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\server\\\xff\\dir')
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\server\\share\\\xff')
+ else:
+ self.assertEqual(splitdrive(b'\\\\\xff\\\xfe\\\xfd'),
+ (b'\\\\\xff\\\xfe', b'\\\xfd'))
+
def test_splitroot(self):
tester("ntpath.splitroot('')", ('', '', ''))
tester("ntpath.splitroot('foo')", ('', '', 'foo'))
@@ -214,6 +235,22 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.splitroot(" :/foo")', (" :", "/", "foo"))
tester('ntpath.splitroot("/:/foo")', ("", "/", ":/foo"))
+ def test_splitroot_invalid_paths(self):
+ splitroot = ntpath.splitroot
+ self.assertEqual(splitroot('\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ ('\\\\ser\x00ver\\sha\x00re', '\\', 'di\x00r'))
+ self.assertEqual(splitroot(b'\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ (b'\\\\ser\x00ver\\sha\x00re', b'\\', b'di\x00r'))
+ self.assertEqual(splitroot("\\\\\udfff\\\udffe\\\udffd"),
+ ('\\\\\udfff\\\udffe', '\\', '\udffd'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\\xff\\share\\dir')
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\server\\\xff\\dir')
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\server\\share\\\xff')
+ else:
+ self.assertEqual(splitroot(b'\\\\\xff\\\xfe\\\xfd'),
+ (b'\\\\\xff\\\xfe', b'\\', b'\xfd'))
+
def test_split(self):
tester('ntpath.split("c:\\foo\\bar")', ('c:\\foo', 'bar'))
tester('ntpath.split("\\\\conky\\mountpoint\\foo\\bar")',
@@ -226,6 +263,21 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.split("c:/")', ('c:/', ''))
tester('ntpath.split("//conky/mountpoint/")', ('//conky/mountpoint/', ''))
+ def test_split_invalid_paths(self):
+ split = ntpath.split
+ self.assertEqual(split('c:\\fo\x00o\\ba\x00r'),
+ ('c:\\fo\x00o', 'ba\x00r'))
+ self.assertEqual(split(b'c:\\fo\x00o\\ba\x00r'),
+ (b'c:\\fo\x00o', b'ba\x00r'))
+ self.assertEqual(split('c:\\\udfff\\\udffe'),
+ ('c:\\\udfff', '\udffe'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, split, b'c:\\\xff\\bar')
+ self.assertRaises(UnicodeDecodeError, split, b'c:\\foo\\\xff')
+ else:
+ self.assertEqual(split(b'c:\\\xff\\\xfe'),
+ (b'c:\\\xff', b'\xfe'))
+
def test_isabs(self):
tester('ntpath.isabs("foo\\bar")', 0)
tester('ntpath.isabs("foo/bar")', 0)
@@ -333,6 +385,30 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.join('D:a', './c:b')", 'D:a\\.\\c:b')
tester("ntpath.join('D:/a', './c:b')", 'D:\\a\\.\\c:b')
+ def test_normcase(self):
+ normcase = ntpath.normcase
+ self.assertEqual(normcase(''), '')
+ self.assertEqual(normcase(b''), b'')
+ self.assertEqual(normcase('ABC'), 'abc')
+ self.assertEqual(normcase(b'ABC'), b'abc')
+ self.assertEqual(normcase('\xc4\u0141\u03a8'), '\xe4\u0142\u03c8')
+ expected = '\u03c9\u2126' if sys.platform == 'win32' else '\u03c9\u03c9'
+ self.assertEqual(normcase('\u03a9\u2126'), expected)
+ if sys.platform == 'win32' or sys.getfilesystemencoding() == 'utf-8':
+ self.assertEqual(normcase('\xc4\u0141\u03a8'.encode()),
+ '\xe4\u0142\u03c8'.encode())
+ self.assertEqual(normcase('\u03a9\u2126'.encode()),
+ expected.encode())
+
+ def test_normcase_invalid_paths(self):
+ normcase = ntpath.normcase
+ self.assertEqual(normcase('abc\x00def'), 'abc\x00def')
+ self.assertEqual(normcase(b'abc\x00def'), b'abc\x00def')
+ self.assertEqual(normcase('\udfff'), '\udfff')
+ if sys.platform == 'win32':
+ path = b'ABC' + bytes(range(128, 256))
+ self.assertEqual(normcase(path), path.lower())
+
def test_normpath(self):
tester("ntpath.normpath('A//////././//.//B')", r'A\B')
tester("ntpath.normpath('A/./B')", r'A\B')
@@ -381,6 +457,21 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.normpath('\\\\')", '\\\\')
tester("ntpath.normpath('//?/UNC/server/share/..')", '\\\\?\\UNC\\server\\share\\')
+ def test_normpath_invalid_paths(self):
+ normpath = ntpath.normpath
+ self.assertEqual(normpath('fo\x00o'), 'fo\x00o')
+ self.assertEqual(normpath(b'fo\x00o'), b'fo\x00o')
+ self.assertEqual(normpath('fo\x00o\\..\\bar'), 'bar')
+ self.assertEqual(normpath(b'fo\x00o\\..\\bar'), b'bar')
+ self.assertEqual(normpath('\udfff'), '\udfff')
+ self.assertEqual(normpath('\udfff\\..\\foo'), 'foo')
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, normpath, b'\xff')
+ self.assertRaises(UnicodeDecodeError, normpath, b'\xff\\..\\foo')
+ else:
+ self.assertEqual(normpath(b'\xff'), b'\xff')
+ self.assertEqual(normpath(b'\xff\\..\\foo'), b'foo')
+
def test_realpath_curdir(self):
expected = ntpath.normpath(os.getcwd())
tester("ntpath.realpath('.')", expected)
@@ -389,6 +480,27 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.realpath('.\\.')", expected)
tester("ntpath.realpath('\\'.join(['.'] * 100))", expected)
+ def test_realpath_curdir_strict(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('.', strict=True)", expected)
+ tester("ntpath.realpath('./.', strict=True)", expected)
+ tester("ntpath.realpath('/'.join(['.'] * 100), strict=True)", expected)
+ tester("ntpath.realpath('.\\.', strict=True)", expected)
+ tester("ntpath.realpath('\\'.join(['.'] * 100), strict=True)", expected)
+
+ def test_realpath_curdir_missing_ok(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('./.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('/'.join(['.'] * 100), strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('.\\.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('\\'.join(['.'] * 100), strict=ALLOW_MISSING)",
+ expected)
+
def test_realpath_pardir(self):
expected = ntpath.normpath(os.getcwd())
tester("ntpath.realpath('..')", ntpath.dirname(expected))
@@ -401,28 +513,59 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.realpath('\\'.join(['..'] * 50))",
ntpath.splitdrive(expected)[0] + '\\')
+ def test_realpath_pardir_strict(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('..', strict=True)", ntpath.dirname(expected))
+ tester("ntpath.realpath('../..', strict=True)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('/'.join(['..'] * 50), strict=True)",
+ ntpath.splitdrive(expected)[0] + '\\')
+ tester("ntpath.realpath('..\\..', strict=True)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('\\'.join(['..'] * 50), strict=True)",
+ ntpath.splitdrive(expected)[0] + '\\')
+
+ def test_realpath_pardir_missing_ok(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('..', strict=ALLOW_MISSING)",
+ ntpath.dirname(expected))
+ tester("ntpath.realpath('../..', strict=ALLOW_MISSING)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('/'.join(['..'] * 50), strict=ALLOW_MISSING)",
+ ntpath.splitdrive(expected)[0] + '\\')
+ tester("ntpath.realpath('..\\..', strict=ALLOW_MISSING)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('\\'.join(['..'] * 50), strict=ALLOW_MISSING)",
+ ntpath.splitdrive(expected)[0] + '\\')
+
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_basic(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_basic(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
open(ABSTFN, "wb").close()
self.addCleanup(os_helper.unlink, ABSTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "1")
os.symlink(ABSTFN, ABSTFN + "1")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "1"), ABSTFN)
- self.assertPathEqual(ntpath.realpath(os.fsencode(ABSTFN + "1")),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1", **kwargs), ABSTFN)
+ self.assertPathEqual(ntpath.realpath(os.fsencode(ABSTFN + "1"), **kwargs),
os.fsencode(ABSTFN))
# gh-88013: call ntpath.realpath with binary drive name may raise a
# TypeError. The drive should not exist to reproduce the bug.
drives = {f"{c}:\\" for c in string.ascii_uppercase} - set(os.listdrives())
d = drives.pop().encode()
- self.assertEqual(ntpath.realpath(d), d)
+ self.assertEqual(ntpath.realpath(d, strict=False), d)
# gh-106242: Embedded nulls and non-strict fallback to abspath
- self.assertEqual(ABSTFN + "\0spam",
- ntpath.realpath(os_helper.TESTFN + "\0spam", strict=False))
+ if kwargs:
+ with self.assertRaises(OSError):
+ ntpath.realpath(os_helper.TESTFN + "\0spam",
+ **kwargs)
+ else:
+ self.assertEqual(ABSTFN + "\0spam",
+ ntpath.realpath(os_helper.TESTFN + "\0spam", **kwargs))
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@@ -434,19 +577,77 @@ class TestNtpath(NtpathTestCase):
self.addCleanup(os_helper.unlink, ABSTFN)
self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN, strict=True)
self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN + "2", strict=True)
+
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ def test_realpath_invalid_paths(self):
+ realpath = ntpath.realpath
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ ABSTFNb = os.fsencode(ABSTFN)
+ path = ABSTFN + '\x00'
+ # gh-106242: Embedded nulls and non-strict fallback to abspath
+ self.assertEqual(realpath(path, strict=False), path)
# gh-106242: Embedded nulls should raise OSError (not ValueError)
- self.assertRaises(OSError, ntpath.realpath, ABSTFN + "\0spam", strict=True)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFNb + b'\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFN + '\\nonexistent\\x\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFNb + b'\\nonexistent\\x\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFN + '\x00\\..'
+ self.assertEqual(realpath(path, strict=False), os.getcwd())
+ self.assertEqual(realpath(path, strict=True), os.getcwd())
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), os.getcwd())
+ path = ABSTFNb + b'\x00\\..'
+ self.assertEqual(realpath(path, strict=False), os.getcwdb())
+ self.assertEqual(realpath(path, strict=True), os.getcwdb())
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), os.getcwdb())
+ path = ABSTFN + '\\nonexistent\\x\x00\\..'
+ self.assertEqual(realpath(path, strict=False), ABSTFN + '\\nonexistent')
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), ABSTFN + '\\nonexistent')
+ path = ABSTFNb + b'\\nonexistent\\x\x00\\..'
+ self.assertEqual(realpath(path, strict=False), ABSTFNb + b'\\nonexistent')
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), ABSTFNb + b'\\nonexistent')
+
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_invalid_unicode_paths(self, kwargs):
+ realpath = ntpath.realpath
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ ABSTFNb = os.fsencode(ABSTFN)
+ path = ABSTFNb + b'\xff'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\\nonexistent\\\xff'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\xff\\..'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\\nonexistent\\\xff\\..'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_relative(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_relative(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
open(ABSTFN, "wb").close()
self.addCleanup(os_helper.unlink, ABSTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "1")
os.symlink(ABSTFN, ntpath.relpath(ABSTFN + "1"))
- self.assertPathEqual(ntpath.realpath(ABSTFN + "1"), ABSTFN)
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1", **kwargs), ABSTFN)
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@@ -598,7 +799,62 @@ class TestNtpath(NtpathTestCase):
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_symlink_prefix(self):
+ def test_realpath_symlink_loops_raise(self):
+ # Symlink loops raise OSError in ALLOW_MISSING mode
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ self.addCleanup(os_helper.unlink, ABSTFN)
+ self.addCleanup(os_helper.unlink, ABSTFN + "1")
+ self.addCleanup(os_helper.unlink, ABSTFN + "2")
+ self.addCleanup(os_helper.unlink, ABSTFN + "y")
+ self.addCleanup(os_helper.unlink, ABSTFN + "c")
+ self.addCleanup(os_helper.unlink, ABSTFN + "a")
+ self.addCleanup(os_helper.unlink, ABSTFN + "x")
+
+ os.symlink(ABSTFN, ABSTFN)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN, strict=ALLOW_MISSING)
+
+ os.symlink(ABSTFN + "1", ABSTFN + "2")
+ os.symlink(ABSTFN + "2", ABSTFN + "1")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "1",
+ strict=ALLOW_MISSING)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "2",
+ strict=ALLOW_MISSING)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "1\\x",
+ strict=ALLOW_MISSING)
+
+ # Windows eliminates '..' components before resolving links;
+ # realpath is not expected to raise if this removes the loop.
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\.."),
+ ntpath.dirname(ABSTFN))
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\..\\x"),
+ ntpath.dirname(ABSTFN) + "\\x")
+
+ os.symlink(ABSTFN + "x", ABSTFN + "y")
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\..\\"
+ + ntpath.basename(ABSTFN) + "y"),
+ ABSTFN + "x")
+ self.assertRaises(
+ OSError, ntpath.realpath,
+ ABSTFN + "1\\..\\" + ntpath.basename(ABSTFN) + "1",
+ strict=ALLOW_MISSING)
+
+ os.symlink(ntpath.basename(ABSTFN) + "a\\b", ABSTFN + "a")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "a",
+ strict=ALLOW_MISSING)
+
+ os.symlink("..\\" + ntpath.basename(ntpath.dirname(ABSTFN))
+ + "\\" + ntpath.basename(ABSTFN) + "c", ABSTFN + "c")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "c",
+ strict=ALLOW_MISSING)
+
+ # Test using relative path as well.
+ self.assertRaises(OSError, ntpath.realpath, ntpath.basename(ABSTFN),
+ strict=ALLOW_MISSING)
+
+ @os_helper.skip_unless_symlink
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_symlink_prefix(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "3")
self.addCleanup(os_helper.unlink, "\\\\?\\" + ABSTFN + "3.")
@@ -613,9 +869,9 @@ class TestNtpath(NtpathTestCase):
f.write(b'1')
os.symlink("\\\\?\\" + ABSTFN + "3.", ABSTFN + "3.link")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "3link"),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "3link", **kwargs),
ABSTFN + "3")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "3.link"),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "3.link", **kwargs),
"\\\\?\\" + ABSTFN + "3.")
# Resolved paths should be usable to open target files
@@ -625,14 +881,17 @@ class TestNtpath(NtpathTestCase):
self.assertEqual(f.read(), b'1')
# When the prefix is included, it is not stripped
- self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3link"),
+ self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3link", **kwargs),
"\\\\?\\" + ABSTFN + "3")
- self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3.link"),
+ self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3.link", **kwargs),
"\\\\?\\" + ABSTFN + "3.")
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
def test_realpath_nul(self):
tester("ntpath.realpath('NUL')", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=False)", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=True)", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=ALLOW_MISSING)", r'\\.\NUL')
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@unittest.skipUnless(HAVE_GETSHORTPATHNAME, 'need _getshortpathname')
@@ -656,12 +915,20 @@ class TestNtpath(NtpathTestCase):
self.assertPathEqual(test_file_long, ntpath.realpath(test_file_short))
- with os_helper.change_cwd(test_dir_long):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
- with os_helper.change_cwd(test_dir_long.lower()):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
- with os_helper.change_cwd(test_dir_short):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
+ for kwargs in {}, {'strict': True}, {'strict': ALLOW_MISSING}:
+ with self.subTest(**kwargs):
+ with os_helper.change_cwd(test_dir_long):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
+ with os_helper.change_cwd(test_dir_long.lower()):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
+ with os_helper.change_cwd(test_dir_short):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
def test_realpath_permission(self):
@@ -682,12 +949,15 @@ class TestNtpath(NtpathTestCase):
# Automatic generation of short names may be disabled on
# NTFS volumes for the sake of performance.
# They're not supported at all on ReFS and exFAT.
- subprocess.run(
+ p = subprocess.run(
# Try to set the short name manually.
['fsutil.exe', 'file', 'setShortName', test_file, 'LONGFI~1.TXT'],
creationflags=subprocess.DETACHED_PROCESS
)
+ if p.returncode:
+ raise unittest.SkipTest('failed to set short name')
+
try:
self.assertPathEqual(test_file, ntpath.realpath(test_file_short))
except AssertionError:
@@ -812,8 +1082,6 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.abspath("C:/nul")', "\\\\.\\nul")
tester('ntpath.abspath("C:\\nul")', "\\\\.\\nul")
self.assertTrue(ntpath.isabs(ntpath.abspath("C:spam")))
- self.assertEqual(ntpath.abspath("C:\x00"), ntpath.join(ntpath.abspath("C:"), "\x00"))
- self.assertEqual(ntpath.abspath("\x00:spam"), "\x00:\\spam")
tester('ntpath.abspath("//..")', "\\\\")
tester('ntpath.abspath("//../")', "\\\\..\\")
tester('ntpath.abspath("//../..")', "\\\\..\\")
@@ -847,6 +1115,26 @@ class TestNtpath(NtpathTestCase):
drive, _ = ntpath.splitdrive(cwd_dir)
tester('ntpath.abspath("/abc/")', drive + "\\abc")
+ def test_abspath_invalid_paths(self):
+ abspath = ntpath.abspath
+ if sys.platform == 'win32':
+ self.assertEqual(abspath("C:\x00"), ntpath.join(abspath("C:"), "\x00"))
+ self.assertEqual(abspath(b"C:\x00"), ntpath.join(abspath(b"C:"), b"\x00"))
+ self.assertEqual(abspath("\x00:spam"), "\x00:\\spam")
+ self.assertEqual(abspath(b"\x00:spam"), b"\x00:\\spam")
+ self.assertEqual(abspath('c:\\fo\x00o'), 'c:\\fo\x00o')
+ self.assertEqual(abspath(b'c:\\fo\x00o'), b'c:\\fo\x00o')
+ self.assertEqual(abspath('c:\\fo\x00o\\..\\bar'), 'c:\\bar')
+ self.assertEqual(abspath(b'c:\\fo\x00o\\..\\bar'), b'c:\\bar')
+ self.assertEqual(abspath('c:\\\udfff'), 'c:\\\udfff')
+ self.assertEqual(abspath('c:\\\udfff\\..\\foo'), 'c:\\foo')
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, abspath, b'c:\\\xff')
+ self.assertRaises(UnicodeDecodeError, abspath, b'c:\\\xff\\..\\foo')
+ else:
+ self.assertEqual(abspath(b'c:\\\xff'), b'c:\\\xff')
+ self.assertEqual(abspath(b'c:\\\xff\\..\\foo'), b'c:\\foo')
+
def test_relpath(self):
tester('ntpath.relpath("a")', 'a')
tester('ntpath.relpath(ntpath.abspath("a"))', 'a')
@@ -989,6 +1277,18 @@ class TestNtpath(NtpathTestCase):
self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$"))
self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$\\"))
+ def test_ismount_invalid_paths(self):
+ ismount = ntpath.ismount
+ self.assertFalse(ismount("c:\\\udfff"))
+ if sys.platform == 'win32':
+ self.assertRaises(ValueError, ismount, "c:\\\x00")
+ self.assertRaises(ValueError, ismount, b"c:\\\x00")
+ self.assertRaises(UnicodeDecodeError, ismount, b"c:\\\xff")
+ else:
+ self.assertFalse(ismount("c:\\\x00"))
+ self.assertFalse(ismount(b"c:\\\x00"))
+ self.assertFalse(ismount(b"c:\\\xff"))
+
def test_isreserved(self):
self.assertFalse(ntpath.isreserved(''))
self.assertFalse(ntpath.isreserved('.'))
@@ -1095,6 +1395,13 @@ class TestNtpath(NtpathTestCase):
self.assertFalse(ntpath.isjunction('tmpdir'))
self.assertPathEqual(ntpath.realpath('testjunc'), ntpath.realpath('tmpdir'))
+ def test_isfile_invalid_paths(self):
+ isfile = ntpath.isfile
+ self.assertIs(isfile('/tmp\udfffabcds'), False)
+ self.assertIs(isfile(b'/tmp\xffabcds'), False)
+ self.assertIs(isfile('/tmp\x00abcds'), False)
+ self.assertIs(isfile(b'/tmp\x00abcds'), False)
+
@unittest.skipIf(sys.platform != 'win32', "drive letters are a windows concept")
def test_isfile_driveletter(self):
drive = os.environ.get('SystemDrive')
@@ -1195,9 +1502,6 @@ class PathLikeTests(NtpathTestCase):
def test_path_normcase(self):
self._check_function(self.path.normcase)
- if sys.platform == 'win32':
- self.assertEqual(ntpath.normcase('\u03a9\u2126'), 'ωΩ')
- self.assertEqual(ntpath.normcase('abc\x00def'), 'abc\x00def')
def test_path_isabs(self):
self._check_function(self.path.isabs)
diff --git a/Lib/test/test_opcache.py b/Lib/test/test_opcache.py
index 21d7e62833c..30baa090486 100644
--- a/Lib/test/test_opcache.py
+++ b/Lib/test/test_opcache.py
@@ -560,6 +560,13 @@ class TestCallCache(TestBase):
with self.assertRaises(TypeError):
instantiate()
+ def test_recursion_check_for_general_calls(self):
+ def test(default=None):
+ return test()
+
+ with self.assertRaises(RecursionError):
+ test()
+
def make_deferred_ref_count_obj():
"""Create an object that uses deferred reference counting.
@@ -1803,20 +1810,6 @@ class TestSpecializer(TestBase):
self.assert_specialized(compare_op_str, "COMPARE_OP_STR")
self.assert_no_opcode(compare_op_str, "COMPARE_OP")
- @cpython_only
- @requires_specialization_ft
- def test_load_const(self):
- def load_const():
- def unused(): pass
- # Currently, the empty tuple is immortal, and the otherwise
- # unused nested function's code object is mortal. This test will
- # have to use different values if either of that changes.
- return ()
-
- load_const()
- self.assert_specialized(load_const, "LOAD_CONST_IMMORTAL")
- self.assert_specialized(load_const, "LOAD_CONST_MORTAL")
- self.assert_no_opcode(load_const, "LOAD_CONST")
@cpython_only
@requires_specialization_ft
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
index 1757824580e..1f89986c777 100644
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -636,6 +636,7 @@ class OperatorTestCase:
if dunder:
self.assertIs(dunder, orig)
+ @support.requires_docstrings
def test_attrgetter_signature(self):
operator = self.module
sig = inspect.signature(operator.attrgetter)
@@ -643,6 +644,7 @@ class OperatorTestCase:
sig = inspect.signature(operator.attrgetter('x', 'z', 'y'))
self.assertEqual(str(sig), '(obj, /)')
+ @support.requires_docstrings
def test_itemgetter_signature(self):
operator = self.module
sig = inspect.signature(operator.itemgetter)
@@ -650,6 +652,7 @@ class OperatorTestCase:
sig = inspect.signature(operator.itemgetter(2, 3, 5))
self.assertEqual(str(sig), '(obj, /)')
+ @support.requires_docstrings
def test_methodcaller_signature(self):
operator = self.module
sig = inspect.signature(operator.methodcaller)
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 8655a0537a5..e476e472780 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -14,8 +14,9 @@ import unittest
from io import StringIO
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
from test.support.i18n_helper import TestTranslationsBase, update_translation_snapshots
+from test.support.import_helper import ensure_lazy_imports
import optparse
from optparse import make_option, Option, \
@@ -614,9 +615,9 @@ Options:
self.parser.add_option(
"-p", "--prob",
help="blow up with probability PROB [default: %default]")
- self.parser.set_defaults(prob=0.43)
+ self.parser.set_defaults(prob=0.25)
expected_help = self.help_prefix + \
- " -p PROB, --prob=PROB blow up with probability PROB [default: 0.43]\n"
+ " -p PROB, --prob=PROB blow up with probability PROB [default: 0.25]\n"
self.assertHelp(self.parser, expected_help)
def test_alt_expand(self):
@@ -1655,6 +1656,10 @@ class MiscTestCase(unittest.TestCase):
not_exported = {'check_builtin', 'AmbiguousOptionError', 'NO_DEFAULT'}
support.check__all__(self, optparse, not_exported=not_exported)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("optparse", {"textwrap"})
+
class TestTranslations(TestTranslationsBase):
def test_translations(self):
diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py
index 9f131a9110d..4204a6a47d2 100644
--- a/Lib/test/test_ordered_dict.py
+++ b/Lib/test/test_ordered_dict.py
@@ -147,7 +147,7 @@ class OrderedDictTests:
def test_abc(self):
OrderedDict = self.OrderedDict
self.assertIsInstance(OrderedDict(), MutableMapping)
- self.assertTrue(issubclass(OrderedDict, MutableMapping))
+ self.assertIsSubclass(OrderedDict, MutableMapping)
def test_clear(self):
OrderedDict = self.OrderedDict
@@ -314,14 +314,14 @@ class OrderedDictTests:
check(dup)
self.assertIs(dup.x, od.x)
self.assertIs(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
dup = copy.deepcopy(od)
check(dup)
self.assertEqual(dup.x, od.x)
self.assertIsNot(dup.x, od.x)
self.assertEqual(dup.z, od.z)
self.assertIsNot(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
# pickle directly pulls the module, so we have to fake it
with replaced_module('collections', self.module):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -330,7 +330,7 @@ class OrderedDictTests:
check(dup)
self.assertEqual(dup.x, od.x)
self.assertEqual(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
check(eval(repr(od)))
update_test = OrderedDict()
update_test.update(od)
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 333179a71e3..5217037ae9d 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -13,7 +13,6 @@ import itertools
import locale
import os
import pickle
-import platform
import select
import selectors
import shutil
@@ -818,7 +817,7 @@ class StatAttributeTests(unittest.TestCase):
self.assertEqual(ctx.exception.errno, errno.EBADF)
def check_file_attributes(self, result):
- self.assertTrue(hasattr(result, 'st_file_attributes'))
+ self.assertHasAttr(result, 'st_file_attributes')
self.assertTrue(isinstance(result.st_file_attributes, int))
self.assertTrue(0 <= result.st_file_attributes <= 0xFFFFFFFF)
@@ -1919,6 +1918,10 @@ class MakedirTests(unittest.TestCase):
support.is_wasi,
"WASI's umask is a stub."
)
+ @unittest.skipIf(
+ support.is_emscripten,
+ "TODO: Fails in buildbot; see #135783"
+ )
def test_mode(self):
with os_helper.temp_umask(0o002):
base = os_helper.TESTFN
@@ -2181,7 +2184,7 @@ class GetRandomTests(unittest.TestCase):
self.assertEqual(empty, b'')
def test_getrandom_random(self):
- self.assertTrue(hasattr(os, 'GRND_RANDOM'))
+ self.assertHasAttr(os, 'GRND_RANDOM')
# Don't test os.getrandom(1, os.GRND_RANDOM) to not consume the rare
# resource /dev/random
@@ -4291,13 +4294,8 @@ class EventfdTests(unittest.TestCase):
@unittest.skipIf(sys.platform == "android", "gh-124873: Test is flaky on Android")
@support.requires_linux_version(2, 6, 30)
class TimerfdTests(unittest.TestCase):
- # 1 ms accuracy is reliably achievable on every platform except Android
- # emulators, where we allow 10 ms (gh-108277).
- if sys.platform == "android" and platform.android_ver().is_emulator:
- CLOCK_RES_PLACES = 2
- else:
- CLOCK_RES_PLACES = 3
-
+ # gh-126112: Use 10 ms to tolerate slow buildbots
+ CLOCK_RES_PLACES = 2 # 10 ms
CLOCK_RES = 10 ** -CLOCK_RES_PLACES
CLOCK_RES_NS = 10 ** (9 - CLOCK_RES_PLACES)
@@ -5431,8 +5429,8 @@ class TestPEP519(unittest.TestCase):
def test_pathlike(self):
self.assertEqual('#feelthegil', self.fspath(FakePath('#feelthegil')))
- self.assertTrue(issubclass(FakePath, os.PathLike))
- self.assertTrue(isinstance(FakePath('x'), os.PathLike))
+ self.assertIsSubclass(FakePath, os.PathLike)
+ self.assertIsInstance(FakePath('x'), os.PathLike)
def test_garbage_in_exception_out(self):
vapor = type('blah', (), {})
@@ -5458,8 +5456,8 @@ class TestPEP519(unittest.TestCase):
# true on abstract implementation.
class A(os.PathLike):
pass
- self.assertFalse(issubclass(FakePath, A))
- self.assertTrue(issubclass(FakePath, os.PathLike))
+ self.assertNotIsSubclass(FakePath, A)
+ self.assertIsSubclass(FakePath, os.PathLike)
def test_pathlike_class_getitem(self):
self.assertIsInstance(os.PathLike[bytes], types.GenericAlias)
@@ -5469,7 +5467,7 @@ class TestPEP519(unittest.TestCase):
__slots__ = ()
def __fspath__(self):
return ''
- self.assertFalse(hasattr(A(), '__dict__'))
+ self.assertNotHasAttr(A(), '__dict__')
def test_fspath_set_to_None(self):
class Foo:
diff --git a/Lib/test/test_pathlib/support/lexical_path.py b/Lib/test/test_pathlib/support/lexical_path.py
index f29a521af9b..fd7fbf283a6 100644
--- a/Lib/test/test_pathlib/support/lexical_path.py
+++ b/Lib/test/test_pathlib/support/lexical_path.py
@@ -9,9 +9,10 @@ import posixpath
from . import is_pypi
if is_pypi:
- from pathlib_abc import _JoinablePath
+ from pathlib_abc import vfspath, _JoinablePath
else:
from pathlib.types import _JoinablePath
+ from pathlib._os import vfspath
class LexicalPath(_JoinablePath):
@@ -22,20 +23,20 @@ class LexicalPath(_JoinablePath):
self._segments = pathsegments
def __hash__(self):
- return hash(str(self))
+ return hash(vfspath(self))
def __eq__(self, other):
if not isinstance(other, LexicalPath):
return NotImplemented
- return str(self) == str(other)
+ return vfspath(self) == vfspath(other)
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r})'
+ return f'{type(self).__name__}({vfspath(self)!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments)
diff --git a/Lib/test/test_pathlib/support/local_path.py b/Lib/test/test_pathlib/support/local_path.py
index d481fd45ead..c1423c545bf 100644
--- a/Lib/test/test_pathlib/support/local_path.py
+++ b/Lib/test/test_pathlib/support/local_path.py
@@ -97,7 +97,7 @@ class LocalPathInfo(PathInfo):
__slots__ = ('_path', '_exists', '_is_dir', '_is_file', '_is_symlink')
def __init__(self, path):
- self._path = str(path)
+ self._path = os.fspath(path)
self._exists = None
self._is_dir = None
self._is_file = None
@@ -139,14 +139,12 @@ class ReadableLocalPath(_ReadablePath, LexicalPath):
Simple implementation of a ReadablePath class for local filesystem paths.
"""
__slots__ = ('info',)
+ __fspath__ = LexicalPath.__vfspath__
def __init__(self, *pathsegments):
super().__init__(*pathsegments)
self.info = LocalPathInfo(self)
- def __fspath__(self):
- return str(self)
-
def __open_rb__(self, buffering=-1):
return open(self, 'rb')
@@ -163,9 +161,7 @@ class WritableLocalPath(_WritablePath, LexicalPath):
"""
__slots__ = ()
-
- def __fspath__(self):
- return str(self)
+ __fspath__ = LexicalPath.__vfspath__
def __open_wb__(self, buffering=-1):
return open(self, 'wb')
diff --git a/Lib/test/test_pathlib/support/zip_path.py b/Lib/test/test_pathlib/support/zip_path.py
index 2905260c9df..21e1d07423a 100644
--- a/Lib/test/test_pathlib/support/zip_path.py
+++ b/Lib/test/test_pathlib/support/zip_path.py
@@ -16,9 +16,10 @@ from stat import S_IFMT, S_ISDIR, S_ISREG, S_ISLNK
from . import is_pypi
if is_pypi:
- from pathlib_abc import PathInfo, _ReadablePath, _WritablePath
+ from pathlib_abc import vfspath, PathInfo, _ReadablePath, _WritablePath
else:
from pathlib.types import PathInfo, _ReadablePath, _WritablePath
+ from pathlib._os import vfspath
class ZipPathGround:
@@ -34,16 +35,16 @@ class ZipPathGround:
root.zip_file.close()
def create_file(self, path, data=b''):
- path.zip_file.writestr(str(path), data)
+ path.zip_file.writestr(vfspath(path), data)
def create_dir(self, path):
- zip_info = zipfile.ZipInfo(str(path) + '/')
+ zip_info = zipfile.ZipInfo(vfspath(path) + '/')
zip_info.external_attr |= stat.S_IFDIR << 16
zip_info.external_attr |= stat.FILE_ATTRIBUTE_DIRECTORY
path.zip_file.writestr(zip_info, '')
def create_symlink(self, path, target):
- zip_info = zipfile.ZipInfo(str(path))
+ zip_info = zipfile.ZipInfo(vfspath(path))
zip_info.external_attr = stat.S_IFLNK << 16
path.zip_file.writestr(zip_info, target.encode())
@@ -62,28 +63,28 @@ class ZipPathGround:
self.create_symlink(p.joinpath('brokenLinkLoop'), 'brokenLinkLoop')
def readtext(self, p):
- with p.zip_file.open(str(p), 'r') as f:
+ with p.zip_file.open(vfspath(p), 'r') as f:
f = io.TextIOWrapper(f, encoding='utf-8')
return f.read()
def readbytes(self, p):
- with p.zip_file.open(str(p), 'r') as f:
+ with p.zip_file.open(vfspath(p), 'r') as f:
return f.read()
readlink = readtext
def isdir(self, p):
- path_str = str(p) + "/"
+ path_str = vfspath(p) + "/"
return path_str in p.zip_file.NameToInfo
def isfile(self, p):
- info = p.zip_file.NameToInfo.get(str(p))
+ info = p.zip_file.NameToInfo.get(vfspath(p))
if info is None:
return False
return not stat.S_ISLNK(info.external_attr >> 16)
def islink(self, p):
- info = p.zip_file.NameToInfo.get(str(p))
+ info = p.zip_file.NameToInfo.get(vfspath(p))
if info is None:
return False
return stat.S_ISLNK(info.external_attr >> 16)
@@ -240,20 +241,20 @@ class ReadableZipPath(_ReadablePath):
zip_file.filelist = ZipFileList(zip_file)
def __hash__(self):
- return hash((str(self), self.zip_file))
+ return hash((vfspath(self), self.zip_file))
def __eq__(self, other):
if not isinstance(other, ReadableZipPath):
return NotImplemented
- return str(self) == str(other) and self.zip_file is other.zip_file
+ return vfspath(self) == vfspath(other) and self.zip_file is other.zip_file
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r}, zip_file={self.zip_file!r})'
+ return f'{type(self).__name__}({vfspath(self)!r}, zip_file={self.zip_file!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments, zip_file=self.zip_file)
@@ -261,7 +262,7 @@ class ReadableZipPath(_ReadablePath):
@property
def info(self):
tree = self.zip_file.filelist.tree
- return tree.resolve(str(self), follow_symlinks=False)
+ return tree.resolve(vfspath(self), follow_symlinks=False)
def __open_rb__(self, buffering=-1):
info = self.info.resolve()
@@ -301,36 +302,36 @@ class WritableZipPath(_WritablePath):
self.zip_file = zip_file
def __hash__(self):
- return hash((str(self), self.zip_file))
+ return hash((vfspath(self), self.zip_file))
def __eq__(self, other):
if not isinstance(other, WritableZipPath):
return NotImplemented
- return str(self) == str(other) and self.zip_file is other.zip_file
+ return vfspath(self) == vfspath(other) and self.zip_file is other.zip_file
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r}, zip_file={self.zip_file!r})'
+ return f'{type(self).__name__}({vfspath(self)!r}, zip_file={self.zip_file!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments, zip_file=self.zip_file)
def __open_wb__(self, buffering=-1):
- return self.zip_file.open(str(self), 'w')
+ return self.zip_file.open(vfspath(self), 'w')
def mkdir(self, mode=0o777):
- zinfo = zipfile.ZipInfo(str(self) + '/')
+ zinfo = zipfile.ZipInfo(vfspath(self) + '/')
zinfo.external_attr |= stat.S_IFDIR << 16
zinfo.external_attr |= stat.FILE_ATTRIBUTE_DIRECTORY
self.zip_file.writestr(zinfo, '')
def symlink_to(self, target, target_is_directory=False):
- zinfo = zipfile.ZipInfo(str(self))
+ zinfo = zipfile.ZipInfo(vfspath(self))
zinfo.external_attr = stat.S_IFLNK << 16
if target_is_directory:
zinfo.external_attr |= 0x10
- self.zip_file.writestr(zinfo, str(target))
+ self.zip_file.writestr(zinfo, vfspath(target))
diff --git a/Lib/test/test_pathlib/test_join_windows.py b/Lib/test/test_pathlib/test_join_windows.py
index 2cc634f25ef..f30c80605f7 100644
--- a/Lib/test/test_pathlib/test_join_windows.py
+++ b/Lib/test/test_pathlib/test_join_windows.py
@@ -8,6 +8,11 @@ import unittest
from .support import is_pypi
from .support.lexical_path import LexicalWindowsPath
+if is_pypi:
+ from pathlib_abc import vfspath
+else:
+ from pathlib._os import vfspath
+
class JoinTestBase:
def test_join(self):
@@ -70,17 +75,17 @@ class JoinTestBase:
self.assertEqual(p / './dd:s', P(r'C:/a/b\./dd:s'))
self.assertEqual(p / 'E:d:s', P('E:d:s'))
- def test_str(self):
+ def test_vfspath(self):
p = self.cls(r'a\b\c')
- self.assertEqual(str(p), 'a\\b\\c')
+ self.assertEqual(vfspath(p), 'a\\b\\c')
p = self.cls(r'c:\a\b\c')
- self.assertEqual(str(p), 'c:\\a\\b\\c')
+ self.assertEqual(vfspath(p), 'c:\\a\\b\\c')
p = self.cls('\\\\a\\b\\')
- self.assertEqual(str(p), '\\\\a\\b\\')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\')
p = self.cls(r'\\a\b\c')
- self.assertEqual(str(p), '\\\\a\\b\\c')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\c')
p = self.cls(r'\\a\b\c\d')
- self.assertEqual(str(p), '\\\\a\\b\\c\\d')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\c\\d')
def test_parts(self):
P = self.cls
diff --git a/Lib/test/test_pathlib/test_pathlib.py b/Lib/test/test_pathlib/test_pathlib.py
index 41a79d0dceb..b2e2cdb3338 100644
--- a/Lib/test/test_pathlib/test_pathlib.py
+++ b/Lib/test/test_pathlib/test_pathlib.py
@@ -16,10 +16,11 @@ from unittest import mock
from urllib.request import pathname2url
from test.support import import_helper
+from test.support import cpython_only
from test.support import is_emscripten, is_wasi
from test.support import infinite_recursion
from test.support import os_helper
-from test.support.os_helper import TESTFN, FakePath
+from test.support.os_helper import TESTFN, FS_NONASCII, FakePath
try:
import fcntl
except ImportError:
@@ -76,8 +77,14 @@ def needs_symlinks(fn):
class UnsupportedOperationTest(unittest.TestCase):
def test_is_notimplemented(self):
- self.assertTrue(issubclass(pathlib.UnsupportedOperation, NotImplementedError))
- self.assertTrue(isinstance(pathlib.UnsupportedOperation(), NotImplementedError))
+ self.assertIsSubclass(pathlib.UnsupportedOperation, NotImplementedError)
+ self.assertIsInstance(pathlib.UnsupportedOperation(), NotImplementedError)
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports("pathlib", {"shutil"})
#
@@ -293,8 +300,8 @@ class PurePathTest(unittest.TestCase):
clsname = p.__class__.__name__
r = repr(p)
# The repr() is in the form ClassName("forward-slashes path").
- self.assertTrue(r.startswith(clsname + '('), r)
- self.assertTrue(r.endswith(')'), r)
+ self.assertStartsWith(r, clsname + '(')
+ self.assertEndsWith(r, ')')
inner = r[len(clsname) + 1 : -1]
self.assertEqual(eval(inner), p.as_posix())
@@ -763,12 +770,16 @@ class PurePathTest(unittest.TestCase):
self.assertEqual(self.make_uri(P('c:/')), 'file:///c:/')
self.assertEqual(self.make_uri(P('c:/a/b.c')), 'file:///c:/a/b.c')
self.assertEqual(self.make_uri(P('c:/a/b%#c')), 'file:///c:/a/b%25%23c')
- self.assertEqual(self.make_uri(P('c:/a/b\xe9')), 'file:///c:/a/b%C3%A9')
self.assertEqual(self.make_uri(P('//some/share/')), 'file://some/share/')
self.assertEqual(self.make_uri(P('//some/share/a/b.c')),
'file://some/share/a/b.c')
- self.assertEqual(self.make_uri(P('//some/share/a/b%#c\xe9')),
- 'file://some/share/a/b%25%23c%C3%A9')
+
+ from urllib.parse import quote_from_bytes
+ QUOTED_FS_NONASCII = quote_from_bytes(os.fsencode(FS_NONASCII))
+ self.assertEqual(self.make_uri(P('c:/a/b' + FS_NONASCII)),
+ 'file:///c:/a/b' + QUOTED_FS_NONASCII)
+ self.assertEqual(self.make_uri(P('//some/share/a/b%#c' + FS_NONASCII)),
+ 'file://some/share/a/b%25%23c' + QUOTED_FS_NONASCII)
@needs_windows
def test_ordering_windows(self):
@@ -2943,7 +2954,13 @@ class PathTest(PurePathTest):
else:
# ".." segments are normalized first on Windows, so this path is stat()able.
self.assertEqual(set(p.glob("xyzzy/..")), { P(self.base, "xyzzy", "..") })
- self.assertEqual(set(p.glob("/".join([".."] * 50))), { P(self.base, *[".."] * 50)})
+ if sys.platform == "emscripten":
+ # Emscripten will return ELOOP if there are 49 or more ..'s.
+ # Can remove when https://github.com/emscripten-core/emscripten/pull/24591 is merged.
+ NDOTDOTS = 48
+ else:
+ NDOTDOTS = 50
+ self.assertEqual(set(p.glob("/".join([".."] * NDOTDOTS))), { P(self.base, *[".."] * NDOTDOTS)})
def test_glob_inaccessible(self):
P = self.cls
@@ -3290,7 +3307,6 @@ class PathTest(PurePathTest):
self.assertEqual(P.from_uri('file:////foo/bar'), P('//foo/bar'))
self.assertEqual(P.from_uri('file://localhost/foo/bar'), P('/foo/bar'))
if not is_wasi:
- self.assertEqual(P.from_uri('file://127.0.0.1/foo/bar'), P('/foo/bar'))
self.assertEqual(P.from_uri(f'file://{socket.gethostname()}/foo/bar'),
P('/foo/bar'))
self.assertRaises(ValueError, P.from_uri, 'foo/bar')
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index ae84fe3ce7d..6b74e21ad73 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -1,7 +1,9 @@
# A test suite for pdb; not very comprehensive at the moment.
+import _colorize
import doctest
import gc
+import io
import os
import pdb
import sys
@@ -18,7 +20,7 @@ from asyncio.events import _set_event_loop_policy
from contextlib import ExitStack, redirect_stdout
from io import StringIO
from test import support
-from test.support import force_not_colorized, has_socket_support, os_helper
+from test.support import has_socket_support, os_helper
from test.support.import_helper import import_module
from test.support.pty_helper import run_pty, FakeInput
from test.support.script_helper import kill_python
@@ -3446,6 +3448,7 @@ def test_pdb_issue_gh_65052():
"""
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestCase(unittest.TestCase):
def tearDown(self):
@@ -3740,7 +3743,6 @@ def bœr():
self.assertNotIn(b'Error', stdout,
"Got an error running test script under PDB")
- @force_not_colorized
def test_issue16180(self):
# A syntax error in the debuggee.
script = "def f: pass\n"
@@ -3754,7 +3756,6 @@ def bœr():
'Fail to handle a syntax error in the debuggee.'
.format(expected, stderr))
- @force_not_colorized
def test_issue84583(self):
# A syntax error from ast.literal_eval should not make pdb exit.
script = "import ast; ast.literal_eval('')\n"
@@ -4688,6 +4689,40 @@ class PdbTestInline(unittest.TestCase):
self.assertIn("42", stdout)
+@support.force_colorized_test_class
+class PdbTestColorize(unittest.TestCase):
+ def setUp(self):
+ self._original_can_colorize = _colorize.can_colorize
+ # Force colorize to be enabled because we are sending data
+ # to a StringIO
+ _colorize.can_colorize = lambda *args, **kwargs: True
+
+ def tearDown(self):
+ _colorize.can_colorize = self._original_can_colorize
+
+ def test_code_display(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=False)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ def test_stack_entry(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['w', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+
@support.force_not_colorized_test_class
@support.requires_subprocess()
class TestREPLSession(unittest.TestCase):
@@ -4711,9 +4746,12 @@ class TestREPLSession(unittest.TestCase):
self.assertEqual(p.returncode, 0)
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestReadline(unittest.TestCase):
- def setUpClass():
+
+ @classmethod
+ def setUpClass(cls):
# Ensure that the readline module is loaded
# If this fails, the test is skipped because SkipTest will be raised
readline = import_module('readline')
@@ -4812,14 +4850,37 @@ class PdbTestReadline(unittest.TestCase):
self.assertIn(b'I love Python', output)
+ @unittest.skipIf(sys.platform.startswith('freebsd'),
+ '\\x08 is not interpreted as backspace on FreeBSD')
+ def test_multiline_auto_indent(self):
+ script = textwrap.dedent("""
+ import pdb; pdb.Pdb().set_trace()
+ """)
+
+ input = b"def f(x):\n"
+ input += b"if x > 0:\n"
+ input += b"x += 1\n"
+ input += b"return x\n"
+ # We need to do backspaces to remove the auto-indentation
+ input += b"\x08\x08\x08\x08else:\n"
+ input += b"return -x\n"
+ input += b"\n"
+ input += b"f(-21-21)\n"
+ input += b"c\n"
+
+ output = run_pty(script, input)
+
+ self.assertIn(b'42', output)
+
def test_multiline_completion(self):
script = textwrap.dedent("""
import pdb; pdb.Pdb().set_trace()
""")
input = b"def func():\n"
- # Complete: \treturn 40 + 2
- input += b"\tret\t 40 + 2\n"
+ # Auto-indent
+ # Complete: return 40 + 2
+ input += b"ret\t 40 + 2\n"
input += b"\n"
# Complete: func()
input += b"fun\t()\n"
@@ -4829,6 +4890,8 @@ class PdbTestReadline(unittest.TestCase):
self.assertIn(b'42', output)
+ @unittest.skipIf(sys.platform.startswith('freebsd'),
+ '\\x08 is not interpreted as backspace on FreeBSD')
def test_multiline_indent_completion(self):
script = textwrap.dedent("""
import pdb; pdb.Pdb().set_trace()
@@ -4839,12 +4902,13 @@ class PdbTestReadline(unittest.TestCase):
# if the completion is not working as expected
input = textwrap.dedent("""\
def func():
- \ta = 1
- \ta += 1
- \ta += 1
- \tif a > 0:
- a += 1
- \t\treturn a
+ a = 1
+ \x08\ta += 1
+ \x08\x08\ta += 1
+ \x08\x08\x08\ta += 1
+ \x08\x08\x08\x08\tif a > 0:
+ a += 1
+ \x08\x08\x08\x08return a
func()
c
@@ -4852,9 +4916,37 @@ class PdbTestReadline(unittest.TestCase):
output = run_pty(script, input)
- self.assertIn(b'4', output)
+ self.assertIn(b'5', output)
self.assertNotIn(b'Error', output)
+ def test_interact_completion(self):
+ script = textwrap.dedent("""
+ value = "speci"
+ import pdb; pdb.Pdb().set_trace()
+ """)
+
+ # Enter interact mode
+ input = b"interact\n"
+ # Should fail to complete 'display' because that's a pdb command
+ input += b"disp\t\n"
+ # 'value' should still work
+ input += b"val\t + 'al'\n"
+ # Let's define a function to test <tab>
+ input += b"def f():\n"
+ input += b"\treturn 42\n"
+ input += b"\n"
+ input += b"f() * 2\n"
+ # Exit interact mode
+ input += b"exit()\n"
+ # continue
+ input += b"c\n"
+
+ output = run_pty(script, input)
+
+ self.assertIn(b"'disp' is not defined", output)
+ self.assertIn(b'special', output)
+ self.assertIn(b'84', output)
+
def load_tests(loader, tests, pattern):
from test import test_pdb
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 565e42b04a6..3d7300e1480 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -1,4 +1,5 @@
import dis
+import gc
from itertools import combinations, product
import opcode
import sys
@@ -11,7 +12,7 @@ except ImportError:
from test import support
from test.support.bytecode_helper import (
- BytecodeTestCase, CfgOptimizationTestCase, CompilationStepTestCase)
+ BytecodeTestCase, CfgOptimizationTestCase)
def compile_pattern_with_fast_locals(pattern):
@@ -315,7 +316,7 @@ class TestTranforms(BytecodeTestCase):
return -(1.0-1.0)
for instr in dis.get_instructions(negzero):
- self.assertFalse(instr.opname.startswith('UNARY_'))
+ self.assertNotStartsWith(instr.opname, 'UNARY_')
self.check_lnotab(negzero)
def test_constant_folding_binop(self):
@@ -717,9 +718,9 @@ class TestTranforms(BytecodeTestCase):
self.assertEqual(format('x = %d!', 1234), 'x = 1234!')
self.assertEqual(format('x = %x!', 1234), 'x = 4d2!')
self.assertEqual(format('x = %f!', 1234), 'x = 1234.000000!')
- self.assertEqual(format('x = %s!', 1234.5678901), 'x = 1234.5678901!')
- self.assertEqual(format('x = %f!', 1234.5678901), 'x = 1234.567890!')
- self.assertEqual(format('x = %d!', 1234.5678901), 'x = 1234!')
+ self.assertEqual(format('x = %s!', 1234.0000625), 'x = 1234.0000625!')
+ self.assertEqual(format('x = %f!', 1234.0000625), 'x = 1234.000063!')
+ self.assertEqual(format('x = %d!', 1234.0000625), 'x = 1234!')
self.assertEqual(format('x = %s%% %%%%', 1234), 'x = 1234% %%')
self.assertEqual(format('x = %s!', '%% %s'), 'x = %% %s!')
self.assertEqual(format('x = %s, y = %d', 12, 34), 'x = 12, y = 34')
@@ -2472,6 +2473,13 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.check(insts, insts)
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("DELETE_FAST", 0, 2),
+ ("POP_TOP", None, 3),
+ ]
+ self.check(insts, insts)
+
def test_unoptimized_if_aliased(self):
insts = [
("LOAD_FAST", 0, 1),
@@ -2606,6 +2614,114 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.cfg_optimization_test(insts, expected, consts=[None])
+ def test_format_simple(self):
+ # FORMAT_SIMPLE will leave its operand on the stack if it's a unicode
+ # object. We treat it conservatively and assume that it always leaves
+ # its operand on the stack.
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("STORE_FAST", 1, 3),
+ ]
+ self.check(insts, insts)
+
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("POP_TOP", None, 3),
+ ]
+ expected = [
+ ("LOAD_FAST_BORROW", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("POP_TOP", None, 3),
+ ]
+ self.check(insts, expected)
+
+ def test_set_function_attribute(self):
+ # SET_FUNCTION_ATTRIBUTE leaves the function on the stack
+ insts = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("STORE_FAST", 1, 4),
+ ("LOAD_CONST", 0, 5),
+ ("RETURN_VALUE", None, 6)
+ ]
+ self.cfg_optimization_test(insts, insts, consts=[None])
+
+ insts = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("RETURN_VALUE", None, 4)
+ ]
+ expected = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST_BORROW", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("RETURN_VALUE", None, 4)
+ ]
+ self.cfg_optimization_test(insts, expected, consts=[None])
+
+ def test_get_yield_from_iter(self):
+ # GET_YIELD_FROM_ITER may leave its operand on the stack
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("GET_YIELD_FROM_ITER", None, 2),
+ ("LOAD_CONST", 0, 3),
+ send := self.Label(),
+ ("SEND", end := self.Label(), 5),
+ ("YIELD_VALUE", 1, 6),
+ ("RESUME", 2, 7),
+ ("JUMP", send, 8),
+ end,
+ ("END_SEND", None, 9),
+ ("LOAD_CONST", 0, 10),
+ ("RETURN_VALUE", None, 11),
+ ]
+ self.cfg_optimization_test(insts, insts, consts=[None])
+
+ def test_push_exc_info(self):
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("PUSH_EXC_INFO", None, 2),
+ ]
+ self.check(insts, insts)
+
+ def test_load_special(self):
+ # LOAD_SPECIAL may leave self on the stack
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("LOAD_SPECIAL", 0, 2),
+ ("STORE_FAST", 1, 3),
+ ]
+ self.check(insts, insts)
+
+
+ def test_del_in_finally(self):
+ # This loads `obj` onto the stack, executes `del obj`, then returns the
+ # `obj` from the stack. See gh-133371 for more details.
+ def create_obj():
+ obj = [42]
+ try:
+ return obj
+ finally:
+ del obj
+
+ obj = create_obj()
+ # The crash in the linked issue happens while running GC during
+ # interpreter finalization, so run it here manually.
+ gc.collect()
+ self.assertEqual(obj, [42])
+
+ def test_format_simple_unicode(self):
+ # Repro from gh-134889
+ def f():
+ var = f"{1}"
+ var = f"{var}"
+ return var
+ self.assertEqual(f(), "1")
+
if __name__ == "__main__":
diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py
index 1095e7303c1..aa01a9b8f7e 100644
--- a/Lib/test/test_peg_generator/test_c_parser.py
+++ b/Lib/test/test_peg_generator/test_c_parser.py
@@ -387,10 +387,10 @@ class TestCParser(unittest.TestCase):
test_source = """
stmt = "with (\\n a as b,\\n c as d\\n): pass"
the_ast = parse.parse_string(stmt, mode=1)
- self.assertTrue(ast_dump(the_ast).startswith(
+ self.assertStartsWith(ast_dump(the_ast),
"Module(body=[With(items=[withitem(context_expr=Name(id='a', ctx=Load()), optional_vars=Name(id='b', ctx=Store())), "
"withitem(context_expr=Name(id='c', ctx=Load()), optional_vars=Name(id='d', ctx=Store()))]"
- ))
+ )
"""
self.run_test(grammar_source, test_source)
diff --git a/Lib/test/test_peg_generator/test_pegen.py b/Lib/test/test_peg_generator/test_pegen.py
index d8606521345..d912c558123 100644
--- a/Lib/test/test_peg_generator/test_pegen.py
+++ b/Lib/test/test_peg_generator/test_pegen.py
@@ -91,10 +91,8 @@ class TestPegen(unittest.TestCase):
"""
rules = parse_string(grammar, GrammarParser).rules
self.assertEqual(str(rules["start"]), "start: ','.thing+ NEWLINE")
- self.assertTrue(
- repr(rules["start"]).startswith(
- "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
- )
+ self.assertStartsWith(repr(rules["start"]),
+ "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
)
self.assertEqual(str(rules["thing"]), "thing: NUMBER")
parser_class = make_parser(grammar)
diff --git a/Lib/test/test_perf_profiler.py b/Lib/test/test_perf_profiler.py
index c176e505155..7529c853f9c 100644
--- a/Lib/test/test_perf_profiler.py
+++ b/Lib/test/test_perf_profiler.py
@@ -93,9 +93,7 @@ class TestPerfTrampoline(unittest.TestCase):
perf_line, f"Could not find {expected_symbol} in perf file"
)
perf_addr = perf_line.split(" ")[0]
- self.assertFalse(
- perf_addr.startswith("0x"), "Address should not be prefixed with 0x"
- )
+ self.assertNotStartsWith(perf_addr, "0x")
self.assertTrue(
set(perf_addr).issubset(string.hexdigits),
"Address should contain only hex characters",
@@ -508,9 +506,12 @@ def _is_perf_version_at_least(major, minor):
# The output of perf --version looks like "perf version 6.7-3" but
# it can also be perf version "perf version 5.15.143", or even include
# a commit hash in the version string, like "6.12.9.g242e6068fd5c"
+ #
+ # PermissionError is raised if perf does not exist on the Windows Subsystem
+ # for Linux, see #134987
try:
output = subprocess.check_output(["perf", "--version"], text=True)
- except (subprocess.CalledProcessError, FileNotFoundError):
+ except (subprocess.CalledProcessError, FileNotFoundError, PermissionError):
return False
version = output.split()[2]
version = version.split("-")[0]
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index 296d4b882e1..e2384b33345 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -15,7 +15,8 @@ from textwrap import dedent
import doctest
import unittest
from test import support
-from test.support import import_helper, os_helper
+from test.support import cpython_only, import_helper, os_helper
+from test.support.import_helper import ensure_lazy_imports
from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests
@@ -36,6 +37,12 @@ except ImportError:
has_c_implementation = False
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pickle", {"re"})
+
+
class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
dump = staticmethod(pickle._dump)
dumps = staticmethod(pickle._dumps)
@@ -604,10 +611,10 @@ class CompatPickleTests(unittest.TestCase):
with self.subTest(((module3, name3), (module2, name2))):
if (module2, name2) == ('exceptions', 'OSError'):
attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, OSError))
+ self.assertIsSubclass(attr, OSError)
elif (module2, name2) == ('exceptions', 'ImportError'):
attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, ImportError))
+ self.assertIsSubclass(attr, ImportError)
else:
module, name = mapping(module2, name2)
if module3[:1] != '_':
@@ -745,6 +752,7 @@ class CommandLineTest(unittest.TestCase):
expect = self.text_normalize(expect)
self.assertListEqual(res.splitlines(), expect.splitlines())
+ @support.force_not_colorized
def test_unknown_flag(self):
stderr = io.StringIO()
with self.assertRaises(SystemExit):
diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py
index 6ba630ad527..479649053ab 100644
--- a/Lib/test/test_platform.py
+++ b/Lib/test/test_platform.py
@@ -1,5 +1,8 @@
-import os
+import contextlib
import copy
+import io
+import itertools
+import os
import pickle
import platform
import subprocess
@@ -130,6 +133,22 @@ class PlatformTest(unittest.TestCase):
for terse in (False, True):
res = platform.platform(aliased, terse)
+ def test__platform(self):
+ for src, res in [
+ ('foo bar', 'foo_bar'),
+ (
+ '1/2\\3:4;5"6(7)8(7)6"5;4:3\\2/1',
+ '1-2-3-4-5-6-7-8-7-6-5-4-3-2-1'
+ ),
+ ('--', ''),
+ ('-f', '-f'),
+ ('-foo----', '-foo'),
+ ('--foo---', '-foo'),
+ ('---foo--', '-foo'),
+ ]:
+ with self.subTest(src=src):
+ self.assertEqual(platform._platform(src), res)
+
def test_system(self):
res = platform.system()
@@ -380,15 +399,6 @@ class PlatformTest(unittest.TestCase):
finally:
platform._uname_cache = None
- def test_java_ver(self):
- import re
- msg = re.escape(
- "'java_ver' is deprecated and slated for removal in Python 3.15"
- )
- with self.assertWarnsRegex(DeprecationWarning, msg):
- res = platform.java_ver()
- self.assertEqual(len(res), 4)
-
@unittest.skipUnless(support.MS_WINDOWS, 'This test only makes sense on Windows')
def test_win32_ver(self):
release1, version1, csd1, ptype1 = 'a', 'b', 'c', 'd'
@@ -407,7 +417,7 @@ class PlatformTest(unittest.TestCase):
for v in version.split('.'):
int(v) # should not fail
if csd:
- self.assertTrue(csd.startswith('SP'), msg=csd)
+ self.assertStartsWith(csd, 'SP')
if ptype:
if os.cpu_count() > 1:
self.assertIn('Multiprocessor', ptype)
@@ -741,5 +751,66 @@ class PlatformTest(unittest.TestCase):
self.assertEqual(len(info["SPECIALS"]), 5)
+class CommandLineTest(unittest.TestCase):
+ def setUp(self):
+ platform.invalidate_caches()
+ self.addCleanup(platform.invalidate_caches)
+
+ def invoke_platform(self, *flags):
+ output = io.StringIO()
+ with contextlib.redirect_stdout(output):
+ platform._main(args=flags)
+ return output.getvalue()
+
+ def test_unknown_flag(self):
+ with self.assertRaises(SystemExit):
+ output = io.StringIO()
+ # suppress argparse error message
+ with contextlib.redirect_stderr(output):
+ _ = self.invoke_platform('--unknown')
+ self.assertStartsWith(output, "usage: ")
+
+ def test_invocation(self):
+ flags = (
+ "--terse", "--nonaliased", "terse", "nonaliased"
+ )
+
+ for r in range(len(flags) + 1):
+ for combination in itertools.combinations(flags, r):
+ self.invoke_platform(*combination)
+
+ def test_arg_parsing(self):
+ # For backwards compatibility, the `aliased` and `terse` parameters are
+ # computed based on a combination of positional arguments and flags.
+ #
+ # Test that the arguments are correctly passed to the underlying
+ # `platform.platform()` call.
+ options = (
+ (["--nonaliased"], False, False),
+ (["nonaliased"], False, False),
+ (["--terse"], True, True),
+ (["terse"], True, True),
+ (["nonaliased", "terse"], False, True),
+ (["--nonaliased", "terse"], False, True),
+ (["--terse", "nonaliased"], False, True),
+ )
+
+ for flags, aliased, terse in options:
+ with self.subTest(flags=flags, aliased=aliased, terse=terse):
+ with mock.patch.object(platform, 'platform') as obj:
+ self.invoke_platform(*flags)
+ obj.assert_called_once_with(aliased, terse)
+
+ @support.force_not_colorized
+ def test_help(self):
+ output = io.StringIO()
+
+ with self.assertRaises(SystemExit):
+ with contextlib.redirect_stdout(output):
+ platform._main(args=["--help"])
+
+ self.assertStartsWith(output.getvalue(), "usage:")
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_positional_only_arg.py b/Lib/test/test_positional_only_arg.py
index eea0625012d..e412cb1d58d 100644
--- a/Lib/test/test_positional_only_arg.py
+++ b/Lib/test/test_positional_only_arg.py
@@ -37,8 +37,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "def f(/): pass")
check_syntax_error(self, "def f(*, a, /): pass")
check_syntax_error(self, "def f(*, /, a): pass")
- check_syntax_error(self, "def f(a, /, a): pass", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "def f(a, /, *, a): pass", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "def f(a, /, a): pass", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "def f(a, /, *, a): pass", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "def f(a, b/2, c): pass")
check_syntax_error(self, "def f(a, /, c, /): pass")
check_syntax_error(self, "def f(a, /, c, /, d): pass")
@@ -59,8 +59,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "async def f(/): pass")
check_syntax_error(self, "async def f(*, a, /): pass")
check_syntax_error(self, "async def f(*, /, a): pass")
- check_syntax_error(self, "async def f(a, /, a): pass", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "async def f(a, /, *, a): pass", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "async def f(a, /, a): pass", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "async def f(a, /, *, a): pass", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "async def f(a, b/2, c): pass")
check_syntax_error(self, "async def f(a, /, c, /): pass")
check_syntax_error(self, "async def f(a, /, c, /, d): pass")
@@ -247,8 +247,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "lambda /: None")
check_syntax_error(self, "lambda *, a, /: None")
check_syntax_error(self, "lambda *, /, a: None")
- check_syntax_error(self, "lambda a, /, a: None", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "lambda a, /, *, a: None", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "lambda a, /, a: None", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "lambda a, /, *, a: None", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "lambda a, /, b, /: None")
check_syntax_error(self, "lambda a, /, b, /, c: None")
check_syntax_error(self, "lambda a, /, b, /, c, *, d: None")
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index c9cbe1541e7..628920e34b5 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -1107,7 +1107,7 @@ class PosixTester(unittest.TestCase):
def _test_chflags_regular_file(self, chflags_func, target_file, **kwargs):
st = os.stat(target_file)
- self.assertTrue(hasattr(st, 'st_flags'))
+ self.assertHasAttr(st, 'st_flags')
# ZFS returns EOPNOTSUPP when attempting to set flag UF_IMMUTABLE.
flags = st.st_flags | stat.UF_IMMUTABLE
@@ -1143,7 +1143,7 @@ class PosixTester(unittest.TestCase):
def test_lchflags_symlink(self):
testfn_st = os.stat(os_helper.TESTFN)
- self.assertTrue(hasattr(testfn_st, 'st_flags'))
+ self.assertHasAttr(testfn_st, 'st_flags')
self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
@@ -1521,6 +1521,51 @@ class PosixTester(unittest.TestCase):
self.assertEqual(cm.exception.errno, errno.EINVAL)
os.close(os.pidfd_open(os.getpid(), 0))
+ @os_helper.skip_unless_hardlink
+ @os_helper.skip_unless_symlink
+ def test_link_follow_symlinks(self):
+ default_follow = sys.platform.startswith(
+ ('darwin', 'freebsd', 'netbsd', 'openbsd', 'dragonfly', 'sunos5'))
+ default_no_follow = sys.platform.startswith(('win32', 'linux'))
+ orig = os_helper.TESTFN
+ symlink = orig + 'symlink'
+ posix.symlink(orig, symlink)
+ self.addCleanup(os_helper.unlink, symlink)
+
+ with self.subTest('no follow_symlinks'):
+ # no follow_symlinks -> platform depending
+ link = orig + 'link'
+ posix.link(symlink, link)
+ self.addCleanup(os_helper.unlink, link)
+ if os.link in os.supports_follow_symlinks or default_follow:
+ self.assertEqual(posix.lstat(link), posix.lstat(orig))
+ elif default_no_follow:
+ self.assertEqual(posix.lstat(link), posix.lstat(symlink))
+
+ with self.subTest('follow_symlinks=False'):
+ # follow_symlinks=False -> duplicate the symlink itself
+ link = orig + 'link_nofollow'
+ try:
+ posix.link(symlink, link, follow_symlinks=False)
+ except NotImplementedError:
+ if os.link in os.supports_follow_symlinks or default_no_follow:
+ raise
+ else:
+ self.addCleanup(os_helper.unlink, link)
+ self.assertEqual(posix.lstat(link), posix.lstat(symlink))
+
+ with self.subTest('follow_symlinks=True'):
+ # follow_symlinks=True -> duplicate the target file
+ link = orig + 'link_following'
+ try:
+ posix.link(symlink, link, follow_symlinks=True)
+ except NotImplementedError:
+ if os.link in os.supports_follow_symlinks or default_follow:
+ raise
+ else:
+ self.addCleanup(os_helper.unlink, link)
+ self.assertEqual(posix.lstat(link), posix.lstat(orig))
+
# tests for the posix *at functions follow
class TestPosixDirFd(unittest.TestCase):
@@ -2173,12 +2218,12 @@ class TestPosixWeaklinking(unittest.TestCase):
def test_pwritev(self):
self._verify_available("HAVE_PWRITEV")
if self.mac_ver >= (10, 16):
- self.assertTrue(hasattr(os, "pwritev"), "os.pwritev is not available")
- self.assertTrue(hasattr(os, "preadv"), "os.readv is not available")
+ self.assertHasAttr(os, "pwritev")
+ self.assertHasAttr(os, "preadv")
else:
- self.assertFalse(hasattr(os, "pwritev"), "os.pwritev is available")
- self.assertFalse(hasattr(os, "preadv"), "os.readv is available")
+ self.assertNotHasAttr(os, "pwritev")
+ self.assertNotHasAttr(os, "preadv")
def test_stat(self):
self._verify_available("HAVE_FSTATAT")
diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py
index fa19d549c26..21f06712548 100644
--- a/Lib/test/test_posixpath.py
+++ b/Lib/test/test_posixpath.py
@@ -4,12 +4,13 @@ import posixpath
import random
import sys
import unittest
-from posixpath import realpath, abspath, dirname, basename
+from functools import partial
+from posixpath import realpath, abspath, dirname, basename, ALLOW_MISSING
from test import support
from test import test_genericpath
from test.support import import_helper
from test.support import os_helper
-from test.support.os_helper import FakePath
+from test.support.os_helper import FakePath, TESTFN
from unittest import mock
try:
@@ -21,7 +22,7 @@ except ImportError:
# An absolute path to a temporary filename for testing. We can't rely on TESTFN
# being an absolute path, so we need this.
-ABSTFN = abspath(os_helper.TESTFN)
+ABSTFN = abspath(TESTFN)
def skip_if_ABSTFN_contains_backslash(test):
"""
@@ -33,21 +34,16 @@ def skip_if_ABSTFN_contains_backslash(test):
msg = "ABSTFN is not a posix path - tests fail"
return [test, unittest.skip(msg)(test)][found_backslash]
-def safe_rmdir(dirname):
- try:
- os.rmdir(dirname)
- except OSError:
- pass
+
+def _parameterize(*parameters):
+ return support.subTests('kwargs', parameters)
+
class PosixPathTest(unittest.TestCase):
def setUp(self):
- self.tearDown()
-
- def tearDown(self):
for suffix in ["", "1", "2"]:
- os_helper.unlink(os_helper.TESTFN + suffix)
- safe_rmdir(os_helper.TESTFN + suffix)
+ self.assertFalse(posixpath.lexists(ABSTFN + suffix))
def test_join(self):
fn = posixpath.join
@@ -194,25 +190,28 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(posixpath.dirname(b"//foo//bar"), b"//foo")
def test_islink(self):
- self.assertIs(posixpath.islink(os_helper.TESTFN + "1"), False)
- self.assertIs(posixpath.lexists(os_helper.TESTFN + "2"), False)
+ self.assertIs(posixpath.islink(TESTFN + "1"), False)
+ self.assertIs(posixpath.lexists(TESTFN + "2"), False)
- with open(os_helper.TESTFN + "1", "wb") as f:
+ self.addCleanup(os_helper.unlink, TESTFN + "1")
+ with open(TESTFN + "1", "wb") as f:
f.write(b"foo")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "1"), False)
+ self.assertIs(posixpath.islink(TESTFN + "1"), False)
if os_helper.can_symlink():
- os.symlink(os_helper.TESTFN + "1", os_helper.TESTFN + "2")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "2"), True)
- os.remove(os_helper.TESTFN + "1")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "2"), True)
- self.assertIs(posixpath.exists(os_helper.TESTFN + "2"), False)
- self.assertIs(posixpath.lexists(os_helper.TESTFN + "2"), True)
-
- self.assertIs(posixpath.islink(os_helper.TESTFN + "\udfff"), False)
- self.assertIs(posixpath.islink(os.fsencode(os_helper.TESTFN) + b"\xff"), False)
- self.assertIs(posixpath.islink(os_helper.TESTFN + "\x00"), False)
- self.assertIs(posixpath.islink(os.fsencode(os_helper.TESTFN) + b"\x00"), False)
+ self.addCleanup(os_helper.unlink, TESTFN + "2")
+ os.symlink(TESTFN + "1", TESTFN + "2")
+ self.assertIs(posixpath.islink(TESTFN + "2"), True)
+ os.remove(TESTFN + "1")
+ self.assertIs(posixpath.islink(TESTFN + "2"), True)
+ self.assertIs(posixpath.exists(TESTFN + "2"), False)
+ self.assertIs(posixpath.lexists(TESTFN + "2"), True)
+
+ def test_islink_invalid_paths(self):
+ self.assertIs(posixpath.islink(TESTFN + "\udfff"), False)
+ self.assertIs(posixpath.islink(os.fsencode(TESTFN) + b"\xff"), False)
+ self.assertIs(posixpath.islink(TESTFN + "\x00"), False)
+ self.assertIs(posixpath.islink(os.fsencode(TESTFN) + b"\x00"), False)
def test_ismount(self):
self.assertIs(posixpath.ismount("/"), True)
@@ -227,8 +226,9 @@ class PosixPathTest(unittest.TestCase):
os.mkdir(ABSTFN)
self.assertIs(posixpath.ismount(ABSTFN), False)
finally:
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
+ def test_ismount_invalid_paths(self):
self.assertIs(posixpath.ismount('/\udfff'), False)
self.assertIs(posixpath.ismount(b'/\xff'), False)
self.assertIs(posixpath.ismount('/\x00'), False)
@@ -241,7 +241,7 @@ class PosixPathTest(unittest.TestCase):
os.symlink("/", ABSTFN)
self.assertIs(posixpath.ismount(ABSTFN), False)
finally:
- os.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN)
@unittest.skipIf(posix is None, "Test requires posix module")
def test_ismount_different_device(self):
@@ -448,32 +448,35 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(result, expected)
@skip_if_ABSTFN_contains_backslash
- def test_realpath_curdir(self):
- self.assertEqual(realpath('.'), os.getcwd())
- self.assertEqual(realpath('./.'), os.getcwd())
- self.assertEqual(realpath('/'.join(['.'] * 100)), os.getcwd())
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_curdir(self, kwargs):
+ self.assertEqual(realpath('.', **kwargs), os.getcwd())
+ self.assertEqual(realpath('./.', **kwargs), os.getcwd())
+ self.assertEqual(realpath('/'.join(['.'] * 100), **kwargs), os.getcwd())
- self.assertEqual(realpath(b'.'), os.getcwdb())
- self.assertEqual(realpath(b'./.'), os.getcwdb())
- self.assertEqual(realpath(b'/'.join([b'.'] * 100)), os.getcwdb())
+ self.assertEqual(realpath(b'.', **kwargs), os.getcwdb())
+ self.assertEqual(realpath(b'./.', **kwargs), os.getcwdb())
+ self.assertEqual(realpath(b'/'.join([b'.'] * 100), **kwargs), os.getcwdb())
@skip_if_ABSTFN_contains_backslash
- def test_realpath_pardir(self):
- self.assertEqual(realpath('..'), dirname(os.getcwd()))
- self.assertEqual(realpath('../..'), dirname(dirname(os.getcwd())))
- self.assertEqual(realpath('/'.join(['..'] * 100)), '/')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_pardir(self, kwargs):
+ self.assertEqual(realpath('..', **kwargs), dirname(os.getcwd()))
+ self.assertEqual(realpath('../..', **kwargs), dirname(dirname(os.getcwd())))
+ self.assertEqual(realpath('/'.join(['..'] * 100), **kwargs), '/')
- self.assertEqual(realpath(b'..'), dirname(os.getcwdb()))
- self.assertEqual(realpath(b'../..'), dirname(dirname(os.getcwdb())))
- self.assertEqual(realpath(b'/'.join([b'..'] * 100)), b'/')
+ self.assertEqual(realpath(b'..', **kwargs), dirname(os.getcwdb()))
+ self.assertEqual(realpath(b'../..', **kwargs), dirname(dirname(os.getcwdb())))
+ self.assertEqual(realpath(b'/'.join([b'..'] * 100), **kwargs), b'/')
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_basic(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_basic(self, kwargs):
# Basic operation.
try:
os.symlink(ABSTFN+"1", ABSTFN)
- self.assertEqual(realpath(ABSTFN), ABSTFN+"1")
+ self.assertEqual(realpath(ABSTFN, **kwargs), ABSTFN+"1")
finally:
os_helper.unlink(ABSTFN)
@@ -489,23 +492,121 @@ class PosixPathTest(unittest.TestCase):
finally:
os_helper.unlink(ABSTFN)
+ def test_realpath_invalid_paths(self):
+ path = '/\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/x\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/nonexistent/x\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = '/\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+
+ path = '/nonexistent/x\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/nonexistent/x\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+
+ path = '/\udfff'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/\udfff'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ path = '/\udfff/..'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), '/')
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), '/')
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/\udfff/..'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), '/nonexistent')
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), '/nonexistent')
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+
+ path = b'/\xff'
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertEqual(realpath(path, strict=False), path)
+ if support.is_wasi:
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ path = b'/nonexistent/\xff'
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertEqual(realpath(path, strict=False), path)
+ if support.is_wasi:
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_relative(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_relative(self, kwargs):
try:
os.symlink(posixpath.relpath(ABSTFN+"1"), ABSTFN)
- self.assertEqual(realpath(ABSTFN), ABSTFN+"1")
+ self.assertEqual(realpath(ABSTFN, **kwargs), ABSTFN+"1")
finally:
os_helper.unlink(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_missing_pardir(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_missing_pardir(self, kwargs):
try:
- os.symlink(os_helper.TESTFN + "1", os_helper.TESTFN)
- self.assertEqual(realpath("nonexistent/../" + os_helper.TESTFN), ABSTFN + "1")
+ os.symlink(TESTFN + "1", TESTFN)
+ self.assertEqual(
+ realpath("nonexistent/../" + TESTFN, **kwargs), ABSTFN + "1")
finally:
- os_helper.unlink(os_helper.TESTFN)
+ os_helper.unlink(TESTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -550,37 +651,38 @@ class PosixPathTest(unittest.TestCase):
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_symlink_loops_strict(self):
+ @_parameterize({'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_symlink_loops_strict(self, kwargs):
# Bug #43757, raise OSError if we get into an infinite symlink loop in
- # strict mode.
+ # the strict modes.
try:
os.symlink(ABSTFN, ABSTFN)
- self.assertRaises(OSError, realpath, ABSTFN, strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN, **kwargs)
os.symlink(ABSTFN+"1", ABSTFN+"2")
os.symlink(ABSTFN+"2", ABSTFN+"1")
- self.assertRaises(OSError, realpath, ABSTFN+"1", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"2", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"1", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"2", **kwargs)
- self.assertRaises(OSError, realpath, ABSTFN+"1/x", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"1/..", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"1/../x", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/x", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/..", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/../x", **kwargs)
os.symlink(ABSTFN+"x", ABSTFN+"y")
self.assertRaises(OSError, realpath,
- ABSTFN+"1/../" + basename(ABSTFN) + "y", strict=True)
+ ABSTFN+"1/../" + basename(ABSTFN) + "y", **kwargs)
self.assertRaises(OSError, realpath,
- ABSTFN+"1/../" + basename(ABSTFN) + "1", strict=True)
+ ABSTFN+"1/../" + basename(ABSTFN) + "1", **kwargs)
os.symlink(basename(ABSTFN) + "a/b", ABSTFN+"a")
- self.assertRaises(OSError, realpath, ABSTFN+"a", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"a", **kwargs)
os.symlink("../" + basename(dirname(ABSTFN)) + "/" +
basename(ABSTFN) + "c", ABSTFN+"c")
- self.assertRaises(OSError, realpath, ABSTFN+"c", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"c", **kwargs)
# Test using relative path as well.
with os_helper.change_cwd(dirname(ABSTFN)):
- self.assertRaises(OSError, realpath, basename(ABSTFN), strict=True)
+ self.assertRaises(OSError, realpath, basename(ABSTFN), **kwargs)
finally:
os_helper.unlink(ABSTFN)
os_helper.unlink(ABSTFN+"1")
@@ -591,28 +693,30 @@ class PosixPathTest(unittest.TestCase):
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_repeated_indirect_symlinks(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_repeated_indirect_symlinks(self, kwargs):
# Issue #6975.
try:
os.mkdir(ABSTFN)
os.symlink('../' + basename(ABSTFN), ABSTFN + '/self')
os.symlink('self/self/self', ABSTFN + '/link')
- self.assertEqual(realpath(ABSTFN + '/link'), ABSTFN)
+ self.assertEqual(realpath(ABSTFN + '/link', **kwargs), ABSTFN)
finally:
os_helper.unlink(ABSTFN + '/self')
os_helper.unlink(ABSTFN + '/link')
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_deep_recursion(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_deep_recursion(self, kwargs):
depth = 10
try:
os.mkdir(ABSTFN)
for i in range(depth):
os.symlink('/'.join(['%d' % i] * 10), ABSTFN + '/%d' % (i + 1))
os.symlink('.', ABSTFN + '/0')
- self.assertEqual(realpath(ABSTFN + '/%d' % depth), ABSTFN)
+ self.assertEqual(realpath(ABSTFN + '/%d' % depth, **kwargs), ABSTFN)
# Test using relative path as well.
with os_helper.change_cwd(ABSTFN):
@@ -620,11 +724,12 @@ class PosixPathTest(unittest.TestCase):
finally:
for i in range(depth + 1):
os_helper.unlink(ABSTFN + '/%d' % i)
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_parents(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_parents(self, kwargs):
# We also need to resolve any symlinks in the parents of a relative
# path passed to realpath. E.g.: current working directory is
# /usr/doc with 'doc' being a symlink to /usr/share/doc. We call
@@ -635,15 +740,17 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "/y", ABSTFN + "/k")
with os_helper.change_cwd(ABSTFN + "/k"):
- self.assertEqual(realpath("a"), ABSTFN + "/y/a")
+ self.assertEqual(realpath("a", **kwargs),
+ ABSTFN + "/y/a")
finally:
os_helper.unlink(ABSTFN + "/k")
- safe_rmdir(ABSTFN + "/y")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/y")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_before_normalizing(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_before_normalizing(self, kwargs):
# Bug #990669: Symbolic links should be resolved before we
# normalize the path. E.g.: if we have directories 'a', 'k' and 'y'
# in the following hierarchy:
@@ -658,20 +765,21 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "/k/y", ABSTFN + "/link-y")
# Absolute path.
- self.assertEqual(realpath(ABSTFN + "/link-y/.."), ABSTFN + "/k")
+ self.assertEqual(realpath(ABSTFN + "/link-y/..", **kwargs), ABSTFN + "/k")
# Relative path.
with os_helper.change_cwd(dirname(ABSTFN)):
- self.assertEqual(realpath(basename(ABSTFN) + "/link-y/.."),
+ self.assertEqual(realpath(basename(ABSTFN) + "/link-y/..", **kwargs),
ABSTFN + "/k")
finally:
os_helper.unlink(ABSTFN + "/link-y")
- safe_rmdir(ABSTFN + "/k/y")
- safe_rmdir(ABSTFN + "/k")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/k/y")
+ os_helper.rmdir(ABSTFN + "/k")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_first(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_first(self, kwargs):
# Bug #1213894: The first component of the path, if not absolute,
# must be resolved too.
@@ -681,12 +789,12 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN, ABSTFN + "link")
with os_helper.change_cwd(dirname(ABSTFN)):
base = basename(ABSTFN)
- self.assertEqual(realpath(base + "link"), ABSTFN)
- self.assertEqual(realpath(base + "link/k"), ABSTFN + "/k")
+ self.assertEqual(realpath(base + "link", **kwargs), ABSTFN)
+ self.assertEqual(realpath(base + "link/k", **kwargs), ABSTFN + "/k")
finally:
os_helper.unlink(ABSTFN + "link")
- safe_rmdir(ABSTFN + "/k")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/k")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -700,27 +808,95 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath(ABSTFN + '/foo'), ABSTFN + '/foo')
self.assertEqual(realpath(ABSTFN + '/../foo'), dirname(ABSTFN) + '/foo')
self.assertEqual(realpath(ABSTFN + '/foo/..'), ABSTFN)
+ finally:
+ os.chmod(ABSTFN, 0o755, follow_symlinks=False)
+ os_helper.unlink(ABSTFN)
+
+ @os_helper.skip_unless_symlink
+ @skip_if_ABSTFN_contains_backslash
+ @unittest.skipIf(os.chmod not in os.supports_follow_symlinks, "Can't set symlink permissions")
+ @unittest.skipIf(sys.platform != "darwin", "only macOS requires read permission to readlink()")
+ @_parameterize({'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_unreadable_symlink_strict(self, kwargs):
+ try:
+ os.symlink(ABSTFN+"1", ABSTFN)
+ os.chmod(ABSTFN, 0o000, follow_symlinks=False)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN, **kwargs)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN + '/foo', **kwargs),
with self.assertRaises(PermissionError):
- realpath(ABSTFN, strict=True)
+ realpath(ABSTFN + '/../foo', **kwargs)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN + '/foo/..', **kwargs)
finally:
os.chmod(ABSTFN, 0o755, follow_symlinks=False)
os.unlink(ABSTFN)
@skip_if_ABSTFN_contains_backslash
+ @os_helper.skip_unless_symlink
+ def test_realpath_unreadable_directory(self):
+ try:
+ os.mkdir(ABSTFN)
+ os.mkdir(ABSTFN + '/k')
+ os.chmod(ABSTFN, 0o000)
+ self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN)
+
+ try:
+ os.stat(ABSTFN)
+ except PermissionError:
+ pass
+ else:
+ self.skipTest('Cannot block permissions')
+
+ self.assertEqual(realpath(ABSTFN + '/k', strict=False),
+ ABSTFN + '/k')
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/k',
+ strict=True)
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/k',
+ strict=ALLOW_MISSING)
+
+ self.assertEqual(realpath(ABSTFN + '/missing', strict=False),
+ ABSTFN + '/missing')
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/missing',
+ strict=True)
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/missing',
+ strict=ALLOW_MISSING)
+ finally:
+ os.chmod(ABSTFN, 0o755)
+ os_helper.rmdir(ABSTFN + '/k')
+ os_helper.rmdir(ABSTFN)
+
+ @skip_if_ABSTFN_contains_backslash
def test_realpath_nonterminal_file(self):
try:
with open(ABSTFN, 'w') as f:
f.write('test_posixpath wuz ere')
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN)
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN)
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN)
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
@@ -733,16 +909,30 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "1", ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN + "1")
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "1")
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN + "1")
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN + "1")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN + "1")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "1/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN + "1")
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -754,16 +944,31 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "1", ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN + "2")
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "2")
+ self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "2")
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN + "2")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN + "2")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "2/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN + "1")
+ os_helper.unlink(ABSTFN + "2")
def test_relpath(self):
(real_getcwd, os.getcwd) = (os.getcwd, lambda: r"/home/user/bar")
@@ -889,8 +1094,8 @@ class PathLikeTests(unittest.TestCase):
path = posixpath
def setUp(self):
- self.file_name = os_helper.TESTFN
- self.file_path = FakePath(os_helper.TESTFN)
+ self.file_name = TESTFN
+ self.file_path = FakePath(TESTFN)
self.addCleanup(os_helper.unlink, self.file_name)
with open(self.file_name, 'xb', 0) as file:
file.write(b"test_posixpath.PathLikeTests")
@@ -947,9 +1152,12 @@ class PathLikeTests(unittest.TestCase):
def test_path_abspath(self):
self.assertPathEqual(self.path.abspath)
- def test_path_realpath(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_path_realpath(self, kwargs):
self.assertPathEqual(self.path.realpath)
+ self.assertPathEqual(partial(self.path.realpath, **kwargs))
+
def test_path_relpath(self):
self.assertPathEqual(self.path.relpath)
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
index dfbc2a06e73..41c337ade7e 100644
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -10,6 +10,10 @@ import random
import re
import types
import unittest
+from collections.abc import ItemsView, KeysView, Mapping, MappingView, ValuesView
+
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
# list, tuple and dict subclasses that do or don't overwrite __repr__
class list2(list):
@@ -67,6 +71,14 @@ class dict_custom_repr(dict):
def __repr__(self):
return '*'*len(dict.__repr__(self))
+class mappingview_custom_repr(MappingView):
+ def __repr__(self):
+ return '*'*len(MappingView.__repr__(self))
+
+class keysview_custom_repr(KeysView):
+ def __repr__(self):
+ return '*'*len(KeysView.__repr__(self))
+
@dataclasses.dataclass
class dataclass1:
field1: str
@@ -129,6 +141,10 @@ class QueryTestCase(unittest.TestCase):
self.b = list(range(200))
self.a[-12] = self.b
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pprint", {"dataclasses", "re"})
+
def test_init(self):
pp = pprint.PrettyPrinter()
pp = pprint.PrettyPrinter(indent=4, width=40, depth=5,
@@ -173,10 +189,17 @@ class QueryTestCase(unittest.TestCase):
# Messy dict.
self.d = {}
self.d[0] = self.d[1] = self.d[2] = self.d
+ self.e = {}
+ self.v = ValuesView(self.e)
+ self.m = MappingView(self.e)
+ self.dv = self.e.values()
+ self.e["v"] = self.v
+ self.e["m"] = self.m
+ self.e["dv"] = self.dv
pp = pprint.PrettyPrinter()
- for icky in self.a, self.b, self.d, (self.d, self.d):
+ for icky in self.a, self.b, self.d, (self.d, self.d), self.e, self.v, self.m, self.dv:
self.assertTrue(pprint.isrecursive(icky), "expected isrecursive")
self.assertFalse(pprint.isreadable(icky), "expected not isreadable")
self.assertTrue(pp.isrecursive(icky), "expected isrecursive")
@@ -184,10 +207,11 @@ class QueryTestCase(unittest.TestCase):
# Break the cycles.
self.d.clear()
+ self.e.clear()
del self.a[:]
del self.b[:]
- for safe in self.a, self.b, self.d, (self.d, self.d):
+ for safe in self.a, self.b, self.d, (self.d, self.d), self.e, self.v, self.m, self.dv:
# module-level convenience functions
self.assertFalse(pprint.isrecursive(safe),
"expected not isrecursive for %r" % (safe,))
@@ -230,6 +254,8 @@ class QueryTestCase(unittest.TestCase):
set(), set2(), set3(),
frozenset(), frozenset2(), frozenset3(),
{}, dict2(), dict3(),
+ {}.keys(), {}.values(), {}.items(),
+ MappingView({}), KeysView({}), ItemsView({}), ValuesView({}),
self.assertTrue, pprint,
-6, -6, -6-6j, -1.5, "x", b"x", bytearray(b"x"),
(3,), [3], {3: 6},
@@ -239,6 +265,9 @@ class QueryTestCase(unittest.TestCase):
set({7}), set2({7}), set3({7}),
frozenset({8}), frozenset2({8}), frozenset3({8}),
dict2({5: 6}), dict3({5: 6}),
+ {5: 6}.keys(), {5: 6}.values(), {5: 6}.items(),
+ MappingView({5: 6}), KeysView({5: 6}),
+ ItemsView({5: 6}), ValuesView({5: 6}),
range(10, -11, -1),
True, False, None, ...,
):
@@ -268,6 +297,12 @@ class QueryTestCase(unittest.TestCase):
dict_custom_repr(),
dict_custom_repr({5: 6}),
dict_custom_repr(zip(range(N),range(N))),
+ mappingview_custom_repr({}),
+ mappingview_custom_repr({5: 6}),
+ mappingview_custom_repr(dict(zip(range(N),range(N)))),
+ keysview_custom_repr({}),
+ keysview_custom_repr({5: 6}),
+ keysview_custom_repr(dict(zip(range(N),range(N)))),
):
native = repr(cont)
expected = '*' * len(native)
@@ -296,6 +331,56 @@ class QueryTestCase(unittest.TestCase):
self.assertEqual(pprint.pformat(type(o)), exp)
o = range(100)
+ exp = 'dict_keys([%s])' % ',\n '.join(map(str, o))
+ keys = dict.fromkeys(o).keys()
+ self.assertEqual(pprint.pformat(keys), exp)
+
+ o = range(100)
+ exp = 'dict_values([%s])' % ',\n '.join(map(str, o))
+ values = {v: v for v in o}.values()
+ self.assertEqual(pprint.pformat(values), exp)
+
+ o = range(100)
+ exp = 'dict_items([%s])' % ',\n '.join("(%s, %s)" % (i, i) for i in o)
+ items = {v: v for v in o}.items()
+ self.assertEqual(pprint.pformat(items), exp)
+
+ o = range(100)
+ exp = 'odict_keys([%s])' % ',\n '.join(map(str, o))
+ keys = collections.OrderedDict.fromkeys(o).keys()
+ self.assertEqual(pprint.pformat(keys), exp)
+
+ o = range(100)
+ exp = 'odict_values([%s])' % ',\n '.join(map(str, o))
+ values = collections.OrderedDict({v: v for v in o}).values()
+ self.assertEqual(pprint.pformat(values), exp)
+
+ o = range(100)
+ exp = 'odict_items([%s])' % ',\n '.join("(%s, %s)" % (i, i) for i in o)
+ items = collections.OrderedDict({v: v for v in o}).items()
+ self.assertEqual(pprint.pformat(items), exp)
+
+ o = range(100)
+ exp = 'KeysView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ keys_view = KeysView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(keys_view), exp)
+
+ o = range(100)
+ exp = 'ItemsView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ items_view = ItemsView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(items_view), exp)
+
+ o = range(100)
+ exp = 'MappingView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ mapping_view = MappingView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(mapping_view), exp)
+
+ o = range(100)
+ exp = 'ValuesView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ values_view = ValuesView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(values_view), exp)
+
+ o = range(100)
exp = '[%s]' % ',\n '.join(map(str, o))
for type in [list, list2]:
self.assertEqual(pprint.pformat(type(o)), exp)
@@ -373,7 +458,7 @@ class QueryTestCase(unittest.TestCase):
return super().__new__(Temperature, celsius_degrees)
def __repr__(self):
kelvin_degrees = self + 273.15
- return f"{kelvin_degrees}°K"
+ return f"{kelvin_degrees:.2f}°K"
self.assertEqual(pprint.pformat(Temperature(1000)), '1273.15°K')
def test_sorted_dict(self):
@@ -418,6 +503,30 @@ OrderedDict([('the', 0),
('a', 6),
('lazy', 7),
('dog', 8)])""")
+ self.assertEqual(pprint.pformat(d.keys(), sort_dicts=False),
+"""\
+odict_keys(['the',
+ 'quick',
+ 'brown',
+ 'fox',
+ 'jumped',
+ 'over',
+ 'a',
+ 'lazy',
+ 'dog'])""")
+ self.assertEqual(pprint.pformat(d.items(), sort_dicts=False),
+"""\
+odict_items([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])""")
+ self.assertEqual(pprint.pformat(d.values(), sort_dicts=False),
+ "odict_values([0, 1, 2, 3, 4, 5, 6, 7, 8])")
def test_mapping_proxy(self):
words = 'the quick brown fox jumped over a lazy dog'.split()
@@ -446,6 +555,152 @@ mappingproxy(OrderedDict([('the', 0),
('lazy', 7),
('dog', 8)]))""")
+ def test_dict_views(self):
+ for dict_class in (dict, collections.OrderedDict, collections.Counter):
+ empty = dict_class({})
+ short = dict_class(dict(zip('edcba', 'edcba')))
+ long = dict_class(dict((chr(x), chr(x)) for x in range(90, 64, -1)))
+ lengths = {"empty": empty, "short": short, "long": long}
+ prefix = "odict" if dict_class is collections.OrderedDict else "dict"
+ for name, d in lengths.items():
+ with self.subTest(length=name, prefix=prefix):
+ is_short = len(d) < 6
+ joiner = ", " if is_short else ",\n "
+ k = d.keys()
+ v = d.values()
+ i = d.items()
+ self.assertEqual(pprint.pformat(k, sort_dicts=True),
+ prefix + "_keys([%s])" %
+ joiner.join(repr(key) for key in sorted(k)))
+ self.assertEqual(pprint.pformat(v, sort_dicts=True),
+ prefix + "_values([%s])" %
+ joiner.join(repr(val) for val in sorted(v)))
+ self.assertEqual(pprint.pformat(i, sort_dicts=True),
+ prefix + "_items([%s])" %
+ joiner.join(repr(item) for item in sorted(i)))
+ self.assertEqual(pprint.pformat(k, sort_dicts=False),
+ prefix + "_keys([%s])" %
+ joiner.join(repr(key) for key in k))
+ self.assertEqual(pprint.pformat(v, sort_dicts=False),
+ prefix + "_values([%s])" %
+ joiner.join(repr(val) for val in v))
+ self.assertEqual(pprint.pformat(i, sort_dicts=False),
+ prefix + "_items([%s])" %
+ joiner.join(repr(item) for item in i))
+
+ def test_abc_views(self):
+ empty = {}
+ short = dict(zip('edcba', 'edcba'))
+ long = dict((chr(x), chr(x)) for x in range(90, 64, -1))
+ lengths = {"empty": empty, "short": short, "long": long}
+ # Test that a subclass that doesn't replace __repr__ works with different lengths
+ class MV(MappingView): pass
+
+ for name, d in lengths.items():
+ with self.subTest(length=name, name="Views"):
+ is_short = len(d) < 6
+ joiner = ", " if is_short else ",\n "
+ i = d.items()
+ s = sorted(i)
+ joined_items = "({%s})" % joiner.join(["%r: %r" % (k, v) for (k, v) in i])
+ sorted_items = "({%s})" % joiner.join(["%r: %r" % (k, v) for (k, v) in s])
+ self.assertEqual(pprint.pformat(KeysView(d), sort_dicts=True),
+ KeysView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(ItemsView(d), sort_dicts=True),
+ ItemsView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(MappingView(d), sort_dicts=True),
+ MappingView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(MV(d), sort_dicts=True),
+ MV.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(ValuesView(d), sort_dicts=True),
+ ValuesView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(KeysView(d), sort_dicts=False),
+ KeysView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(ItemsView(d), sort_dicts=False),
+ ItemsView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(MappingView(d), sort_dicts=False),
+ MappingView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(MV(d), sort_dicts=False),
+ MV.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(ValuesView(d), sort_dicts=False),
+ ValuesView.__name__ + joined_items)
+
+ def test_nested_views(self):
+ d = {1: MappingView({1: MappingView({1: MappingView({1: 2})})})}
+ self.assertEqual(repr(d),
+ "{1: MappingView({1: MappingView({1: MappingView({1: 2})})})}")
+ self.assertEqual(pprint.pformat(d),
+ "{1: MappingView({1: MappingView({1: MappingView({1: 2})})})}")
+ self.assertEqual(pprint.pformat(d, depth=2),
+ "{1: MappingView({1: {...}})}")
+ d = {}
+ d1 = {1: d.values()}
+ d2 = {1: d1.values()}
+ d3 = {1: d2.values()}
+ self.assertEqual(pprint.pformat(d3),
+ "{1: dict_values([dict_values([dict_values([])])])}")
+ self.assertEqual(pprint.pformat(d3, depth=2),
+ "{1: dict_values([{...}])}")
+
+ def test_unorderable_items_views(self):
+ """Check that views with unorderable items have stable sorting."""
+ d = dict((((3+1j), 3), ((1+1j), (1+0j)), (1j, 0j), (500, None), (499, None)))
+ iv = ItemsView(d)
+ self.assertEqual(pprint.pformat(iv),
+ pprint.pformat(iv))
+ self.assertTrue(pprint.pformat(iv).endswith(", 499: None, 500: None})"),
+ pprint.pformat(iv))
+ self.assertEqual(pprint.pformat(d.items()), # Won't be equal unless _safe_tuple
+ pprint.pformat(d.items())) # is used in _safe_repr
+ self.assertTrue(pprint.pformat(d.items()).endswith(", (499, None), (500, None)])"))
+
+ def test_mapping_view_subclass_no_mapping(self):
+ class BMV(MappingView):
+ def __init__(self, d):
+ super().__init__(d)
+ self.mapping = self._mapping
+ del self._mapping
+
+ self.assertRaises(AttributeError, pprint.pformat, BMV({}))
+
+ def test_mapping_subclass_repr(self):
+ """Test that mapping ABC views use their ._mapping's __repr__."""
+ class MyMapping(Mapping):
+ def __init__(self, keys=None):
+ self._keys = {} if keys is None else dict.fromkeys(keys)
+
+ def __getitem__(self, item):
+ return self._keys[item]
+
+ def __len__(self):
+ return len(self._keys)
+
+ def __iter__(self):
+ return iter(self._keys)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}([{', '.join(map(repr, self._keys.keys()))}])"
+
+ m = MyMapping(["test", 1])
+ self.assertEqual(repr(m), "MyMapping(['test', 1])")
+ short_view_repr = "%s(MyMapping(['test', 1]))"
+ self.assertEqual(repr(m.keys()), short_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(m.items()), short_view_repr % "ItemsView")
+ self.assertEqual(pprint.pformat(m.keys()), short_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(MappingView(m)), short_view_repr % "MappingView")
+ self.assertEqual(pprint.pformat(m.values()), short_view_repr % "ValuesView")
+
+ alpha = "abcdefghijklmnopqrstuvwxyz"
+ m = MyMapping(alpha)
+ alpha_repr = ", ".join(map(repr, list(alpha)))
+ long_view_repr = "%%s(MyMapping([%s]))" % alpha_repr
+ self.assertEqual(repr(m), "MyMapping([%s])" % alpha_repr)
+ self.assertEqual(repr(m.keys()), long_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(m.items()), long_view_repr % "ItemsView")
+ self.assertEqual(pprint.pformat(m.keys()), long_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(MappingView(m)), long_view_repr % "MappingView")
+ self.assertEqual(pprint.pformat(m.values()), long_view_repr % "ValuesView")
+
def test_empty_simple_namespace(self):
ns = types.SimpleNamespace()
formatted = pprint.pformat(ns)
@@ -761,6 +1016,10 @@ frozenset2({0,
'frozenset({' + ','.join(map(repr, skeys)) + '})')
self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys))),
'{' + ','.join('%r:None' % k for k in skeys) + '}')
+ self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys).keys())),
+ 'dict_keys([' + ','.join('%r' % k for k in skeys) + '])')
+ self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys).items())),
+ 'dict_items([' + ','.join('(%r,None)' % k for k in skeys) + '])')
# Issue 10017: TypeError on user-defined types as dict keys.
self.assertEqual(pprint.pformat({Unorderable: 0, 1: 0}),
@@ -1042,6 +1301,66 @@ ChainMap({'a': 6,
('a', 6),
('lazy', 7),
('dog', 8)]))""")
+ self.assertEqual(pprint.pformat(d.keys()),
+"""\
+KeysView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
+ self.assertEqual(pprint.pformat(d.items()),
+ """\
+ItemsView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
+ self.assertEqual(pprint.pformat(d.values()),
+ """\
+ValuesView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
def test_deque(self):
d = collections.deque()
@@ -1089,6 +1408,36 @@ deque([('brown', 2),
'over': 5,
'quick': 1,
'the': 0}""")
+ self.assertEqual(pprint.pformat(d.keys()), """\
+KeysView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
+ self.assertEqual(pprint.pformat(d.items()), """\
+ItemsView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
+ self.assertEqual(pprint.pformat(d.values()), """\
+ValuesView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
def test_user_list(self):
d = collections.UserList()
diff --git a/Lib/test/test_property.py b/Lib/test/test_property.py
index cea241b0f20..26aefdbf042 100644
--- a/Lib/test/test_property.py
+++ b/Lib/test/test_property.py
@@ -87,8 +87,8 @@ class PropertyTests(unittest.TestCase):
self.assertEqual(base.spam, 10)
self.assertEqual(base._spam, 10)
delattr(base, "spam")
- self.assertTrue(not hasattr(base, "spam"))
- self.assertTrue(not hasattr(base, "_spam"))
+ self.assertNotHasAttr(base, "spam")
+ self.assertNotHasAttr(base, "_spam")
base.spam = 20
self.assertEqual(base.spam, 20)
self.assertEqual(base._spam, 20)
diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py
index d5a5a9738c2..a26a8c1d522 100644
--- a/Lib/test/test_pstats.py
+++ b/Lib/test/test_pstats.py
@@ -1,6 +1,7 @@
import unittest
from test import support
+from test.support.import_helper import ensure_lazy_imports
from io import StringIO
from pstats import SortKey
from enum import StrEnum, _test_simple_enum
@@ -10,6 +11,12 @@ import pstats
import tempfile
import cProfile
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pstats", {"typing"})
+
+
class AddCallersTestCase(unittest.TestCase):
"""Tests for pstats.add_callers helper."""
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index c1728f5019d..4836f38c388 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -20,7 +20,6 @@ import select
import signal
import socket
import io # readline
-import warnings
TEST_STRING_1 = b"I wish to buy a fish license.\n"
TEST_STRING_2 = b"For my pet fish, Eric.\n"
diff --git a/Lib/test/test_pulldom.py b/Lib/test/test_pulldom.py
index 6dc51e4371d..3c8ed251aca 100644
--- a/Lib/test/test_pulldom.py
+++ b/Lib/test/test_pulldom.py
@@ -46,7 +46,7 @@ class PullDOMTestCase(unittest.TestCase):
items = pulldom.parseString(SMALL_SAMPLE)
evt, node = next(items)
# Just check the node is a Document:
- self.assertTrue(hasattr(node, "createElement"))
+ self.assertHasAttr(node, "createElement")
self.assertEqual(pulldom.START_DOCUMENT, evt)
evt, node = next(items)
self.assertEqual(pulldom.START_ELEMENT, evt)
@@ -192,7 +192,7 @@ class ThoroughTestCase(unittest.TestCase):
evt, node = next(pd)
self.assertEqual(pulldom.START_DOCUMENT, evt)
# Just check the node is a Document:
- self.assertTrue(hasattr(node, "createElement"))
+ self.assertHasAttr(node, "createElement")
if before_root:
evt, node = next(pd)
diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py
index df05cd07d7e..3e7b2cd0dc9 100644
--- a/Lib/test/test_pyclbr.py
+++ b/Lib/test/test_pyclbr.py
@@ -103,7 +103,7 @@ class PyclbrTest(TestCase):
for name, value in dict.items():
if name in ignore:
continue
- self.assertHasAttr(module, name, ignore)
+ self.assertHasAttr(module, name)
py_item = getattr(module, name)
if isinstance(value, pyclbr.Function):
self.assertIsInstance(py_item, (FunctionType, BuiltinFunctionType))
diff --git a/Lib/test/test_pydoc/test_pydoc.py b/Lib/test/test_pydoc/test_pydoc.py
index 8cb253f67ea..d1d6f4987de 100644
--- a/Lib/test/test_pydoc/test_pydoc.py
+++ b/Lib/test/test_pydoc/test_pydoc.py
@@ -553,7 +553,7 @@ class PydocDocTest(unittest.TestCase):
# of the known subclasses of object. (doc.docclass() used to
# fail if HeapType was imported before running this test, like
# when running tests sequentially.)
- from _testcapi import HeapType
+ from _testcapi import HeapType # noqa: F401
except ImportError:
pass
text = doc.docclass(object)
@@ -1380,7 +1380,7 @@ class PydocImportTest(PydocBaseTest):
helper('modules garbage')
result = help_io.getvalue()
- self.assertTrue(result.startswith(expected))
+ self.assertStartsWith(result, expected)
def test_importfile(self):
try:
@@ -1927,18 +1927,28 @@ class PydocFodderTest(unittest.TestCase):
self.assertIn(' | global_func(x, y) from test.test_pydoc.pydocfodder', lines)
self.assertIn(' | global_func_alias = global_func(x, y)', lines)
self.assertIn(' | global_func2_alias = global_func2(x, y) from test.test_pydoc.pydocfodder', lines)
- self.assertIn(' | count(self, value, /) from builtins.list', lines)
- self.assertIn(' | list_count = count(self, value, /)', lines)
- self.assertIn(' | __repr__(self, /) from builtins.object', lines)
- self.assertIn(' | object_repr = __repr__(self, /)', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' | count(self, value, /) from builtins.list', lines)
+ self.assertIn(' | list_count = count(self, value, /)', lines)
+ self.assertIn(' | __repr__(self, /) from builtins.object', lines)
+ self.assertIn(' | object_repr = __repr__(self, /)', lines)
+ else:
+ self.assertIn(' | count(self, object, /) from builtins.list', lines)
+ self.assertIn(' | list_count = count(self, object, /)', lines)
+ self.assertIn(' | __repr__(...) from builtins.object', lines)
+ self.assertIn(' | object_repr = __repr__(...)', lines)
lines = self.getsection(result, f' | Static methods {where}:', ' | ' + '-'*70)
self.assertIn(' | A_classmethod_ref = A_classmethod(x) class method of test.test_pydoc.pydocfodder.A', lines)
note = '' if cls is pydocfodder.B else ' class method of test.test_pydoc.pydocfodder.B'
self.assertIn(' | B_classmethod_ref = B_classmethod(x)' + note, lines)
self.assertIn(' | A_method_ref = A_method() method of test.test_pydoc.pydocfodder.A instance', lines)
- self.assertIn(' | get(key, default=None, /) method of builtins.dict instance', lines)
- self.assertIn(' | dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' | get(key, default=None, /) method of builtins.dict instance', lines)
+ self.assertIn(' | dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ else:
+ self.assertIn(' | get(...) method of builtins.dict instance', lines)
+ self.assertIn(' | dict_get = get(...) method of builtins.dict instance', lines)
lines = self.getsection(result, f' | Class methods {where}:', ' | ' + '-'*70)
self.assertIn(' | B_classmethod(x)', lines)
@@ -1957,10 +1967,16 @@ class PydocFodderTest(unittest.TestCase):
self.assertIn('global_func(x, y) from test.test_pydoc.pydocfodder', lines)
self.assertIn('global_func_alias = global_func(x, y)', lines)
self.assertIn('global_func2_alias = global_func2(x, y) from test.test_pydoc.pydocfodder', lines)
- self.assertIn('count(self, value, /) from builtins.list', lines)
- self.assertIn('list_count = count(self, value, /)', lines)
- self.assertIn('__repr__(self, /) from builtins.object', lines)
- self.assertIn('object_repr = __repr__(self, /)', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn('count(self, value, /) from builtins.list', lines)
+ self.assertIn('list_count = count(self, value, /)', lines)
+ self.assertIn('__repr__(self, /) from builtins.object', lines)
+ self.assertIn('object_repr = __repr__(self, /)', lines)
+ else:
+ self.assertIn('count(self, object, /) from builtins.list', lines)
+ self.assertIn('list_count = count(self, object, /)', lines)
+ self.assertIn('__repr__(...) from builtins.object', lines)
+ self.assertIn('object_repr = __repr__(...)', lines)
lines = self.getsection(result, f'Static methods {where}:', '-'*70)
self.assertIn('A_classmethod_ref = A_classmethod(x) class method of test.test_pydoc.pydocfodder.A', lines)
@@ -1997,15 +2013,27 @@ class PydocFodderTest(unittest.TestCase):
self.assertIn(' A_method3 = A_method() method of B instance', lines)
self.assertIn(' A_staticmethod_ref = A_staticmethod(x, y)', lines)
self.assertIn(' A_staticmethod_ref2 = A_staticmethod(y) method of B instance', lines)
- self.assertIn(' get(key, default=None, /) method of builtins.dict instance', lines)
- self.assertIn(' dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' get(key, default=None, /) method of builtins.dict instance', lines)
+ self.assertIn(' dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ else:
+ self.assertIn(' get(...) method of builtins.dict instance', lines)
+ self.assertIn(' dict_get = get(...) method of builtins.dict instance', lines)
+
# unbound methods
self.assertIn(' B_method(self)', lines)
self.assertIn(' B_method2 = B_method(self)', lines)
- self.assertIn(' count(self, value, /) unbound builtins.list method', lines)
- self.assertIn(' list_count = count(self, value, /) unbound builtins.list method', lines)
- self.assertIn(' __repr__(self, /) unbound builtins.object method', lines)
- self.assertIn(' object_repr = __repr__(self, /) unbound builtins.object method', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' count(self, value, /) unbound builtins.list method', lines)
+ self.assertIn(' list_count = count(self, value, /) unbound builtins.list method', lines)
+ self.assertIn(' __repr__(self, /) unbound builtins.object method', lines)
+ self.assertIn(' object_repr = __repr__(self, /) unbound builtins.object method', lines)
+ else:
+ self.assertIn(' count(self, object, /) unbound builtins.list method', lines)
+ self.assertIn(' list_count = count(self, object, /) unbound builtins.list method', lines)
+ self.assertIn(' __repr__(...) unbound builtins.object method', lines)
+ self.assertIn(' object_repr = __repr__(...) unbound builtins.object method', lines)
+
def test_html_doc_routines_in_module(self):
doc = pydoc.HTMLDoc()
@@ -2026,15 +2054,25 @@ class PydocFodderTest(unittest.TestCase):
self.assertIn(' A_method3 = A_method() method of B instance', lines)
self.assertIn(' A_staticmethod_ref = A_staticmethod(x, y)', lines)
self.assertIn(' A_staticmethod_ref2 = A_staticmethod(y) method of B instance', lines)
- self.assertIn(' get(key, default=None, /) method of builtins.dict instance', lines)
- self.assertIn(' dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' get(key, default=None, /) method of builtins.dict instance', lines)
+ self.assertIn(' dict_get = get(key, default=None, /) method of builtins.dict instance', lines)
+ else:
+ self.assertIn(' get(...) method of builtins.dict instance', lines)
+ self.assertIn(' dict_get = get(...) method of builtins.dict instance', lines)
# unbound methods
self.assertIn(' B_method(self)', lines)
self.assertIn(' B_method2 = B_method(self)', lines)
- self.assertIn(' count(self, value, /) unbound builtins.list method', lines)
- self.assertIn(' list_count = count(self, value, /) unbound builtins.list method', lines)
- self.assertIn(' __repr__(self, /) unbound builtins.object method', lines)
- self.assertIn(' object_repr = __repr__(self, /) unbound builtins.object method', lines)
+ if not support.MISSING_C_DOCSTRINGS:
+ self.assertIn(' count(self, value, /) unbound builtins.list method', lines)
+ self.assertIn(' list_count = count(self, value, /) unbound builtins.list method', lines)
+ self.assertIn(' __repr__(self, /) unbound builtins.object method', lines)
+ self.assertIn(' object_repr = __repr__(self, /) unbound builtins.object method', lines)
+ else:
+ self.assertIn(' count(self, object, /) unbound builtins.list method', lines)
+ self.assertIn(' list_count = count(self, object, /) unbound builtins.list method', lines)
+ self.assertIn(' __repr__(...) unbound builtins.object method', lines)
+ self.assertIn(' object_repr = __repr__(...) unbound builtins.object method', lines)
@unittest.skipIf(
diff --git a/Lib/test/test_pyrepl/support.py b/Lib/test/test_pyrepl/support.py
index 3692e164cb9..4f7f9d77933 100644
--- a/Lib/test/test_pyrepl/support.py
+++ b/Lib/test/test_pyrepl/support.py
@@ -113,9 +113,6 @@ handle_events_narrow_console = partial(
prepare_console=partial(prepare_console, width=10),
)
-reader_no_colors = partial(prepare_reader, can_colorize=False)
-reader_force_colors = partial(prepare_reader, can_colorize=True)
-
class FakeConsole(Console):
def __init__(self, events, encoding="utf-8") -> None:
diff --git a/Lib/test/test_pyrepl/test_eventqueue.py b/Lib/test/test_pyrepl/test_eventqueue.py
index afb55710342..edfe6ac4748 100644
--- a/Lib/test/test_pyrepl/test_eventqueue.py
+++ b/Lib/test/test_pyrepl/test_eventqueue.py
@@ -53,7 +53,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": "b"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "b")
@@ -63,7 +63,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"c": "d"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "a")
@@ -73,13 +73,13 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "c")
- eq.push("d")
+ eq.push(b"d")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "d")
@@ -88,32 +88,32 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
eq.flush_buf()
- eq.push("\033")
+ eq.push(b"\033")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\033")
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "b")
def test_push_special_key(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("A")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"A")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
def test_push_unrecognized_escape_sequence(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("Z")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"Z")
self.assertEqual(len(eq.events), 3)
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
@@ -122,12 +122,54 @@ class EventQueueTestBase:
self.assertEqual(eq.events[2].evt, "key")
self.assertEqual(eq.events[2].data, "Z")
- def test_push_unicode_character(self):
+ def test_push_unicode_character_as_str(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("ч")
- self.assertEqual(eq.events[0].evt, "key")
- self.assertEqual(eq.events[0].data, "ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ñ")
+
+ def test_push_unicode_character_two_bytes(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ encoded = "ч".encode(eq.encoding, "replace")
+ self.assertEqual(len(encoded), 2)
+
+ eq.push(encoded[0])
+ e = eq.get()
+ self.assertIsNone(e)
+
+ eq.push(encoded[1])
+ e = eq.get()
+ self.assertEqual(e.evt, "key")
+ self.assertEqual(e.data, "ч")
+
+ def test_push_single_chars_and_unicode_character_as_str(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ def _event(evt, data, raw=None):
+ r = raw if raw is not None else data.encode(eq.encoding)
+ e = Event(evt, data, r)
+ return e
+
+ def _push(keys):
+ for k in keys:
+ eq.push(k)
+
+ self.assertIsInstance("ñ", str)
+
+ # If an exception happens during push, the existing events must be
+ # preserved and we can continue to push.
+ _push(b"b")
+ with self.assertRaises(AssertionError):
+ _push("ñ")
+ _push(b"a")
+
+ self.assertEqual(eq.get(), _event("key", "b"))
+ self.assertEqual(eq.get(), _event("key", "a"))
@unittest.skipIf(support.MS_WINDOWS, "No Unix event queue on Windows")
diff --git a/Lib/test/test_pyrepl/test_interact.py b/Lib/test/test_pyrepl/test_interact.py
index a20719033fc..8c0eeab6dca 100644
--- a/Lib/test/test_pyrepl/test_interact.py
+++ b/Lib/test/test_pyrepl/test_interact.py
@@ -113,7 +113,7 @@ class TestSimpleInteract(unittest.TestCase):
r = """
def f(x, x): ...
^
-SyntaxError: duplicate argument 'x' in function definition"""
+SyntaxError: duplicate parameter 'x' in function definition"""
self.assertIn(r, f.getvalue())
def test_runsource_shows_syntax_error_for_failed_compilation(self):
diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py
index 75a5afad562..98bae7dd703 100644
--- a/Lib/test/test_pyrepl/test_pyrepl.py
+++ b/Lib/test/test_pyrepl/test_pyrepl.py
@@ -8,9 +8,10 @@ import select
import subprocess
import sys
import tempfile
+from pkgutil import ModuleInfo
from unittest import TestCase, skipUnless, skipIf
from unittest.mock import patch
-from test.support import force_not_colorized, make_clean_env
+from test.support import force_not_colorized, make_clean_env, Py_DEBUG
from test.support import SHORT_TIMEOUT, STDLIB_DIR
from test.support.import_helper import import_module
from test.support.os_helper import EnvironmentVarGuard, unlink
@@ -45,6 +46,7 @@ class ReplTestCase(TestCase):
cmdline_args: list[str] | None = None,
cwd: str | None = None,
skip: bool = False,
+ timeout: float = SHORT_TIMEOUT,
) -> tuple[str, int]:
temp_dir = None
if cwd is None:
@@ -52,7 +54,12 @@ class ReplTestCase(TestCase):
cwd = temp_dir.name
try:
return self._run_repl(
- repl_input, env=env, cmdline_args=cmdline_args, cwd=cwd, skip=skip,
+ repl_input,
+ env=env,
+ cmdline_args=cmdline_args,
+ cwd=cwd,
+ skip=skip,
+ timeout=timeout,
)
finally:
if temp_dir is not None:
@@ -66,6 +73,7 @@ class ReplTestCase(TestCase):
cmdline_args: list[str] | None,
cwd: str,
skip: bool,
+ timeout: float,
) -> tuple[str, int]:
assert pty
master_fd, slave_fd = pty.openpty()
@@ -103,7 +111,7 @@ class ReplTestCase(TestCase):
os.write(master_fd, repl_input.encode("utf-8"))
output = []
- while select.select([master_fd], [], [], SHORT_TIMEOUT)[0]:
+ while select.select([master_fd], [], [], timeout)[0]:
try:
data = os.read(master_fd, 1024).decode("utf-8")
if not data:
@@ -114,12 +122,12 @@ class ReplTestCase(TestCase):
else:
os.close(master_fd)
process.kill()
- process.wait(timeout=SHORT_TIMEOUT)
+ process.wait(timeout=timeout)
self.fail(f"Timeout while waiting for output, got: {''.join(output)}")
os.close(master_fd)
try:
- exit_code = process.wait(timeout=SHORT_TIMEOUT)
+ exit_code = process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
exit_code = process.wait()
@@ -445,6 +453,11 @@ class TestPyReplAutoindent(TestCase):
)
# fmt: on
+ events = code_to_events(input_code)
+ reader = self.prepare_reader(events)
+ output = multiline_input(reader)
+ self.assertEqual(output, output_code)
+
def test_auto_indent_continuation(self):
# auto indenting according to previous user indentation
# fmt: off
@@ -905,7 +918,14 @@ class TestPyReplCompleter(TestCase):
class TestPyReplModuleCompleter(TestCase):
def setUp(self):
+ import importlib
+ # Make iter_modules() search only the standard library.
+ # This makes the test more reliable in case there are
+ # other user packages/scripts on PYTHONPATH which can
+ # interfere with the completions.
+ lib_path = os.path.dirname(importlib.__path__[0])
self._saved_sys_path = sys.path
+ sys.path = [lib_path]
def tearDown(self):
sys.path = self._saved_sys_path
@@ -913,19 +933,12 @@ class TestPyReplModuleCompleter(TestCase):
def prepare_reader(self, events, namespace):
console = FakeConsole(events)
config = ReadlineConfig()
+ config.module_completer = ModuleCompleter(namespace)
config.readline_completer = rlcompleter.Completer(namespace).complete
reader = ReadlineAlikeReader(console=console, config=config)
return reader
def test_import_completions(self):
- import importlib
- # Make iter_modules() search only the standard library.
- # This makes the test more reliable in case there are
- # other user packages/scripts on PYTHONPATH which can
- # intefere with the completions.
- lib_path = os.path.dirname(importlib.__path__[0])
- sys.path = [lib_path]
-
cases = (
("import path\t\n", "import pathlib"),
("import importlib.\t\tres\t\n", "import importlib.resources"),
@@ -947,10 +960,17 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
- def test_relative_import_completions(self):
+ @patch("pkgutil.iter_modules", lambda: [ModuleInfo(None, "public", True),
+ ModuleInfo(None, "_private", True)])
+ @patch("sys.builtin_module_names", ())
+ def test_private_completions(self):
cases = (
- ("from .readl\t\n", "from .readline"),
- ("from . import readl\t\n", "from . import readline"),
+ # Return public methods by default
+ ("import \t\n", "import public"),
+ ("from \t\n", "from public"),
+ # Return private methods if explicitly specified
+ ("import _\t\n", "import _private"),
+ ("from _\t\n", "from _private"),
)
for code, expected in cases:
with self.subTest(code=code):
@@ -959,8 +979,63 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
- @patch("pkgutil.iter_modules", lambda: [(None, 'valid_name', None),
- (None, 'invalid-name', None)])
+ @patch(
+ "_pyrepl._module_completer.ModuleCompleter.iter_submodules",
+ lambda *_: [
+ ModuleInfo(None, "public", True),
+ ModuleInfo(None, "_private", True),
+ ],
+ )
+ def test_sub_module_private_completions(self):
+ cases = (
+ # Return public methods by default
+ ("from foo import \t\n", "from foo import public"),
+ # Return private methods if explicitly specified
+ ("from foo import _\t\n", "from foo import _private"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ def test_builtin_completion_top_level(self):
+ import importlib
+ # Make iter_modules() search only the standard library.
+ # This makes the test more reliable in case there are
+ # other user packages/scripts on PYTHONPATH which can
+ # intefere with the completions.
+ lib_path = os.path.dirname(importlib.__path__[0])
+ sys.path = [lib_path]
+
+ cases = (
+ ("import bui\t\n", "import builtins"),
+ ("from bui\t\n", "from builtins"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ def test_relative_import_completions(self):
+ cases = (
+ (None, "from .readl\t\n", "from .readl"),
+ (None, "from . import readl\t\n", "from . import readl"),
+ ("_pyrepl", "from .readl\t\n", "from .readline"),
+ ("_pyrepl", "from . import readl\t\n", "from . import readline"),
+ )
+ for package, code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={"__package__": package})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ @patch("pkgutil.iter_modules", lambda: [ModuleInfo(None, "valid_name", True),
+ ModuleInfo(None, "invalid-name", True)])
def test_invalid_identifiers(self):
# Make sure modules which are not valid identifiers
# are not suggested as those cannot be imported via 'import'.
@@ -976,6 +1051,19 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
+ def test_no_fallback_on_regular_completion(self):
+ cases = (
+ ("import pri\t\n", "import pri"),
+ ("from pri\t\n", "from pri"),
+ ("from typing import Na\t\n", "from typing import Na"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
def test_get_path_and_prefix(self):
cases = (
('', ('', '')),
@@ -1043,11 +1131,15 @@ class TestPyReplModuleCompleter(TestCase):
self.assertEqual(actual, parsed)
# The parser should not get tripped up by any
# other preceding statements
- code = f'import xyz\n{code}'
- with self.subTest(code=code):
+ _code = f'import xyz\n{code}'
+ parser = ImportParser(_code)
+ actual = parser.parse()
+ with self.subTest(code=_code):
self.assertEqual(actual, parsed)
- code = f'import xyz;{code}'
- with self.subTest(code=code):
+ _code = f'import xyz;{code}'
+ parser = ImportParser(_code)
+ actual = parser.parse()
+ with self.subTest(code=_code):
self.assertEqual(actual, parsed)
def test_parse_error(self):
@@ -1320,7 +1412,7 @@ class TestMain(ReplTestCase):
)
@force_not_colorized
- def _run_repl_globals_test(self, expectations, *, as_file=False, as_module=False):
+ def _run_repl_globals_test(self, expectations, *, as_file=False, as_module=False, pythonstartup=False):
clean_env = make_clean_env()
clean_env["NO_COLOR"] = "1" # force_not_colorized doesn't touch subprocesses
@@ -1329,9 +1421,13 @@ class TestMain(ReplTestCase):
blue.mkdir()
mod = blue / "calx.py"
mod.write_text("FOO = 42", encoding="utf-8")
+ startup = blue / "startup.py"
+ startup.write_text("BAR = 64", encoding="utf-8")
commands = [
"print(f'^{" + var + "=}')" for var in expectations
] + ["exit()"]
+ if pythonstartup:
+ clean_env["PYTHONSTARTUP"] = str(startup)
if as_file and as_module:
self.fail("as_file and as_module are mutually exclusive")
elif as_file:
@@ -1350,7 +1446,13 @@ class TestMain(ReplTestCase):
skip=True,
)
else:
- self.fail("Choose one of as_file or as_module")
+ output, exit_code = self.run_repl(
+ commands,
+ cmdline_args=[],
+ env=clean_env,
+ cwd=td,
+ skip=True,
+ )
self.assertEqual(exit_code, 0)
for var, expected in expectations.items():
@@ -1363,6 +1465,23 @@ class TestMain(ReplTestCase):
self.assertNotIn("Exception", output)
self.assertNotIn("Traceback", output)
+ def test_globals_initialized_as_default(self):
+ expectations = {
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations)
+
+ def test_globals_initialized_from_pythonstartup(self):
+ expectations = {
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations, pythonstartup=True)
+
def test_inspect_keeps_globals_from_inspected_file(self):
expectations = {
"FOO": "42",
@@ -1372,6 +1491,16 @@ class TestMain(ReplTestCase):
}
self._run_repl_globals_test(expectations, as_file=True)
+ def test_inspect_keeps_globals_from_inspected_file_with_pythonstartup(self):
+ expectations = {
+ "FOO": "42",
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations, as_file=True, pythonstartup=True)
+
def test_inspect_keeps_globals_from_inspected_module(self):
expectations = {
"FOO": "42",
@@ -1381,26 +1510,32 @@ class TestMain(ReplTestCase):
}
self._run_repl_globals_test(expectations, as_module=True)
+ def test_inspect_keeps_globals_from_inspected_module_with_pythonstartup(self):
+ expectations = {
+ "FOO": "42",
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "'blue'",
+ "__file__": re.compile(r"^'.*calx.py'$"),
+ }
+ self._run_repl_globals_test(expectations, as_module=True, pythonstartup=True)
+
@force_not_colorized
def test_python_basic_repl(self):
env = os.environ.copy()
- commands = ("from test.support import initialized_with_pyrepl\n"
- "initialized_with_pyrepl()\n"
- "exit()\n")
-
+ pyrepl_commands = "clear\nexit()\n"
env.pop("PYTHON_BASIC_REPL", None)
- output, exit_code = self.run_repl(commands, env=env, skip=True)
+ output, exit_code = self.run_repl(pyrepl_commands, env=env, skip=True)
self.assertEqual(exit_code, 0)
- self.assertIn("True", output)
- self.assertNotIn("False", output)
self.assertNotIn("Exception", output)
+ self.assertNotIn("NameError", output)
self.assertNotIn("Traceback", output)
+ basic_commands = "help\nexit()\n"
env["PYTHON_BASIC_REPL"] = "1"
- output, exit_code = self.run_repl(commands, env=env)
+ output, exit_code = self.run_repl(basic_commands, env=env)
self.assertEqual(exit_code, 0)
- self.assertIn("False", output)
- self.assertNotIn("True", output)
+ self.assertIn("Type help() for interactive help", output)
self.assertNotIn("Exception", output)
self.assertNotIn("Traceback", output)
@@ -1537,6 +1672,17 @@ class TestMain(ReplTestCase):
self.assertEqual(exit_code, 0)
self.assertNotIn("TypeError", output)
+ @force_not_colorized
+ def test_non_string_suggestion_candidates(self):
+ commands = ("import runpy\n"
+ "runpy._run_module_code('blech', {0: '', 'bluch': ''}, '')\n"
+ "exit()\n")
+
+ output, exit_code = self.run_repl(commands)
+ self.assertEqual(exit_code, 0)
+ self.assertNotIn("all elements in 'candidates' must be strings", output)
+ self.assertIn("bluch", output)
+
def test_readline_history_file(self):
# skip, if readline module is not available
readline = import_module('readline')
@@ -1561,25 +1707,29 @@ class TestMain(ReplTestCase):
def test_history_survive_crash(self):
env = os.environ.copy()
- commands = "1\nexit()\n"
- output, exit_code = self.run_repl(commands, env=env, skip=True)
with tempfile.NamedTemporaryFile() as hfile:
env["PYTHON_HISTORY"] = hfile.name
- commands = "spam\nimport time\ntime.sleep(1000)\npreved\n"
+
+ commands = "1\n2\n3\nexit()\n"
+ output, exit_code = self.run_repl(commands, env=env, skip=True)
+
+ commands = "spam\nimport time\ntime.sleep(1000)\nquit\n"
try:
- self.run_repl(commands, env=env)
+ self.run_repl(commands, env=env, timeout=3)
except AssertionError:
pass
history = pathlib.Path(hfile.name).read_text()
+ self.assertIn("2", history)
+ self.assertIn("exit()", history)
self.assertIn("spam", history)
- self.assertIn("time", history)
+ self.assertIn("import time", history)
self.assertNotIn("sleep", history)
- self.assertNotIn("preved", history)
+ self.assertNotIn("quit", history)
def test_keyboard_interrupt_after_isearch(self):
- output, exit_code = self.run_repl(["\x12", "\x03", "exit"])
+ output, exit_code = self.run_repl("\x12\x03exit\n")
self.assertEqual(exit_code, 0)
def test_prompt_after_help(self):
@@ -1594,3 +1744,16 @@ class TestMain(ReplTestCase):
# Extra stuff (newline and `exit` rewrites) are necessary
# because of how run_repl works.
self.assertNotIn(">>> \n>>> >>>", cleaned_output)
+
+ @skipUnless(Py_DEBUG, '-X showrefcount requires a Python debug build')
+ def test_showrefcount(self):
+ env = os.environ.copy()
+ env.pop("PYTHON_BASIC_REPL", "")
+ output, _ = self.run_repl("1\n1+2\nexit()\n", cmdline_args=['-Xshowrefcount'], env=env)
+ matches = re.findall(r'\[-?\d+ refs, \d+ blocks\]', output)
+ self.assertEqual(len(matches), 3)
+
+ env["PYTHON_BASIC_REPL"] = "1"
+ output, _ = self.run_repl("1\n1+2\nexit()\n", cmdline_args=['-Xshowrefcount'], env=env)
+ matches = re.findall(r'\[-?\d+ refs, \d+ blocks\]', output)
+ self.assertEqual(len(matches), 3)
diff --git a/Lib/test/test_pyrepl/test_reader.py b/Lib/test/test_pyrepl/test_reader.py
index 109cb603ae8..1f655264f1c 100644
--- a/Lib/test/test_pyrepl/test_reader.py
+++ b/Lib/test/test_pyrepl/test_reader.py
@@ -1,16 +1,24 @@
import itertools
import functools
import rlcompleter
+from textwrap import dedent
from unittest import TestCase
from unittest.mock import MagicMock
+from test.support import force_colorized_test_class, force_not_colorized_test_class
from .support import handle_all_events, handle_events_narrow_console
from .support import ScreenEqualMixin, code_to_events
from .support import prepare_reader, prepare_console
from _pyrepl.console import Event
from _pyrepl.reader import Reader
+from _colorize import default_theme
+overrides = {"reset": "z", "soft_keyword": "K"}
+colors = {overrides.get(k, k[0].lower()): v for k, v in default_theme.syntax.items()}
+
+
+@force_not_colorized_test_class
class TestReader(ScreenEqualMixin, TestCase):
def test_calc_screen_wrap_simple(self):
events = code_to_events(10 * "a")
@@ -120,12 +128,6 @@ class TestReader(ScreenEqualMixin, TestCase):
reader.setpos_from_xy(0, 0)
self.assertEqual(reader.pos, 0)
- def test_control_characters(self):
- code = 'flag = "🏳️‍🌈"'
- events = code_to_events(code)
- reader, _ = handle_all_events(events)
- self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
-
def test_setpos_from_xy_multiple_lines(self):
# fmt: off
code = (
@@ -355,3 +357,200 @@ class TestReader(ScreenEqualMixin, TestCase):
reader, _ = handle_all_events(events)
reader.setpos_from_xy(8, 0)
self.assertEqual(reader.pos, 7)
+
+@force_colorized_test_class
+class TestReaderInColor(ScreenEqualMixin, TestCase):
+ def test_syntax_highlighting_basic(self):
+ code = dedent(
+ """\
+ import re, sys
+ def funct(case: str = sys.platform) -> None:
+ match = re.search(
+ "(me)",
+ '''
+ Come on
+ Come on now
+ You know that it's time to emerge
+ ''',
+ )
+ match case:
+ case "emscripten": print("on the web")
+ case "ios" | "android": print("on the phone")
+ case _: print('arms around', match.group(1))
+ """
+ )
+ expected = dedent(
+ """\
+ {k}import{z} re{o},{z} sys
+ {a}{k}def{z} {d}funct{z}{o}({z}case{o}:{z} {b}str{z} {o}={z} sys{o}.{z}platform{o}){z} {o}->{z} {k}None{z}{o}:{z}
+ match {o}={z} re{o}.{z}search{o}({z}
+ {s}"(me)"{z}{o},{z}
+ {s}'''{z}
+ {s} Come on{z}
+ {s} Come on now{z}
+ {s} You know that it's time to emerge{z}
+ {s} '''{z}{o},{z}
+ {o}){z}
+ {K}match{z} case{o}:{z}
+ {K}case{z} {s}"emscripten"{z}{o}:{z} {b}print{z}{o}({z}{s}"on the web"{z}{o}){z}
+ {K}case{z} {s}"ios"{z} {o}|{z} {s}"android"{z}{o}:{z} {b}print{z}{o}({z}{s}"on the phone"{z}{o}){z}
+ {K}case{z} {K}_{z}{o}:{z} {b}print{z}{o}({z}{s}'arms around'{z}{o},{z} match{o}.{z}group{o}({z}{n}1{z}{o}){z}{o}){z}
+ """
+ )
+ expected_sync = expected.format(a="", **colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected_sync)
+ self.assertEqual(reader.pos, 2**7 + 2**8)
+ self.assertEqual(reader.cxy, (0, 14))
+
+ async_msg = "{k}async{z} ".format(**colors)
+ expected_async = expected.format(a=async_msg, **colors)
+ more_events = itertools.chain(
+ code_to_events(code),
+ [Event(evt="key", data="up", raw=bytearray(b"\x1bOA"))] * 13,
+ code_to_events("async "),
+ )
+ reader, _ = handle_all_events(more_events)
+ self.assert_screen_equal(reader, expected_async)
+ self.assertEqual(reader.pos, 21)
+ self.assertEqual(reader.cxy, (6, 1))
+
+ def test_syntax_highlighting_incomplete_string_first_line(self):
+ code = dedent(
+ """\
+ def unfinished_function(arg: str = "still typing
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}arg{o}:{z} {b}str{z} {o}={z} {s}"still typing{z}
+ """
+ ).format(**colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_incomplete_string_another_line(self):
+ code = dedent(
+ """\
+ def unfinished_function(
+ arg: str = "still typing
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}
+ arg{o}:{z} {b}str{z} {o}={z} {s}"still typing{z}
+ """
+ ).format(**colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_incomplete_multiline_string(self):
+ code = dedent(
+ """\
+ def unfinished_function():
+ '''Still writing
+ the docstring
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}{o}){z}{o}:{z}
+ {s}'''Still writing{z}
+ {s} the docstring{z}
+ """
+ ).format(**colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_incomplete_fstring(self):
+ code = dedent(
+ """\
+ def unfinished_function():
+ var = f"Single-quote but {
+ 1
+ +
+ 1
+ } multi-line!
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}{o}){z}{o}:{z}
+ var {o}={z} {s}f"{z}{s}Single-quote but {z}{o}{OB}{z}
+ {n}1{z}
+ {o}+{z}
+ {n}1{z}
+ {o}{CB}{z}{s} multi-line!{z}
+ """
+ ).format(OB="{", CB="}", **colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_indentation_error(self):
+ code = dedent(
+ """\
+ def unfinished_function():
+ var = 1
+ oops
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}{o}){z}{o}:{z}
+ var {o}={z} {n}1{z}
+ oops
+ """
+ ).format(**colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_literal_brace_in_fstring_or_tstring(self):
+ code = dedent(
+ """\
+ f"{{"
+ f"}}"
+ f"a{{b"
+ f"a}}b"
+ f"a{{b}}c"
+ t"a{{b}}c"
+ f"{{{0}}}"
+ f"{ {0} }"
+ """
+ )
+ expected = dedent(
+ """\
+ {s}f"{z}{s}<<{z}{s}"{z}
+ {s}f"{z}{s}>>{z}{s}"{z}
+ {s}f"{z}{s}a<<{z}{s}b{z}{s}"{z}
+ {s}f"{z}{s}a>>{z}{s}b{z}{s}"{z}
+ {s}f"{z}{s}a<<{z}{s}b>>{z}{s}c{z}{s}"{z}
+ {s}t"{z}{s}a<<{z}{s}b>>{z}{s}c{z}{s}"{z}
+ {s}f"{z}{s}<<{z}{o}<{z}{n}0{z}{o}>{z}{s}>>{z}{s}"{z}
+ {s}f"{z}{o}<{z} {o}<{z}{n}0{z}{o}>{z} {o}>{z}{s}"{z}
+ """
+ ).format(**colors).replace("<", "{").replace(">", "}")
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.maxDiff=None
+ self.assert_screen_equal(reader, expected)
+
+ def test_control_characters(self):
+ code = 'flag = "🏳️‍🌈"'
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
+ self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors))
diff --git a/Lib/test/test_pyrepl/test_unix_console.py b/Lib/test/test_pyrepl/test_unix_console.py
index 2f5c150402b..b3f7dc028fe 100644
--- a/Lib/test/test_pyrepl/test_unix_console.py
+++ b/Lib/test/test_pyrepl/test_unix_console.py
@@ -3,11 +3,12 @@ import os
import sys
import unittest
from functools import partial
-from test.support import os_helper
+from test.support import os_helper, force_not_colorized_test_class
+
from unittest import TestCase
from unittest.mock import MagicMock, call, patch, ANY
-from .support import handle_all_events, code_to_events, reader_no_colors
+from .support import handle_all_events, code_to_events
try:
from _pyrepl.console import Event
@@ -19,6 +20,7 @@ except ImportError:
def unix_console(events, **kwargs):
console = UnixConsole()
console.get_event = MagicMock(side_effect=events)
+ console.getpending = MagicMock(return_value=Event("key", ""))
height = kwargs.get("height", 25)
width = kwargs.get("width", 80)
@@ -33,7 +35,7 @@ def unix_console(events, **kwargs):
handle_events_unix_console = partial(
handle_all_events,
- prepare_console=partial(unix_console),
+ prepare_console=unix_console,
)
handle_events_narrow_unix_console = partial(
handle_all_events,
@@ -118,6 +120,7 @@ TERM_CAPABILITIES = {
)
@patch("termios.tcsetattr", lambda a, b, c: None)
@patch("os.write")
+@force_not_colorized_test_class
class TestConsole(TestCase):
def test_simple_addition(self, _os_write):
code = "12+34"
@@ -253,9 +256,7 @@ class TestConsole(TestCase):
# fmt: on
events = itertools.chain(code_to_events(code))
- reader, console = handle_events_short_unix_console(
- events, prepare_reader=reader_no_colors
- )
+ reader, console = handle_events_short_unix_console(events)
console.height = 2
console.getheightwidth = MagicMock(lambda _: (2, 80))
diff --git a/Lib/test/test_pyrepl/test_utils.py b/Lib/test/test_pyrepl/test_utils.py
index 0d59968206a..8ce1e537138 100644
--- a/Lib/test/test_pyrepl/test_utils.py
+++ b/Lib/test/test_pyrepl/test_utils.py
@@ -1,6 +1,6 @@
from unittest import TestCase
-from _pyrepl.utils import str_width, wlen
+from _pyrepl.utils import str_width, wlen, prev_next_window
class TestUtils(TestCase):
@@ -25,3 +25,38 @@ class TestUtils(TestCase):
self.assertEqual(wlen('hello'), 5)
self.assertEqual(wlen('hello' + '\x1a'), 7)
+
+ def test_prev_next_window(self):
+ def gen_normal():
+ yield 1
+ yield 2
+ yield 3
+ yield 4
+
+ pnw = prev_next_window(gen_normal())
+ self.assertEqual(next(pnw), (None, 1, 2))
+ self.assertEqual(next(pnw), (1, 2, 3))
+ self.assertEqual(next(pnw), (2, 3, 4))
+ self.assertEqual(next(pnw), (3, 4, None))
+ with self.assertRaises(StopIteration):
+ next(pnw)
+
+ def gen_short():
+ yield 1
+
+ pnw = prev_next_window(gen_short())
+ self.assertEqual(next(pnw), (None, 1, None))
+ with self.assertRaises(StopIteration):
+ next(pnw)
+
+ def gen_raise():
+ yield from gen_normal()
+ 1/0
+
+ pnw = prev_next_window(gen_raise())
+ self.assertEqual(next(pnw), (None, 1, 2))
+ self.assertEqual(next(pnw), (1, 2, 3))
+ self.assertEqual(next(pnw), (2, 3, 4))
+ self.assertEqual(next(pnw), (3, 4, None))
+ with self.assertRaises(ZeroDivisionError):
+ next(pnw)
diff --git a/Lib/test/test_pyrepl/test_windows_console.py b/Lib/test/test_pyrepl/test_windows_console.py
index 69f2d5af2a4..f9607e02c60 100644
--- a/Lib/test/test_pyrepl/test_windows_console.py
+++ b/Lib/test/test_pyrepl/test_windows_console.py
@@ -7,11 +7,13 @@ if sys.platform != "win32":
import itertools
from functools import partial
+from test.support import force_not_colorized_test_class
from typing import Iterable
from unittest import TestCase
from unittest.mock import MagicMock, call
from .support import handle_all_events, code_to_events
+from .support import prepare_reader as default_prepare_reader
try:
from _pyrepl.console import Event, Console
@@ -23,14 +25,17 @@ try:
MOVE_DOWN,
ERASE_IN_LINE,
)
+ import _pyrepl.windows_console as wc
except ImportError:
pass
+@force_not_colorized_test_class
class WindowsConsoleTests(TestCase):
def console(self, events, **kwargs) -> Console:
console = WindowsConsole()
console.get_event = MagicMock(side_effect=events)
+ console.getpending = MagicMock(return_value=Event("key", ""))
console.wait = MagicMock()
console._scroll = MagicMock()
console._hide_cursor = MagicMock()
@@ -47,14 +52,22 @@ class WindowsConsoleTests(TestCase):
setattr(console, key, val)
return console
- def handle_events(self, events: Iterable[Event], **kwargs):
- return handle_all_events(events, partial(self.console, **kwargs))
+ def handle_events(
+ self,
+ events: Iterable[Event],
+ prepare_console=None,
+ prepare_reader=None,
+ **kwargs,
+ ):
+ prepare_console = prepare_console or partial(self.console, **kwargs)
+ prepare_reader = prepare_reader or default_prepare_reader
+ return handle_all_events(events, prepare_console, prepare_reader)
def handle_events_narrow(self, events):
return self.handle_events(events, width=5)
- def handle_events_short(self, events):
- return self.handle_events(events, height=1)
+ def handle_events_short(self, events, **kwargs):
+ return self.handle_events(events, height=1, **kwargs)
def handle_events_height_3(self, events):
return self.handle_events(events, height=3)
@@ -341,8 +354,227 @@ class WindowsConsoleTests(TestCase):
Event(evt="key", data='\x1a', raw=bytearray(b'\x1a')),
],
)
- reader, _ = self.handle_events_narrow(events)
+ reader, con = self.handle_events_narrow(events)
self.assertEqual(reader.cxy, (2, 3))
+ con.restore()
+
+
+class WindowsConsoleGetEventTests(TestCase):
+ # Virtual-Key Codes: https://learn.microsoft.com/en-us/windows/win32/inputdev/virtual-key-codes
+ VK_BACK = 0x08
+ VK_RETURN = 0x0D
+ VK_LEFT = 0x25
+ VK_7 = 0x37
+ VK_M = 0x4D
+ # Used for miscellaneous characters; it can vary by keyboard.
+ # For the US standard keyboard, the '" key.
+ # For the German keyboard, the Ä key.
+ VK_OEM_7 = 0xDE
+
+ # State of control keys: https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ RIGHT_ALT_PRESSED = 0x0001
+ RIGHT_CTRL_PRESSED = 0x0004
+ LEFT_ALT_PRESSED = 0x0002
+ LEFT_CTRL_PRESSED = 0x0008
+ ENHANCED_KEY = 0x0100
+ SHIFT_PRESSED = 0x0010
+
+
+ def get_event(self, input_records, **kwargs) -> Console:
+ self.console = WindowsConsole(encoding='utf-8')
+ self.mock = MagicMock(side_effect=input_records)
+ self.console._read_input = self.mock
+ self.console._WindowsConsole__vt_support = kwargs.get("vt_support",
+ False)
+ self.console.wait = MagicMock(return_value=True)
+ event = self.console.get_event(block=False)
+ return event
+
+ def get_input_record(self, unicode_char, vcode=0, control=0):
+ return wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(
+ bKeyDown=True,
+ wRepeatCount=1,
+ wVirtualKeyCode=vcode,
+ wVirtualScanCode=0, # not used
+ uChar=wc.Char(unicode_char),
+ dwControlKeyState=control
+ )))
+
+ def test_EmptyBuffer(self):
+ self.assertEqual(self.get_event([None]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_WINDOW_BUFFER_SIZE_EVENT(self):
+ ir = wc.INPUT_RECORD(
+ wc.WINDOW_BUFFER_SIZE_EVENT,
+ wc.ConsoleEvent(WindowsBufferSizeEvent=
+ wc.WindowsBufferSizeEvent(
+ wc._COORD(0, 0))))
+ self.assertEqual(self.get_event([ir]), Event("resize", ""))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_KEY_EVENT_up_ignored(self):
+ ir = wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_unhandled_events(self):
+ for event in (wc.FOCUS_EVENT, wc.MENU_EVENT, wc.MOUSE_EVENT):
+ ir = wc.INPUT_RECORD(
+ event,
+ # fake data, nothing is read except bKeyDown
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_enter(self):
+ ir = self.get_input_record("\r", self.VK_RETURN)
+ self.assertEqual(self.get_event([ir]), Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace(self):
+ ir = self.get_input_record("\x08", self.VK_BACK)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "backspace"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m(self):
+ ir = self.get_input_record("m", self.VK_M)
+ self.assertEqual(self.get_event([ir]), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_M(self):
+ ir = self.get_input_record("M", self.VK_M, self.SHIFT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "M"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left(self):
+ # VK_LEFT is sent as ENHANCED_KEY
+ ir = self.get_input_record("\x00", self.VK_LEFT, self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ # self.mock is not called again, since the second time we read from the
+ # command queue
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED_and_LEFT_CTRL_PRESSED(self):
+ # For the shift keys, Windows does not send anything when
+ # ALT and CTRL are both pressed, so let's test with VK_M.
+ # get_event() receives this input, but does not
+ # generate an event.
+ # This is for e.g. an English keyboard layout, for a
+ # German layout this returns `µ`, see test_AltGr_m.
+ ir = self.get_input_record(
+ "\x00", self.VK_M, self.LEFT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.RIGHT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_7(self):
+ # E.g. on a German keyboard layout, '{' is entered via
+ # AltGr + 7, where AltGr is the right Alt key on the keyboard.
+ # In this case, Windows automatically sets
+ # RIGHT_ALT_PRESSED = 0x0001 + LEFT_CTRL_PRESSED = 0x0008
+ # This can also be entered like
+ # LeftAlt + LeftCtrl + 7 or
+ # LeftAlt + RightCtrl + 7
+ # See https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw
+ ir = self.get_input_record(
+ "{", vcode=self.VK_7,
+ control=self.RIGHT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "{"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_m(self):
+ # E.g. on a German keyboard layout, this yields 'µ'
+ # Let's use LEFT_ALT_PRESSED and RIGHT_CTRL_PRESSED this
+ # time, to cover that, too. See above in test_AltGr_7.
+ ir = self.get_input_record(
+ "µ", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED | self.RIGHT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "µ"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_umlaut_a_german(self):
+ ir = self.get_input_record("ä", self.VK_OEM_7)
+ self.assertEqual(self.get_event([ir]), Event("key", "ä"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ # virtual terminal tests
+ # Note: wVirtualKeyCode, wVirtualScanCode and dwControlKeyState
+ # are always zero in this case.
+ # "\r" and backspace are handled specially, everything else
+ # is handled in "elif self.__vt_support:" in WindowsConsole.get_event().
+ # Hence, only one regular key ("m") and a terminal sequence
+ # are sufficient to test here, the real tests happen in test_eventqueue
+ # and test_keymap.
+
+ def test_enter_vt(self):
+ ir = self.get_input_record("\r")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace_vt(self):
+ ir = self.get_input_record("\x7f")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "backspace", b"\x7f"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_up_vt(self):
+ irs = [self.get_input_record(x) for x in "\x1b[A"]
+ self.assertEqual(self.get_event(irs, vt_support=True),
+ Event(evt='key', data='up', raw=bytearray(b'\x1b[A')))
+ self.assertEqual(self.mock.call_count, 3)
if __name__ == "__main__":
diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py
index 7f4fe357034..c855fb8fe2b 100644
--- a/Lib/test/test_queue.py
+++ b/Lib/test/test_queue.py
@@ -6,7 +6,7 @@ import threading
import time
import unittest
import weakref
-from test.support import gc_collect
+from test.support import gc_collect, bigmemtest
from test.support import import_helper
from test.support import threading_helper
@@ -963,33 +963,33 @@ class BaseSimpleQueueTest:
# One producer, one consumer => results appended in well-defined order
self.assertEqual(results, inputs)
- def test_many_threads(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads(self, size):
# Test multiple concurrent put() and get()
- N = 50
q = self.q
inputs = list(range(10000))
- results = self.run_threads(N, q, inputs, self.feed, self.consume)
+ results = self.run_threads(size, q, inputs, self.feed, self.consume)
# Multiple consumers without synchronization append the
# results in random order
self.assertEqual(sorted(results), inputs)
- def test_many_threads_nonblock(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads_nonblock(self, size):
# Test multiple concurrent put() and get(block=False)
- N = 50
q = self.q
inputs = list(range(10000))
- results = self.run_threads(N, q, inputs,
+ results = self.run_threads(size, q, inputs,
self.feed, self.consume_nonblock)
self.assertEqual(sorted(results), inputs)
- def test_many_threads_timeout(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads_timeout(self, size):
# Test multiple concurrent put() and get(timeout=...)
- N = 50
q = self.q
inputs = list(range(1000))
- results = self.run_threads(N, q, inputs,
+ results = self.run_threads(size, q, inputs,
self.feed, self.consume_timeout)
self.assertEqual(sorted(results), inputs)
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 96f6cc86219..0217ebd132b 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -14,6 +14,15 @@ from test import support
from fractions import Fraction
from collections import abc, Counter
+
+class MyIndex:
+ def __init__(self, value):
+ self.value = value
+
+ def __index__(self):
+ return self.value
+
+
class TestBasicOps:
# Superclass with tests common to all generators.
# Subclasses must arrange for self.gen to retrieve the Random instance
@@ -142,6 +151,7 @@ class TestBasicOps:
# Exception raised if size of sample exceeds that of population
self.assertRaises(ValueError, self.gen.sample, population, N+1)
self.assertRaises(ValueError, self.gen.sample, [], -1)
+ self.assertRaises(TypeError, self.gen.sample, population, 1.0)
def test_sample_distribution(self):
# For the entire allowable range of 0 <= k <= N, validate that
@@ -259,6 +269,7 @@ class TestBasicOps:
choices(data, range(4), k=5),
choices(k=5, population=data, weights=range(4)),
choices(k=5, population=data, cum_weights=range(4)),
+ choices(data, k=MyIndex(5)),
]:
self.assertEqual(len(sample), 5)
self.assertEqual(type(sample), list)
@@ -369,118 +380,40 @@ class TestBasicOps:
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
+ @support.requires_IEEE_754
+ def test_53_bits_per_float(self):
+ span = 2 ** 53
+ cum = 0
+ for i in range(100):
+ cum |= int(self.gen.random() * span)
+ self.assertEqual(cum, span-1)
+
def test_getrandbits(self):
+ getrandbits = self.gen.getrandbits
# Verify ranges
for k in range(1, 1000):
- self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
- self.assertEqual(self.gen.getrandbits(0), 0)
+ self.assertTrue(0 <= getrandbits(k) < 2**k)
+ self.assertEqual(getrandbits(0), 0)
# Verify all bits active
- getbits = self.gen.getrandbits
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
all_bits = 2**span-1
cum = 0
cpl_cum = 0
for i in range(100):
- v = getbits(span)
+ v = getrandbits(span)
cum |= v
cpl_cum |= all_bits ^ v
self.assertEqual(cum, all_bits)
self.assertEqual(cpl_cum, all_bits)
# Verify argument checking
- self.assertRaises(TypeError, self.gen.getrandbits)
- self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
- self.assertRaises(ValueError, self.gen.getrandbits, -1)
- self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
-
- def test_pickling(self):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- state = pickle.dumps(self.gen, proto)
- origseq = [self.gen.random() for i in range(10)]
- newgen = pickle.loads(state)
- restoredseq = [newgen.random() for i in range(10)]
- self.assertEqual(origseq, restoredseq)
-
- def test_bug_1727780(self):
- # verify that version-2-pickles can be loaded
- # fine, whether they are created on 32-bit or 64-bit
- # platforms, and that version-3-pickles load fine.
- files = [("randv2_32.pck", 780),
- ("randv2_64.pck", 866),
- ("randv3.pck", 343)]
- for file, value in files:
- with open(support.findfile(file),"rb") as f:
- r = pickle.load(f)
- self.assertEqual(int(r.random()*1000), value)
-
- def test_bug_9025(self):
- # Had problem with an uneven distribution in int(n*random())
- # Verify the fix by checking that distributions fall within expectations.
- n = 100000
- randrange = self.gen.randrange
- k = sum(randrange(6755399441055744) % 3 == 2 for i in range(n))
- self.assertTrue(0.30 < k/n < .37, (k/n))
-
- def test_randbytes(self):
- # Verify ranges
- for n in range(1, 10):
- data = self.gen.randbytes(n)
- self.assertEqual(type(data), bytes)
- self.assertEqual(len(data), n)
-
- self.assertEqual(self.gen.randbytes(0), b'')
-
- # Verify argument checking
- self.assertRaises(TypeError, self.gen.randbytes)
- self.assertRaises(TypeError, self.gen.randbytes, 1, 2)
- self.assertRaises(ValueError, self.gen.randbytes, -1)
- self.assertRaises(TypeError, self.gen.randbytes, 1.0)
-
- def test_mu_sigma_default_args(self):
- self.assertIsInstance(self.gen.normalvariate(), float)
- self.assertIsInstance(self.gen.gauss(), float)
-
-
-try:
- random.SystemRandom().random()
-except NotImplementedError:
- SystemRandom_available = False
-else:
- SystemRandom_available = True
-
-@unittest.skipUnless(SystemRandom_available, "random.SystemRandom not available")
-class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
- gen = random.SystemRandom()
-
- def test_autoseed(self):
- # Doesn't need to do anything except not fail
- self.gen.seed()
-
- def test_saverestore(self):
- self.assertRaises(NotImplementedError, self.gen.getstate)
- self.assertRaises(NotImplementedError, self.gen.setstate, None)
-
- def test_seedargs(self):
- # Doesn't need to do anything except not fail
- self.gen.seed(100)
-
- def test_gauss(self):
- self.gen.gauss_next = None
- self.gen.seed(100)
- self.assertEqual(self.gen.gauss_next, None)
-
- def test_pickling(self):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.assertRaises(NotImplementedError, pickle.dumps, self.gen, proto)
-
- def test_53_bits_per_float(self):
- # This should pass whenever a C double has 53 bit precision.
- span = 2 ** 53
- cum = 0
- for i in range(100):
- cum |= int(self.gen.random() * span)
- self.assertEqual(cum, span-1)
+ self.assertRaises(TypeError, getrandbits)
+ self.assertRaises(TypeError, getrandbits, 1, 2)
+ self.assertRaises(ValueError, getrandbits, -1)
+ self.assertRaises(OverflowError, getrandbits, 1<<1000)
+ self.assertRaises(ValueError, getrandbits, -1<<1000)
+ self.assertRaises(TypeError, getrandbits, 10.1)
def test_bigrand(self):
# The randrange routine should build-up the required number of bits
@@ -559,6 +492,10 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
randrange(1000, step=100)
with self.assertRaises(TypeError):
randrange(1000, None, step=100)
+ with self.assertRaises(TypeError):
+ randrange(1000, step=MyIndex(1))
+ with self.assertRaises(TypeError):
+ randrange(1000, None, step=MyIndex(1))
def test_randbelow_logic(self, _log=log, int=int):
# check bitcount transition points: 2**i and 2**(i+1)-1
@@ -581,6 +518,116 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(k, numbits) # note the stronger assertion
self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
+ def test_randrange_index(self):
+ randrange = self.gen.randrange
+ self.assertIn(randrange(MyIndex(5)), range(5))
+ self.assertIn(randrange(MyIndex(2), MyIndex(7)), range(2, 7))
+ self.assertIn(randrange(MyIndex(5), MyIndex(15), MyIndex(2)), range(5, 15, 2))
+
+ def test_randint(self):
+ randint = self.gen.randint
+ self.assertIn(randint(2, 5), (2, 3, 4, 5))
+ self.assertEqual(randint(2, 2), 2)
+ self.assertIn(randint(MyIndex(2), MyIndex(5)), (2, 3, 4, 5))
+ self.assertEqual(randint(MyIndex(2), MyIndex(2)), 2)
+
+ self.assertRaises(ValueError, randint, 5, 2)
+ self.assertRaises(TypeError, randint)
+ self.assertRaises(TypeError, randint, 2)
+ self.assertRaises(TypeError, randint, 2, 5, 1)
+ self.assertRaises(TypeError, randint, 2.0, 5)
+ self.assertRaises(TypeError, randint, 2, 5.0)
+
+ def test_pickling(self):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ state = pickle.dumps(self.gen, proto)
+ origseq = [self.gen.random() for i in range(10)]
+ newgen = pickle.loads(state)
+ restoredseq = [newgen.random() for i in range(10)]
+ self.assertEqual(origseq, restoredseq)
+
+ def test_bug_1727780(self):
+ # verify that version-2-pickles can be loaded
+ # fine, whether they are created on 32-bit or 64-bit
+ # platforms, and that version-3-pickles load fine.
+ files = [("randv2_32.pck", 780),
+ ("randv2_64.pck", 866),
+ ("randv3.pck", 343)]
+ for file, value in files:
+ with open(support.findfile(file),"rb") as f:
+ r = pickle.load(f)
+ self.assertEqual(int(r.random()*1000), value)
+
+ def test_bug_9025(self):
+ # Had problem with an uneven distribution in int(n*random())
+ # Verify the fix by checking that distributions fall within expectations.
+ n = 100000
+ randrange = self.gen.randrange
+ k = sum(randrange(6755399441055744) % 3 == 2 for i in range(n))
+ self.assertTrue(0.30 < k/n < .37, (k/n))
+
+ def test_randrange_bug_1590891(self):
+ start = 1000000000000
+ stop = -100000000000000000000
+ step = -200
+ x = self.gen.randrange(start, stop, step)
+ self.assertTrue(stop < x <= start)
+ self.assertEqual((x+stop)%step, 0)
+
+ def test_randbytes(self):
+ # Verify ranges
+ for n in range(1, 10):
+ data = self.gen.randbytes(n)
+ self.assertEqual(type(data), bytes)
+ self.assertEqual(len(data), n)
+
+ self.assertEqual(self.gen.randbytes(0), b'')
+
+ # Verify argument checking
+ self.assertRaises(TypeError, self.gen.randbytes)
+ self.assertRaises(TypeError, self.gen.randbytes, 1, 2)
+ self.assertRaises(ValueError, self.gen.randbytes, -1)
+ self.assertRaises(OverflowError, self.gen.randbytes, 1<<1000)
+ self.assertRaises((ValueError, OverflowError), self.gen.randbytes, -1<<1000)
+ self.assertRaises(TypeError, self.gen.randbytes, 1.0)
+
+ def test_mu_sigma_default_args(self):
+ self.assertIsInstance(self.gen.normalvariate(), float)
+ self.assertIsInstance(self.gen.gauss(), float)
+
+
+try:
+ random.SystemRandom().random()
+except NotImplementedError:
+ SystemRandom_available = False
+else:
+ SystemRandom_available = True
+
+@unittest.skipUnless(SystemRandom_available, "random.SystemRandom not available")
+class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
+ gen = random.SystemRandom()
+
+ def test_autoseed(self):
+ # Doesn't need to do anything except not fail
+ self.gen.seed()
+
+ def test_saverestore(self):
+ self.assertRaises(NotImplementedError, self.gen.getstate)
+ self.assertRaises(NotImplementedError, self.gen.setstate, None)
+
+ def test_seedargs(self):
+ # Doesn't need to do anything except not fail
+ self.gen.seed(100)
+
+ def test_gauss(self):
+ self.gen.gauss_next = None
+ self.gen.seed(100)
+ self.assertEqual(self.gen.gauss_next, None)
+
+ def test_pickling(self):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.assertRaises(NotImplementedError, pickle.dumps, self.gen, proto)
+
class TestRawMersenneTwister(unittest.TestCase):
@test.support.cpython_only
@@ -766,38 +813,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
seed = (1 << (10000 * 8)) - 1 # about 10K bytes
self.gen.seed(seed)
- def test_53_bits_per_float(self):
- # This should pass whenever a C double has 53 bit precision.
- span = 2 ** 53
- cum = 0
- for i in range(100):
- cum |= int(self.gen.random() * span)
- self.assertEqual(cum, span-1)
-
- def test_bigrand(self):
- # The randrange routine should build-up the required number of bits
- # in stages so that all bit positions are active.
- span = 2 ** 500
- cum = 0
- for i in range(100):
- r = self.gen.randrange(span)
- self.assertTrue(0 <= r < span)
- cum |= r
- self.assertEqual(cum, span-1)
-
- def test_bigrand_ranges(self):
- for i in [40,80, 160, 200, 211, 250, 375, 512, 550]:
- start = self.gen.randrange(2 ** (i-2))
- stop = self.gen.randrange(2 ** i)
- if stop <= start:
- continue
- self.assertTrue(start <= self.gen.randrange(start, stop) < stop)
-
- def test_rangelimits(self):
- for start, stop in [(-2,0), (-(2**60)-2,-(2**60)), (2**60,2**60+2)]:
- self.assertEqual(set(range(start,stop)),
- set([self.gen.randrange(start,stop) for i in range(100)]))
-
def test_getrandbits(self):
super().test_getrandbits()
@@ -805,6 +820,25 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.gen.seed(1234567)
self.assertEqual(self.gen.getrandbits(100),
97904845777343510404718956115)
+ self.gen.seed(1234567)
+ self.assertEqual(self.gen.getrandbits(MyIndex(100)),
+ 97904845777343510404718956115)
+
+ def test_getrandbits_2G_bits(self):
+ size = 2**31
+ self.gen.seed(1234567)
+ x = self.gen.getrandbits(size)
+ self.assertEqual(x.bit_length(), size)
+ self.assertEqual(x & (2**100-1), 890186470919986886340158459475)
+ self.assertEqual(x >> (size-100), 1226514312032729439655761284440)
+
+ @support.bigmemtest(size=2**32, memuse=1/8+2/15, dry_run=False)
+ def test_getrandbits_4G_bits(self, size):
+ self.gen.seed(1234568)
+ x = self.gen.getrandbits(size)
+ self.assertEqual(x.bit_length(), size)
+ self.assertEqual(x & (2**100-1), 287241425661104632871036099814)
+ self.assertEqual(x >> (size-100), 739728759900339699429794460738)
def test_randrange_uses_getrandbits(self):
# Verify use of getrandbits by randrange
@@ -816,27 +850,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(self.gen.randrange(2**99),
97904845777343510404718956115)
- def test_randbelow_logic(self, _log=log, int=int):
- # check bitcount transition points: 2**i and 2**(i+1)-1
- # show that: k = int(1.001 + _log(n, 2))
- # is equal to or one greater than the number of bits in n
- for i in range(1, 1000):
- n = 1 << i # check an exact power of two
- numbits = i+1
- k = int(1.00001 + _log(n, 2))
- self.assertEqual(k, numbits)
- self.assertEqual(n, 2**(k-1))
-
- n += n - 1 # check 1 below the next power of two
- k = int(1.00001 + _log(n, 2))
- self.assertIn(k, [numbits, numbits+1])
- self.assertTrue(2**k > n > 2**(k-2))
-
- n -= n >> 15 # check a little farther below the next power of two
- k = int(1.00001 + _log(n, 2))
- self.assertEqual(k, numbits) # note the stronger assertion
- self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
-
def test_randbelow_without_getrandbits(self):
# Random._randbelow() can only use random() when the built-in one
# has been overridden but no new getrandbits() method was supplied.
@@ -871,14 +884,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.gen._randbelow_without_getrandbits(n, maxsize=maxsize)
self.assertEqual(random_mock.call_count, 2)
- def test_randrange_bug_1590891(self):
- start = 1000000000000
- stop = -100000000000000000000
- step = -200
- x = self.gen.randrange(start, stop, step)
- self.assertTrue(stop < x <= start)
- self.assertEqual((x+stop)%step, 0)
-
def test_choices_algorithms(self):
# The various ways of specifying weights should produce the same results
choices = self.gen.choices
@@ -962,6 +967,14 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(self.gen.randbytes(n),
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
+ @support.bigmemtest(size=2**29, memuse=1+16/15, dry_run=False)
+ def test_randbytes_256M(self, size):
+ self.gen.seed(2849427419)
+ x = self.gen.randbytes(size)
+ self.assertEqual(len(x), size)
+ self.assertEqual(x[:12].hex(), 'f6fd9ae63855ab91ea238b4f')
+ self.assertEqual(x[-12:].hex(), '0e7af69a84ee99bf4a11becc')
+
def test_sample_counts_equivalence(self):
# Test the documented strong equivalence to a sample with repeated elements.
# We run this test on random.Random() which makes deterministic selections
@@ -1411,30 +1424,31 @@ class TestModule(unittest.TestCase):
class CommandLineTest(unittest.TestCase):
+ @support.force_not_colorized
def test_parse_args(self):
args, help_text = random._parse_args(shlex.split("--choice a b c"))
self.assertEqual(args.choice, ["a", "b", "c"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("--integer 5"))
self.assertEqual(args.integer, 5)
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("--float 2.5"))
self.assertEqual(args.float, 2.5)
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("a b c"))
self.assertEqual(args.input, ["a", "b", "c"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("5"))
self.assertEqual(args.input, ["5"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("2.5"))
self.assertEqual(args.input, ["2.5"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
def test_main(self):
for command, expected in [
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index cf8525ed901..e9128ac1d97 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -619,6 +619,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.fullmatch(r"a.*?b", "axxb").span(), (0, 4))
self.assertIsNone(re.fullmatch(r"a+", "ab"))
self.assertIsNone(re.fullmatch(r"abc$", "abc\n"))
+ self.assertIsNone(re.fullmatch(r"abc\z", "abc\n"))
self.assertIsNone(re.fullmatch(r"abc\Z", "abc\n"))
self.assertIsNone(re.fullmatch(r"(?m)abc$", "abc\n"))
self.assertEqual(re.fullmatch(r"ab(?=c)cd", "abcd").span(), (0, 4))
@@ -802,6 +803,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.search(r"\B(b.)\B",
"abc bcd bc abxd", re.ASCII).group(1), "bx")
self.assertEqual(re.search(r"^abc$", "\nabc\n", re.M).group(0), "abc")
+ self.assertEqual(re.search(r"^\Aabc\z$", "abc", re.M).group(0), "abc")
+ self.assertIsNone(re.search(r"^\Aabc\z$", "\nabc\n", re.M))
self.assertEqual(re.search(r"^\Aabc\Z$", "abc", re.M).group(0), "abc")
self.assertIsNone(re.search(r"^\Aabc\Z$", "\nabc\n", re.M))
self.assertEqual(re.search(br"\b(b.)\b",
@@ -813,6 +816,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.search(br"\B(b.)\B",
b"abc bcd bc abxd", re.LOCALE).group(1), b"bx")
self.assertEqual(re.search(br"^abc$", b"\nabc\n", re.M).group(0), b"abc")
+ self.assertEqual(re.search(br"^\Aabc\z$", b"abc", re.M).group(0), b"abc")
+ self.assertIsNone(re.search(br"^\Aabc\z$", b"\nabc\n", re.M))
self.assertEqual(re.search(br"^\Aabc\Z$", b"abc", re.M).group(0), b"abc")
self.assertIsNone(re.search(br"^\Aabc\Z$", b"\nabc\n", re.M))
self.assertEqual(re.search(r"\d\D\w\W\s\S",
@@ -836,7 +841,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.match(r"[\^a]+", 'a^').group(), 'a^')
self.assertIsNone(re.match(r"[\^a]+", 'b'))
re.purge() # for warnings
- for c in 'ceghijklmopqyzCEFGHIJKLMNOPQRTVXY':
+ for c in 'ceghijklmopqyCEFGHIJKLMNOPQRTVXY':
with self.subTest(c):
self.assertRaises(re.PatternError, re.compile, '\\%c' % c)
for c in 'ceghijklmopqyzABCEFGHIJKLMNOPQRTVXYZ':
@@ -2608,8 +2613,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.findall(r'(?>(?:ab){1,3})', 'ababc'), ['abab'])
def test_bug_gh91616(self):
- self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\Z', "a.txt")) # reproducer
- self.assertTrue(re.fullmatch(r'(?s:(?=(?P<g0>.*?\.))(?P=g0).*)\Z', "a.txt"))
+ self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\z', "a.txt")) # reproducer
+ self.assertTrue(re.fullmatch(r'(?s:(?=(?P<g0>.*?\.))(?P=g0).*)\z', "a.txt"))
def test_bug_gh100061(self):
# gh-100061
@@ -2863,11 +2868,11 @@ class PatternReprTests(unittest.TestCase):
pattern = 'Very %spattern' % ('long ' * 1000)
r = repr(re.compile(pattern))
self.assertLess(len(r), 300)
- self.assertEqual(r[:30], "re.compile('Very long long lon")
+ self.assertStartsWith(r, "re.compile('Very long long lon")
r = repr(re.compile(pattern, re.I))
self.assertLess(len(r), 300)
- self.assertEqual(r[:30], "re.compile('Very long long lon")
- self.assertEqual(r[-16:], ", re.IGNORECASE)")
+ self.assertStartsWith(r, "re.compile('Very long long lon")
+ self.assertEndsWith(r, ", re.IGNORECASE)")
def test_flags_repr(self):
self.assertEqual(repr(re.I), "re.IGNORECASE")
@@ -2946,7 +2951,7 @@ class ImplementationTest(unittest.TestCase):
self.assertEqual(mod.__name__, name)
self.assertEqual(mod.__package__, '')
for attr in deprecated[name]:
- self.assertTrue(hasattr(mod, attr))
+ self.assertHasAttr(mod, attr)
del sys.modules[name]
@cpython_only
diff --git a/Lib/test/test_readline.py b/Lib/test/test_readline.py
index b9d082b3597..45192fe5082 100644
--- a/Lib/test/test_readline.py
+++ b/Lib/test/test_readline.py
@@ -1,6 +1,7 @@
"""
Very minimal unittests for parts of the readline module.
"""
+import codecs
import locale
import os
import sys
@@ -231,6 +232,13 @@ print("History length:", readline.get_current_history_length())
# writing and reading non-ASCII bytes into/from a TTY works, but
# readline or ncurses ignores non-ASCII bytes on read.
self.skipTest(f"the LC_CTYPE locale is {loc!r}")
+ if sys.flags.utf8_mode:
+ encoding = locale.getencoding()
+ encoding = codecs.lookup(encoding).name # normalize the name
+ if encoding != "utf-8":
+ # gh-133711: The Python UTF-8 Mode ignores the LC_CTYPE locale
+ # and always use the UTF-8 encoding.
+ self.skipTest(f"the LC_CTYPE encoding is {encoding!r}")
try:
readline.add_history("\xEB\xEF")
diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py
index 7e317d5ab94..5bc3c5924b0 100644
--- a/Lib/test/test_regrtest.py
+++ b/Lib/test/test_regrtest.py
@@ -768,13 +768,16 @@ class BaseTestCase(unittest.TestCase):
self.fail(msg)
return proc
- def run_python(self, args, **kw):
+ def run_python(self, args, isolated=True, **kw):
extraargs = []
if 'uops' in sys._xoptions:
# Pass -X uops along
extraargs.extend(['-X', 'uops'])
- args = [sys.executable, *extraargs, '-X', 'faulthandler', '-I', *args]
- proc = self.run_command(args, **kw)
+ cmd = [sys.executable, *extraargs, '-X', 'faulthandler']
+ if isolated:
+ cmd.append('-I')
+ cmd.extend(args)
+ proc = self.run_command(cmd, **kw)
return proc.stdout
@@ -831,8 +834,8 @@ class ProgramsTestCase(BaseTestCase):
self.check_executed_tests(output, self.tests,
randomize=True, stats=len(self.tests))
- def run_tests(self, args, env=None):
- output = self.run_python(args, env=env)
+ def run_tests(self, args, env=None, isolated=True):
+ output = self.run_python(args, env=env, isolated=isolated)
self.check_output(output)
def test_script_regrtest(self):
@@ -874,7 +877,10 @@ class ProgramsTestCase(BaseTestCase):
self.run_tests(args)
def run_batch(self, *args):
- proc = self.run_command(args)
+ proc = self.run_command(args,
+ # gh-133711: cmd.exe uses the OEM code page
+ # to display the non-ASCII current directory
+ errors="backslashreplace")
self.check_output(proc.stdout)
@unittest.skipUnless(sysconfig.is_python_build(),
@@ -2064,7 +2070,7 @@ class ArgsTestCase(BaseTestCase):
self.check_executed_tests(output, [testname],
failed=[testname],
parallel=True,
- stats=TestStats(1, 1, 0))
+ stats=TestStats(1, 2, 1))
def _check_random_seed(self, run_workers: bool):
# gh-109276: When -r/--randomize is used, random.seed() is called
@@ -2273,7 +2279,6 @@ class ArgsTestCase(BaseTestCase):
def test_xml(self):
code = textwrap.dedent(r"""
import unittest
- from test import support
class VerboseTests(unittest.TestCase):
def test_failed(self):
@@ -2308,6 +2313,50 @@ class ArgsTestCase(BaseTestCase):
for out in testcase.iter('system-out'):
self.assertEqual(out.text, r"abc \x1b def")
+ def test_nonascii(self):
+ code = textwrap.dedent(r"""
+ import unittest
+
+ class NonASCIITests(unittest.TestCase):
+ def test_docstring(self):
+ '''docstring:\u20ac'''
+
+ def test_subtest(self):
+ with self.subTest(param='subtest:\u20ac'):
+ pass
+
+ def test_skip(self):
+ self.skipTest('skipped:\u20ac')
+ """)
+ testname = self.create_test(code=code)
+
+ env = dict(os.environ)
+ env['PYTHONIOENCODING'] = 'ascii'
+
+ def check(output):
+ self.check_executed_tests(output, testname, stats=TestStats(3, 0, 1))
+ self.assertIn(r'docstring:\u20ac', output)
+ self.assertIn(r'skipped:\u20ac', output)
+
+ # Run sequentially
+ output = self.run_tests('-v', testname, env=env, isolated=False)
+ check(output)
+
+ # Run in parallel
+ output = self.run_tests('-j1', '-v', testname, env=env, isolated=False)
+ check(output)
+
+ def test_pgo_exclude(self):
+ # Get PGO tests
+ output = self.run_tests('--pgo', '--list-tests')
+ pgo_tests = output.strip().split()
+
+ # Exclude test_re
+ output = self.run_tests('--pgo', '--list-tests', '-x', 'test_re')
+ tests = output.strip().split()
+ self.assertNotIn('test_re', tests)
+ self.assertEqual(len(tests), len(pgo_tests) - 1)
+
class TestUtils(unittest.TestCase):
def test_format_duration(self):
diff --git a/Lib/test/test_remote_pdb.py b/Lib/test/test_remote_pdb.py
index 2c4a17abd82..a1c50af15f3 100644
--- a/Lib/test/test_remote_pdb.py
+++ b/Lib/test/test_remote_pdb.py
@@ -1,21 +1,19 @@
import io
-import time
+import itertools
import json
import os
+import re
import signal
import socket
import subprocess
import sys
-import tempfile
import textwrap
-import threading
import unittest
import unittest.mock
-from contextlib import contextmanager
-from pathlib import Path
-from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT
-from test.support.os_helper import temp_dir, TESTFN, unlink
-from typing import Dict, List, Optional, Tuple, Union, Any
+from contextlib import closing, contextmanager, redirect_stdout, redirect_stderr, ExitStack
+from test.support import is_wasi, cpython_only, force_color, requires_subprocess, SHORT_TIMEOUT
+from test.support.os_helper import TESTFN, unlink
+from typing import List
import pdb
from pdb import _PdbServer, _PdbClient
@@ -78,6 +76,746 @@ class MockSocketFile:
return results
+class PdbClientTestCase(unittest.TestCase):
+ """Tests for the _PdbClient class."""
+
+ def do_test(
+ self,
+ *,
+ incoming,
+ simulate_send_failure=False,
+ simulate_sigint_during_stdout_write=False,
+ use_interrupt_socket=False,
+ expected_outgoing=None,
+ expected_outgoing_signals=None,
+ expected_completions=None,
+ expected_exception=None,
+ expected_stdout="",
+ expected_stdout_substring="",
+ expected_state=None,
+ ):
+ if expected_outgoing is None:
+ expected_outgoing = []
+ if expected_outgoing_signals is None:
+ expected_outgoing_signals = []
+ if expected_completions is None:
+ expected_completions = []
+ if expected_state is None:
+ expected_state = {}
+
+ expected_state.setdefault("write_failed", False)
+ messages = [m for source, m in incoming if source == "server"]
+ prompts = [m["prompt"] for source, m in incoming if source == "user"]
+
+ input_iter = (m for source, m in incoming if source == "user")
+ completions = []
+
+ def mock_input(prompt):
+ message = next(input_iter, None)
+ if message is None:
+ raise EOFError
+
+ if req := message.get("completion_request"):
+ readline_mock = unittest.mock.Mock()
+ readline_mock.get_line_buffer.return_value = req["line"]
+ readline_mock.get_begidx.return_value = req["begidx"]
+ readline_mock.get_endidx.return_value = req["endidx"]
+ unittest.mock.seal(readline_mock)
+ with unittest.mock.patch.dict(sys.modules, {"readline": readline_mock}):
+ for param in itertools.count():
+ prefix = req["line"][req["begidx"] : req["endidx"]]
+ completion = client.complete(prefix, param)
+ if completion is None:
+ break
+ completions.append(completion)
+
+ reply = message["input"]
+ if isinstance(reply, BaseException):
+ raise reply
+ if isinstance(reply, str):
+ return reply
+ return reply()
+
+ with ExitStack() as stack:
+ client_sock, server_sock = socket.socketpair()
+ stack.enter_context(closing(client_sock))
+ stack.enter_context(closing(server_sock))
+
+ server_sock = unittest.mock.Mock(wraps=server_sock)
+
+ client_sock.sendall(
+ b"".join(
+ (m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n"
+ for m in messages
+ )
+ )
+ client_sock.shutdown(socket.SHUT_WR)
+
+ if simulate_send_failure:
+ server_sock.sendall = unittest.mock.Mock(
+ side_effect=OSError("sendall failed")
+ )
+ client_sock.shutdown(socket.SHUT_RD)
+
+ stdout = io.StringIO()
+
+ if simulate_sigint_during_stdout_write:
+ orig_stdout_write = stdout.write
+
+ def sigint_stdout_write(s):
+ signal.raise_signal(signal.SIGINT)
+ return orig_stdout_write(s)
+
+ stdout.write = sigint_stdout_write
+
+ input_mock = stack.enter_context(
+ unittest.mock.patch("pdb.input", side_effect=mock_input)
+ )
+ stack.enter_context(redirect_stdout(stdout))
+
+ if use_interrupt_socket:
+ interrupt_sock = unittest.mock.Mock(spec=socket.socket)
+ mock_kill = None
+ else:
+ interrupt_sock = None
+ mock_kill = stack.enter_context(
+ unittest.mock.patch("os.kill", spec=os.kill)
+ )
+
+ client = _PdbClient(
+ pid=12345,
+ server_socket=server_sock,
+ interrupt_sock=interrupt_sock,
+ )
+
+ if expected_exception is not None:
+ exception = expected_exception["exception"]
+ msg = expected_exception["msg"]
+ stack.enter_context(self.assertRaises(exception, msg=msg))
+
+ client.cmdloop()
+
+ sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls]
+ for msg in sent_msgs:
+ assert msg.endswith(b"\n")
+ actual_outgoing = [json.loads(msg) for msg in sent_msgs]
+
+ self.assertEqual(actual_outgoing, expected_outgoing)
+ self.assertEqual(completions, expected_completions)
+ if expected_stdout_substring and not expected_stdout:
+ self.assertIn(expected_stdout_substring, stdout.getvalue())
+ else:
+ self.assertEqual(stdout.getvalue(), expected_stdout)
+ input_mock.assert_has_calls([unittest.mock.call(p) for p in prompts])
+ actual_state = {k: getattr(client, k) for k in expected_state}
+ self.assertEqual(actual_state, expected_state)
+
+ if use_interrupt_socket:
+ outgoing_signals = [
+ signal.Signals(int.from_bytes(call.args[0]))
+ for call in interrupt_sock.sendall.call_args_list
+ ]
+ else:
+ assert mock_kill is not None
+ outgoing_signals = []
+ for call in mock_kill.call_args_list:
+ pid, signum = call.args
+ self.assertEqual(pid, 12345)
+ outgoing_signals.append(signal.Signals(signum))
+ self.assertEqual(outgoing_signals, expected_outgoing_signals)
+
+ def test_remote_immediately_closing_the_connection(self):
+ """Test the behavior when the remote closes the connection immediately."""
+ incoming = []
+ expected_outgoing = []
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=expected_outgoing,
+ )
+
+ def test_handling_command_list(self):
+ """Test handling the command_list message."""
+ incoming = [
+ ("server", {"command_list": ["help", "list", "continue"]}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_state={
+ "pdb_commands": {"help", "list", "continue"},
+ },
+ )
+
+ def test_handling_info_message(self):
+ """Test handling a message payload with type='info'."""
+ incoming = [
+ ("server", {"message": "Some message or other\n", "type": "info"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout="Some message or other\n",
+ )
+
+ def test_handling_error_message(self):
+ """Test handling a message payload with type='error'."""
+ incoming = [
+ ("server", {"message": "Some message or other.", "type": "error"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout="*** Some message or other.\n",
+ )
+
+ def test_handling_other_message(self):
+ """Test handling a message payload with an unrecognized type."""
+ incoming = [
+ ("server", {"message": "Some message.\n", "type": "unknown"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout="Some message.\n",
+ )
+
+ def test_handling_help_for_command(self):
+ """Test handling a request to display help for a command."""
+ incoming = [
+ ("server", {"help": "ll"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout_substring="Usage: ll | longlist",
+ )
+
+ def test_handling_help_without_a_specific_topic(self):
+ """Test handling a request to display a help overview."""
+ incoming = [
+ ("server", {"help": ""}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout_substring="type help <topic>",
+ )
+
+ def test_handling_help_pdb(self):
+ """Test handling a request to display the full PDB manual."""
+ incoming = [
+ ("server", {"help": "pdb"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout_substring=">>> import pdb",
+ )
+
+ def test_handling_pdb_prompts(self):
+ """Test responding to pdb's normal prompts."""
+ incoming = [
+ ("server", {"command_list": ["b"]}),
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "lst ["}),
+ ("user", {"prompt": "... ", "input": "0 ]"}),
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": ""}),
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "b ["}),
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "! b ["}),
+ ("user", {"prompt": "... ", "input": "b ]"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "lst [\n0 ]"},
+ {"reply": ""},
+ {"reply": "b ["},
+ {"reply": "!b [\nb ]"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_handling_interact_prompts(self):
+ """Test responding to pdb's interact mode prompts."""
+ incoming = [
+ ("server", {"command_list": ["b"]}),
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ ("user", {"prompt": ">>> ", "input": "lst ["}),
+ ("user", {"prompt": "... ", "input": "0 ]"}),
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ ("user", {"prompt": ">>> ", "input": ""}),
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ ("user", {"prompt": ">>> ", "input": "b ["}),
+ ("user", {"prompt": "... ", "input": "b ]"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "lst [\n0 ]"},
+ {"reply": ""},
+ {"reply": "b [\nb ]"},
+ ],
+ expected_state={"state": "interact"},
+ )
+
+ def test_retry_pdb_prompt_on_syntax_error(self):
+ """Test re-prompting after a SyntaxError in a Python expression."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": " lst ["}),
+ ("user", {"prompt": "(Pdb) ", "input": "lst ["}),
+ ("user", {"prompt": "... ", "input": " 0 ]"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "lst [\n 0 ]"},
+ ],
+ expected_stdout_substring="*** IndentationError",
+ expected_state={"state": "pdb"},
+ )
+
+ def test_retry_interact_prompt_on_syntax_error(self):
+ """Test re-prompting after a SyntaxError in a Python expression."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ ("user", {"prompt": ">>> ", "input": "!lst ["}),
+ ("user", {"prompt": ">>> ", "input": "lst ["}),
+ ("user", {"prompt": "... ", "input": " 0 ]"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "lst [\n 0 ]"},
+ ],
+ expected_stdout_substring="*** SyntaxError",
+ expected_state={"state": "interact"},
+ )
+
+ def test_handling_unrecognized_prompt_type(self):
+ """Test fallback to "dumb" single-line mode for unknown states."""
+ incoming = [
+ ("server", {"prompt": "Do it? ", "state": "confirm"}),
+ ("user", {"prompt": "Do it? ", "input": "! ["}),
+ ("server", {"prompt": "Do it? ", "state": "confirm"}),
+ ("user", {"prompt": "Do it? ", "input": "echo hello"}),
+ ("server", {"prompt": "Do it? ", "state": "confirm"}),
+ ("user", {"prompt": "Do it? ", "input": ""}),
+ ("server", {"prompt": "Do it? ", "state": "confirm"}),
+ ("user", {"prompt": "Do it? ", "input": "echo goodbye"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "! ["},
+ {"reply": "echo hello"},
+ {"reply": ""},
+ {"reply": "echo goodbye"},
+ ],
+ expected_state={"state": "dumb"},
+ )
+
+ def test_sigint_at_prompt(self):
+ """Test signaling when a prompt gets interrupted."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ (
+ "user",
+ {
+ "prompt": "(Pdb) ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"signal": "INT"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_sigint_at_continuation_prompt(self):
+ """Test signaling when a continuation prompt gets interrupted."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "if True:"}),
+ (
+ "user",
+ {
+ "prompt": "... ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"signal": "INT"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_sigint_when_writing(self):
+ """Test siginaling when sys.stdout.write() gets interrupted."""
+ incoming = [
+ ("server", {"message": "Some message or other\n", "type": "info"}),
+ ]
+ for use_interrupt_socket in [False, True]:
+ with self.subTest(use_interrupt_socket=use_interrupt_socket):
+ self.do_test(
+ incoming=incoming,
+ simulate_sigint_during_stdout_write=True,
+ use_interrupt_socket=use_interrupt_socket,
+ expected_outgoing=[],
+ expected_outgoing_signals=[signal.SIGINT],
+ expected_stdout="Some message or other\n",
+ )
+
+ def test_eof_at_prompt(self):
+ """Test signaling when a prompt gets an EOFError."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": EOFError()}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"signal": "EOF"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_unrecognized_json_message(self):
+ """Test failing after getting an unrecognized payload."""
+ incoming = [
+ ("server", {"monty": "python"}),
+ ("server", {"message": "Some message or other\n", "type": "info"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_exception={
+ "exception": RuntimeError,
+ "msg": 'Unrecognized payload b\'{"monty": "python"}\'',
+ },
+ )
+
+ def test_continuing_after_getting_a_non_json_payload(self):
+ """Test continuing after getting a non JSON payload."""
+ incoming = [
+ ("server", b"spam"),
+ ("server", {"message": "Something", "type": "info"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[],
+ expected_stdout="\n".join(
+ [
+ "*** Invalid JSON from remote: b'spam\\n'",
+ "Something",
+ ]
+ ),
+ )
+
+ def test_write_failing(self):
+ """Test terminating if write fails due to a half closed socket."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[{"signal": "INT"}],
+ simulate_send_failure=True,
+ expected_state={"write_failed": True},
+ )
+
+ def test_completion_in_pdb_state(self):
+ """Test requesting tab completions at a (Pdb) prompt."""
+ # GIVEN
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ (
+ "user",
+ {
+ "prompt": "(Pdb) ",
+ "completion_request": {
+ "line": " mod._",
+ "begidx": 8,
+ "endidx": 9,
+ },
+ "input": "print(\n mod.__name__)",
+ },
+ ),
+ ("server", {"completions": ["__name__", "__file__"]}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "_",
+ "line": "mod._",
+ "begidx": 4,
+ "endidx": 5,
+ }
+ },
+ {"reply": "print(\n mod.__name__)"},
+ ],
+ expected_completions=["__name__", "__file__"],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_multiline_completion_in_pdb_state(self):
+ """Test requesting tab completions at a (Pdb) continuation prompt."""
+ # GIVEN
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "if True:"}),
+ (
+ "user",
+ {
+ "prompt": "... ",
+ "completion_request": {
+ "line": " b",
+ "begidx": 4,
+ "endidx": 5,
+ },
+ "input": " bool()",
+ },
+ ),
+ ("server", {"completions": ["bin", "bool", "bytes"]}),
+ ("user", {"prompt": "... ", "input": ""}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "b",
+ "line": "! b",
+ "begidx": 2,
+ "endidx": 3,
+ }
+ },
+ {"reply": "if True:\n bool()\n"},
+ ],
+ expected_completions=["bin", "bool", "bytes"],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_completion_in_interact_state(self):
+ """Test requesting tab completions at a >>> prompt."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ (
+ "user",
+ {
+ "prompt": ">>> ",
+ "completion_request": {
+ "line": " mod.__",
+ "begidx": 8,
+ "endidx": 10,
+ },
+ "input": "print(\n mod.__name__)",
+ },
+ ),
+ ("server", {"completions": ["__name__", "__file__"]}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "__",
+ "line": "mod.__",
+ "begidx": 4,
+ "endidx": 6,
+ }
+ },
+ {"reply": "print(\n mod.__name__)"},
+ ],
+ expected_completions=["__name__", "__file__"],
+ expected_state={"state": "interact"},
+ )
+
+ def test_completion_in_unknown_state(self):
+ """Test requesting tab completions at an unrecognized prompt."""
+ incoming = [
+ ("server", {"command_list": ["p"]}),
+ ("server", {"prompt": "Do it? ", "state": "confirm"}),
+ (
+ "user",
+ {
+ "prompt": "Do it? ",
+ "completion_request": {
+ "line": "_",
+ "begidx": 0,
+ "endidx": 1,
+ },
+ "input": "__name__",
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"reply": "__name__"},
+ ],
+ expected_state={"state": "dumb"},
+ )
+
+ def test_write_failure_during_completion(self):
+ """Test failing to write to the socket to request tab completions."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ (
+ "user",
+ {
+ "prompt": ">>> ",
+ "completion_request": {
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ },
+ "input": "xyz",
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "xy",
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ }
+ },
+ {"reply": "xyz"},
+ ],
+ simulate_send_failure=True,
+ expected_completions=[],
+ expected_state={"state": "interact", "write_failed": True},
+ )
+
+ def test_read_failure_during_completion(self):
+ """Test failing to read tab completions from the socket."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ (
+ "user",
+ {
+ "prompt": ">>> ",
+ "completion_request": {
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ },
+ "input": "xyz",
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "xy",
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ }
+ },
+ {"reply": "xyz"},
+ ],
+ expected_completions=[],
+ expected_state={"state": "interact"},
+ )
+
+ def test_reading_invalid_json_during_completion(self):
+ """Test receiving invalid JSON when getting tab completions."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ (
+ "user",
+ {
+ "prompt": ">>> ",
+ "completion_request": {
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ },
+ "input": "xyz",
+ },
+ ),
+ ("server", b'{"completions": '),
+ ("user", {"prompt": ">>> ", "input": "xyz"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "xy",
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ }
+ },
+ {"reply": "xyz"},
+ ],
+ expected_stdout_substring="*** json.decoder.JSONDecodeError",
+ expected_completions=[],
+ expected_state={"state": "interact"},
+ )
+
+ def test_reading_empty_json_during_completion(self):
+ """Test receiving an empty JSON object when getting tab completions."""
+ incoming = [
+ ("server", {"prompt": ">>> ", "state": "interact"}),
+ (
+ "user",
+ {
+ "prompt": ">>> ",
+ "completion_request": {
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ },
+ "input": "xyz",
+ },
+ ),
+ ("server", {}),
+ ("user", {"prompt": ">>> ", "input": "xyz"}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "xy",
+ "line": "xy",
+ "begidx": 0,
+ "endidx": 2,
+ }
+ },
+ {"reply": "xyz"},
+ ],
+ expected_stdout=(
+ "*** RuntimeError: Failed to get valid completions."
+ " Got: {}\n"
+ ),
+ expected_completions=[],
+ expected_state={"state": "interact"},
+ )
+
+
class RemotePdbTestCase(unittest.TestCase):
"""Tests for the _PdbServer class."""
@@ -298,6 +1036,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=False,
+ colorize=False,
)
return x # This line won't be reached in debugging
@@ -355,23 +1095,6 @@ class PdbConnectTestCase(unittest.TestCase):
client_file.write(json.dumps({"reply": command}).encode() + b"\n")
client_file.flush()
- def _send_interrupt(self, pid):
- """Helper to send an interrupt signal to the debugger."""
- # with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
- interrupt_script = TESTFN + "_interrupt_script.py"
- with open(interrupt_script, 'w') as f:
- f.write(
- 'import pdb, sys\n'
- 'print("Hello, world!")\n'
- 'if inst := pdb.Pdb._last_pdb_instance:\n'
- ' inst.set_trace(sys._getframe(1))\n'
- )
- self.addCleanup(unlink, interrupt_script)
- try:
- sys.remote_exec(pid, interrupt_script)
- except PermissionError:
- self.skipTest("Insufficient permissions to execute code in remote process")
-
def test_connect_and_basic_commands(self):
"""Test connecting to a remote debugger and sending basic commands."""
self._create_script()
@@ -484,6 +1207,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=True,
+ colorize=False,
)
print("Connected to debugger")
iterations = 50
@@ -499,6 +1224,10 @@ class PdbConnectTestCase(unittest.TestCase):
self._create_script(script=script)
process, client_file = self._connect_and_get_client_file()
+ # Accept a 2nd connection from the subprocess to tell it about signals
+ signal_sock, _ = self.server_sock.accept()
+ self.addCleanup(signal_sock.close)
+
with kill_on_error(process):
# Skip initial messages until we get to the prompt
self._read_until_prompt(client_file)
@@ -514,7 +1243,7 @@ class PdbConnectTestCase(unittest.TestCase):
break
# Inject a script to interrupt the running process
- self._send_interrupt(process.pid)
+ signal_sock.sendall(signal.SIGINT.to_bytes())
messages = self._read_until_prompt(client_file)
# Verify we got the keyboard interrupt message.
@@ -570,6 +1299,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=fake_version,
+ signal_raising_thread=False,
+ colorize=False,
)
# This should print if the debugger detaches correctly
@@ -697,5 +1428,151 @@ class PdbConnectTestCase(unittest.TestCase):
self.assertIn("Function returned: 42", stdout)
self.assertEqual(process.returncode, 0)
+
+def _supports_remote_attaching():
+ PROCESS_VM_READV_SUPPORTED = False
+
+ try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ except ImportError:
+ pass
+
+ return PROCESS_VM_READV_SUPPORTED
+
+
+@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled")
+@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32",
+ "Test only runs on Linux, Windows and MacOS")
+@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(),
+ "Testing on Linux requires process_vm_readv support")
+@cpython_only
+@requires_subprocess()
+class PdbAttachTestCase(unittest.TestCase):
+ def setUp(self):
+ # Create a server socket that will wait for the debugger to connect
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock.bind(('127.0.0.1', 0)) # Let OS assign port
+ self.sock.listen(1)
+ self.port = self.sock.getsockname()[1]
+ self._create_script()
+
+ def _create_script(self, script=None):
+ # Create a file for subprocess script
+ script = textwrap.dedent(
+ f"""
+ import socket
+ import time
+
+ def foo():
+ return bar()
+
+ def bar():
+ return baz()
+
+ def baz():
+ x = 1
+ # Trigger attach
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(('127.0.0.1', {self.port}))
+ sock.close()
+ count = 0
+ while x == 1 and count < 100:
+ count += 1
+ time.sleep(0.1)
+ return x
+
+ result = foo()
+ print(f"Function returned: {{result}}")
+ """
+ )
+
+ self.script_path = TESTFN + "_connect_test.py"
+ with open(self.script_path, 'w') as f:
+ f.write(script)
+
+ def tearDown(self):
+ self.sock.close()
+ try:
+ unlink(self.script_path)
+ except OSError:
+ pass
+
+ def do_integration_test(self, client_stdin):
+ process = subprocess.Popen(
+ [sys.executable, self.script_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+ self.addCleanup(process.stdout.close)
+ self.addCleanup(process.stderr.close)
+
+ # Wait for the process to reach our attachment point
+ self.sock.settimeout(10)
+ conn, _ = self.sock.accept()
+ conn.close()
+
+ client_stdin = io.StringIO(client_stdin)
+ client_stdout = io.StringIO()
+ client_stderr = io.StringIO()
+
+ self.addCleanup(client_stdin.close)
+ self.addCleanup(client_stdout.close)
+ self.addCleanup(client_stderr.close)
+ self.addCleanup(process.wait)
+
+ with (
+ unittest.mock.patch("sys.stdin", client_stdin),
+ redirect_stdout(client_stdout),
+ redirect_stderr(client_stderr),
+ unittest.mock.patch("sys.argv", ["pdb", "-p", str(process.pid)]),
+ ):
+ try:
+ pdb.main()
+ except PermissionError:
+ self.skipTest("Insufficient permissions for remote execution")
+
+ process.wait()
+ server_stdout = process.stdout.read()
+ server_stderr = process.stderr.read()
+
+ if process.returncode != 0:
+ print("server failed")
+ print(f"server stdout:\n{server_stdout}")
+ print(f"server stderr:\n{server_stderr}")
+
+ self.assertEqual(process.returncode, 0)
+ return {
+ "client": {
+ "stdout": client_stdout.getvalue(),
+ "stderr": client_stderr.getvalue(),
+ },
+ "server": {
+ "stdout": server_stdout,
+ "stderr": server_stderr,
+ },
+ }
+
+ def test_attach_to_process_without_colors(self):
+ with force_color(False):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("while x == 1", output["client"]["stdout"])
+ self.assertNotIn("\x1b", output["client"]["stdout"])
+
+ def test_attach_to_process_with_colors(self):
+ with force_color(True):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("\x1b", output["client"]["stdout"])
+ self.assertNotIn("while x == 1", output["client"]["stdout"])
+ self.assertIn("while x == 1", re.sub("\x1b[^m]*m", "", output["client"]["stdout"]))
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py
index 228b326699e..f4a4634fc62 100644
--- a/Lib/test/test_repl.py
+++ b/Lib/test/test_repl.py
@@ -38,8 +38,8 @@ def spawn_repl(*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kw):
# line option '-i' and the process name set to '<stdin>'.
# The directory of argv[0] must match the directory of the Python
# executable for the Popen() call to python to succeed as the directory
- # path may be used by Py_GetPath() to build the default module search
- # path.
+ # path may be used by PyConfig_Get("module_search_paths") to build the
+ # default module search path.
stdin_fname = os.path.join(os.path.dirname(sys.executable), "<stdin>")
cmd_line = [stdin_fname, '-I', '-i']
cmd_line.extend(args)
@@ -197,7 +197,7 @@ class TestInteractiveInterpreter(unittest.TestCase):
expected_lines = [
' def f(x, x): ...',
' ^',
- "SyntaxError: duplicate argument 'x' in function definition"
+ "SyntaxError: duplicate parameter 'x' in function definition"
]
self.assertEqual(output.splitlines()[4:-1], expected_lines)
diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py
index ffeb1fba7b8..22a55b57c07 100644
--- a/Lib/test/test_reprlib.py
+++ b/Lib/test/test_reprlib.py
@@ -3,6 +3,7 @@
Nick Mathewson
"""
+import annotationlib
import sys
import os
import shutil
@@ -11,7 +12,7 @@ import importlib.util
import unittest
import textwrap
-from test.support import verbose
+from test.support import verbose, EqualToForwardRef
from test.support.os_helper import create_empty_file
from reprlib import repr as r # Don't shadow builtin repr
from reprlib import Repr
@@ -150,14 +151,38 @@ class ReprTests(unittest.TestCase):
eq(r(frozenset({1, 2, 3, 4, 5, 6, 7})), "frozenset({1, 2, 3, 4, 5, 6, ...})")
def test_numbers(self):
- eq = self.assertEqual
- eq(r(123), repr(123))
- eq(r(123), repr(123))
- eq(r(1.0/3), repr(1.0/3))
-
- n = 10**100
- expected = repr(n)[:18] + "..." + repr(n)[-19:]
- eq(r(n), expected)
+ for x in [123, 1.0 / 3]:
+ self.assertEqual(r(x), repr(x))
+
+ max_digits = sys.get_int_max_str_digits()
+ for k in [100, max_digits - 1]:
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ expected = repr(n)[:18] + "..." + repr(n)[-19:]
+ self.assertEqual(r(n), expected)
+
+ def re_msg(n, d):
+ return (rf'<{n.__class__.__name__} instance with roughly {d} '
+ rf'digits \(limit at {max_digits}\) at 0x[a-f0-9]+>')
+
+ k = max_digits
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
+
+ for k in [max_digits + 1, 2 * max_digits]:
+ self.assertGreater(k, 100)
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
+ with self.subTest(f'10 ** {k} - 1', k=k):
+ n = 10 ** k - 1
+ # Here, since math.log10(n) == math.log10(n-1),
+ # the number of digits of n - 1 is overestimated.
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
def test_instance(self):
eq = self.assertEqual
@@ -172,13 +197,13 @@ class ReprTests(unittest.TestCase):
eq(r(i3), ("<ClassWithFailingRepr instance at %#x>"%id(i3)))
s = r(ClassWithFailingRepr)
- self.assertTrue(s.startswith("<class "))
- self.assertTrue(s.endswith(">"))
+ self.assertStartsWith(s, "<class ")
+ self.assertEndsWith(s, ">")
self.assertIn(s.find("..."), [12, 13])
def test_lambda(self):
r = repr(lambda x: x)
- self.assertTrue(r.startswith("<function ReprTests.test_lambda.<locals>.<lambda"), r)
+ self.assertStartsWith(r, "<function ReprTests.test_lambda.<locals>.<lambda")
# XXX anonymous functions? see func_repr
def test_builtin_function(self):
@@ -186,8 +211,8 @@ class ReprTests(unittest.TestCase):
# Functions
eq(repr(hash), '<built-in function hash>')
# Methods
- self.assertTrue(repr(''.split).startswith(
- '<built-in method split of str object at 0x'))
+ self.assertStartsWith(repr(''.split),
+ '<built-in method split of str object at 0x')
def test_range(self):
eq = self.assertEqual
@@ -372,20 +397,20 @@ class ReprTests(unittest.TestCase):
'object': {
1: 'two',
b'three': [
- (4.5, 6.7),
+ (4.5, 6.25),
[set((8, 9)), frozenset((10, 11))],
],
},
'tests': (
(dict(indent=None), '''\
- {1: 'two', b'three': [(4.5, 6.7), [{8, 9}, frozenset({10, 11})]]}'''),
+ {1: 'two', b'three': [(4.5, 6.25), [{8, 9}, frozenset({10, 11})]]}'''),
(dict(indent=False), '''\
{
1: 'two',
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -405,7 +430,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -425,7 +450,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -445,7 +470,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -465,7 +490,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -493,7 +518,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -513,7 +538,7 @@ class ReprTests(unittest.TestCase):
-->b'three': [
-->-->(
-->-->-->4.5,
- -->-->-->6.7,
+ -->-->-->6.25,
-->-->),
-->-->[
-->-->-->{
@@ -533,7 +558,7 @@ class ReprTests(unittest.TestCase):
....b'three': [
........(
............4.5,
- ............6.7,
+ ............6.25,
........),
........[
............{
@@ -729,8 +754,8 @@ class baz:
importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import baz
ibaz = baz.baz()
- self.assertTrue(repr(ibaz).startswith(
- "<%s.baz object at 0x" % baz.__name__))
+ self.assertStartsWith(repr(ibaz),
+ "<%s.baz object at 0x" % baz.__name__)
def test_method(self):
self._check_path_limitations('qux')
@@ -743,13 +768,13 @@ class aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import qux
# Unbound methods first
r = repr(qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod)
- self.assertTrue(r.startswith('<function aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod'), r)
+ self.assertStartsWith(r, '<function aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod')
# Bound method next
iqux = qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa()
r = repr(iqux.amethod)
- self.assertTrue(r.startswith(
+ self.assertStartsWith(r,
'<bound method aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod of <%s.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa object at 0x' \
- % (qux.__name__,) ), r)
+ % (qux.__name__,) )
@unittest.skip('needs a built-in function with a really long name')
def test_builtin_function(self):
@@ -829,5 +854,19 @@ class TestRecursiveRepr(unittest.TestCase):
self.assertEqual(type_params[0].__name__, 'T')
self.assertEqual(type_params[0].__bound__, str)
+ def test_annotations(self):
+ class My:
+ @recursive_repr()
+ def __repr__(self, default: undefined = ...):
+ return default
+
+ annotations = annotationlib.get_annotations(
+ My.__repr__, format=annotationlib.Format.FORWARDREF
+ )
+ self.assertEqual(
+ annotations,
+ {'default': EqualToForwardRef("undefined", owner=My.__repr__)}
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_rlcompleter.py b/Lib/test/test_rlcompleter.py
index 1cff6a218f8..a8914953ce9 100644
--- a/Lib/test/test_rlcompleter.py
+++ b/Lib/test/test_rlcompleter.py
@@ -54,11 +54,26 @@ class TestRlcompleter(unittest.TestCase):
['str.{}('.format(x) for x in dir(str)
if x.startswith('s')])
self.assertEqual(self.stdcompleter.attr_matches('tuple.foospamegg'), [])
- expected = sorted({'None.%s%s' % (x,
- '()' if x in ('__init_subclass__', '__class__')
- else '' if x == '__doc__'
- else '(')
- for x in dir(None)})
+
+ def create_expected_for_none():
+ if not MISSING_C_DOCSTRINGS:
+ parentheses = ('__init_subclass__', '__class__')
+ else:
+ # When `--without-doc-strings` is used, `__class__`
+ # won't have a known signature.
+ parentheses = ('__init_subclass__',)
+
+ items = set()
+ for x in dir(None):
+ if x in parentheses:
+ items.add(f'None.{x}()')
+ elif x == '__doc__':
+ items.add(f'None.{x}')
+ else:
+ items.add(f'None.{x}(')
+ return sorted(items)
+
+ expected = create_expected_for_none()
self.assertEqual(self.stdcompleter.attr_matches('None.'), expected)
self.assertEqual(self.stdcompleter.attr_matches('None._'), expected)
self.assertEqual(self.stdcompleter.attr_matches('None.__'), expected)
@@ -73,7 +88,7 @@ class TestRlcompleter(unittest.TestCase):
['CompleteMe._ham'])
matches = self.completer.attr_matches('CompleteMe.__')
for x in matches:
- self.assertTrue(x.startswith('CompleteMe.__'), x)
+ self.assertStartsWith(x, 'CompleteMe.__')
self.assertIn('CompleteMe.__name__', matches)
self.assertIn('CompleteMe.__new__(', matches)
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
index ada78ec8e6b..a2a07c04f58 100644
--- a/Lib/test/test_runpy.py
+++ b/Lib/test/test_runpy.py
@@ -796,7 +796,7 @@ class TestExit(unittest.TestCase):
# Use -E to ignore PYTHONSAFEPATH
cmd = [sys.executable, '-E', *cmd]
proc = subprocess.run(cmd, *args, **kwargs, text=True, stderr=subprocess.PIPE)
- self.assertTrue(proc.stderr.endswith("\nKeyboardInterrupt\n"), proc.stderr)
+ self.assertEndsWith(proc.stderr, "\nKeyboardInterrupt\n")
self.assertEqual(proc.returncode, self.EXPECTED_CODE)
def test_pymain_run_file(self):
diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
index 0d0f86c145b..5c10bcedc69 100644
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -1,5 +1,4 @@
# regression test for SAX 2.0
-# $Id$
from xml.sax import make_parser, ContentHandler, \
SAXException, SAXReaderNotAvailable, SAXParseException
@@ -832,8 +831,9 @@ class StreamReaderWriterXmlgenTest(XmlgenTest, unittest.TestCase):
fname = os_helper.TESTFN + '-codecs'
def ioclass(self):
- writer = codecs.open(self.fname, 'w', encoding='ascii',
- errors='xmlcharrefreplace', buffering=0)
+ with self.assertWarns(DeprecationWarning):
+ writer = codecs.open(self.fname, 'w', encoding='ascii',
+ errors='xmlcharrefreplace', buffering=0)
def cleanup():
writer.close()
os_helper.unlink(self.fname)
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index 24a366efc6c..520fbc1b662 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -778,7 +778,7 @@ class ScopeTests(unittest.TestCase):
class X:
locals()["x"] = 43
del x
- self.assertFalse(hasattr(X, "x"))
+ self.assertNotHasAttr(X, "x")
self.assertEqual(x, 42)
@cpython_only
diff --git a/Lib/test/test_script_helper.py b/Lib/test/test_script_helper.py
index f7871fd3b77..eeea6c4842b 100644
--- a/Lib/test/test_script_helper.py
+++ b/Lib/test/test_script_helper.py
@@ -74,8 +74,7 @@ class TestScriptHelperEnvironment(unittest.TestCase):
"""Code coverage for interpreter_requires_environment()."""
def setUp(self):
- self.assertTrue(
- hasattr(script_helper, '__cached_interp_requires_environment'))
+ self.assertHasAttr(script_helper, '__cached_interp_requires_environment')
# Reset the private cached state.
script_helper.__dict__['__cached_interp_requires_environment'] = None
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index c01e323553d..c0df9507bd7 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -237,7 +237,7 @@ class TestJointOps:
if type(self.s) not in (set, frozenset):
self.assertEqual(self.s.x, dup.x)
self.assertEqual(self.s.z, dup.z)
- self.assertFalse(hasattr(self.s, 'y'))
+ self.assertNotHasAttr(self.s, 'y')
del self.s.x, self.s.z
def test_iterator_pickling(self):
@@ -876,8 +876,8 @@ class TestBasicOps:
def check_repr_against_values(self):
text = repr(self.set)
- self.assertTrue(text.startswith('{'))
- self.assertTrue(text.endswith('}'))
+ self.assertStartsWith(text, '{')
+ self.assertEndsWith(text, '}')
result = text[1:-1].split(', ')
result.sort()
diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py
index f35571ea886..a13ddcb76b7 100644
--- a/Lib/test/test_shlex.py
+++ b/Lib/test/test_shlex.py
@@ -3,6 +3,7 @@ import itertools
import shlex
import string
import unittest
+from test.support import cpython_only
from test.support import import_helper
@@ -364,6 +365,7 @@ class ShlexTest(unittest.TestCase):
with self.assertRaises(AttributeError):
shlex_instance.punctuation_chars = False
+ @cpython_only
def test_lazy_imports(self):
import_helper.ensure_lazy_imports('shlex', {'collections', 're', 'os'})
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index ed01163074a..ebb6cf88336 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -427,12 +427,12 @@ class TestRmTree(BaseTest, unittest.TestCase):
else:
self.assertIs(func, os.listdir)
self.assertIn(arg, [TESTFN, self.child_dir_path])
- self.assertTrue(issubclass(exc[0], OSError))
+ self.assertIsSubclass(exc[0], OSError)
self.errorState += 1
else:
self.assertEqual(func, os.rmdir)
self.assertEqual(arg, TESTFN)
- self.assertTrue(issubclass(exc[0], OSError))
+ self.assertIsSubclass(exc[0], OSError)
self.errorState = 3
@unittest.skipIf(sys.platform[:6] == 'cygwin',
@@ -2153,6 +2153,10 @@ class TestArchives(BaseTest, unittest.TestCase):
def test_unpack_archive_bztar(self):
self.check_unpack_tarball('bztar')
+ @support.requires_zstd()
+ def test_unpack_archive_zstdtar(self):
+ self.check_unpack_tarball('zstdtar')
+
@support.requires_lzma()
@unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger")
def test_unpack_archive_xztar(self):
@@ -3475,7 +3479,7 @@ class PublicAPITests(unittest.TestCase):
"""Ensures that the correct values are exposed in the public API."""
def test_module_all_attribute(self):
- self.assertTrue(hasattr(shutil, '__all__'))
+ self.assertHasAttr(shutil, '__all__')
target_api = ['copyfileobj', 'copyfile', 'copymode', 'copystat',
'copy', 'copy2', 'copytree', 'move', 'rmtree', 'Error',
'SpecialFileError', 'make_archive',
@@ -3488,7 +3492,7 @@ class PublicAPITests(unittest.TestCase):
target_api.append('disk_usage')
self.assertEqual(set(shutil.__all__), set(target_api))
with self.assertWarns(DeprecationWarning):
- from shutil import ExecError
+ from shutil import ExecError # noqa: F401
if __name__ == '__main__':
diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py
index a7e9241f44d..d0e32942635 100644
--- a/Lib/test/test_site.py
+++ b/Lib/test/test_site.py
@@ -307,8 +307,7 @@ class HelperFunctionsTests(unittest.TestCase):
with EnvironmentVarGuard() as environ:
environ['PYTHONUSERBASE'] = 'xoxo'
- self.assertTrue(site.getuserbase().startswith('xoxo'),
- site.getuserbase())
+ self.assertStartsWith(site.getuserbase(), 'xoxo')
@unittest.skipUnless(HAS_USER_SITE, 'need user site')
def test_getusersitepackages(self):
@@ -318,7 +317,7 @@ class HelperFunctionsTests(unittest.TestCase):
# the call sets USER_BASE *and* USER_SITE
self.assertEqual(site.USER_SITE, user_site)
- self.assertTrue(user_site.startswith(site.USER_BASE), user_site)
+ self.assertStartsWith(user_site, site.USER_BASE)
self.assertEqual(site.USER_BASE, site.getuserbase())
def test_getsitepackages(self):
@@ -359,11 +358,10 @@ class HelperFunctionsTests(unittest.TestCase):
environ.unset('PYTHONUSERBASE', 'APPDATA')
user_base = site.getuserbase()
- self.assertTrue(user_base.startswith('~' + os.sep),
- user_base)
+ self.assertStartsWith(user_base, '~' + os.sep)
user_site = site.getusersitepackages()
- self.assertTrue(user_site.startswith(user_base), user_site)
+ self.assertStartsWith(user_site, user_base)
with mock.patch('os.path.isdir', return_value=False) as mock_isdir, \
mock.patch.object(site, 'addsitedir') as mock_addsitedir, \
@@ -495,18 +493,18 @@ class ImportSideEffectTests(unittest.TestCase):
def test_setting_quit(self):
# 'quit' and 'exit' should be injected into builtins
- self.assertTrue(hasattr(builtins, "quit"))
- self.assertTrue(hasattr(builtins, "exit"))
+ self.assertHasAttr(builtins, "quit")
+ self.assertHasAttr(builtins, "exit")
def test_setting_copyright(self):
# 'copyright', 'credits', and 'license' should be in builtins
- self.assertTrue(hasattr(builtins, "copyright"))
- self.assertTrue(hasattr(builtins, "credits"))
- self.assertTrue(hasattr(builtins, "license"))
+ self.assertHasAttr(builtins, "copyright")
+ self.assertHasAttr(builtins, "credits")
+ self.assertHasAttr(builtins, "license")
def test_setting_help(self):
# 'help' should be set in builtins
- self.assertTrue(hasattr(builtins, "help"))
+ self.assertHasAttr(builtins, "help")
def test_sitecustomize_executed(self):
# If sitecustomize is available, it should have been imported.
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index ace97ce0cbe..3dd67b2a2ab 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,8 +2,9 @@ import unittest
from unittest import mock
from test import support
from test.support import (
- is_apple, os_helper, refleak_helper, socket_helper, threading_helper
+ cpython_only, is_apple, os_helper, refleak_helper, socket_helper, threading_helper
)
+from test.support.import_helper import ensure_lazy_imports
import _thread as thread
import array
import contextlib
@@ -257,6 +258,12 @@ HAVE_SOCKET_HYPERV = _have_socket_hyperv()
# Size in bytes of the int type
SIZEOF_INT = array.array("i").itemsize
+class TestLazyImport(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("socket", {"array", "selectors"})
+
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
@@ -1078,9 +1085,7 @@ class GeneralModuleTests(unittest.TestCase):
'IPV6_USE_MIN_MTU',
}
for opt in opts:
- self.assertTrue(
- hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'"
- )
+ self.assertHasAttr(socket, opt)
def testHostnameRes(self):
# Testing hostname resolution mechanisms
@@ -1586,11 +1591,11 @@ class GeneralModuleTests(unittest.TestCase):
@unittest.skipUnless(os.name == "nt", "Windows specific")
def test_sock_ioctl(self):
- self.assertTrue(hasattr(socket.socket, 'ioctl'))
- self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
- self.assertTrue(hasattr(socket, 'RCVALL_ON'))
- self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
- self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
+ self.assertHasAttr(socket.socket, 'ioctl')
+ self.assertHasAttr(socket, 'SIO_RCVALL')
+ self.assertHasAttr(socket, 'RCVALL_ON')
+ self.assertHasAttr(socket, 'RCVALL_OFF')
+ self.assertHasAttr(socket, 'SIO_KEEPALIVE_VALS')
s = socket.socket()
self.addCleanup(s.close)
self.assertRaises(ValueError, s.ioctl, -1, None)
@@ -6075,10 +6080,10 @@ class UDPLITETimeoutTest(SocketUDPLITETest):
class TestExceptions(unittest.TestCase):
def testExceptionTree(self):
- self.assertTrue(issubclass(OSError, Exception))
- self.assertTrue(issubclass(socket.herror, OSError))
- self.assertTrue(issubclass(socket.gaierror, OSError))
- self.assertTrue(issubclass(socket.timeout, OSError))
+ self.assertIsSubclass(OSError, Exception)
+ self.assertIsSubclass(socket.herror, OSError)
+ self.assertIsSubclass(socket.gaierror, OSError)
+ self.assertIsSubclass(socket.timeout, OSError)
self.assertIs(socket.error, OSError)
self.assertIs(socket.timeout, TimeoutError)
diff --git a/Lib/test/test_source_encoding.py b/Lib/test/test_source_encoding.py
index 61b00778f83..1399f3fcd2d 100644
--- a/Lib/test/test_source_encoding.py
+++ b/Lib/test/test_source_encoding.py
@@ -145,8 +145,7 @@ class MiscSourceEncodingTest(unittest.TestCase):
compile(input, "<string>", "exec")
expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \
"ordinal not in range(128)"
- self.assertTrue(c.exception.args[0].startswith(expected),
- msg=c.exception.args[0])
+ self.assertStartsWith(c.exception.args[0], expected)
def test_file_parse_error_multiline(self):
# gh96611:
diff --git a/Lib/test/test_sqlite3/test_cli.py b/Lib/test/test_sqlite3/test_cli.py
index dcd90d11d46..720fa3c4c1e 100644
--- a/Lib/test/test_sqlite3/test_cli.py
+++ b/Lib/test/test_sqlite3/test_cli.py
@@ -1,12 +1,26 @@
"""sqlite3 CLI tests."""
import sqlite3
+import sys
+import textwrap
import unittest
+import unittest.mock
+import os
from sqlite3.__main__ import main as cli
+from test.support.import_helper import import_module
from test.support.os_helper import TESTFN, unlink
-from test.support import captured_stdout, captured_stderr, captured_stdin
+from test.support.pty_helper import run_pty
+from test.support import (
+ captured_stdout,
+ captured_stderr,
+ captured_stdin,
+ force_not_colorized_test_class,
+ requires_subprocess,
+ verbose,
+)
+@force_not_colorized_test_class
class CommandLineInterface(unittest.TestCase):
def _do_test(self, *args, expect_success=True):
@@ -63,6 +77,7 @@ class CommandLineInterface(unittest.TestCase):
self.assertIn("(0,)", out)
+@force_not_colorized_test_class
class InteractiveSession(unittest.TestCase):
MEMORY_DB_MSG = "Connected to a transient in-memory database"
PS1 = "sqlite> "
@@ -110,6 +125,38 @@ class InteractiveSession(unittest.TestCase):
self.assertEqual(out.count(self.PS2), 0)
self.assertIn(sqlite3.sqlite_version, out)
+ def test_interact_empty_source(self):
+ out, err = self.run_cli(commands=("", " "))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 3)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_dot_commands_unknown(self):
+ out, err = self.run_cli(commands=(".unknown_command", ))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
+ self.assertIn('Error: unknown command: "', err)
+ # test "unknown_command" is pointed out in the error message
+ self.assertIn("unknown_command", err)
+
+ def test_interact_dot_commands_empty(self):
+ out, err = self.run_cli(commands=("."))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_dot_commands_with_whitespaces(self):
+ out, err = self.run_cli(commands=(".version ", ". version"))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEqual(out.count(sqlite3.sqlite_version + "\n"), 2)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 3)
+ self.assertEqual(out.count(self.PS2), 0)
+
def test_interact_valid_sql(self):
out, err = self.run_cli(commands=("SELECT 1;",))
self.assertIn(self.MEMORY_DB_MSG, err)
@@ -152,6 +199,117 @@ class InteractiveSession(unittest.TestCase):
out, _ = self.run_cli(TESTFN, commands=("SELECT count(t) FROM t;",))
self.assertIn("(0,)\n", out)
+ def test_color(self):
+ with unittest.mock.patch("_colorize.can_colorize", return_value=True):
+ out, err = self.run_cli(commands="TEXT\n")
+ self.assertIn("\x1b[1;35msqlite> \x1b[0m", out)
+ self.assertIn("\x1b[1;35m ... \x1b[0m\x1b", out)
+ out, err = self.run_cli(commands=("sel;",))
+ self.assertIn('\x1b[1;35mOperationalError (SQLITE_ERROR)\x1b[0m: '
+ '\x1b[35mnear "sel": syntax error\x1b[0m', err)
+
+
+@requires_subprocess()
+@force_not_colorized_test_class
+class Completion(unittest.TestCase):
+ PS1 = "sqlite> "
+
+ @classmethod
+ def setUpClass(cls):
+ _sqlite3 = import_module("_sqlite3")
+ if not hasattr(_sqlite3, "SQLITE_KEYWORDS"):
+ raise unittest.SkipTest("unable to determine SQLite keywords")
+
+ readline = import_module("readline")
+ if readline.backend == "editline":
+ raise unittest.SkipTest("libedit readline is not supported")
+
+ def write_input(self, input_, env=None):
+ script = textwrap.dedent("""
+ import readline
+ from sqlite3.__main__ import main
+
+ readline.parse_and_bind("set colored-completion-prefix off")
+ main()
+ """)
+ return run_pty(script, input_, env)
+
+ def test_complete_sql_keywords(self):
+ # List candidates starting with 'S', there should be multiple matches.
+ input_ = b"S\t\tEL\t 1;\n.quit\n"
+ output = self.write_input(input_)
+ self.assertIn(b"SELECT", output)
+ self.assertIn(b"SET", output)
+ self.assertIn(b"SAVEPOINT", output)
+ self.assertIn(b"(1,)", output)
+
+ # Keywords are completed in upper case for even lower case user input.
+ input_ = b"sel\t\t 1;\n.quit\n"
+ output = self.write_input(input_)
+ self.assertIn(b"SELECT", output)
+ self.assertIn(b"(1,)", output)
+
+ @unittest.skipIf(sys.platform.startswith("freebsd"),
+ "Two actual tabs are inserted when there are no matching"
+ " completions in the pseudo-terminal opened by run_pty()"
+ " on FreeBSD")
+ def test_complete_no_match(self):
+ input_ = b"xyzzy\t\t\b\b\b\b\b\b\b.quit\n"
+ # Set NO_COLOR to disable coloring for self.PS1.
+ output = self.write_input(input_, env={**os.environ, "NO_COLOR": "1"})
+ lines = output.decode().splitlines()
+ indices = (
+ i for i, line in enumerate(lines, 1)
+ if line.startswith(f"{self.PS1}xyzzy")
+ )
+ line_num = next(indices, -1)
+ self.assertNotEqual(line_num, -1)
+ # Completions occupy lines, assert no extra lines when there is nothing
+ # to complete.
+ self.assertEqual(line_num, len(lines))
+
+ def test_complete_no_input(self):
+ from _sqlite3 import SQLITE_KEYWORDS
+
+ script = textwrap.dedent("""
+ import readline
+ from sqlite3.__main__ import main
+
+ # Configure readline to ...:
+ # - hide control sequences surrounding each candidate
+ # - hide "Display all xxx possibilities? (y or n)"
+ # - hide "--More--"
+ # - show candidates one per line
+ readline.parse_and_bind("set colored-completion-prefix off")
+ readline.parse_and_bind("set colored-stats off")
+ readline.parse_and_bind("set completion-query-items 0")
+ readline.parse_and_bind("set page-completions off")
+ readline.parse_and_bind("set completion-display-width 0")
+ readline.parse_and_bind("set show-all-if-ambiguous off")
+ readline.parse_and_bind("set show-all-if-unmodified off")
+
+ main()
+ """)
+ input_ = b"\t\t.quit\n"
+ output = run_pty(script, input_, env={**os.environ, "NO_COLOR": "1"})
+ try:
+ lines = output.decode().splitlines()
+ indices = [
+ i for i, line in enumerate(lines)
+ if line.startswith(self.PS1)
+ ]
+ self.assertEqual(len(indices), 2)
+ start, end = indices
+ candidates = [l.strip() for l in lines[start+1:end]]
+ self.assertEqual(candidates, sorted(SQLITE_KEYWORDS))
+ except:
+ if verbose:
+ print(' PTY output: '.center(30, '-'))
+ print(output.decode(errors='replace'))
+ print(' end PTY output '.center(30, '-'))
+ raise
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index c3aa3bf2d7b..291e0356253 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -550,17 +550,9 @@ class ConnectionTests(unittest.TestCase):
cx.execute("insert into u values(0)")
def test_connect_positional_arguments(self):
- regex = (
- r"Passing more than 1 positional argument to sqlite3.connect\(\)"
- " is deprecated. Parameters 'timeout', 'detect_types', "
- "'isolation_level', 'check_same_thread', 'factory', "
- "'cached_statements' and 'uri' will become keyword-only "
- "parameters in Python 3.15."
- )
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- cx = sqlite.connect(":memory:", 1.0)
- cx.close()
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ r'connect\(\) takes at most 1 positional arguments'):
+ sqlite.connect(":memory:", 1.0)
def test_connection_resource_warning(self):
with self.assertWarns(ResourceWarning):
diff --git a/Lib/test/test_sqlite3/test_factory.py b/Lib/test/test_sqlite3/test_factory.py
index cc9f1ec5c4b..776659e3b16 100644
--- a/Lib/test/test_sqlite3/test_factory.py
+++ b/Lib/test/test_sqlite3/test_factory.py
@@ -71,18 +71,9 @@ class ConnectionFactoryTests(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(Factory, self).__init__(*args, **kwargs)
- regex = (
- r"Passing more than 1 positional argument to _sqlite3.Connection\(\) "
- r"is deprecated. Parameters 'timeout', 'detect_types', "
- r"'isolation_level', 'check_same_thread', 'factory', "
- r"'cached_statements' and 'uri' will become keyword-only "
- r"parameters in Python 3.15."
- )
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- with memory_database(5.0, 0, None, True, Factory) as con:
- self.assertIsNone(con.isolation_level)
- self.assertIsInstance(con, Factory)
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ r'connect\(\) takes at most 1 positional arguments'):
+ memory_database(5.0, 0, None, True, Factory)
class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py
index 53b8a39bf29..2b907e35131 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -220,16 +220,9 @@ class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
""")
def test_progress_handler_keyword_args(self):
- regex = (
- r"Passing keyword argument 'progress_handler' to "
- r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. "
- r"Parameter 'progress_handler' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes at least 1 positional argument'):
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
- self.assertEqual(cm.filename, __file__)
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@@ -353,16 +346,9 @@ class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
cx.execute("select 1")
def test_trace_keyword_args(self):
- regex = (
- r"Passing keyword argument 'trace_callback' to "
- r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. "
- r"Parameter 'trace_callback' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 1 positional argument'):
self.con.set_trace_callback(trace_callback=lambda: None)
- self.assertEqual(cm.filename, __file__)
if __name__ == "__main__":
diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py
index 3abc43a3b1a..11cf877a011 100644
--- a/Lib/test/test_sqlite3/test_userfunctions.py
+++ b/Lib/test/test_sqlite3/test_userfunctions.py
@@ -422,27 +422,9 @@ class FunctionTests(unittest.TestCase):
self.con.execute, "select badreturn()")
def test_func_keyword_args(self):
- regex = (
- r"Passing keyword arguments 'name', 'narg' and 'func' to "
- r"_sqlite3.Connection.create_function\(\) is deprecated. "
- r"Parameters 'name', 'narg' and 'func' will become "
- r"positional-only in Python 3.15."
- )
-
- def noop():
- return None
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function("noop", 0, func=noop)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function("noop", narg=0, func=noop)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function(name="noop", narg=0, func=noop)
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 3 positional arguments'):
+ self.con.create_function("noop", 0, func=lambda: None)
class WindowSumInt:
@@ -737,25 +719,9 @@ class AggregateTests(unittest.TestCase):
self.assertEqual(val, txt)
def test_agg_keyword_args(self):
- regex = (
- r"Passing keyword arguments 'name', 'n_arg' and 'aggregate_class' to "
- r"_sqlite3.Connection.create_aggregate\(\) is deprecated. "
- r"Parameters 'name', 'n_arg' and 'aggregate_class' will become "
- r"positional-only in Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 3 positional arguments'):
self.con.create_aggregate("test", 1, aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_aggregate("test", n_arg=1, aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_aggregate(name="test", n_arg=0,
- aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
class AuthorizerTests(unittest.TestCase):
@@ -800,16 +766,9 @@ class AuthorizerTests(unittest.TestCase):
self.con.execute("select c2 from t1")
def test_authorizer_keyword_args(self):
- regex = (
- r"Passing keyword argument 'authorizer_callback' to "
- r"_sqlite3.Connection.set_authorizer\(\) is deprecated. "
- r"Parameter 'authorizer_callback' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 1 positional argument'):
self.con.set_authorizer(authorizer_callback=lambda: None)
- self.assertEqual(cm.filename, __file__)
class AuthorizerRaiseExceptionTests(AuthorizerTests):
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 395b2ef88ab..f123f6ece40 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -31,6 +31,7 @@ import weakref
import platform
import sysconfig
import functools
+from contextlib import nullcontext
try:
import ctypes
except ImportError:
@@ -539,9 +540,9 @@ class BasicSocketTests(unittest.TestCase):
openssl_ver = f"OpenSSL {major:d}.{minor:d}.{patch:d}"
else:
openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
- self.assertTrue(
- s.startswith((openssl_ver, libressl_ver, "AWS-LC")),
- (s, t, hex(n))
+ self.assertStartsWith(
+ s, (openssl_ver, libressl_ver, "AWS-LC"),
+ (t, hex(n))
)
@support.cpython_only
@@ -1668,7 +1669,7 @@ class SSLErrorTests(unittest.TestCase):
regex = "(NO_START_LINE|UNSUPPORTED_PUBLIC_KEY_TYPE)"
self.assertRegex(cm.exception.reason, regex)
s = str(cm.exception)
- self.assertTrue("NO_START_LINE" in s, s)
+ self.assertIn("NO_START_LINE", s)
def test_subclass(self):
# Check that the appropriate SSLError subclass is raised
@@ -1683,7 +1684,7 @@ class SSLErrorTests(unittest.TestCase):
with self.assertRaises(ssl.SSLWantReadError) as cm:
c.do_handshake()
s = str(cm.exception)
- self.assertTrue(s.startswith("The operation did not complete (read)"), s)
+ self.assertStartsWith(s, "The operation did not complete (read)")
# For compatibility
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
@@ -2843,6 +2844,7 @@ class ThreadedTests(unittest.TestCase):
# See GH-124984: OpenSSL is not thread safe.
threads = []
+ warnings_filters = sys.flags.context_aware_warnings
global USE_SAME_TEST_CONTEXT
USE_SAME_TEST_CONTEXT = True
try:
@@ -2851,7 +2853,10 @@ class ThreadedTests(unittest.TestCase):
self.test_alpn_protocols,
self.test_getpeercert,
self.test_crl_check,
- self.test_check_hostname_idn,
+ functools.partial(
+ self.test_check_hostname_idn,
+ warnings_filters=warnings_filters,
+ ),
self.test_wrong_cert_tls12,
self.test_wrong_cert_tls13,
):
@@ -3097,7 +3102,7 @@ class ThreadedTests(unittest.TestCase):
cipher = s.cipher()[0].split('-')
self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
- def test_check_hostname_idn(self):
+ def test_check_hostname_idn(self, warnings_filters=True):
if support.verbose:
sys.stdout.write("\n")
@@ -3152,16 +3157,30 @@ class ThreadedTests(unittest.TestCase):
server_hostname="python.example.org") as s:
with self.assertRaises(ssl.CertificateError):
s.connect((HOST, server.port))
- with ThreadedEchoServer(context=server_context, chatty=True) as server:
- with warnings_helper.check_no_resource_warning(self):
- with self.assertRaises(UnicodeError):
- context.wrap_socket(socket.socket(),
- server_hostname='.pythontest.net')
- with ThreadedEchoServer(context=server_context, chatty=True) as server:
- with warnings_helper.check_no_resource_warning(self):
- with self.assertRaises(UnicodeDecodeError):
- context.wrap_socket(socket.socket(),
- server_hostname=b'k\xf6nig.idn.pythontest.net')
+ with (
+ ThreadedEchoServer(context=server_context, chatty=True) as server,
+ (
+ warnings_helper.check_no_resource_warning(self)
+ if warnings_filters
+ else nullcontext()
+ ),
+ self.assertRaises(UnicodeError),
+ ):
+ context.wrap_socket(socket.socket(), server_hostname='.pythontest.net')
+
+ with (
+ ThreadedEchoServer(context=server_context, chatty=True) as server,
+ (
+ warnings_helper.check_no_resource_warning(self)
+ if warnings_filters
+ else nullcontext()
+ ),
+ self.assertRaises(UnicodeDecodeError),
+ ):
+ context.wrap_socket(
+ socket.socket(),
+ server_hostname=b'k\xf6nig.idn.pythontest.net',
+ )
def test_wrong_cert_tls12(self):
"""Connecting when the server rejects the client's certificate
@@ -4488,6 +4507,7 @@ class ThreadedTests(unittest.TestCase):
@requires_tls_version('TLSv1_3')
@unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build')
+ @unittest.skipUnless(ssl.HAS_PSK_TLS13, 'TLS 1.3 PSK disabled on this OpenSSL build')
def test_psk_tls1_3(self):
psk = bytes.fromhex('deadbeef')
identity_hint = 'identity-hint'
diff --git a/Lib/test/test_stable_abi_ctypes.py b/Lib/test/test_stable_abi_ctypes.py
index 1e6f69d49e9..5a6ba9de337 100644
--- a/Lib/test/test_stable_abi_ctypes.py
+++ b/Lib/test/test_stable_abi_ctypes.py
@@ -658,7 +658,11 @@ SYMBOL_NAMES = (
"PySys_AuditTuple",
"PySys_FormatStderr",
"PySys_FormatStdout",
+ "PySys_GetAttr",
+ "PySys_GetAttrString",
"PySys_GetObject",
+ "PySys_GetOptionalAttr",
+ "PySys_GetOptionalAttrString",
"PySys_GetXOptions",
"PySys_HasWarnOptions",
"PySys_ResetWarnOptions",
diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py
index 49013a4bcd8..5fd25d5012c 100644
--- a/Lib/test/test_stat.py
+++ b/Lib/test/test_stat.py
@@ -157,7 +157,7 @@ class TestFilemode:
os.chmod(TESTFN, 0o700)
st_mode, modestr = self.get_mode()
- self.assertEqual(modestr[:3], '-rw')
+ self.assertStartsWith(modestr, '-rw')
self.assertS_IS("REG", st_mode)
self.assertEqual(self.statmod.S_IFMT(st_mode),
self.statmod.S_IFREG)
@@ -256,7 +256,7 @@ class TestFilemode:
"FILE_ATTRIBUTE_* constants are Win32 specific")
def test_file_attribute_constants(self):
for key, value in sorted(self.file_attributes.items()):
- self.assertTrue(hasattr(self.statmod, key), key)
+ self.assertHasAttr(self.statmod, key)
modvalue = getattr(self.statmod, key)
self.assertEqual(value, modvalue, key)
@@ -314,7 +314,7 @@ class TestFilemode:
self.assertEqual(self.statmod.S_ISGID, 0o002000)
self.assertEqual(self.statmod.S_ISVTX, 0o001000)
- self.assertFalse(hasattr(self.statmod, "S_ISTXT"))
+ self.assertNotHasAttr(self.statmod, "S_ISTXT")
self.assertEqual(self.statmod.S_IREAD, self.statmod.S_IRUSR)
self.assertEqual(self.statmod.S_IWRITE, self.statmod.S_IWUSR)
self.assertEqual(self.statmod.S_IEXEC, self.statmod.S_IXUSR)
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index c69baa4bf4d..8250b0aef09 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -645,7 +645,7 @@ class TestNumericTestCase(unittest.TestCase):
def test_numerictestcase_is_testcase(self):
# Ensure that NumericTestCase actually is a TestCase.
- self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
+ self.assertIsSubclass(NumericTestCase, unittest.TestCase)
def test_error_msg_numeric(self):
# Test the error message generated for numeric comparisons.
@@ -683,32 +683,23 @@ class GlobalsTest(unittest.TestCase):
def test_meta(self):
# Test for the existence of metadata.
for meta in self.expected_metadata:
- self.assertTrue(hasattr(self.module, meta),
- "%s not present" % meta)
+ self.assertHasAttr(self.module, meta)
def test_check_all(self):
# Check everything in __all__ exists and is public.
module = self.module
for name in module.__all__:
# No private names in __all__:
- self.assertFalse(name.startswith("_"),
+ self.assertNotStartsWith(name, "_",
'private name "%s" in __all__' % name)
# And anything in __all__ must exist:
- self.assertTrue(hasattr(module, name),
- 'missing name "%s" in __all__' % name)
+ self.assertHasAttr(module, name)
class StatisticsErrorTest(unittest.TestCase):
def test_has_exception(self):
- errmsg = (
- "Expected StatisticsError to be a ValueError, but got a"
- " subclass of %r instead."
- )
- self.assertTrue(hasattr(statistics, 'StatisticsError'))
- self.assertTrue(
- issubclass(statistics.StatisticsError, ValueError),
- errmsg % statistics.StatisticsError.__base__
- )
+ self.assertHasAttr(statistics, 'StatisticsError')
+ self.assertIsSubclass(statistics.StatisticsError, ValueError)
# === Tests for private utility functions ===
@@ -2355,6 +2346,7 @@ class TestGeometricMean(unittest.TestCase):
class TestKDE(unittest.TestCase):
+ @support.requires_resource('cpu')
def test_kde(self):
kde = statistics.kde
StatisticsError = statistics.StatisticsError
@@ -3327,7 +3319,8 @@ class TestNormalDistC(unittest.TestCase, TestNormalDist):
def load_tests(loader, tests, ignore):
"""Used for doctest/unittest integration."""
tests.addTests(doctest.DocTestSuite())
- tests.addTests(doctest.DocTestSuite(statistics))
+ if sys.float_repr_style == 'short':
+ tests.addTests(doctest.DocTestSuite(statistics))
return tests
diff --git a/Lib/test/test_str.py b/Lib/test/test_str.py
index d6a7bd0da59..2584fbf72d3 100644
--- a/Lib/test/test_str.py
+++ b/Lib/test/test_str.py
@@ -1231,10 +1231,10 @@ class StrTest(string_tests.StringLikeTest,
self.assertEqual('{0:\x00^6}'.format(3), '\x00\x003\x00\x00\x00')
self.assertEqual('{0:<6}'.format(3), '3 ')
- self.assertEqual('{0:\x00<6}'.format(3.14), '3.14\x00\x00')
- self.assertEqual('{0:\x01<6}'.format(3.14), '3.14\x01\x01')
- self.assertEqual('{0:\x00^6}'.format(3.14), '\x003.14\x00')
- self.assertEqual('{0:^6}'.format(3.14), ' 3.14 ')
+ self.assertEqual('{0:\x00<6}'.format(3.25), '3.25\x00\x00')
+ self.assertEqual('{0:\x01<6}'.format(3.25), '3.25\x01\x01')
+ self.assertEqual('{0:\x00^6}'.format(3.25), '\x003.25\x00')
+ self.assertEqual('{0:^6}'.format(3.25), ' 3.25 ')
self.assertEqual('{0:\x00<12}'.format(3+2.0j), '(3+2j)\x00\x00\x00\x00\x00\x00')
self.assertEqual('{0:\x01<12}'.format(3+2.0j), '(3+2j)\x01\x01\x01\x01\x01\x01')
diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py
index 752e31359cf..375f6aaedd8 100644
--- a/Lib/test/test_strftime.py
+++ b/Lib/test/test_strftime.py
@@ -39,7 +39,21 @@ class StrftimeTest(unittest.TestCase):
if now[3] < 12: self.ampm='(AM|am)'
else: self.ampm='(PM|pm)'
- self.jan1 = time.localtime(time.mktime((now[0], 1, 1, 0, 0, 0, 0, 1, 0)))
+ jan1 = time.struct_time(
+ (
+ now.tm_year, # Year
+ 1, # Month (January)
+ 1, # Day (1st)
+ 0, # Hour (0)
+ 0, # Minute (0)
+ 0, # Second (0)
+ -1, # tm_wday (will be determined)
+ 1, # tm_yday (day 1 of the year)
+ -1, # tm_isdst (let the system determine)
+ )
+ )
+ # use mktime to get the correct tm_wday and tm_isdst values
+ self.jan1 = time.localtime(time.mktime(jan1))
try:
if now[8]: self.tz = time.tzname[1]
diff --git a/Lib/test/test_string/__init__.py b/Lib/test/test_string/__init__.py
new file mode 100644
index 00000000000..4b16ecc3115
--- /dev/null
+++ b/Lib/test/test_string/__init__.py
@@ -0,0 +1,5 @@
+import os
+from test.support import load_package_tests
+
+def load_tests(*args):
+ return load_package_tests(os.path.dirname(__file__), *args)
diff --git a/Lib/test/test_string/_support.py b/Lib/test/test_string/_support.py
new file mode 100644
index 00000000000..abdddaf187b
--- /dev/null
+++ b/Lib/test/test_string/_support.py
@@ -0,0 +1,54 @@
+from string.templatelib import Interpolation
+
+
+class TStringBaseCase:
+ def assertTStringEqual(self, t, strings, interpolations):
+ """Test template string literal equality.
+
+ The *strings* argument must be a tuple of strings equal to *t.strings*.
+
+ The *interpolations* argument must be a sequence of tuples which are
+ compared against *t.interpolations*. Each tuple consists of
+ (value, expression, conversion, format_spec), though the final two
+ items may be omitted, and are assumed to be None and '' respectively.
+ """
+ self.assertEqual(t.strings, strings)
+ self.assertEqual(len(t.interpolations), len(interpolations))
+
+ for i, exp in zip(t.interpolations, interpolations, strict=True):
+ if len(exp) == 4:
+ actual = (i.value, i.expression, i.conversion, i.format_spec)
+ self.assertEqual(actual, exp)
+ continue
+
+ if len(exp) == 3:
+ self.assertEqual((i.value, i.expression, i.conversion), exp)
+ self.assertEqual(i.format_spec, '')
+ continue
+
+ self.assertEqual((i.value, i.expression), exp)
+ self.assertEqual(i.format_spec, '')
+ self.assertIsNone(i.conversion)
+
+
+def convert(value, conversion):
+ if conversion == "a":
+ return ascii(value)
+ elif conversion == "r":
+ return repr(value)
+ elif conversion == "s":
+ return str(value)
+ return value
+
+
+def fstring(template):
+ parts = []
+ for item in template:
+ match item:
+ case str() as s:
+ parts.append(s)
+ case Interpolation(value, _, conversion, format_spec):
+ value = convert(value, conversion)
+ value = format(value, format_spec)
+ parts.append(value)
+ return "".join(parts)
diff --git a/Lib/test/test_string.py b/Lib/test/test_string/test_string.py
index f6d112d8a93..5394fe4e12c 100644
--- a/Lib/test/test_string.py
+++ b/Lib/test/test_string/test_string.py
@@ -2,6 +2,14 @@ import unittest
import string
from string import Template
import types
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "collections"})
class ModuleTest(unittest.TestCase):
diff --git a/Lib/test/test_string/test_templatelib.py b/Lib/test/test_string/test_templatelib.py
new file mode 100644
index 00000000000..85fcff486d6
--- /dev/null
+++ b/Lib/test/test_string/test_templatelib.py
@@ -0,0 +1,160 @@
+import pickle
+import unittest
+from collections.abc import Iterator, Iterable
+from string.templatelib import Template, Interpolation
+
+from test.test_string._support import TStringBaseCase, fstring
+
+
+class TestTemplate(unittest.TestCase, TStringBaseCase):
+
+ def test_common(self):
+ self.assertEqual(type(t'').__name__, 'Template')
+ self.assertEqual(type(t'').__qualname__, 'Template')
+ self.assertEqual(type(t'').__module__, 'string.templatelib')
+
+ a = 'a'
+ i = t'{a}'.interpolations[0]
+ self.assertEqual(type(i).__name__, 'Interpolation')
+ self.assertEqual(type(i).__qualname__, 'Interpolation')
+ self.assertEqual(type(i).__module__, 'string.templatelib')
+
+ def test_final_types(self):
+ with self.assertRaisesRegex(TypeError, 'is not an acceptable base type'):
+ class Sub(Template): ...
+
+ with self.assertRaisesRegex(TypeError, 'is not an acceptable base type'):
+ class Sub(Interpolation): ...
+
+ def test_basic_creation(self):
+ # Simple t-string creation
+ t = t'Hello, world'
+ self.assertIsInstance(t, Template)
+ self.assertTStringEqual(t, ('Hello, world',), ())
+ self.assertEqual(fstring(t), 'Hello, world')
+
+ # Empty t-string
+ t = t''
+ self.assertTStringEqual(t, ('',), ())
+ self.assertEqual(fstring(t), '')
+
+ # Multi-line t-string
+ t = t"""Hello,
+world"""
+ self.assertEqual(t.strings, ('Hello,\nworld',))
+ self.assertEqual(len(t.interpolations), 0)
+ self.assertEqual(fstring(t), 'Hello,\nworld')
+
+ def test_creation_interleaving(self):
+ # Should add strings on either side
+ t = Template(Interpolation('Maria', 'name', None, ''))
+ self.assertTStringEqual(t, ('', ''), [('Maria', 'name')])
+ self.assertEqual(fstring(t), 'Maria')
+
+ # Should prepend empty string
+ t = Template(Interpolation('Maria', 'name', None, ''), ' is my name')
+ self.assertTStringEqual(t, ('', ' is my name'), [('Maria', 'name')])
+ self.assertEqual(fstring(t), 'Maria is my name')
+
+ # Should append empty string
+ t = Template('Hello, ', Interpolation('Maria', 'name', None, ''))
+ self.assertTStringEqual(t, ('Hello, ', ''), [('Maria', 'name')])
+ self.assertEqual(fstring(t), 'Hello, Maria')
+
+ # Should concatenate strings
+ t = Template('Hello', ', ', Interpolation('Maria', 'name', None, ''),
+ '!')
+ self.assertTStringEqual(t, ('Hello, ', '!'), [('Maria', 'name')])
+ self.assertEqual(fstring(t), 'Hello, Maria!')
+
+ # Should add strings on either side and in between
+ t = Template(Interpolation('Maria', 'name', None, ''),
+ Interpolation('Python', 'language', None, ''))
+ self.assertTStringEqual(
+ t, ('', '', ''), [('Maria', 'name'), ('Python', 'language')]
+ )
+ self.assertEqual(fstring(t), 'MariaPython')
+
+ def test_template_values(self):
+ t = t'Hello, world'
+ self.assertEqual(t.values, ())
+
+ name = "Lys"
+ t = t'Hello, {name}'
+ self.assertEqual(t.values, ("Lys",))
+
+ country = "GR"
+ age = 0
+ t = t'Hello, {name}, {age} from {country}'
+ self.assertEqual(t.values, ("Lys", 0, "GR"))
+
+ def test_pickle_template(self):
+ user = 'test'
+ for template in (
+ t'',
+ t"No values",
+ t'With inter {user}',
+ t'With ! {user!r}',
+ t'With format {1 / 0.3:.2f}',
+ Template(),
+ Template('a'),
+ Template(Interpolation('Nikita', 'name', None, '')),
+ Template('a', Interpolation('Nikita', 'name', 'r', '')),
+ ):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto, template=template):
+ pickled = pickle.dumps(template, protocol=proto)
+ unpickled = pickle.loads(pickled)
+
+ self.assertEqual(unpickled.values, template.values)
+ self.assertEqual(fstring(unpickled), fstring(template))
+
+ def test_pickle_interpolation(self):
+ for interpolation in (
+ Interpolation('Nikita', 'name', None, ''),
+ Interpolation('Nikita', 'name', 'r', ''),
+ Interpolation(1/3, 'x', None, '.2f'),
+ ):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto, interpolation=interpolation):
+ pickled = pickle.dumps(interpolation, protocol=proto)
+ unpickled = pickle.loads(pickled)
+
+ self.assertEqual(unpickled.value, interpolation.value)
+ self.assertEqual(unpickled.expression, interpolation.expression)
+ self.assertEqual(unpickled.conversion, interpolation.conversion)
+ self.assertEqual(unpickled.format_spec, interpolation.format_spec)
+
+
+class TemplateIterTests(unittest.TestCase):
+ def test_abc(self):
+ self.assertIsInstance(iter(t''), Iterable)
+ self.assertIsInstance(iter(t''), Iterator)
+
+ def test_final(self):
+ TemplateIter = type(iter(t''))
+ with self.assertRaisesRegex(TypeError, 'is not an acceptable base type'):
+ class Sub(TemplateIter): ...
+
+ def test_iter(self):
+ x = 1
+ res = list(iter(t'abc {x} yz'))
+
+ self.assertEqual(res[0], 'abc ')
+ self.assertIsInstance(res[1], Interpolation)
+ self.assertEqual(res[1].value, 1)
+ self.assertEqual(res[1].expression, 'x')
+ self.assertEqual(res[1].conversion, None)
+ self.assertEqual(res[1].format_spec, '')
+ self.assertEqual(res[2], ' yz')
+
+ def test_exhausted(self):
+ # See https://github.com/python/cpython/issues/134119.
+ template_iter = iter(t"{1}")
+ self.assertIsInstance(next(template_iter), Interpolation)
+ self.assertRaises(StopIteration, next, template_iter)
+ self.assertRaises(StopIteration, next, template_iter)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_strptime.py b/Lib/test/test_strptime.py
index 268230f6da7..0241e543cd7 100644
--- a/Lib/test/test_strptime.py
+++ b/Lib/test/test_strptime.py
@@ -221,14 +221,16 @@ class StrptimeTests(unittest.TestCase):
self.assertRaises(ValueError, _strptime._strptime_time, data_string="%d",
format="%A")
for bad_format in ("%", "% ", "%\n"):
- with self.assertRaisesRegex(ValueError, "stray % in format "):
+ with (self.subTest(format=bad_format),
+ self.assertRaisesRegex(ValueError, "stray % in format ")):
_strptime._strptime_time("2005", bad_format)
- for bad_format in ("%e", "%Oe", "%O", "%O ", "%Ee", "%E", "%E ",
- "%.", "%+", "%_", "%~", "%\\",
+ for bad_format in ("%i", "%Oi", "%O", "%O ", "%Ee", "%E", "%E ",
+ "%.", "%+", "%~", "%\\",
"%O.", "%O+", "%O_", "%O~", "%O\\"):
directive = bad_format[1:].rstrip()
- with self.assertRaisesRegex(ValueError,
- f"'{re.escape(directive)}' is a bad directive in format "):
+ with (self.subTest(format=bad_format),
+ self.assertRaisesRegex(ValueError,
+ f"'{re.escape(directive)}' is a bad directive in format ")):
_strptime._strptime_time("2005", bad_format)
msg_week_no_year_or_weekday = r"ISO week directive '%V' must be used with " \
@@ -335,6 +337,15 @@ class StrptimeTests(unittest.TestCase):
self.roundtrip('%B', 1, (1900, m, 1, 0, 0, 0, 0, 1, 0))
self.roundtrip('%b', 1, (1900, m, 1, 0, 0, 0, 0, 1, 0))
+ @run_with_locales('LC_TIME', 'az_AZ', 'ber_DZ', 'ber_MA', 'crh_UA')
+ def test_month_locale2(self):
+ # Test for month directives
+ # Month name contains 'İ' ('\u0130')
+ self.roundtrip('%B', 1, (2025, 6, 1, 0, 0, 0, 6, 152, 0))
+ self.roundtrip('%b', 1, (2025, 6, 1, 0, 0, 0, 6, 152, 0))
+ self.roundtrip('%B', 1, (2025, 7, 1, 0, 0, 0, 1, 182, 0))
+ self.roundtrip('%b', 1, (2025, 7, 1, 0, 0, 0, 1, 182, 0))
+
def test_day(self):
# Test for day directives
self.roundtrip('%d %Y', 2)
@@ -480,13 +491,11 @@ class StrptimeTests(unittest.TestCase):
# * Year is not included: ha_NG.
# * Use non-Gregorian calendar: lo_LA, thai, th_TH.
# On Windows: ar_IN, ar_SA, fa_IR, ps_AF.
- #
- # BUG: Generates regexp that does not match the current date and time
- # for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'he_IL', 'eu_ES', 'ar_AE', 'mfe_MU', 'yo_NG',
'csb_PL', 'br_FR', 'gez_ET', 'brx_IN',
- 'my_MM', 'or_IN', 'shn_MM', 'az_IR')
+ 'my_MM', 'or_IN', 'shn_MM', 'az_IR',
+ 'byn_ER', 'wal_ET', 'lzh_TW')
def test_date_time_locale(self):
# Test %c directive
loc = locale.getlocale(locale.LC_TIME)[0]
@@ -525,11 +534,9 @@ class StrptimeTests(unittest.TestCase):
# NB: Does not roundtrip because use non-Gregorian calendar:
# lo_LA, thai, th_TH. On Windows: ar_IN, ar_SA, fa_IR, ps_AF.
- # BUG: Generates regexp that does not match the current date
- # for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'he_IL', 'eu_ES', 'ar_AE',
- 'az_IR', 'my_MM', 'or_IN', 'shn_MM')
+ 'az_IR', 'my_MM', 'or_IN', 'shn_MM', 'lzh_TW')
def test_date_locale(self):
# Test %x directive
now = time.time()
@@ -546,7 +553,7 @@ class StrptimeTests(unittest.TestCase):
# NB: Dates before 1969 do not roundtrip on many locales, including C.
@unittest.skipIf(support.linked_to_musl(), "musl libc issue, bpo-46390")
@run_with_locales('LC_TIME', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
- 'eu_ES', 'ar_AE', 'my_MM', 'shn_MM')
+ 'eu_ES', 'ar_AE', 'my_MM', 'shn_MM', 'lzh_TW')
def test_date_locale2(self):
# Test %x directive
loc = locale.getlocale(locale.LC_TIME)[0]
@@ -562,11 +569,11 @@ class StrptimeTests(unittest.TestCase):
# norwegian, nynorsk.
# * Hours are in 12-hour notation without AM/PM indication: hy_AM,
# ms_MY, sm_WS.
- # BUG: Generates regexp that does not match the current time for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'aa_ET', 'am_ET', 'az_IR', 'byn_ER', 'fa_IR', 'gez_ET',
'my_MM', 'om_ET', 'or_IN', 'shn_MM', 'sid_ET', 'so_SO',
- 'ti_ET', 'tig_ER', 'wal_ET')
+ 'ti_ET', 'tig_ER', 'wal_ET', 'lzh_TW',
+ 'ar_SA', 'bg_BG')
def test_time_locale(self):
# Test %X directive
loc = locale.getlocale(locale.LC_TIME)[0]
diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py
index 2727514fad4..570de390a95 100644
--- a/Lib/test/test_strtod.py
+++ b/Lib/test/test_strtod.py
@@ -19,7 +19,7 @@ strtod_parser = re.compile(r""" # A numeric string consists of:
(?P<int>\d*) # having a (possibly empty) integer part
(?:\.(?P<frac>\d*))? # followed by an optional fractional part
(?:E(?P<exp>[-+]?\d+))? # and an optional exponent
- \Z
+ \z
""", re.VERBOSE | re.IGNORECASE).match
# Pure Python version of correctly rounded string->float conversion.
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index a410fd5a194..7df01f28f09 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -22,12 +22,6 @@ byteorders = '', '@', '=', '<', '>', '!'
INF = float('inf')
NAN = float('nan')
-try:
- struct.pack('D', 1j)
- have_c_complex = True
-except struct.error:
- have_c_complex = False
-
def iter_integer_formats(byteorders=byteorders):
for code in integer_codes:
for byteorder in byteorders:
@@ -796,7 +790,6 @@ class StructTest(ComplexesAreIdenticalMixin, unittest.TestCase):
s = struct.Struct('=i2H')
self.assertEqual(repr(s), f'Struct({s.format!r})')
- @unittest.skipUnless(have_c_complex, "requires C11 complex type support")
def test_c_complex_round_trip(self):
values = [complex(*_) for _ in combinations([1, -1, 0.0, -0.0, 2,
-3, INF, -INF, NAN], 2)]
@@ -806,19 +799,6 @@ class StructTest(ComplexesAreIdenticalMixin, unittest.TestCase):
round_trip = struct.unpack(f, struct.pack(f, z))[0]
self.assertComplexesAreIdentical(z, round_trip)
- @unittest.skipIf(have_c_complex, "requires no C11 complex type support")
- def test_c_complex_error(self):
- msg1 = "'F' format not supported on this system"
- msg2 = "'D' format not supported on this system"
- with self.assertRaisesRegex(struct.error, msg1):
- struct.pack('F', 1j)
- with self.assertRaisesRegex(struct.error, msg1):
- struct.unpack('F', b'1')
- with self.assertRaisesRegex(struct.error, msg2):
- struct.pack('D', 1j)
- with self.assertRaisesRegex(struct.error, msg2):
- struct.unpack('D', b'1')
-
class UnpackIteratorTest(unittest.TestCase):
"""
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index d0bc0bd7b61..9622151143c 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -42,7 +42,7 @@ class StructSeqTest(unittest.TestCase):
# os.stat() gives a complicated struct sequence.
st = os.stat(__file__)
rep = repr(st)
- self.assertTrue(rep.startswith("os.stat_result"))
+ self.assertStartsWith(rep, "os.stat_result")
self.assertIn("st_mode=", rep)
self.assertIn("st_ino=", rep)
self.assertIn("st_dev=", rep)
@@ -307,7 +307,7 @@ class StructSeqTest(unittest.TestCase):
self.assertEqual(t5.tm_mon, 2)
# named invisible fields
- self.assertTrue(hasattr(t, 'tm_zone'), f"{t} has no attribute 'tm_zone'")
+ self.assertHasAttr(t, 'tm_zone')
with self.assertRaisesRegex(AttributeError, 'readonly attribute'):
t.tm_zone = 'some other zone'
self.assertEqual(t2.tm_zone, t.tm_zone)
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index 3cb755cd56c..f0e350c71f6 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -162,6 +162,20 @@ class ProcessTestCase(BaseTestCase):
[sys.executable, "-c", "while True: pass"],
timeout=0.1)
+ def test_timeout_exception(self):
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = -1)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("-1 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = 0)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("0 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+
def test_check_call_zero(self):
# check_call() function with zero return code
rc = subprocess.check_call(ZERO_RETURN_CMD)
@@ -1164,7 +1178,7 @@ class ProcessTestCase(BaseTestCase):
self.assertEqual("line1\nline2\nline3\nline4\nline5\n", stdout)
# Python debug build push something like "[42442 refs]\n"
# to stderr at exit of subprocess.
- self.assertTrue(stderr.startswith("eline2\neline6\neline7\n"))
+ self.assertStartsWith(stderr, "eline2\neline6\neline7\n")
def test_universal_newlines_communicate_encodings(self):
# Check that universal newlines mode works for various encodings,
@@ -1496,7 +1510,7 @@ class ProcessTestCase(BaseTestCase):
"[sys.executable, '-c', 'print(\"Hello World!\")'])",
'assert retcode == 0'))
output = subprocess.check_output([sys.executable, '-c', code])
- self.assertTrue(output.startswith(b'Hello World!'), ascii(output))
+ self.assertStartsWith(output, b'Hello World!')
def test_handles_closed_on_exception(self):
# If CreateProcess exits with an error, ensure the
@@ -1821,8 +1835,8 @@ class RunFuncTestCase(BaseTestCase):
capture_output=True)
lines = cp.stderr.splitlines()
self.assertEqual(len(lines), 2, lines)
- self.assertTrue(lines[0].startswith(b"<string>:2: EncodingWarning: "))
- self.assertTrue(lines[1].startswith(b"<string>:3: EncodingWarning: "))
+ self.assertStartsWith(lines[0], b"<string>:2: EncodingWarning: ")
+ self.assertStartsWith(lines[1], b"<string>:3: EncodingWarning: ")
def _get_test_grp_name():
diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py
index 5cef612a340..193c8b7d7f3 100644
--- a/Lib/test/test_super.py
+++ b/Lib/test/test_super.py
@@ -547,11 +547,11 @@ class TestSuper(unittest.TestCase):
self.assertEqual(s.__reduce__, e.__reduce__)
self.assertEqual(s.__reduce_ex__, e.__reduce_ex__)
self.assertEqual(s.__getstate__, e.__getstate__)
- self.assertFalse(hasattr(s, '__getnewargs__'))
- self.assertFalse(hasattr(s, '__getnewargs_ex__'))
- self.assertFalse(hasattr(s, '__setstate__'))
- self.assertFalse(hasattr(s, '__copy__'))
- self.assertFalse(hasattr(s, '__deepcopy__'))
+ self.assertNotHasAttr(s, '__getnewargs__')
+ self.assertNotHasAttr(s, '__getnewargs_ex__')
+ self.assertNotHasAttr(s, '__setstate__')
+ self.assertNotHasAttr(s, '__copy__')
+ self.assertNotHasAttr(s, '__deepcopy__')
def test_pickling(self):
e = E()
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 468bac82924..e48a2464ee5 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -407,10 +407,10 @@ class TestSupport(unittest.TestCase):
with support.swap_attr(obj, "y", 5) as y:
self.assertEqual(obj.y, 5)
self.assertIsNone(y)
- self.assertFalse(hasattr(obj, 'y'))
+ self.assertNotHasAttr(obj, 'y')
with support.swap_attr(obj, "y", 5):
del obj.y
- self.assertFalse(hasattr(obj, 'y'))
+ self.assertNotHasAttr(obj, 'y')
def test_swap_item(self):
D = {"x":1}
@@ -561,6 +561,7 @@ class TestSupport(unittest.TestCase):
['-Wignore', '-X', 'dev'],
['-X', 'faulthandler'],
['-X', 'importtime'],
+ ['-X', 'importtime=2'],
['-X', 'showrefcount'],
['-X', 'tracemalloc'],
['-X', 'tracemalloc=3'],
diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py
index 4c001f9c9b0..c52d2421941 100644
--- a/Lib/test/test_syntax.py
+++ b/Lib/test/test_syntax.py
@@ -382,6 +382,13 @@ SyntaxError: invalid syntax
Traceback (most recent call last):
SyntaxError: invalid syntax
+# But prefixes of soft keywords should
+# still raise specialized errors
+
+>>> (mat x)
+Traceback (most recent call last):
+SyntaxError: invalid syntax. Perhaps you forgot a comma?
+
From compiler_complex_args():
>>> def f(None=1):
@@ -419,7 +426,7 @@ SyntaxError: invalid syntax
>>> def foo(/,a,b=,c):
... pass
Traceback (most recent call last):
-SyntaxError: at least one argument must precede /
+SyntaxError: at least one parameter must precede /
>>> def foo(a,/,/,b,c):
... pass
@@ -454,67 +461,67 @@ SyntaxError: / must be ahead of *
>>> def foo(a,*b=3,c):
... pass
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> def foo(a,*b: int=,c):
... pass
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> def foo(a,**b=3):
... pass
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> def foo(a,**b: int=3):
... pass
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> def foo(a,*a, b, **c, d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, d=4):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, *d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, **d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a=1,/,**b,/,c):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(*b,*d):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,*b,c,*d,*e,c):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,b,/,c,*b,c,*d,*e,c):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,b,/,c,*b,c,*d,**e):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a=1,/*,b,c):
... pass
@@ -538,7 +545,7 @@ SyntaxError: expected default value expression
>>> lambda /,a,b,c: None
Traceback (most recent call last):
-SyntaxError: at least one argument must precede /
+SyntaxError: at least one parameter must precede /
>>> lambda a,/,/,b,c: None
Traceback (most recent call last):
@@ -570,47 +577,47 @@ SyntaxError: expected comma between / and *
>>> lambda a,*b=3,c: None
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> lambda a,**b=3: None
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> lambda a, *a, b, **c, d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, d=4: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, *d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, **d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a=1,/,**b,/,c: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda *b,*d: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,*b,c,*d,*e,c: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,b,/,c,*b,c,*d,*e,c: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,b,/,c,*b,c,*d,**e: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a=1,d=,c: None
Traceback (most recent call last):
@@ -1304,7 +1311,7 @@ Missing parens after function definition
Traceback (most recent call last):
SyntaxError: expected '('
-Parenthesized arguments in function definitions
+Parenthesized parameters in function definitions
>>> def f(x, (y, z), w):
... pass
@@ -1431,6 +1438,23 @@ Better error message for using `except as` with not a name:
Traceback (most recent call last):
SyntaxError: cannot use except* statement with literal
+Regression tests for gh-133999:
+
+ >>> try: pass
+ ... except TypeError as name: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> try: pass
+ ... except* TypeError as name: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> match 1:
+ ... case 1 | 2 as abc: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
Ensure that early = are not matched by the parser as invalid comparisons
>>> f(2, 4, x=34); 1 $ 2
Traceback (most recent call last):
@@ -1678,6 +1702,28 @@ Make sure that the old "raise X, Y[, Z]" form is gone:
...
SyntaxError: invalid syntax
+Better errors for `raise` statement:
+
+ >>> raise ValueError from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression after 'from'?
+
+ >>> raise mod.ValueError() from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression after 'from'?
+
+ >>> raise from exc
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> raise from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
Check that an multiple exception types with missing parentheses
raise a custom exception only when using 'as'
@@ -1877,6 +1923,86 @@ SyntaxError: cannot assign to f-string expression here. Maybe you meant '==' ins
Traceback (most recent call last):
SyntaxError: cannot assign to f-string expression here. Maybe you meant '==' instead of '='?
+>>> ub''
+Traceback (most recent call last):
+SyntaxError: 'u' and 'b' prefixes are incompatible
+
+>>> bu"привет"
+Traceback (most recent call last):
+SyntaxError: 'u' and 'b' prefixes are incompatible
+
+>>> ur''
+Traceback (most recent call last):
+SyntaxError: 'u' and 'r' prefixes are incompatible
+
+>>> ru"\t"
+Traceback (most recent call last):
+SyntaxError: 'u' and 'r' prefixes are incompatible
+
+>>> uf'{1 + 1}'
+Traceback (most recent call last):
+SyntaxError: 'u' and 'f' prefixes are incompatible
+
+>>> fu""
+Traceback (most recent call last):
+SyntaxError: 'u' and 'f' prefixes are incompatible
+
+>>> ut'{1}'
+Traceback (most recent call last):
+SyntaxError: 'u' and 't' prefixes are incompatible
+
+>>> tu"234"
+Traceback (most recent call last):
+SyntaxError: 'u' and 't' prefixes are incompatible
+
+>>> bf'{x!r}'
+Traceback (most recent call last):
+SyntaxError: 'b' and 'f' prefixes are incompatible
+
+>>> fb"text"
+Traceback (most recent call last):
+SyntaxError: 'b' and 'f' prefixes are incompatible
+
+>>> bt"text"
+Traceback (most recent call last):
+SyntaxError: 'b' and 't' prefixes are incompatible
+
+>>> tb''
+Traceback (most recent call last):
+SyntaxError: 'b' and 't' prefixes are incompatible
+
+>>> tf"{0.3:.02f}"
+Traceback (most recent call last):
+SyntaxError: 'f' and 't' prefixes are incompatible
+
+>>> ft'{x=}'
+Traceback (most recent call last):
+SyntaxError: 'f' and 't' prefixes are incompatible
+
+>>> tfu"{x=}"
+Traceback (most recent call last):
+SyntaxError: 'u' and 'f' prefixes are incompatible
+
+>>> turf"{x=}"
+Traceback (most recent call last):
+SyntaxError: 'u' and 'r' prefixes are incompatible
+
+>>> burft"{x=}"
+Traceback (most recent call last):
+SyntaxError: 'u' and 'b' prefixes are incompatible
+
+>>> brft"{x=}"
+Traceback (most recent call last):
+SyntaxError: 'b' and 'f' prefixes are incompatible
+
+>>> t'{x}' = 42
+Traceback (most recent call last):
+SyntaxError: cannot assign to t-string expression here. Maybe you meant '==' instead of '='?
+
+>>> t'{x}-{y}' = 42
+Traceback (most recent call last):
+SyntaxError: cannot assign to t-string expression here. Maybe you meant '==' instead of '='?
+
>>> (x, y, z=3, d, e)
Traceback (most recent call last):
SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='?
@@ -1957,6 +2083,56 @@ SyntaxError: cannot assign to __debug__
Traceback (most recent call last):
SyntaxError: cannot assign to __debug__
+>>> import a as b.c
+Traceback (most recent call last):
+SyntaxError: cannot use attribute as import target
+
+>>> import a.b as (a, b)
+Traceback (most recent call last):
+SyntaxError: cannot use tuple as import target
+
+>>> import a, a.b as 1
+Traceback (most recent call last):
+SyntaxError: cannot use literal as import target
+
+>>> import a.b as 'a', a
+Traceback (most recent call last):
+SyntaxError: cannot use literal as import target
+
+>>> from a import (b as c.d)
+Traceback (most recent call last):
+SyntaxError: cannot use attribute as import target
+
+>>> from a import b as 1
+Traceback (most recent call last):
+SyntaxError: cannot use literal as import target
+
+>>> from a import (
+... b as f())
+Traceback (most recent call last):
+SyntaxError: cannot use function call as import target
+
+>>> from a import (
+... b as [],
+... )
+Traceback (most recent call last):
+SyntaxError: cannot use list as import target
+
+>>> from a import (
+... b,
+... c as ()
+... )
+Traceback (most recent call last):
+SyntaxError: cannot use tuple as import target
+
+>>> from a import b, с as d[e]
+Traceback (most recent call last):
+SyntaxError: cannot use subscript as import target
+
+>>> from a import с as d[e], b
+Traceback (most recent call last):
+SyntaxError: cannot use subscript as import target
+
# Check that we dont raise the "trailing comma" error if there is more
# input to the left of the valid part that we parsed.
@@ -2048,7 +2224,7 @@ Corner-cases that used to fail to raise the correct error:
>>> with (lambda *:0): pass
Traceback (most recent call last):
- SyntaxError: named arguments must follow bare *
+ SyntaxError: named parameters must follow bare *
Corner-cases that used to crash:
@@ -2696,6 +2872,13 @@ class SyntaxErrorTestCase(unittest.TestCase):
"""
self._check_error(source, "parameter and nonlocal", lineno=3)
+ def test_raise_from_error_message(self):
+ source = """if 1:
+ raise AssertionError() from None
+ print(1,,2)
+ """
+ self._check_error(source, "invalid syntax", lineno=3)
+
def test_yield_outside_function(self):
self._check_error("if 0: yield", "outside function")
self._check_error("if 0: yield\nelse: x=1", "outside function")
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 56413d00823..486bf10a0b5 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -24,7 +24,7 @@ from test.support import import_helper
from test.support import force_not_colorized
from test.support import SHORT_TIMEOUT
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ImportError:
interpreters = None
import textwrap
@@ -57,7 +57,7 @@ class DisplayHookTest(unittest.TestCase):
dh(None)
self.assertEqual(out.getvalue(), "")
- self.assertTrue(not hasattr(builtins, "_"))
+ self.assertNotHasAttr(builtins, "_")
# sys.displayhook() requires arguments
self.assertRaises(TypeError, dh)
@@ -172,7 +172,7 @@ class ExceptHookTest(unittest.TestCase):
with support.captured_stderr() as err:
sys.__excepthook__(*sys.exc_info())
- self.assertTrue(err.getvalue().endswith("ValueError: 42\n"))
+ self.assertEndsWith(err.getvalue(), "ValueError: 42\n")
self.assertRaises(TypeError, sys.__excepthook__)
@@ -192,7 +192,7 @@ class ExceptHookTest(unittest.TestCase):
err = err.getvalue()
self.assertIn(""" File "b'bytes_filename'", line 123\n""", err)
self.assertIn(""" text\n""", err)
- self.assertTrue(err.endswith("SyntaxError: msg\n"))
+ self.assertEndsWith(err, "SyntaxError: msg\n")
def test_excepthook(self):
with test.support.captured_output("stderr") as stderr:
@@ -269,8 +269,7 @@ class SysModuleTest(unittest.TestCase):
rc, out, err = assert_python_failure('-c', code, **env_vars)
self.assertEqual(rc, 1)
self.assertEqual(out, b'')
- self.assertTrue(err.startswith(expected),
- "%s doesn't start with %s" % (ascii(err), ascii(expected)))
+ self.assertStartsWith(err, expected)
# test that stderr buffer is flushed before the exit message is written
# into stderr
@@ -437,7 +436,7 @@ class SysModuleTest(unittest.TestCase):
@unittest.skipUnless(hasattr(sys, "setdlopenflags"),
'test needs sys.setdlopenflags()')
def test_dlopenflags(self):
- self.assertTrue(hasattr(sys, "getdlopenflags"))
+ self.assertHasAttr(sys, "getdlopenflags")
self.assertRaises(TypeError, sys.getdlopenflags, 42)
oldflags = sys.getdlopenflags()
self.assertRaises(TypeError, sys.setdlopenflags)
@@ -623,8 +622,7 @@ class SysModuleTest(unittest.TestCase):
# And the next record must be for g456().
filename, lineno, funcname, sourceline = stack[i+1]
self.assertEqual(funcname, "g456")
- self.assertTrue((sourceline.startswith("if leave_g.wait(") or
- sourceline.startswith("g_raised.set()")))
+ self.assertStartsWith(sourceline, ("if leave_g.wait(", "g_raised.set()"))
finally:
# Reap the spawned thread.
leave_g.set()
@@ -731,7 +729,7 @@ class SysModuleTest(unittest.TestCase):
info = sys.thread_info
self.assertEqual(len(info), 3)
self.assertIn(info.name, ('nt', 'pthread', 'pthread-stubs', 'solaris', None))
- self.assertIn(info.lock, ('semaphore', 'mutex+cond', None))
+ self.assertIn(info.lock, ('pymutex', None))
if sys.platform.startswith(("linux", "android", "freebsd")):
self.assertEqual(info.name, "pthread")
elif sys.platform == "win32":
@@ -860,7 +858,7 @@ class SysModuleTest(unittest.TestCase):
"hash_randomization", "isolated", "dev_mode", "utf8_mode",
"warn_default_encoding", "safe_path", "int_max_str_digits")
for attr in attrs:
- self.assertTrue(hasattr(sys.flags, attr), attr)
+ self.assertHasAttr(sys.flags, attr)
attr_type = bool if attr in ("dev_mode", "safe_path") else int
self.assertEqual(type(getattr(sys.flags, attr)), attr_type, attr)
self.assertTrue(repr(sys.flags))
@@ -871,12 +869,7 @@ class SysModuleTest(unittest.TestCase):
def assert_raise_on_new_sys_type(self, sys_attr):
# Users are intentionally prevented from creating new instances of
# sys.flags, sys.version_info, and sys.getwindowsversion.
- arg = sys_attr
- attr_type = type(sys_attr)
- with self.assertRaises(TypeError):
- attr_type(arg)
- with self.assertRaises(TypeError):
- attr_type.__new__(attr_type, arg)
+ support.check_disallow_instantiation(self, type(sys_attr), sys_attr)
def test_sys_flags_no_instantiation(self):
self.assert_raise_on_new_sys_type(sys.flags)
@@ -1072,10 +1065,11 @@ class SysModuleTest(unittest.TestCase):
levels = {'alpha': 0xA, 'beta': 0xB, 'candidate': 0xC, 'final': 0xF}
- self.assertTrue(hasattr(sys.implementation, 'name'))
- self.assertTrue(hasattr(sys.implementation, 'version'))
- self.assertTrue(hasattr(sys.implementation, 'hexversion'))
- self.assertTrue(hasattr(sys.implementation, 'cache_tag'))
+ self.assertHasAttr(sys.implementation, 'name')
+ self.assertHasAttr(sys.implementation, 'version')
+ self.assertHasAttr(sys.implementation, 'hexversion')
+ self.assertHasAttr(sys.implementation, 'cache_tag')
+ self.assertHasAttr(sys.implementation, 'supports_isolated_interpreters')
version = sys.implementation.version
self.assertEqual(version[:2], (version.major, version.minor))
@@ -1089,6 +1083,15 @@ class SysModuleTest(unittest.TestCase):
self.assertEqual(sys.implementation.name,
sys.implementation.name.lower())
+ # https://peps.python.org/pep-0734
+ sii = sys.implementation.supports_isolated_interpreters
+ self.assertIsInstance(sii, bool)
+ if test.support.check_impl_detail(cpython=True):
+ if test.support.is_emscripten or test.support.is_wasi:
+ self.assertFalse(sii)
+ else:
+ self.assertTrue(sii)
+
@test.support.cpython_only
def test_debugmallocstats(self):
# Test sys._debugmallocstats()
@@ -1137,23 +1140,12 @@ class SysModuleTest(unittest.TestCase):
b = sys.getallocatedblocks()
self.assertLessEqual(b, a)
try:
- # While we could imagine a Python session where the number of
- # multiple buffer objects would exceed the sharing of references,
- # it is unlikely to happen in a normal test run.
- #
- # In free-threaded builds each code object owns an array of
- # pointers to copies of the bytecode. When the number of
- # code objects is a large fraction of the total number of
- # references, this can cause the total number of allocated
- # blocks to exceed the total number of references.
- #
- # For some reason, iOS seems to trigger the "unlikely to happen"
- # case reliably under CI conditions. It's not clear why; but as
- # this test is checking the behavior of getallocatedblock()
- # under garbage collection, we can skip this pre-condition check
- # for now. See GH-130384.
- if not support.Py_GIL_DISABLED and not support.is_apple_mobile:
- self.assertLess(a, sys.gettotalrefcount())
+ # The reported blocks will include immortalized strings, but the
+ # total ref count will not. This will sanity check that among all
+ # other objects (those eligible for garbage collection) there
+ # are more references being tracked than allocated blocks.
+ interned_immortal = sys.getunicodeinternedsize(_only_immortal=True)
+ self.assertLess(a - interned_immortal, sys.gettotalrefcount())
except AttributeError:
# gettotalrefcount() not available
pass
@@ -1301,6 +1293,7 @@ class SysModuleTest(unittest.TestCase):
for name in sys.stdlib_module_names:
self.assertIsInstance(name, str)
+ @unittest.skipUnless(hasattr(sys, '_stdlib_dir'), 'need sys._stdlib_dir')
def test_stdlib_dir(self):
os = import_helper.import_fresh_module('os')
marker = getattr(os, '__file__', None)
@@ -1419,7 +1412,7 @@ class UnraisableHookTest(unittest.TestCase):
else:
self.assertIn("ValueError", report)
self.assertIn("del is broken", report)
- self.assertTrue(report.endswith("\n"))
+ self.assertEndsWith(report, "\n")
def test_original_unraisablehook_exception_qualname(self):
# See bpo-41031, bpo-45083.
@@ -1955,33 +1948,19 @@ class SizeofTest(unittest.TestCase):
self.assertEqual(out, b"")
self.assertEqual(err, b"")
-
-def _supports_remote_attaching():
- PROCESS_VM_READV_SUPPORTED = False
-
- try:
- from _testexternalinspection import PROCESS_VM_READV_SUPPORTED
- except ImportError:
- pass
-
- return PROCESS_VM_READV_SUPPORTED
-
-@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled")
-@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32",
- "Test only runs on Linux, Windows and MacOS")
-@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(),
- "Test only runs on Linux with process_vm_readv support")
+@test.support.support_remote_exec_only
@test.support.cpython_only
class TestRemoteExec(unittest.TestCase):
def tearDown(self):
test.support.reap_children()
- def _run_remote_exec_test(self, script_code, python_args=None, env=None, prologue=''):
+ def _run_remote_exec_test(self, script_code, python_args=None, env=None,
+ prologue='',
+ script_path=os_helper.TESTFN + '_remote.py'):
# Create the script that will be remotely executed
- script = os_helper.TESTFN + '_remote.py'
- self.addCleanup(os_helper.unlink, script)
+ self.addCleanup(os_helper.unlink, script_path)
- with open(script, 'w') as f:
+ with open(script_path, 'w') as f:
f.write(script_code)
# Create and run the target process
@@ -2050,7 +2029,7 @@ sock.close()
self.assertEqual(response, b"ready")
# Try remote exec on the target process
- sys.remote_exec(proc.pid, script)
+ sys.remote_exec(proc.pid, script_path)
# Signal script to continue
client_socket.sendall(b"continue")
@@ -2073,14 +2052,32 @@ sock.close()
def test_remote_exec(self):
"""Test basic remote exec functionality"""
- script = '''
-print("Remote script executed successfully!")
-'''
+ script = 'print("Remote script executed successfully!")'
returncode, stdout, stderr = self._run_remote_exec_test(script)
# self.assertEqual(returncode, 0)
self.assertIn(b"Remote script executed successfully!", stdout)
self.assertEqual(stderr, b"")
+ def test_remote_exec_bytes(self):
+ script = 'print("Remote script executed successfully!")'
+ script_path = os.fsencode(os_helper.TESTFN) + b'_bytes_remote.py'
+ returncode, stdout, stderr = self._run_remote_exec_test(script,
+ script_path=script_path)
+ self.assertIn(b"Remote script executed successfully!", stdout)
+ self.assertEqual(stderr, b"")
+
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE, 'requires undecodable path')
+ @unittest.skipIf(sys.platform == 'darwin',
+ 'undecodable paths are not supported on macOS')
+ def test_remote_exec_undecodable(self):
+ script = 'print("Remote script executed successfully!")'
+ script_path = os_helper.TESTFN_UNDECODABLE + b'_undecodable_remote.py'
+ for script_path in [script_path, os.fsdecode(script_path)]:
+ returncode, stdout, stderr = self._run_remote_exec_test(script,
+ script_path=script_path)
+ self.assertIn(b"Remote script executed successfully!", stdout)
+ self.assertEqual(stderr, b"")
+
def test_remote_exec_with_self_process(self):
"""Test remote exec with the target process being the same as the test process"""
@@ -2101,7 +2098,7 @@ print("Remote script executed successfully!")
prologue = '''\
import sys
def audit_hook(event, arg):
- print(f"Audit event: {event}, arg: {arg}")
+ print(f"Audit event: {event}, arg: {arg}".encode("ascii", errors="replace"))
sys.addaudithook(audit_hook)
'''
script = '''
@@ -2110,7 +2107,7 @@ print("Remote script executed successfully!")
returncode, stdout, stderr = self._run_remote_exec_test(script, prologue=prologue)
self.assertEqual(returncode, 0)
self.assertIn(b"Remote script executed successfully!", stdout)
- self.assertIn(b"Audit event: remote_debugger_script, arg: ", stdout)
+ self.assertIn(b"Audit event: cpython.remote_debugger_script, arg: ", stdout)
self.assertEqual(stderr, b"")
def test_remote_exec_with_exception(self):
@@ -2157,6 +2154,13 @@ raise Exception("Remote script exception")
with self.assertRaises(OSError):
sys.remote_exec(99999, "print('should not run')")
+ def test_remote_exec_invalid_script(self):
+ """Test remote exec with invalid script type"""
+ with self.assertRaises(TypeError):
+ sys.remote_exec(0, None)
+ with self.assertRaises(TypeError):
+ sys.remote_exec(0, 123)
+
def test_remote_exec_syntax_error(self):
"""Test remote exec with syntax error in script"""
script = '''
@@ -2196,6 +2200,64 @@ this is invalid python code
self.assertIn(b"Remote debugging is not enabled", err)
self.assertEqual(out, b"")
+class TestSysJIT(unittest.TestCase):
+
+ def test_jit_is_available(self):
+ available = sys._jit.is_available()
+ script = f"import sys; assert sys._jit.is_available() is {available}"
+ assert_python_ok("-c", script, PYTHON_JIT="0")
+ assert_python_ok("-c", script, PYTHON_JIT="1")
+
+ def test_jit_is_enabled(self):
+ available = sys._jit.is_available()
+ script = "import sys; assert sys._jit.is_enabled() is {enabled}"
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
+
+ def test_jit_is_active(self):
+ available = sys._jit.is_available()
+ script = textwrap.dedent(
+ """
+ import _testcapi
+ import _testinternalcapi
+ import sys
+
+ def frame_0_interpreter() -> None:
+ assert sys._jit.is_active() is False
+
+ def frame_1_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_0_interpreter()
+ assert sys._jit.is_active() is False
+
+ def frame_2_jit(expected: bool) -> None:
+ # Inlined into the last loop of frame_3_jit:
+ assert sys._jit.is_active() is expected
+ # Insert C frame:
+ _testcapi.pyobject_vectorcall(frame_1_interpreter, None, None)
+ assert sys._jit.is_active() is expected
+
+ def frame_3_jit() -> None:
+ # JITs just before the last loop:
+ for i in range(_testinternalcapi.TIER2_THRESHOLD + 1):
+ # Careful, doing this in the reverse order breaks tracing:
+ expected = {enabled} and i == _testinternalcapi.TIER2_THRESHOLD
+ assert sys._jit.is_active() is expected
+ frame_2_jit(expected)
+ assert sys._jit.is_active() is expected
+
+ def frame_4_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_3_jit()
+ assert sys._jit.is_active() is False
+
+ assert sys._jit.is_active() is False
+ frame_4_interpreter()
+ assert sys._jit.is_active() is False
+ """
+ )
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
if __name__ == "__main__":
diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py
index 53e55383bf9..2eb8de4b29f 100644
--- a/Lib/test/test_sysconfig.py
+++ b/Lib/test/test_sysconfig.py
@@ -32,7 +32,6 @@ from sysconfig import (get_paths, get_platform, get_config_vars,
from sysconfig.__main__ import _main, _parse_makefile, _get_pybuilddir, _get_json_data_name
import _imp
import _osx_support
-import _sysconfig
HAS_USER_BASE = sysconfig._HAS_USER_BASE
@@ -186,7 +185,7 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
# The include directory on POSIX isn't exactly the same as before,
# but it is "within"
sysconfig_includedir = sysconfig.get_path('include', scheme='posix_venv', vars=vars)
- self.assertTrue(sysconfig_includedir.startswith(incpath + os.sep))
+ self.assertStartsWith(sysconfig_includedir, incpath + os.sep)
def test_nt_venv_scheme(self):
# The following directories were hardcoded in the venv module
@@ -531,13 +530,10 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
Python_h = os.path.join(srcdir, 'Include', 'Python.h')
self.assertTrue(os.path.exists(Python_h), Python_h)
# <srcdir>/PC/pyconfig.h.in always exists even if unused
- pyconfig_h = os.path.join(srcdir, 'PC', 'pyconfig.h.in')
- self.assertTrue(os.path.exists(pyconfig_h), pyconfig_h)
pyconfig_h_in = os.path.join(srcdir, 'pyconfig.h.in')
self.assertTrue(os.path.exists(pyconfig_h_in), pyconfig_h_in)
if os.name == 'nt':
- # <executable dir>/pyconfig.h exists on Windows in a build tree
- pyconfig_h = os.path.join(sys.executable, '..', 'pyconfig.h')
+ pyconfig_h = os.path.join(srcdir, 'PC', 'pyconfig.h')
self.assertTrue(os.path.exists(pyconfig_h), pyconfig_h)
elif os.name == 'posix':
makefile_dir = os.path.dirname(sysconfig.get_makefile_filename())
@@ -572,8 +568,7 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
expected_suffixes = 'i386-linux-gnu.so', 'x86_64-linux-gnux32.so', 'i386-linux-musl.so'
else: # 8 byte pointer size
expected_suffixes = 'x86_64-linux-gnu.so', 'x86_64-linux-musl.so'
- self.assertTrue(suffix.endswith(expected_suffixes),
- f'unexpected suffix {suffix!r}')
+ self.assertEndsWith(suffix, expected_suffixes)
@unittest.skipUnless(sys.platform == 'android', 'Android-specific test')
def test_android_ext_suffix(self):
@@ -585,13 +580,12 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
"aarch64": "aarch64-linux-android",
"armv7l": "arm-linux-androideabi",
}[machine]
- self.assertTrue(suffix.endswith(f"-{expected_triplet}.so"),
- f"{machine=}, {suffix=}")
+ self.assertEndsWith(suffix, f"-{expected_triplet}.so")
@unittest.skipUnless(sys.platform == 'darwin', 'OS X-specific test')
def test_osx_ext_suffix(self):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
- self.assertTrue(suffix.endswith('-darwin.so'), suffix)
+ self.assertEndsWith(suffix, '-darwin.so')
def test_always_set_py_debug(self):
self.assertIn('Py_DEBUG', sysconfig.get_config_vars())
@@ -717,8 +711,8 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
ignore_keys |= {'prefix', 'exec_prefix', 'base', 'platbase', 'installed_base', 'installed_platbase'}
for key in ignore_keys:
- json_config_vars.pop(key)
- system_config_vars.pop(key)
+ json_config_vars.pop(key, None)
+ system_config_vars.pop(key, None)
self.assertEqual(system_config_vars, json_config_vars)
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index fcbaf854cc2..7055e1ed147 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -38,6 +38,10 @@ try:
import lzma
except ImportError:
lzma = None
+try:
+ from compression import zstd
+except ImportError:
+ zstd = None
def sha256sum(data):
return sha256(data).hexdigest()
@@ -48,6 +52,7 @@ tarname = support.findfile("testtar.tar", subdir="archivetestdata")
gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
xzname = os.path.join(TEMPDIR, "testtar.tar.xz")
+zstname = os.path.join(TEMPDIR, "testtar.tar.zst")
tmpname = os.path.join(TEMPDIR, "tmp.tar")
dotlessname = os.path.join(TEMPDIR, "testtar")
@@ -90,6 +95,12 @@ class LzmaTest:
open = lzma.LZMAFile if lzma else None
taropen = tarfile.TarFile.xzopen
+@support.requires_zstd()
+class ZstdTest:
+ tarname = zstname
+ suffix = 'zst'
+ open = zstd.ZstdFile if zstd else None
+ taropen = tarfile.TarFile.zstopen
class ReadTest(TarTest):
@@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest):
class LzmaUstarReadTest(LzmaTest, UstarReadTest):
pass
+class ZstdUstarReadTest(ZstdTest, UstarReadTest):
+ pass
class ListTest(ReadTest, unittest.TestCase):
@@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest):
class LzmaListTest(LzmaTest, ListTest):
pass
+class ZstdListTest(ZstdTest, ListTest):
+ pass
class CommonReadTest(ReadTest):
@@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase):
class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase):
pass
+class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase):
+ pass
class StreamReadTest(CommonReadTest, unittest.TestCase):
@@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest):
class LzmaStreamReadTest(LzmaTest, StreamReadTest):
pass
+class ZstdStreamReadTest(ZstdTest, StreamReadTest):
+ pass
+
class TarStreamModeReadTest(StreamModeTest, unittest.TestCase):
def test_stream_mode_no_cache(self):
@@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest):
class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest):
pass
+class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest):
+ pass
+
class DetectReadTest(TarTest, unittest.TestCase):
def _testfunc_file(self, name, mode):
try:
@@ -986,6 +1009,8 @@ class Bz2DetectReadTest(Bz2Test, DetectReadTest):
class LzmaDetectReadTest(LzmaTest, DetectReadTest):
pass
+class ZstdDetectReadTest(ZstdTest, DetectReadTest):
+ pass
class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase):
"""
@@ -1625,7 +1650,7 @@ class WriteTest(WriteTestBase, unittest.TestCase):
try:
for t in tar:
if t.name != ".":
- self.assertTrue(t.name.startswith("./"), t.name)
+ self.assertStartsWith(t.name, "./")
finally:
tar.close()
@@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest):
class LzmaWriteTest(LzmaTest, WriteTest):
pass
+class ZstdWriteTest(ZstdTest, WriteTest):
+ pass
class StreamWriteTest(WriteTestBase, unittest.TestCase):
@@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest):
class LzmaStreamWriteTest(LzmaTest, StreamWriteTest):
decompressor = lzma.LZMADecompressor if lzma else None
+class ZstdStreamWriteTest(ZstdTest, StreamWriteTest):
+ decompressor = zstd.ZstdDecompressor if zstd else None
+
class _CompressedWriteTest(TarTest):
# This is not actually a standalone test.
# It does not inherit WriteTest because it only makes sense with gz,bz2
@@ -2042,6 +2072,14 @@ class LzmaCreateTest(LzmaTest, CreateTest):
tobj.add(self.file_path)
+class ZstdCreateTest(ZstdTest, CreateTest):
+
+ # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel.
+ # It does not allow for level to be specified when reading.
+ def test_create_with_level(self):
+ with tarfile.open(tmpname, self.mode, level=1) as tobj:
+ tobj.add(self.file_path)
+
class CreateWithXModeTest(CreateTest):
prefix = "x"
@@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase):
class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase):
pass
+class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase):
+ pass
class LimitsTest(unittest.TestCase):
@@ -2675,6 +2715,31 @@ class MiscTest(unittest.TestCase):
str(excinfo.exception),
)
+ @unittest.skipUnless(os_helper.can_symlink(), 'requires symlink support')
+ @unittest.skipUnless(hasattr(os, 'chmod'), "missing os.chmod")
+ @unittest.mock.patch('os.chmod')
+ def test_deferred_directory_attributes_update(self, mock_chmod):
+ # Regression test for gh-127987: setting attributes on arbitrary files
+ tempdir = os.path.join(TEMPDIR, 'test127987')
+ def mock_chmod_side_effect(path, mode, **kwargs):
+ target_path = os.path.realpath(path)
+ if os.path.commonpath([target_path, tempdir]) != tempdir:
+ raise Exception("should not try to chmod anything outside the destination", target_path)
+ mock_chmod.side_effect = mock_chmod_side_effect
+
+ outside_tree_dir = os.path.join(TEMPDIR, 'outside_tree_dir')
+ with ArchiveMaker() as arc:
+ arc.add('x', symlink_to='.')
+ arc.add('x', type=tarfile.DIRTYPE, mode='?rwsrwsrwt')
+ arc.add('x', symlink_to=outside_tree_dir)
+
+ os.makedirs(outside_tree_dir)
+ try:
+ arc.open().extractall(path=tempdir, filter='tar')
+ finally:
+ os_helper.rmtree(outside_tree_dir)
+ os_helper.rmtree(tempdir)
+
class CommandLineTest(unittest.TestCase):
@@ -2835,7 +2900,7 @@ class CommandLineTest(unittest.TestCase):
support.findfile('tokenize_tests-no-coding-cookie-'
'and-utf8-bom-sig-only.txt',
subdir='tokenizedata')]
- for filetype in (GzipTest, Bz2Test, LzmaTest):
+ for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest):
if not filetype.open:
continue
try:
@@ -3235,6 +3300,10 @@ class NoneInfoExtractTests(ReadTest):
got_paths = set(
p.relative_to(directory)
for p in pathlib.Path(directory).glob('**/*'))
+ if self.extraction_filter in (None, 'data'):
+ # The 'data' filter is expected to reject special files
+ for path in 'ustar/fifotype', 'ustar/blktype', 'ustar/chrtype':
+ got_paths.discard(pathlib.Path(path))
self.assertEqual(self.control_paths, got_paths)
@contextmanager
@@ -3450,11 +3519,12 @@ class ArchiveMaker:
with t.open() as tar:
... # `tar` is now a TarFile with 'filename' in it!
"""
- def __init__(self):
+ def __init__(self, **kwargs):
self.bio = io.BytesIO()
+ self.tar_kwargs = dict(kwargs)
def __enter__(self):
- self.tar_w = tarfile.TarFile(mode='w', fileobj=self.bio)
+ self.tar_w = tarfile.TarFile(mode='w', fileobj=self.bio, **self.tar_kwargs)
return self
def __exit__(self, *exc):
@@ -3463,12 +3533,28 @@ class ArchiveMaker:
self.bio = None
def add(self, name, *, type=None, symlink_to=None, hardlink_to=None,
- mode=None, size=None, **kwargs):
- """Add a member to the test archive. Call within `with`."""
+ mode=None, size=None, content=None, **kwargs):
+ """Add a member to the test archive. Call within `with`.
+
+ Provides many shortcuts:
+ - default `type` is based on symlink_to, hardlink_to, and trailing `/`
+ in name (which is stripped)
+ - size & content defaults are based on each other
+ - content can be str or bytes
+ - mode should be textual ('-rwxrwxrwx')
+
+ (add more! this is unstable internal test-only API)
+ """
name = str(name)
tarinfo = tarfile.TarInfo(name).replace(**kwargs)
+ if content is not None:
+ if isinstance(content, str):
+ content = content.encode()
+ size = len(content)
if size is not None:
tarinfo.size = size
+ if content is None:
+ content = bytes(tarinfo.size)
if mode:
tarinfo.mode = _filemode_to_int(mode)
if symlink_to is not None:
@@ -3482,7 +3568,7 @@ class ArchiveMaker:
if type is not None:
tarinfo.type = type
if tarinfo.isreg():
- fileobj = io.BytesIO(bytes(tarinfo.size))
+ fileobj = io.BytesIO(content)
else:
fileobj = None
self.tar_w.addfile(tarinfo, fileobj)
@@ -3516,7 +3602,7 @@ class TestExtractionFilters(unittest.TestCase):
destdir = outerdir / 'dest'
@contextmanager
- def check_context(self, tar, filter):
+ def check_context(self, tar, filter, *, check_flag=True):
"""Extracts `tar` to `self.destdir` and allows checking the result
If an error occurs, it must be checked using `expect_exception`
@@ -3525,27 +3611,40 @@ class TestExtractionFilters(unittest.TestCase):
except the destination directory itself and parent directories of
other files.
When checking directories, do so before their contents.
+
+ A file called 'flag' is made in outerdir (i.e. outside destdir)
+ before extraction; it should not be altered nor should its contents
+ be read/copied.
"""
with os_helper.temp_dir(self.outerdir):
+ flag_path = self.outerdir / 'flag'
+ flag_path.write_text('capture me')
try:
tar.extractall(self.destdir, filter=filter)
except Exception as exc:
self.raised_exception = exc
+ self.reraise_exception = True
self.expected_paths = set()
else:
self.raised_exception = None
+ self.reraise_exception = False
self.expected_paths = set(self.outerdir.glob('**/*'))
self.expected_paths.discard(self.destdir)
+ self.expected_paths.discard(flag_path)
try:
- yield
+ yield self
finally:
tar.close()
- if self.raised_exception:
+ if self.reraise_exception:
raise self.raised_exception
self.assertEqual(self.expected_paths, set())
+ if check_flag:
+ self.assertEqual(flag_path.read_text(), 'capture me')
+ else:
+ assert filter == 'fully_trusted'
def expect_file(self, name, type=None, symlink_to=None, mode=None,
- size=None):
+ size=None, content=None):
"""Check a single file. See check_context."""
if self.raised_exception:
raise self.raised_exception
@@ -3564,26 +3663,45 @@ class TestExtractionFilters(unittest.TestCase):
# The symlink might be the same (textually) as what we expect,
# but some systems change the link to an equivalent path, so
# we fall back to samefile().
- if expected != got:
- self.assertTrue(got.samefile(expected))
+ try:
+ if expected != got:
+ self.assertTrue(got.samefile(expected))
+ except Exception as e:
+ # attach a note, so it's shown even if `samefile` fails
+ e.add_note(f'{expected=}, {got=}')
+ raise
elif type == tarfile.REGTYPE or type is None:
self.assertTrue(path.is_file())
elif type == tarfile.DIRTYPE:
self.assertTrue(path.is_dir())
elif type == tarfile.FIFOTYPE:
self.assertTrue(path.is_fifo())
+ elif type == tarfile.SYMTYPE:
+ self.assertTrue(path.is_symlink())
else:
raise NotImplementedError(type)
if size is not None:
self.assertEqual(path.stat().st_size, size)
+ if content is not None:
+ self.assertEqual(path.read_text(), content)
for parent in path.parents:
self.expected_paths.discard(parent)
+ def expect_any_tree(self, name):
+ """Check a directory; forget about its contents."""
+ tree_path = (self.destdir / name).resolve()
+ self.expect_file(tree_path, type=tarfile.DIRTYPE)
+ self.expected_paths = {
+ p for p in self.expected_paths
+ if tree_path not in p.parents
+ }
+
def expect_exception(self, exc_type, message_re='.'):
with self.assertRaisesRegex(exc_type, message_re):
if self.raised_exception is not None:
raise self.raised_exception
- self.raised_exception = None
+ self.reraise_exception = False
+ return self.raised_exception
def test_benign_file(self):
with ArchiveMaker() as arc:
@@ -3669,6 +3787,80 @@ class TestExtractionFilters(unittest.TestCase):
self.expect_file('parent/evil')
@symlink_test
+ @os_helper.skip_unless_symlink
+ def test_realpath_limit_attack(self):
+ # (CVE-2025-4517)
+
+ with ArchiveMaker() as arc:
+ # populate the symlinks and dirs that expand in os.path.realpath()
+ # The component length is chosen so that in common cases, the unexpanded
+ # path fits in PATH_MAX, but it overflows when the final symlink
+ # is expanded
+ steps = "abcdefghijklmnop"
+ if sys.platform == 'win32':
+ component = 'd' * 25
+ elif 'PC_PATH_MAX' in os.pathconf_names:
+ max_path_len = os.pathconf(self.outerdir.parent, "PC_PATH_MAX")
+ path_sep_len = 1
+ dest_len = len(str(self.destdir)) + path_sep_len
+ component_len = (max_path_len - dest_len) // (len(steps) + path_sep_len)
+ component = 'd' * component_len
+ else:
+ raise NotImplementedError("Need to guess component length for {sys.platform}")
+ path = ""
+ step_path = ""
+ for i in steps:
+ arc.add(os.path.join(path, component), type=tarfile.DIRTYPE,
+ mode='drwxrwxrwx')
+ arc.add(os.path.join(path, i), symlink_to=component)
+ path = os.path.join(path, component)
+ step_path = os.path.join(step_path, i)
+ # create the final symlink that exceeds PATH_MAX and simply points
+ # to the top dir.
+ # this link will never be expanded by
+ # os.path.realpath(strict=False), nor anything after it.
+ linkpath = os.path.join(*steps, "l"*254)
+ parent_segments = [".."] * len(steps)
+ arc.add(linkpath, symlink_to=os.path.join(*parent_segments))
+ # make a symlink outside to keep the tar command happy
+ arc.add("escape", symlink_to=os.path.join(linkpath, ".."))
+ # use the symlinks above, that are not checked, to create a hardlink
+ # to a file outside of the destination path
+ arc.add("flaglink", hardlink_to=os.path.join("escape", "flag"))
+ # now that we have the hardlink we can overwrite the file
+ arc.add("flaglink", content='overwrite')
+ # we can also create new files as well!
+ arc.add("escape/newfile", content='new')
+
+ with (self.subTest('fully_trusted'),
+ self.check_context(arc.open(), filter='fully_trusted',
+ check_flag=False)):
+ if sys.platform == 'win32':
+ self.expect_exception((FileNotFoundError, FileExistsError))
+ elif self.raised_exception:
+ # Cannot symlink/hardlink: tarfile falls back to getmember()
+ self.expect_exception(KeyError)
+ # Otherwise, this block should never enter.
+ else:
+ self.expect_any_tree(component)
+ self.expect_file('flaglink', content='overwrite')
+ self.expect_file('../newfile', content='new')
+ self.expect_file('escape', type=tarfile.SYMTYPE)
+ self.expect_file('a', symlink_to=component)
+
+ for filter in 'tar', 'data':
+ with self.subTest(filter), self.check_context(arc.open(), filter=filter):
+ exc = self.expect_exception((OSError, KeyError))
+ if isinstance(exc, OSError):
+ if sys.platform == 'win32':
+ # 3: ERROR_PATH_NOT_FOUND
+ # 5: ERROR_ACCESS_DENIED
+ # 206: ERROR_FILENAME_EXCED_RANGE
+ self.assertIn(exc.winerror, (3, 5, 206))
+ else:
+ self.assertEqual(exc.errno, errno.ENAMETOOLONG)
+
+ @symlink_test
def test_parent_symlink2(self):
# Test interplaying symlinks
# Inspired by 'dirsymlink2b' in jwilk/traversal-archives
@@ -3890,8 +4082,8 @@ class TestExtractionFilters(unittest.TestCase):
arc.add('symlink2', symlink_to=os.path.join(
'linkdir', 'hardlink2'))
arc.add('targetdir/target', size=3)
- arc.add('linkdir/hardlink', hardlink_to='targetdir/target')
- arc.add('linkdir/hardlink2', hardlink_to='linkdir/symlink')
+ arc.add('linkdir/hardlink', hardlink_to=os.path.join('targetdir', 'target'))
+ arc.add('linkdir/hardlink2', hardlink_to=os.path.join('linkdir', 'symlink'))
for filter in 'tar', 'data', 'fully_trusted':
with self.check_context(arc.open(), filter):
@@ -3907,6 +4099,129 @@ class TestExtractionFilters(unittest.TestCase):
self.expect_file('linkdir/symlink', size=3)
self.expect_file('symlink2', size=3)
+ @symlink_test
+ def test_sneaky_hardlink_fallback(self):
+ # (CVE-2025-4330)
+ # Test that when hardlink extraction falls back to extracting members
+ # from the archive, the extracted member is (re-)filtered.
+ with ArchiveMaker() as arc:
+ # Create a directory structure so the c/escape symlink stays
+ # inside the path
+ arc.add("a/t/dummy")
+ # Create b/ directory
+ arc.add("b/")
+ # Point "c" to the bottom of the tree in "a"
+ arc.add("c", symlink_to=os.path.join("a", "t"))
+ # link to non-existant location under "a"
+ arc.add("c/escape", symlink_to=os.path.join("..", "..",
+ "link_here"))
+ # Move "c" to point to "b" ("c/escape" no longer exists)
+ arc.add("c", symlink_to="b")
+ # Attempt to create a hard link to "c/escape". Since it doesn't
+ # exist it will attempt to extract "cescape" but at "boom".
+ arc.add("boom", hardlink_to=os.path.join("c", "escape"))
+
+ with self.check_context(arc.open(), 'data'):
+ if not os_helper.can_symlink():
+ # When 'c/escape' is extracted, 'c' is a regular
+ # directory, and 'c/escape' *would* point outside
+ # the destination if symlinks were allowed.
+ self.expect_exception(
+ tarfile.LinkOutsideDestinationError)
+ elif sys.platform == "win32":
+ # On Windows, 'c/escape' points outside the destination
+ self.expect_exception(tarfile.LinkOutsideDestinationError)
+ else:
+ e = self.expect_exception(
+ tarfile.LinkFallbackError,
+ "link 'boom' would be extracted as a copy of "
+ + "'c/escape', which was rejected")
+ self.assertIsInstance(e.__cause__,
+ tarfile.LinkOutsideDestinationError)
+ for filter in 'tar', 'fully_trusted':
+ with self.subTest(filter), self.check_context(arc.open(), filter):
+ if not os_helper.can_symlink():
+ self.expect_file("a/t/dummy")
+ self.expect_file("b/")
+ self.expect_file("c/")
+ else:
+ self.expect_file("a/t/dummy")
+ self.expect_file("b/")
+ self.expect_file("a/t/escape", symlink_to='../../link_here')
+ self.expect_file("boom", symlink_to='../../link_here')
+ self.expect_file("c", symlink_to='b')
+
+ @symlink_test
+ def test_exfiltration_via_symlink(self):
+ # (CVE-2025-4138)
+ # Test changing symlinks that result in a symlink pointing outside
+ # the extraction directory, unless prevented by 'data' filter's
+ # normalization.
+ with ArchiveMaker() as arc:
+ arc.add("escape", symlink_to=os.path.join('link', 'link', '..', '..', 'link-here'))
+ arc.add("link", symlink_to='./')
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter):
+ if os_helper.can_symlink():
+ self.expect_file("link", symlink_to='./')
+ if filter == 'data':
+ self.expect_file("escape", symlink_to='link-here')
+ else:
+ self.expect_file("escape",
+ symlink_to='link/link/../../link-here')
+ else:
+ # Nothing is extracted.
+ pass
+
+ @symlink_test
+ def test_chmod_outside_dir(self):
+ # (CVE-2024-12718)
+ # Test that members used for delayed updates of directory metadata
+ # are (re-)filtered.
+ with ArchiveMaker() as arc:
+ # "pwn" is a veeeery innocent symlink:
+ arc.add("a/pwn", symlink_to='.')
+ # But now "pwn" is also a directory, so it's scheduled to have its
+ # metadata updated later:
+ arc.add("a/pwn/", mode='drwxrwxrwx')
+ # Oops, "pwn" is not so innocent any more:
+ arc.add("a/pwn", symlink_to='x/../')
+ # Newly created symlink points to the dest dir,
+ # so it's OK for the "data" filter.
+ arc.add('a/x', symlink_to=('../'))
+ # But now "pwn" points outside the dest dir
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter) as cc:
+ if not os_helper.can_symlink():
+ self.expect_file("a/pwn/")
+ elif filter == 'data':
+ self.expect_file("a/x", symlink_to='../')
+ self.expect_file("a/pwn", symlink_to='.')
+ else:
+ self.expect_file("a/x", symlink_to='../')
+ self.expect_file("a/pwn", symlink_to='x/../')
+ if sys.platform != "win32":
+ st_mode = cc.outerdir.stat().st_mode
+ self.assertNotEqual(st_mode & 0o777, 0o777)
+
+ def test_link_fallback_normalizes(self):
+ # Make sure hardlink fallbacks work for non-normalized paths for all
+ # filters
+ with ArchiveMaker() as arc:
+ arc.add("dir/")
+ arc.add("dir/../afile")
+ arc.add("link1", hardlink_to='dir/../afile')
+ arc.add("link2", hardlink_to='dir/../dir/../afile')
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter) as cc:
+ self.expect_file("dir/")
+ self.expect_file("afile")
+ self.expect_file("link1")
+ self.expect_file("link2")
+
def test_modes(self):
# Test how file modes are extracted
# (Note that the modes are ignored on platforms without working chmod)
@@ -4031,24 +4346,64 @@ class TestExtractionFilters(unittest.TestCase):
# The 'tar' filter returns TarInfo objects with the same name/type.
# (It can also fail for particularly "evil" input, but we don't have
# that in the test archive.)
- with tarfile.TarFile.open(tarname) as tar:
+ with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar:
for tarinfo in tar.getmembers():
- filtered = tarfile.tar_filter(tarinfo, '')
+ try:
+ filtered = tarfile.tar_filter(tarinfo, '')
+ except UnicodeEncodeError:
+ continue
self.assertIs(filtered.name, tarinfo.name)
self.assertIs(filtered.type, tarinfo.type)
def test_data_filter(self):
# The 'data' filter either raises, or returns TarInfo with the same
# name/type.
- with tarfile.TarFile.open(tarname) as tar:
+ with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar:
for tarinfo in tar.getmembers():
try:
filtered = tarfile.data_filter(tarinfo, '')
- except tarfile.FilterError:
+ except (tarfile.FilterError, UnicodeEncodeError):
continue
self.assertIs(filtered.name, tarinfo.name)
self.assertIs(filtered.type, tarinfo.type)
+ @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths')
+ def test_filter_unencodable(self):
+ # Sanity check using a valid path.
+ tarinfo = tarfile.TarInfo(os_helper.TESTFN)
+ filtered = tarfile.tar_filter(tarinfo, '')
+ self.assertIs(filtered.name, tarinfo.name)
+ filtered = tarfile.data_filter(tarinfo, '')
+ self.assertIs(filtered.name, tarinfo.name)
+
+ tarinfo = tarfile.TarInfo('test\x00')
+ self.assertRaises(ValueError, tarfile.tar_filter, tarinfo, '')
+ self.assertRaises(ValueError, tarfile.data_filter, tarinfo, '')
+ tarinfo = tarfile.TarInfo('\ud800')
+ self.assertRaises(UnicodeEncodeError, tarfile.tar_filter, tarinfo, '')
+ self.assertRaises(UnicodeEncodeError, tarfile.data_filter, tarinfo, '')
+
+ @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths')
+ def test_extract_unencodable(self):
+ # Create a member with name \xed\xa0\x80 which is UTF-8 encoded
+ # lone surrogate \ud800.
+ with ArchiveMaker(encoding='ascii', errors='surrogateescape') as arc:
+ arc.add('\udced\udca0\udc80')
+ with os_helper.temp_cwd() as tmp:
+ tar = arc.open(encoding='utf-8', errors='surrogatepass',
+ errorlevel=1)
+ self.assertEqual(tar.getnames(), ['\ud800'])
+ with self.assertRaises(UnicodeEncodeError):
+ tar.extractall()
+ self.assertEqual(os.listdir(), [])
+
+ tar = arc.open(encoding='utf-8', errors='surrogatepass',
+ errorlevel=0, debug=1)
+ with support.captured_stderr() as stderr:
+ tar.extractall()
+ self.assertEqual(os.listdir(), [])
+ self.assertIn('tarfile: UnicodeEncodeError ', stderr.getvalue())
+
def test_change_default_filter_on_instance(self):
tar = tarfile.TarFile(tarname, 'r')
def strict_filter(tarinfo, path):
@@ -4161,13 +4516,13 @@ class TestExtractionFilters(unittest.TestCase):
# If errorlevel is 0, errors affected by errorlevel are ignored
with self.check_context(arc.open(errorlevel=0), extracterror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), filtererror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), oserror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), tarerror_filter):
self.expect_exception(tarfile.TarError)
@@ -4178,7 +4533,7 @@ class TestExtractionFilters(unittest.TestCase):
# If 1, all fatal errors are raised
with self.check_context(arc.open(errorlevel=1), extracterror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=1), filtererror_filter):
self.expect_exception(tarfile.FilterError)
@@ -4257,7 +4612,7 @@ def setUpModule():
data = fobj.read()
# Create compressed tarfiles.
- for c in GzipTest, Bz2Test, LzmaTest:
+ for c in GzipTest, Bz2Test, LzmaTest, ZstdTest:
if c.open:
os_helper.unlink(c.tarname)
testtarnames.append(c.tarname)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index d46d3c0f040..52b13b98cbc 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -516,11 +516,11 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase):
_mock_candidate_names('aaa', 'aaa', 'bbb'):
(fd1, name1) = self.make_temp()
os.close(fd1)
- self.assertTrue(name1.endswith('aaa'))
+ self.assertEndsWith(name1, 'aaa')
(fd2, name2) = self.make_temp()
os.close(fd2)
- self.assertTrue(name2.endswith('bbb'))
+ self.assertEndsWith(name2, 'bbb')
def test_collision_with_existing_directory(self):
# _mkstemp_inner tries another name when a directory with
@@ -528,11 +528,11 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase):
with _inside_empty_temp_dir(), \
_mock_candidate_names('aaa', 'aaa', 'bbb'):
dir = tempfile.mkdtemp()
- self.assertTrue(dir.endswith('aaa'))
+ self.assertEndsWith(dir, 'aaa')
(fd, name) = self.make_temp()
os.close(fd)
- self.assertTrue(name.endswith('bbb'))
+ self.assertEndsWith(name, 'bbb')
class TestGetTempPrefix(BaseTestCase):
@@ -828,9 +828,9 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase):
_mock_candidate_names('aaa', 'aaa', 'bbb'):
file = tempfile.NamedTemporaryFile(delete=False)
file.close()
- self.assertTrue(file.name.endswith('aaa'))
+ self.assertEndsWith(file.name, 'aaa')
dir = tempfile.mkdtemp()
- self.assertTrue(dir.endswith('bbb'))
+ self.assertEndsWith(dir, 'bbb')
def test_collision_with_existing_directory(self):
# mkdtemp tries another name when a directory with
@@ -838,9 +838,9 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase):
with _inside_empty_temp_dir(), \
_mock_candidate_names('aaa', 'aaa', 'bbb'):
dir1 = tempfile.mkdtemp()
- self.assertTrue(dir1.endswith('aaa'))
+ self.assertEndsWith(dir1, 'aaa')
dir2 = tempfile.mkdtemp()
- self.assertTrue(dir2.endswith('bbb'))
+ self.assertEndsWith(dir2, 'bbb')
def test_for_tempdir_is_bytes_issue40701_api_warts(self):
orig_tempdir = tempfile.tempdir
diff --git a/Lib/test/test_termios.py b/Lib/test/test_termios.py
index e5d11cf84d2..ce8392a6ccd 100644
--- a/Lib/test/test_termios.py
+++ b/Lib/test/test_termios.py
@@ -290,8 +290,8 @@ class TestModule(unittest.TestCase):
self.assertGreaterEqual(value, 0)
def test_exception(self):
- self.assertTrue(issubclass(termios.error, Exception))
- self.assertFalse(issubclass(termios.error, OSError))
+ self.assertIsSubclass(termios.error, Exception)
+ self.assertNotIsSubclass(termios.error, OSError)
if __name__ == '__main__':
diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py
index 420fc6ec8be..acb427b0c78 100644
--- a/Lib/test/test_threadedtempfile.py
+++ b/Lib/test/test_threadedtempfile.py
@@ -15,6 +15,7 @@ provoking a 2.0 failure under Linux.
import tempfile
+from test import support
from test.support import threading_helper
import unittest
import io
@@ -49,7 +50,8 @@ class TempFileGreedy(threading.Thread):
class ThreadedTempFileTest(unittest.TestCase):
- def test_main(self):
+ @support.bigmemtest(size=NUM_THREADS, memuse=60*2**20, dry_run=False)
+ def test_main(self, size):
threads = [TempFileGreedy() for i in range(NUM_THREADS)]
with threading_helper.start_threads(threads, startEvent.set):
pass
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index b7688863626..13b55d0f0a2 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -5,7 +5,7 @@ Tests for the threading module.
import test.support
from test.support import threading_helper, requires_subprocess, requires_gil_enabled
from test.support import verbose, cpython_only, os_helper
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.script_helper import assert_python_ok, assert_python_failure
from test.support import force_not_colorized
@@ -28,7 +28,7 @@ from test import lock_tests
from test import support
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ImportError:
interpreters = None
@@ -121,6 +121,10 @@ class ThreadTests(BaseTestCase):
maxDiff = 9999
@cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("threading", {"functools", "warnings"})
+
+ @cpython_only
def test_name(self):
def func(): pass
@@ -526,7 +530,8 @@ class ThreadTests(BaseTestCase):
finally:
sys.setswitchinterval(old_interval)
- def test_join_from_multiple_threads(self):
+ @support.bigmemtest(size=20, memuse=72*2**20, dry_run=False)
+ def test_join_from_multiple_threads(self, size):
# Thread.join() should be thread-safe
errors = []
@@ -1219,18 +1224,18 @@ class ThreadTests(BaseTestCase):
import threading
done = threading.Event()
- def loop():
+ def set_event():
done.set()
-
class Cycle:
def __init__(self):
self.self_ref = self
- self.thr = threading.Thread(target=loop, daemon=True)
+ self.thr = threading.Thread(target=set_event, daemon=True)
self.thr.start()
- done.wait()
+ self.thr.join()
def __del__(self):
+ assert done.is_set()
assert not self.thr.is_alive()
self.thr.join()
assert not self.thr.is_alive()
@@ -1248,7 +1253,7 @@ class ThreadTests(BaseTestCase):
# its state should be removed from interpreter' thread states list
# to avoid its double cleanup
try:
- from resource import setrlimit, RLIMIT_NPROC
+ from resource import setrlimit, RLIMIT_NPROC # noqa: F401
except ImportError as err:
self.skipTest(err) # RLIMIT_NPROC is specific to Linux and BSD
code = """if 1:
@@ -1279,12 +1284,6 @@ class ThreadTests(BaseTestCase):
@cpython_only
def test_finalize_daemon_thread_hang(self):
- if support.check_sanitizer(thread=True, memory=True):
- # the thread running `time.sleep(100)` below will still be alive
- # at process exit
- self.skipTest(
- "https://github.com/python/cpython/issues/124878 - Known"
- " race condition that TSAN identifies.")
# gh-87135: tests that daemon threads hang during finalization
script = textwrap.dedent('''
import os
@@ -1347,6 +1346,35 @@ class ThreadTests(BaseTestCase):
''')
assert_python_ok("-c", script)
+ @skip_unless_reliable_fork
+ @unittest.skipUnless(hasattr(threading, 'get_native_id'), "test needs threading.get_native_id()")
+ def test_native_id_after_fork(self):
+ script = """if True:
+ import threading
+ import os
+ from test import support
+
+ parent_thread_native_id = threading.current_thread().native_id
+ print(parent_thread_native_id, flush=True)
+ assert parent_thread_native_id == threading.get_native_id()
+ childpid = os.fork()
+ if childpid == 0:
+ print(threading.current_thread().native_id, flush=True)
+ assert threading.current_thread().native_id == threading.get_native_id()
+ else:
+ try:
+ assert parent_thread_native_id == threading.current_thread().native_id
+ assert parent_thread_native_id == threading.get_native_id()
+ finally:
+ support.wait_process(childpid, exitcode=0)
+ """
+ rc, out, err = assert_python_ok('-c', script)
+ self.assertEqual(rc, 0)
+ self.assertEqual(err, b"")
+ native_ids = out.strip().splitlines()
+ self.assertEqual(len(native_ids), 2)
+ self.assertNotEqual(native_ids[0], native_ids[1])
+
class ThreadJoinOnShutdown(BaseTestCase):
def _run_and_join(self, script):
@@ -1427,7 +1455,8 @@ class ThreadJoinOnShutdown(BaseTestCase):
self._run_and_join(script)
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
- def test_4_daemon_threads(self):
+ @support.bigmemtest(size=40, memuse=70*2**20, dry_run=False)
+ def test_4_daemon_threads(self, size):
# Check that a daemon thread cannot crash the interpreter on shutdown
# by manipulating internal structures that are being disposed of in
# the main thread.
@@ -2131,8 +2160,7 @@ class CRLockTests(lock_tests.RLockTests):
]
for args, kwargs in arg_types:
with self.subTest(args=args, kwargs=kwargs):
- with self.assertWarns(DeprecationWarning):
- threading.RLock(*args, **kwargs)
+ self.assertRaises(TypeError, threading.RLock, *args, **kwargs)
# Subtypes with custom `__init__` are allowed (but, not recommended):
class CustomRLock(self.locktype):
@@ -2150,6 +2178,9 @@ class ConditionAsRLockTests(lock_tests.RLockTests):
# Condition uses an RLock by default and exports its API.
locktype = staticmethod(threading.Condition)
+ def test_constructor_noargs(self):
+ self.skipTest("Condition allows positional arguments")
+
def test_recursion_count(self):
self.skipTest("Condition does not expose _recursion_count()")
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index d06f65270ef..5312faa5077 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -761,17 +761,17 @@ class TestPytime(unittest.TestCase):
# Get the localtime and examine it for the offset and zone.
lt = time.localtime()
- self.assertTrue(hasattr(lt, "tm_gmtoff"))
- self.assertTrue(hasattr(lt, "tm_zone"))
+ self.assertHasAttr(lt, "tm_gmtoff")
+ self.assertHasAttr(lt, "tm_zone")
# See if the offset and zone are similar to the module
# attributes.
if lt.tm_gmtoff is None:
- self.assertTrue(not hasattr(time, "timezone"))
+ self.assertNotHasAttr(time, "timezone")
else:
self.assertEqual(lt.tm_gmtoff, -[time.timezone, time.altzone][lt.tm_isdst])
if lt.tm_zone is None:
- self.assertTrue(not hasattr(time, "tzname"))
+ self.assertNotHasAttr(time, "tzname")
else:
self.assertEqual(lt.tm_zone, time.tzname[lt.tm_isdst])
@@ -1184,11 +1184,11 @@ class TestTimeWeaklinking(unittest.TestCase):
if mac_ver >= (10, 12):
for name in clock_names:
- self.assertTrue(hasattr(time, name), f"time.{name} is not available")
+ self.assertHasAttr(time, name)
else:
for name in clock_names:
- self.assertFalse(hasattr(time, name), f"time.{name} is available")
+ self.assertNotHasAttr(time, name)
if __name__ == "__main__":
diff --git a/Lib/test/test_timeit.py b/Lib/test/test_timeit.py
index f5ae0a84eb3..2aeebea9f93 100644
--- a/Lib/test/test_timeit.py
+++ b/Lib/test/test_timeit.py
@@ -222,8 +222,8 @@ class TestTimeit(unittest.TestCase):
def assert_exc_string(self, exc_string, expected_exc_name):
exc_lines = exc_string.splitlines()
self.assertGreater(len(exc_lines), 2)
- self.assertTrue(exc_lines[0].startswith('Traceback'))
- self.assertTrue(exc_lines[-1].startswith(expected_exc_name))
+ self.assertStartsWith(exc_lines[0], 'Traceback')
+ self.assertStartsWith(exc_lines[-1], expected_exc_name)
def test_print_exc(self):
s = io.StringIO()
diff --git a/Lib/test/test_tkinter/support.py b/Lib/test/test_tkinter/support.py
index ebb9e00ff91..46b01e6f131 100644
--- a/Lib/test/test_tkinter/support.py
+++ b/Lib/test/test_tkinter/support.py
@@ -58,7 +58,7 @@ class AbstractDefaultRootTest:
destroy_default_root()
tkinter.NoDefaultRoot()
self.assertRaises(RuntimeError, constructor)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
def destroy_default_root():
diff --git a/Lib/test/test_tkinter/test_misc.py b/Lib/test/test_tkinter/test_misc.py
index 96ea3f0117c..0c76e07066f 100644
--- a/Lib/test/test_tkinter/test_misc.py
+++ b/Lib/test/test_tkinter/test_misc.py
@@ -497,7 +497,7 @@ class MiscTest(AbstractTkTest, unittest.TestCase):
self.assertEqual(vi.serial, 0)
else:
self.assertEqual(vi.micro, 0)
- self.assertTrue(str(vi).startswith(f'{vi.major}.{vi.minor}'))
+ self.assertStartsWith(str(vi), f'{vi.major}.{vi.minor}')
def test_embedded_null(self):
widget = tkinter.Entry(self.root)
@@ -609,7 +609,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, '??')
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, '??')
self.assertEqual(e.char, '??')
@@ -642,7 +642,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, '??')
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, '??')
self.assertEqual(e.char, '??')
@@ -676,7 +676,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertIsInstance(e.state, int)
self.assertNotEqual(e.state, 0)
@@ -747,7 +747,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, 1)
self.assertEqual(e.state, 0)
self.assertEqual(e.char, '??')
@@ -781,7 +781,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0x100)
self.assertEqual(e.char, '??')
@@ -814,7 +814,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIs(e.widget, f)
self.assertIsInstance(e.serial, int)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.time, 0)
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0)
@@ -849,7 +849,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0)
self.assertEqual(e.char, '??')
@@ -1308,17 +1308,17 @@ class DefaultRootTest(AbstractDefaultRootTest, unittest.TestCase):
self.assertIs(tkinter._default_root, root)
tkinter.NoDefaultRoot()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
# repeated call is no-op
tkinter.NoDefaultRoot()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root.destroy()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root = tkinter.Tk()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root.destroy()
def test_getboolean(self):
diff --git a/Lib/test/test_tkinter/test_widgets.py b/Lib/test/test_tkinter/test_widgets.py
index f6e77973061..ff3f92e9b5e 100644
--- a/Lib/test/test_tkinter/test_widgets.py
+++ b/Lib/test/test_tkinter/test_widgets.py
@@ -354,6 +354,11 @@ class OptionMenuTest(MenubuttonTest, unittest.TestCase):
with self.assertRaisesRegex(TclError, r"^unknown option -image$"):
tkinter.OptionMenu(self.root, None, 'b', image='')
+ def test_specify_name(self):
+ widget = tkinter.OptionMenu(self.root, None, ':)', name="option_menu")
+ self.assertEqual(str(widget), ".option_menu")
+ self.assertIs(self.root.children["option_menu"], widget)
+
@add_configure_tests(IntegerSizeTests, StandardOptionsTests)
class EntryTest(AbstractWidgetTest, unittest.TestCase):
_rounds_pixels = (tk_version < (9, 0))
diff --git a/Lib/test/test_tkinter/widget_tests.py b/Lib/test/test_tkinter/widget_tests.py
index ac7fb5977e0..f518925e994 100644
--- a/Lib/test/test_tkinter/widget_tests.py
+++ b/Lib/test/test_tkinter/widget_tests.py
@@ -65,7 +65,7 @@ class AbstractWidgetTest(AbstractTkTest):
orig = widget[name]
if errmsg is not None:
errmsg = errmsg.format(re.escape(str(value)))
- errmsg = fr'\A{errmsg}\Z'
+ errmsg = fr'\A{errmsg}\z'
with self.assertRaisesRegex(tkinter.TclError, errmsg or ''):
widget[name] = value
self.assertEqual(widget[name], orig)
diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py
index 2d41a5e5ac0..865e0c5b40d 100644
--- a/Lib/test/test_tokenize.py
+++ b/Lib/test/test_tokenize.py
@@ -1,6 +1,8 @@
import contextlib
+import itertools
import os
import re
+import string
import tempfile
import token
import tokenize
@@ -1975,6 +1977,10 @@ if 1:
for case in cases:
self.check_roundtrip(case)
+ self.check_roundtrip(r"t'{ {}}'")
+ self.check_roundtrip(r"t'{f'{ {}}'}{ {}}'")
+ self.check_roundtrip(r"f'{t'{ {}}'}{ {}}'")
+
def test_continuation(self):
# Balancing continuation
@@ -3234,5 +3240,77 @@ class CommandLineTest(unittest.TestCase):
self.check_output(source, expect, flag)
+class StringPrefixTest(unittest.TestCase):
+ @staticmethod
+ def determine_valid_prefixes():
+ # Try all lengths until we find a length that has zero valid
+ # prefixes. This will miss the case where for example there
+ # are no valid 3 character prefixes, but there are valid 4
+ # character prefixes. That seems unlikely.
+
+ single_char_valid_prefixes = set()
+
+ # Find all of the single character string prefixes. Just get
+ # the lowercase version, we'll deal with combinations of upper
+ # and lower case later. I'm using this logic just in case
+ # some uppercase-only prefix is added.
+ for letter in itertools.chain(string.ascii_lowercase, string.ascii_uppercase):
+ try:
+ eval(f'{letter}""')
+ single_char_valid_prefixes.add(letter.lower())
+ except SyntaxError:
+ pass
+
+ # This logic assumes that all combinations of valid prefixes only use
+ # the characters that are valid single character prefixes. That seems
+ # like a valid assumption, but if it ever changes this will need
+ # adjusting.
+ valid_prefixes = set()
+ for length in itertools.count():
+ num_at_this_length = 0
+ for prefix in (
+ "".join(l)
+ for l in itertools.combinations(single_char_valid_prefixes, length)
+ ):
+ for t in itertools.permutations(prefix):
+ for u in itertools.product(*[(c, c.upper()) for c in t]):
+ p = "".join(u)
+ if p == "not":
+ # 'not' can never be a string prefix,
+ # because it's a valid expression: not ""
+ continue
+ try:
+ eval(f'{p}""')
+
+ # No syntax error, so p is a valid string
+ # prefix.
+
+ valid_prefixes.add(p)
+ num_at_this_length += 1
+ except SyntaxError:
+ pass
+ if num_at_this_length == 0:
+ return valid_prefixes
+
+
+ def test_prefixes(self):
+ # Get the list of defined string prefixes. I don't see an
+ # obvious documented way of doing this, but probably the best
+ # thing is to split apart tokenize.StringPrefix.
+
+ # Make sure StringPrefix begins and ends in parens. We're
+ # assuming it's of the form "(a|b|ab)", if a, b, and cd are
+ # valid string prefixes.
+ self.assertEqual(tokenize.StringPrefix[0], '(')
+ self.assertEqual(tokenize.StringPrefix[-1], ')')
+
+ # Then split apart everything else by '|'.
+ defined_prefixes = set(tokenize.StringPrefix[1:-1].split('|'))
+
+ # Now compute the actual allowed string prefixes and compare
+ # to what is defined in the tokenize module.
+ self.assertEqual(defined_prefixes, self.determine_valid_prefixes())
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_tools/i18n_data/docstrings.py b/Lib/test/test_tools/i18n_data/docstrings.py
index 151a55a4b56..14559a632da 100644
--- a/Lib/test/test_tools/i18n_data/docstrings.py
+++ b/Lib/test/test_tools/i18n_data/docstrings.py
@@ -1,7 +1,7 @@
"""Module docstring"""
# Test docstring extraction
-from gettext import gettext as _
+from gettext import gettext as _ # noqa: F401
# Empty docstring
diff --git a/Lib/test/test_tools/test_i18n.py b/Lib/test/test_tools/test_i18n.py
index 8416b1bad82..d1831d68f02 100644
--- a/Lib/test/test_tools/test_i18n.py
+++ b/Lib/test/test_tools/test_i18n.py
@@ -162,6 +162,14 @@ class Test_pygettext(unittest.TestCase):
# This will raise if the date format does not exactly match.
datetime.strptime(creationDate, '%Y-%m-%d %H:%M%z')
+ def test_output_option(self):
+ for opt in ('-o', '--output='):
+ with temp_cwd():
+ assert_python_ok(self.script, f'{opt}test')
+ self.assertTrue(os.path.exists('test'))
+ res = assert_python_ok(self.script, f'{opt}-')
+ self.assertIn(b'Project-Id-Version: PACKAGE VERSION', res.out)
+
def test_funcdocstring(self):
for doc in ('"""doc"""', "r'''doc'''", "R'doc'", 'u"doc"'):
with self.subTest(doc):
diff --git a/Lib/test/test_tools/test_msgfmt.py b/Lib/test/test_tools/test_msgfmt.py
index ea10d4693df..7be606bbff6 100644
--- a/Lib/test/test_tools/test_msgfmt.py
+++ b/Lib/test/test_tools/test_msgfmt.py
@@ -9,18 +9,21 @@ from pathlib import Path
from test.support.os_helper import temp_cwd
from test.support.script_helper import assert_python_failure, assert_python_ok
-from test.test_tools import skip_if_missing, toolsdir
+from test.test_tools import imports_under_tool, skip_if_missing, toolsdir
skip_if_missing('i18n')
data_dir = (Path(__file__).parent / 'msgfmt_data').resolve()
script_dir = Path(toolsdir) / 'i18n'
-msgfmt = script_dir / 'msgfmt.py'
+msgfmt_py = script_dir / 'msgfmt.py'
+
+with imports_under_tool("i18n"):
+ import msgfmt
def compile_messages(po_file, mo_file):
- assert_python_ok(msgfmt, '-o', mo_file, po_file)
+ assert_python_ok(msgfmt_py, '-o', mo_file, po_file)
class CompilationTest(unittest.TestCase):
@@ -92,7 +95,7 @@ class CompilationTest(unittest.TestCase):
with temp_cwd():
Path('bom.po').write_bytes(b'\xef\xbb\xbfmsgid "Python"\nmsgstr "Pioton"\n')
- res = assert_python_failure(msgfmt, 'bom.po')
+ res = assert_python_failure(msgfmt_py, 'bom.po')
err = res.err.decode('utf-8')
self.assertIn('The file bom.po starts with a UTF-8 BOM', err)
@@ -103,7 +106,7 @@ msgid_plural "plural"
msgstr[0] "singular"
''')
- res = assert_python_failure(msgfmt, 'invalid.po')
+ res = assert_python_failure(msgfmt_py, 'invalid.po')
err = res.err.decode('utf-8')
self.assertIn('msgid_plural not preceded by msgid', err)
@@ -114,7 +117,7 @@ msgid "foo"
msgstr[0] "bar"
''')
- res = assert_python_failure(msgfmt, 'invalid.po')
+ res = assert_python_failure(msgfmt_py, 'invalid.po')
err = res.err.decode('utf-8')
self.assertIn('plural without msgid_plural', err)
@@ -126,7 +129,7 @@ msgid_plural "foos"
msgstr "bar"
''')
- res = assert_python_failure(msgfmt, 'invalid.po')
+ res = assert_python_failure(msgfmt_py, 'invalid.po')
err = res.err.decode('utf-8')
self.assertIn('indexed msgstr required for plural', err)
@@ -136,38 +139,136 @@ msgstr "bar"
"foo"
''')
- res = assert_python_failure(msgfmt, 'invalid.po')
+ res = assert_python_failure(msgfmt_py, 'invalid.po')
err = res.err.decode('utf-8')
self.assertIn('Syntax error', err)
+
+class POParserTest(unittest.TestCase):
+ @classmethod
+ def tearDownClass(cls):
+ # msgfmt uses a global variable to store messages,
+ # clear it after the tests.
+ msgfmt.MESSAGES.clear()
+
+ def test_strings(self):
+ # Test that the PO parser correctly handles and unescape
+ # strings in the PO file.
+ # The PO file format allows for a variety of escape sequences,
+ # octal and hex escapes.
+ valid_strings = (
+ # empty strings
+ ('""', ''),
+ ('"" "" ""', ''),
+ # allowed escape sequences
+ (r'"\\"', '\\'),
+ (r'"\""', '"'),
+ (r'"\t"', '\t'),
+ (r'"\n"', '\n'),
+ (r'"\r"', '\r'),
+ (r'"\f"', '\f'),
+ (r'"\a"', '\a'),
+ (r'"\b"', '\b'),
+ (r'"\v"', '\v'),
+ # non-empty strings
+ ('"foo"', 'foo'),
+ ('"foo" "bar"', 'foobar'),
+ ('"foo""bar"', 'foobar'),
+ ('"" "foo" ""', 'foo'),
+ # newlines and tabs
+ (r'"foo\nbar"', 'foo\nbar'),
+ (r'"foo\n" "bar"', 'foo\nbar'),
+ (r'"foo\tbar"', 'foo\tbar'),
+ (r'"foo\t" "bar"', 'foo\tbar'),
+ # escaped quotes
+ (r'"foo\"bar"', 'foo"bar'),
+ (r'"foo\"" "bar"', 'foo"bar'),
+ (r'"foo\\" "bar"', 'foo\\bar'),
+ # octal escapes
+ (r'"\120\171\164\150\157\156"', 'Python'),
+ (r'"\120\171\164" "\150\157\156"', 'Python'),
+ (r'"\"\120\171\164" "\150\157\156\""', '"Python"'),
+ # hex escapes
+ (r'"\x50\x79\x74\x68\x6f\x6e"', 'Python'),
+ (r'"\x50\x79\x74" "\x68\x6f\x6e"', 'Python'),
+ (r'"\"\x50\x79\x74" "\x68\x6f\x6e\""', '"Python"'),
+ )
+
+ with temp_cwd():
+ for po_string, expected in valid_strings:
+ with self.subTest(po_string=po_string):
+ # Construct a PO file with a single entry,
+ # compile it, read it into a catalog and
+ # check the result.
+ po = f'msgid {po_string}\nmsgstr "translation"'
+ Path('messages.po').write_text(po)
+ # Reset the global MESSAGES dictionary
+ msgfmt.MESSAGES.clear()
+ msgfmt.make('messages.po', 'messages.mo')
+
+ with open('messages.mo', 'rb') as f:
+ actual = GNUTranslations(f)
+
+ self.assertDictEqual(actual._catalog, {expected: 'translation'})
+
+ invalid_strings = (
+ # "''", # invalid but currently accepted
+ '"',
+ '"""',
+ '"" "',
+ 'foo',
+ '"" "foo',
+ '"foo" foo',
+ '42',
+ '"" 42 ""',
+ # disallowed escape sequences
+ # r'"\'"', # invalid but currently accepted
+ # r'"\e"', # invalid but currently accepted
+ # r'"\8"', # invalid but currently accepted
+ # r'"\9"', # invalid but currently accepted
+ r'"\x"',
+ r'"\u1234"',
+ r'"\N{ROMAN NUMERAL NINE}"'
+ )
+ with temp_cwd():
+ for invalid_string in invalid_strings:
+ with self.subTest(string=invalid_string):
+ po = f'msgid {invalid_string}\nmsgstr "translation"'
+ Path('messages.po').write_text(po)
+ # Reset the global MESSAGES dictionary
+ msgfmt.MESSAGES.clear()
+ with self.assertRaises(Exception):
+ msgfmt.make('messages.po', 'messages.mo')
+
+
class CLITest(unittest.TestCase):
def test_help(self):
for option in ('--help', '-h'):
- res = assert_python_ok(msgfmt, option)
+ res = assert_python_ok(msgfmt_py, option)
err = res.err.decode('utf-8')
self.assertIn('Generate binary message catalog from textual translation description.', err)
def test_version(self):
for option in ('--version', '-V'):
- res = assert_python_ok(msgfmt, option)
+ res = assert_python_ok(msgfmt_py, option)
out = res.out.decode('utf-8').strip()
self.assertEqual('msgfmt.py 1.2', out)
def test_invalid_option(self):
- res = assert_python_failure(msgfmt, '--invalid-option')
+ res = assert_python_failure(msgfmt_py, '--invalid-option')
err = res.err.decode('utf-8')
self.assertIn('Generate binary message catalog from textual translation description.', err)
self.assertIn('option --invalid-option not recognized', err)
def test_no_input_file(self):
- res = assert_python_ok(msgfmt)
+ res = assert_python_ok(msgfmt_py)
err = res.err.decode('utf-8').replace('\r\n', '\n')
self.assertIn('No input file given\n'
"Try `msgfmt --help' for more information.", err)
def test_nonexistent_file(self):
- assert_python_failure(msgfmt, 'nonexistent.po')
+ assert_python_failure(msgfmt_py, 'nonexistent.po')
def update_catalog_snapshots():
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index a806dbf1582..74b979d0096 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -37,6 +37,12 @@ test_code.co_positions = lambda _: iter([(6, 6, 0, 0)])
test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals'])
test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next', 'tb_lasti'])
+color_overrides = {"reset": "z", "filename": "fn", "error_highlight": "E"}
+colors = {
+ color_overrides.get(k, k[0].lower()): v
+ for k, v in _colorize.default_theme.traceback.items()
+}
+
LEVENSHTEIN_DATA_FILE = Path(__file__).parent / 'levenshtein_examples.json'
@@ -4182,6 +4188,15 @@ class SuggestionFormattingTestBase:
self.assertNotIn("blech", actual)
self.assertNotIn("oh no!", actual)
+ def test_attribute_error_with_non_string_candidates(self):
+ class T:
+ bluch = 1
+
+ instance = T()
+ instance.__dict__[0] = 1
+ actual = self.get_suggestion(instance, 'blich')
+ self.assertIn("bluch", actual)
+
def test_attribute_error_with_bad_name(self):
def raise_attribute_error_with_bad_name():
raise AttributeError(name=12, obj=23)
@@ -4217,8 +4232,8 @@ class SuggestionFormattingTestBase:
return mod_name
- def get_import_from_suggestion(self, mod_dict, name):
- modname = self.make_module(mod_dict)
+ def get_import_from_suggestion(self, code, name):
+ modname = self.make_module(code)
def callable():
try:
@@ -4295,6 +4310,13 @@ class SuggestionFormattingTestBase:
self.assertIn("'_bluch'", self.get_import_from_suggestion(code, '_luch'))
self.assertNotIn("'_bluch'", self.get_import_from_suggestion(code, 'bluch'))
+ def test_import_from_suggestions_non_string(self):
+ modWithNonStringAttr = textwrap.dedent("""\
+ globals()[0] = 1
+ bluch = 1
+ """)
+ self.assertIn("'bluch'", self.get_import_from_suggestion(modWithNonStringAttr, 'blech'))
+
def test_import_from_suggestions_do_not_trigger_for_long_attributes(self):
code = "blech = None"
@@ -4391,6 +4413,15 @@ class SuggestionFormattingTestBase:
actual = self.get_suggestion(func)
self.assertIn("'ZeroDivisionError'?", actual)
+ def test_name_error_suggestions_with_non_string_candidates(self):
+ def func():
+ abc = 1
+ custom_globals = globals().copy()
+ custom_globals[0] = 1
+ print(eval("abv", custom_globals, locals()))
+ actual = self.get_suggestion(func)
+ self.assertIn("abc", actual)
+
def test_name_error_suggestions_do_not_trigger_for_long_names(self):
def func():
somethingverywronghehehehehehe = None
@@ -4555,6 +4586,28 @@ class SuggestionFormattingTestBase:
actual = self.get_suggestion(instance.foo)
self.assertNotIn("self.blech", actual)
+ def test_unbound_local_error_with_side_effect(self):
+ # gh-132385
+ class A:
+ def __getattr__(self, key):
+ if key == 'foo':
+ raise AttributeError('foo')
+ if key == 'spam':
+ raise ValueError('spam')
+
+ def bar(self):
+ foo
+ def baz(self):
+ spam
+
+ suggestion = self.get_suggestion(A().bar)
+ self.assertNotIn('self.', suggestion)
+ self.assertIn("'foo'", suggestion)
+
+ suggestion = self.get_suggestion(A().baz)
+ self.assertNotIn('self.', suggestion)
+ self.assertIn("'spam'", suggestion)
+
def test_unbound_local_error_does_not_match(self):
def func():
something = 3
@@ -4699,6 +4752,8 @@ class MiscTest(unittest.TestCase):
class TestColorizedTraceback(unittest.TestCase):
+ maxDiff = None
+
def test_colorized_traceback(self):
def foo(*args):
x = {'a':{'b': None}}
@@ -4721,9 +4776,9 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
lines = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
+ red = colors["e"]
+ boldr = colors["E"]
+ reset = colors["z"]
self.assertIn("y = " + red + "x['a']['b']" + reset + boldr + "['c']" + reset, lines)
self.assertIn("return " + red + "(lambda *args: foo(*args))" + reset + boldr + "(1,2,3,4)" + reset, lines)
self.assertIn("return (lambda *args: " + red + "foo" + reset + boldr + "(*args)" + reset + ")(1,2,3,4)", lines)
@@ -4739,18 +4794,16 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
actual = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
- expected = "".join([
- f' File {magenta}"<string>"{reset}, line {magenta}1{reset}\n',
- f' a {boldr}${reset} b\n',
- f' {boldr}^{reset}\n',
- f'{boldm}SyntaxError{reset}: {magenta}invalid syntax{reset}\n']
- )
- self.assertIn(expected, actual)
+ def expected(t, m, fn, l, f, E, e, z):
+ return "".join(
+ [
+ f' File {fn}"<string>"{z}, line {l}1{z}\n',
+ f' a {E}${z} b\n',
+ f' {E}^{z}\n',
+ f'{t}SyntaxError{z}: {m}invalid syntax{z}\n'
+ ]
+ )
+ self.assertIn(expected(**colors), actual)
def test_colorized_traceback_is_the_default(self):
def foo():
@@ -4766,23 +4819,21 @@ class TestColorizedTraceback(unittest.TestCase):
exception_print(e)
actual = tbstderr.getvalue().splitlines()
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
- expected = ['Traceback (most recent call last):',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+5}{reset}, in {magenta}test_colorized_traceback_is_the_default{reset}',
- f' {red}foo{reset+boldr}(){reset}',
- f' {red}~~~{reset+boldr}^^{reset}',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+1}{reset}, in {magenta}foo{reset}',
- f' {red}1{reset+boldr}/{reset+red}0{reset}',
- f' {red}~{reset+boldr}^{reset+red}~{reset}',
- f'{boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ 'Traceback (most recent call last):',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+5}{z}, in {f}test_colorized_traceback_is_the_default{z}',
+ f' {e}foo{z}{E}(){z}',
+ f' {e}~~~{z}{E}^^{z}',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+1}{z}, in {f}foo{z}',
+ f' {e}1{z}{E}/{z}{e}0{z}',
+ f' {e}~{z}{E}^{z}{e}~{z}',
+ f'{t}ZeroDivisionError{z}: {m}division by zero{z}',
+ ]
+ self.assertEqual(actual, expected(**colors))
def test_colorized_traceback_from_exception_group(self):
def foo():
@@ -4800,33 +4851,31 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
actual = "".join(exc.format(colorize=True)).splitlines()
- expected = [f" + Exception Group Traceback (most recent call last):",
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+9}{reset}, in {magenta}test_colorized_traceback_from_exception_group{reset}',
- f' | {red}foo{reset}{boldr}(){reset}',
- f' | {red}~~~{reset}{boldr}^^{reset}',
- f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
- f" | foo = {foo}",
- f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+6}{reset}, in {magenta}foo{reset}',
- f' | raise ExceptionGroup("test", exceptions)',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ExceptionGroup{reset}: {magenta}test (1 sub-exception){reset}',
- f' +-+---------------- 1 ----------------',
- f' | Traceback (most recent call last):',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+3}{reset}, in {magenta}foo{reset}',
- f' | {red}1 {reset}{boldr}/{reset}{red} 0{reset}',
- f' | {red}~~{reset}{boldr}^{reset}{red}~~{reset}',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}',
- f' +------------------------------------']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ f" + Exception Group Traceback (most recent call last):",
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+9}{z}, in {f}test_colorized_traceback_from_exception_group{z}',
+ f' | {e}foo{z}{E}(){z}',
+ f' | {e}~~~{z}{E}^^{z}',
+ f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
+ f" | foo = {foo}",
+ f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+6}{z}, in {f}foo{z}',
+ f' | raise ExceptionGroup("test", exceptions)',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ExceptionGroup{z}: {m}test (1 sub-exception){z}',
+ f' +-+---------------- 1 ----------------',
+ f' | Traceback (most recent call last):',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+3}{z}, in {f}foo{z}',
+ f' | {e}1 {z}{E}/{z}{e} 0{z}',
+ f' | {e}~~{z}{E}^{z}{e}~~{z}',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ZeroDivisionError{z}: {m}division by zero{z}',
+ f' +------------------------------------',
+ ]
+ self.assertEqual(actual, expected(**colors))
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_tstring.py b/Lib/test/test_tstring.py
new file mode 100644
index 00000000000..aabae385567
--- /dev/null
+++ b/Lib/test/test_tstring.py
@@ -0,0 +1,314 @@
+import unittest
+
+from test.test_string._support import TStringBaseCase, fstring
+
+
+class TestTString(unittest.TestCase, TStringBaseCase):
+ def test_string_representation(self):
+ # Test __repr__
+ t = t"Hello"
+ self.assertEqual(repr(t), "Template(strings=('Hello',), interpolations=())")
+
+ name = "Python"
+ t = t"Hello, {name}"
+ self.assertEqual(repr(t),
+ "Template(strings=('Hello, ', ''), "
+ "interpolations=(Interpolation('Python', 'name', None, ''),))"
+ )
+
+ def test_interpolation_basics(self):
+ # Test basic interpolation
+ name = "Python"
+ t = t"Hello, {name}"
+ self.assertTStringEqual(t, ("Hello, ", ""), [(name, "name")])
+ self.assertEqual(fstring(t), "Hello, Python")
+
+ # Multiple interpolations
+ first = "Python"
+ last = "Developer"
+ t = t"{first} {last}"
+ self.assertTStringEqual(
+ t, ("", " ", ""), [(first, 'first'), (last, 'last')]
+ )
+ self.assertEqual(fstring(t), "Python Developer")
+
+ # Interpolation with expressions
+ a = 10
+ b = 20
+ t = t"Sum: {a + b}"
+ self.assertTStringEqual(t, ("Sum: ", ""), [(a + b, "a + b")])
+ self.assertEqual(fstring(t), "Sum: 30")
+
+ # Interpolation with function
+ def square(x):
+ return x * x
+ t = t"Square: {square(5)}"
+ self.assertTStringEqual(
+ t, ("Square: ", ""), [(square(5), "square(5)")]
+ )
+ self.assertEqual(fstring(t), "Square: 25")
+
+ # Test attribute access in expressions
+ class Person:
+ def __init__(self, name):
+ self.name = name
+
+ def upper(self):
+ return self.name.upper()
+
+ person = Person("Alice")
+ t = t"Name: {person.name}"
+ self.assertTStringEqual(
+ t, ("Name: ", ""), [(person.name, "person.name")]
+ )
+ self.assertEqual(fstring(t), "Name: Alice")
+
+ # Test method calls
+ t = t"Name: {person.upper()}"
+ self.assertTStringEqual(
+ t, ("Name: ", ""), [(person.upper(), "person.upper()")]
+ )
+ self.assertEqual(fstring(t), "Name: ALICE")
+
+ # Test dictionary access
+ data = {"name": "Bob", "age": 30}
+ t = t"Name: {data['name']}, Age: {data['age']}"
+ self.assertTStringEqual(
+ t, ("Name: ", ", Age: ", ""),
+ [(data["name"], "data['name']"), (data["age"], "data['age']")],
+ )
+ self.assertEqual(fstring(t), "Name: Bob, Age: 30")
+
+ def test_format_specifiers(self):
+ # Test basic format specifiers
+ value = 3.14159
+ t = t"Pi: {value:.2f}"
+ self.assertTStringEqual(
+ t, ("Pi: ", ""), [(value, "value", None, ".2f")]
+ )
+ self.assertEqual(fstring(t), "Pi: 3.14")
+
+ def test_conversions(self):
+ # Test !s conversion (str)
+ obj = object()
+ t = t"Object: {obj!s}"
+ self.assertTStringEqual(t, ("Object: ", ""), [(obj, "obj", "s")])
+ self.assertEqual(fstring(t), f"Object: {str(obj)}")
+
+ # Test !r conversion (repr)
+ t = t"Data: {obj!r}"
+ self.assertTStringEqual(t, ("Data: ", ""), [(obj, "obj", "r")])
+ self.assertEqual(fstring(t), f"Data: {repr(obj)}")
+
+ # Test !a conversion (ascii)
+ text = "Café"
+ t = t"ASCII: {text!a}"
+ self.assertTStringEqual(t, ("ASCII: ", ""), [(text, "text", "a")])
+ self.assertEqual(fstring(t), f"ASCII: {ascii(text)}")
+
+ # Test !z conversion (error)
+ num = 1
+ with self.assertRaises(SyntaxError):
+ eval("t'{num!z}'")
+
+ def test_debug_specifier(self):
+ # Test debug specifier
+ value = 42
+ t = t"Value: {value=}"
+ self.assertTStringEqual(
+ t, ("Value: value=", ""), [(value, "value", "r")]
+ )
+ self.assertEqual(fstring(t), "Value: value=42")
+
+ # Test debug specifier with format (conversion default to !r)
+ t = t"Value: {value=:.2f}"
+ self.assertTStringEqual(
+ t, ("Value: value=", ""), [(value, "value", None, ".2f")]
+ )
+ self.assertEqual(fstring(t), "Value: value=42.00")
+
+ # Test debug specifier with conversion
+ t = t"Value: {value=!s}"
+ self.assertTStringEqual(
+ t, ("Value: value=", ""), [(value, "value", "s")]
+ )
+
+ # Test white space in debug specifier
+ t = t"Value: {value = }"
+ self.assertTStringEqual(
+ t, ("Value: value = ", ""), [(value, "value", "r")]
+ )
+ self.assertEqual(fstring(t), "Value: value = 42")
+
+ def test_raw_tstrings(self):
+ path = r"C:\Users"
+ t = rt"{path}\Documents"
+ self.assertTStringEqual(t, ("", r"\Documents"), [(path, "path")])
+ self.assertEqual(fstring(t), r"C:\Users\Documents")
+
+ # Test alternative prefix
+ t = tr"{path}\Documents"
+ self.assertTStringEqual(t, ("", r"\Documents"), [(path, "path")])
+
+
+ def test_template_concatenation(self):
+ # Test template + template
+ t1 = t"Hello, "
+ t2 = t"world"
+ combined = t1 + t2
+ self.assertTStringEqual(combined, ("Hello, world",), ())
+ self.assertEqual(fstring(combined), "Hello, world")
+
+ # Test template + string
+ t1 = t"Hello"
+ combined = t1 + ", world"
+ self.assertTStringEqual(combined, ("Hello, world",), ())
+ self.assertEqual(fstring(combined), "Hello, world")
+
+ # Test template + template with interpolation
+ name = "Python"
+ t1 = t"Hello, "
+ t2 = t"{name}"
+ combined = t1 + t2
+ self.assertTStringEqual(combined, ("Hello, ", ""), [(name, "name")])
+ self.assertEqual(fstring(combined), "Hello, Python")
+
+ # Test string + template
+ t = "Hello, " + t"{name}"
+ self.assertTStringEqual(t, ("Hello, ", ""), [(name, "name")])
+ self.assertEqual(fstring(t), "Hello, Python")
+
+ def test_nested_templates(self):
+ # Test a template inside another template expression
+ name = "Python"
+ inner = t"{name}"
+ t = t"Language: {inner}"
+
+ t_interp = t.interpolations[0]
+ self.assertEqual(t.strings, ("Language: ", ""))
+ self.assertEqual(t_interp.value.strings, ("", ""))
+ self.assertEqual(t_interp.value.interpolations[0].value, name)
+ self.assertEqual(t_interp.value.interpolations[0].expression, "name")
+ self.assertEqual(t_interp.value.interpolations[0].conversion, None)
+ self.assertEqual(t_interp.value.interpolations[0].format_spec, "")
+ self.assertEqual(t_interp.expression, "inner")
+ self.assertEqual(t_interp.conversion, None)
+ self.assertEqual(t_interp.format_spec, "")
+
+ def test_syntax_errors(self):
+ for case, err in (
+ ("t'", "unterminated t-string literal"),
+ ("t'''", "unterminated triple-quoted t-string literal"),
+ ("t''''", "unterminated triple-quoted t-string literal"),
+ ("t'{", "'{' was never closed"),
+ ("t'{'", "t-string: expecting '}'"),
+ ("t'{a'", "t-string: expecting '}'"),
+ ("t'}'", "t-string: single '}' is not allowed"),
+ ("t'{}'", "t-string: valid expression required before '}'"),
+ ("t'{=x}'", "t-string: valid expression required before '='"),
+ ("t'{!x}'", "t-string: valid expression required before '!'"),
+ ("t'{:x}'", "t-string: valid expression required before ':'"),
+ ("t'{x;y}'", "t-string: expecting '=', or '!', or ':', or '}'"),
+ ("t'{x=y}'", "t-string: expecting '!', or ':', or '}'"),
+ ("t'{x!s!}'", "t-string: expecting ':' or '}'"),
+ ("t'{x!s:'", "t-string: expecting '}', or format specs"),
+ ("t'{x!}'", "t-string: missing conversion character"),
+ ("t'{x=!}'", "t-string: missing conversion character"),
+ ("t'{x!z}'", "t-string: invalid conversion character 'z': "
+ "expected 's', 'r', or 'a'"),
+ ("t'{lambda:1}'", "t-string: lambda expressions are not allowed "
+ "without parentheses"),
+ ("t'{x:{;}}'", "t-string: expecting a valid expression after '{'"),
+ ("t'{1:d\n}'", "t-string: newlines are not allowed in format specifiers")
+ ):
+ with self.subTest(case), self.assertRaisesRegex(SyntaxError, err):
+ eval(case)
+
+ def test_runtime_errors(self):
+ # Test missing variables
+ with self.assertRaises(NameError):
+ eval("t'Hello, {name}'")
+
+ def test_literal_concatenation(self):
+ # Test concatenation of t-string literals
+ t = t"Hello, " t"world"
+ self.assertTStringEqual(t, ("Hello, world",), ())
+ self.assertEqual(fstring(t), "Hello, world")
+
+ # Test concatenation with interpolation
+ name = "Python"
+ t = t"Hello, " t"{name}"
+ self.assertTStringEqual(t, ("Hello, ", ""), [(name, "name")])
+ self.assertEqual(fstring(t), "Hello, Python")
+
+ # Test concatenation with string literal
+ name = "Python"
+ t = t"Hello, {name}" "and welcome!"
+ self.assertTStringEqual(
+ t, ("Hello, ", "and welcome!"), [(name, "name")]
+ )
+ self.assertEqual(fstring(t), "Hello, Pythonand welcome!")
+
+ # Test concatenation with Unicode literal
+ name = "Python"
+ t = t"Hello, {name}" u"and welcome!"
+ self.assertTStringEqual(
+ t, ("Hello, ", "and welcome!"), [(name, "name")]
+ )
+ self.assertEqual(fstring(t), "Hello, Pythonand welcome!")
+
+ # Test concatenation with f-string literal
+ tab = '\t'
+ t = t"Tab: {tab}. " f"f-tab: {tab}."
+ self.assertTStringEqual(t, ("Tab: ", ". f-tab: \t."), [(tab, "tab")])
+ self.assertEqual(fstring(t), "Tab: \t. f-tab: \t.")
+
+ # Test concatenation with raw string literal
+ tab = '\t'
+ t = t"Tab: {tab}. " r"Raw tab: \t."
+ self.assertTStringEqual(
+ t, ("Tab: ", r". Raw tab: \t."), [(tab, "tab")]
+ )
+ self.assertEqual(fstring(t), "Tab: \t. Raw tab: \\t.")
+
+ # Test concatenation with raw f-string literal
+ tab = '\t'
+ t = t"Tab: {tab}. " rf"f-tab: {tab}. Raw tab: \t."
+ self.assertTStringEqual(
+ t, ("Tab: ", ". f-tab: \t. Raw tab: \\t."), [(tab, "tab")]
+ )
+ self.assertEqual(fstring(t), "Tab: \t. f-tab: \t. Raw tab: \\t.")
+
+ what = 't'
+ expected_msg = 'cannot mix bytes and nonbytes literals'
+ for case in (
+ "t'{what}-string literal' b'bytes literal'",
+ "t'{what}-string literal' br'raw bytes literal'",
+ ):
+ with self.assertRaisesRegex(SyntaxError, expected_msg):
+ eval(case)
+
+ def test_triple_quoted(self):
+ # Test triple-quoted t-strings
+ t = t"""
+ Hello,
+ world
+ """
+ self.assertTStringEqual(
+ t, ("\n Hello,\n world\n ",), ()
+ )
+ self.assertEqual(fstring(t), "\n Hello,\n world\n ")
+
+ # Test triple-quoted with interpolation
+ name = "Python"
+ t = t"""
+ Hello,
+ {name}
+ """
+ self.assertTStringEqual(
+ t, ("\n Hello,\n ", "\n "), [(name, "name")]
+ )
+ self.assertEqual(fstring(t), "\n Hello,\n Python\n ")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_ttk/test_extensions.py b/Lib/test/test_ttk/test_extensions.py
index d5e06971697..05bca59e703 100644
--- a/Lib/test/test_ttk/test_extensions.py
+++ b/Lib/test/test_ttk/test_extensions.py
@@ -319,6 +319,12 @@ class OptionMenuTest(AbstractTkTest, unittest.TestCase):
textvar.trace_remove("write", cb_name)
optmenu.destroy()
+ def test_specify_name(self):
+ textvar = tkinter.StringVar(self.root)
+ widget = ttk.OptionMenu(self.root, textvar, ":)", name="option_menu_ex")
+ self.assertEqual(str(widget), ".option_menu_ex")
+ self.assertIs(self.root.children["option_menu_ex"], widget)
+
class DefaultRootTest(AbstractDefaultRootTest, unittest.TestCase):
diff --git a/Lib/test/test_ttk/test_widgets.py b/Lib/test/test_ttk/test_widgets.py
index d5620becfa7..f33da2a8848 100644
--- a/Lib/test/test_ttk/test_widgets.py
+++ b/Lib/test/test_ttk/test_widgets.py
@@ -490,7 +490,7 @@ class ComboboxTest(EntryTest, unittest.TestCase):
width = self.combo.winfo_width()
x, y = width - 5, 5
if sys.platform != 'darwin': # there's no down arrow on macOS
- self.assertRegex(self.combo.identify(x, y), r'.*downarrow\Z')
+ self.assertRegex(self.combo.identify(x, y), r'.*downarrow\z')
self.combo.event_generate('<Button-1>', x=x, y=y)
self.combo.event_generate('<ButtonRelease-1>', x=x, y=y)
@@ -1250,7 +1250,7 @@ class SpinboxTest(EntryTest, unittest.TestCase):
height = self.spin.winfo_height()
x = width - 5
y = height//2 - 5
- self.assertRegex(self.spin.identify(x, y), r'.*uparrow\Z')
+ self.assertRegex(self.spin.identify(x, y), r'.*uparrow\z')
self.spin.event_generate('<ButtonPress-1>', x=x, y=y)
self.spin.event_generate('<ButtonRelease-1>', x=x, y=y)
self.spin.update_idletasks()
@@ -1260,7 +1260,7 @@ class SpinboxTest(EntryTest, unittest.TestCase):
height = self.spin.winfo_height()
x = width - 5
y = height//2 + 4
- self.assertRegex(self.spin.identify(x, y), r'.*downarrow\Z')
+ self.assertRegex(self.spin.identify(x, y), r'.*downarrow\z')
self.spin.event_generate('<ButtonPress-1>', x=x, y=y)
self.spin.event_generate('<ButtonRelease-1>', x=x, y=y)
self.spin.update_idletasks()
diff --git a/Lib/test/test_type_annotations.py b/Lib/test/test_type_annotations.py
index b72d3dbe516..c66cb058552 100644
--- a/Lib/test/test_type_annotations.py
+++ b/Lib/test/test_type_annotations.py
@@ -327,6 +327,25 @@ class AnnotateTests(unittest.TestCase):
f.__annotations__ = {"z": 43}
self.assertIs(f.__annotate__, None)
+ def test_user_defined_annotate(self):
+ class X:
+ a: int
+
+ def __annotate__(format):
+ return {"a": str}
+ self.assertEqual(X.__annotate__(annotationlib.Format.VALUE), {"a": str})
+ self.assertEqual(annotationlib.get_annotations(X), {"a": str})
+
+ mod = build_module(
+ """
+ a: int
+ def __annotate__(format):
+ return {"a": str}
+ """
+ )
+ self.assertEqual(mod.__annotate__(annotationlib.Format.VALUE), {"a": str})
+ self.assertEqual(annotationlib.get_annotations(mod), {"a": str})
+
class DeferredEvaluationTests(unittest.TestCase):
def test_function(self):
@@ -479,6 +498,28 @@ class DeferredEvaluationTests(unittest.TestCase):
self.assertEqual(f.__annotate__(annotationlib.Format.VALUE), annos)
self.assertEqual(f.__annotations__, annos)
+ def test_set_annotations(self):
+ function_code = textwrap.dedent("""
+ def f(x: int):
+ pass
+ """)
+ class_code = textwrap.dedent("""
+ class f:
+ x: int
+ """)
+ for future in (False, True):
+ for label, code in (("function", function_code), ("class", class_code)):
+ with self.subTest(future=future, label=label):
+ if future:
+ code = "from __future__ import annotations\n" + code
+ ns = run_code(code)
+ f = ns["f"]
+ anno = "int" if future else int
+ self.assertEqual(f.__annotations__, {"x": anno})
+
+ f.__annotations__ = {"x": str}
+ self.assertEqual(f.__annotations__, {"x": str})
+
def test_name_clash_with_format(self):
# this test would fail if __annotate__'s parameter was called "format"
# during symbol table construction
diff --git a/Lib/test/test_type_comments.py b/Lib/test/test_type_comments.py
index ee8939f62d0..c40c45594f4 100644
--- a/Lib/test/test_type_comments.py
+++ b/Lib/test/test_type_comments.py
@@ -344,7 +344,7 @@ class TypeCommentTests(unittest.TestCase):
todo = set(t.name[1:])
self.assertEqual(len(t.args.args) + len(t.args.posonlyargs),
len(todo) - bool(t.args.vararg) - bool(t.args.kwarg))
- self.assertTrue(t.name.startswith('f'), t.name)
+ self.assertStartsWith(t.name, 'f')
for index, c in enumerate(t.name[1:]):
todo.remove(c)
if c == 'v':
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 3552b6b4ef8..fc26e71ffcb 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -2,7 +2,7 @@
from test.support import (
run_with_locale, cpython_only, no_rerun,
- MISSING_C_DOCSTRINGS, EqualToForwardRef,
+ MISSING_C_DOCSTRINGS, EqualToForwardRef, check_disallow_instantiation,
)
from test.support.script_helper import assert_python_ok
from test.support.import_helper import import_fresh_module
@@ -517,8 +517,8 @@ class TypesTests(unittest.TestCase):
# and a number after the decimal. This is tricky, because
# a totally empty format specifier means something else.
# So, just use a sign flag
- test(1e200, '+g', '+1e+200')
- test(1e200, '+', '+1e+200')
+ test(1.25e200, '+g', '+1.25e+200')
+ test(1.25e200, '+', '+1.25e+200')
test(1.1e200, '+g', '+1.1e+200')
test(1.1e200, '+', '+1.1e+200')
@@ -827,15 +827,15 @@ class UnionTests(unittest.TestCase):
self.assertIsInstance(True, x)
self.assertIsInstance('a', x)
self.assertNotIsInstance(None, x)
- self.assertTrue(issubclass(int, x))
- self.assertTrue(issubclass(bool, x))
- self.assertTrue(issubclass(str, x))
- self.assertFalse(issubclass(type(None), x))
+ self.assertIsSubclass(int, x)
+ self.assertIsSubclass(bool, x)
+ self.assertIsSubclass(str, x)
+ self.assertNotIsSubclass(type(None), x)
for x in (int | None, typing.Union[int, None]):
with self.subTest(x=x):
self.assertIsInstance(None, x)
- self.assertTrue(issubclass(type(None), x))
+ self.assertIsSubclass(type(None), x)
for x in (
int | collections.abc.Mapping,
@@ -844,8 +844,8 @@ class UnionTests(unittest.TestCase):
with self.subTest(x=x):
self.assertIsInstance({}, x)
self.assertNotIsInstance((), x)
- self.assertTrue(issubclass(dict, x))
- self.assertFalse(issubclass(list, x))
+ self.assertIsSubclass(dict, x)
+ self.assertNotIsSubclass(list, x)
def test_instancecheck_and_subclasscheck_order(self):
T = typing.TypeVar('T')
@@ -857,7 +857,7 @@ class UnionTests(unittest.TestCase):
for x in will_resolve:
with self.subTest(x=x):
self.assertIsInstance(1, x)
- self.assertTrue(issubclass(int, x))
+ self.assertIsSubclass(int, x)
wont_resolve = (
T | int,
@@ -890,7 +890,7 @@ class UnionTests(unittest.TestCase):
def __subclasscheck__(cls, sub):
1/0
x = int | BadMeta('A', (), {})
- self.assertTrue(issubclass(int, x))
+ self.assertIsSubclass(int, x)
self.assertRaises(ZeroDivisionError, issubclass, list, x)
def test_or_type_operator_with_TypeVar(self):
@@ -1148,8 +1148,7 @@ class UnionTests(unittest.TestCase):
msg='Check for union reference leak.')
def test_instantiation(self):
- with self.assertRaises(TypeError):
- types.UnionType()
+ check_disallow_instantiation(self, types.UnionType)
self.assertIs(int, types.UnionType[int])
self.assertIs(int, types.UnionType[int, int])
self.assertEqual(int | str, types.UnionType[int, str])
@@ -1399,7 +1398,7 @@ class ClassCreationTests(unittest.TestCase):
def test_new_class_subclass(self):
C = types.new_class("C", (int,))
- self.assertTrue(issubclass(C, int))
+ self.assertIsSubclass(C, int)
def test_new_class_meta(self):
Meta = self.Meta
@@ -1444,7 +1443,7 @@ class ClassCreationTests(unittest.TestCase):
bases=(int,),
kwds=dict(metaclass=Meta, z=2),
exec_body=func)
- self.assertTrue(issubclass(C, int))
+ self.assertIsSubclass(C, int)
self.assertIsInstance(C, Meta)
self.assertEqual(C.x, 0)
self.assertEqual(C.y, 1)
@@ -2513,15 +2512,16 @@ class SubinterpreterTests(unittest.TestCase):
def setUpClass(cls):
global interpreters
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ModuleNotFoundError:
raise unittest.SkipTest('subinterpreters required')
- import test.support.interpreters.channels
+ from test.support import channels # noqa: F401
+ cls.create_channel = staticmethod(channels.create)
@cpython_only
@no_rerun('channels (and queues) might have a refleak; see gh-122199')
def test_static_types_inherited_slots(self):
- rch, sch = interpreters.channels.create()
+ rch, sch = self.create_channel()
script = textwrap.dedent("""
import test.support
@@ -2547,7 +2547,7 @@ class SubinterpreterTests(unittest.TestCase):
main_results = collate_results(raw)
interp = interpreters.create()
- interp.exec('from test.support import interpreters')
+ interp.exec('from concurrent import interpreters')
interp.prepare_main(sch=sch)
interp.exec(script)
raw = rch.recv_nowait()
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 8c55ba4623e..ef02e8202fc 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -46,11 +46,10 @@ import abc
import textwrap
import typing
import weakref
-import warnings
import types
from test.support import (
- captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper, run_code,
+ captured_stderr, cpython_only, requires_docstrings, import_helper, run_code,
EqualToForwardRef,
)
from test.typinganndata import (
@@ -6859,12 +6858,10 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(hints, {'value': Final})
def test_top_level_class_var(self):
- # https://bugs.python.org/issue45166
- with self.assertRaisesRegex(
- TypeError,
- r'typing.ClassVar\[int\] is not valid as type argument',
- ):
- get_type_hints(ann_module6)
+ # This is not meaningful but we don't raise for it.
+ # https://github.com/python/cpython/issues/133959
+ hints = get_type_hints(ann_module6)
+ self.assertEqual(hints, {'wrong': ClassVar[int]})
def test_get_type_hints_typeddict(self):
self.assertEqual(get_type_hints(TotalMovie), {'title': str, 'year': int})
@@ -6967,6 +6964,11 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': Callable[..., T]})
+ def test_special_forms_no_forward(self):
+ def f(x: ClassVar[int]):
+ pass
+ self.assertEqual(get_type_hints(f), {'x': ClassVar[int]})
+
def test_special_forms_forward(self):
class C:
@@ -6982,8 +6984,9 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(get_type_hints(C, globals())['b'], Final[int])
self.assertEqual(get_type_hints(C, globals())['x'], ClassVar)
self.assertEqual(get_type_hints(C, globals())['y'], Final)
- with self.assertRaises(TypeError):
- get_type_hints(CF, globals()),
+ lfi = get_type_hints(CF, globals())['b']
+ self.assertIs(get_origin(lfi), list)
+ self.assertEqual(get_args(lfi), (Final[int],))
def test_union_forward_recursion(self):
ValueList = List['Value']
@@ -7216,33 +7219,113 @@ class GetUtilitiesTestCase(TestCase):
class EvaluateForwardRefTests(BaseTestCase):
def test_evaluate_forward_ref(self):
int_ref = ForwardRef('int')
- missing = ForwardRef('missing')
+ self.assertIs(typing.evaluate_forward_ref(int_ref), int)
self.assertIs(
typing.evaluate_forward_ref(int_ref, type_params=()),
int,
)
self.assertIs(
+ typing.evaluate_forward_ref(int_ref, format=annotationlib.Format.VALUE),
+ int,
+ )
+ self.assertIs(
typing.evaluate_forward_ref(
- int_ref, type_params=(), format=annotationlib.Format.FORWARDREF,
+ int_ref, format=annotationlib.Format.FORWARDREF,
),
int,
)
+ self.assertEqual(
+ typing.evaluate_forward_ref(
+ int_ref, format=annotationlib.Format.STRING,
+ ),
+ 'int',
+ )
+
+ def test_evaluate_forward_ref_undefined(self):
+ missing = ForwardRef('missing')
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(missing)
self.assertIs(
typing.evaluate_forward_ref(
- missing, type_params=(), format=annotationlib.Format.FORWARDREF,
+ missing, format=annotationlib.Format.FORWARDREF,
),
missing,
)
self.assertEqual(
typing.evaluate_forward_ref(
- int_ref, type_params=(), format=annotationlib.Format.STRING,
+ missing, format=annotationlib.Format.STRING,
),
- 'int',
+ "missing",
)
- def test_evaluate_forward_ref_no_type_params(self):
- ref = ForwardRef('int')
- self.assertIs(typing.evaluate_forward_ref(ref), int)
+ def test_evaluate_forward_ref_nested(self):
+ ref = ForwardRef("int | list['str']")
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref),
+ int | list[str],
+ )
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF),
+ int | list[str],
+ )
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.STRING),
+ "int | list['str']",
+ )
+
+ why = ForwardRef('"\'str\'"')
+ self.assertIs(typing.evaluate_forward_ref(why), str)
+
+ def test_evaluate_forward_ref_none(self):
+ none_ref = ForwardRef('None')
+ self.assertIs(typing.evaluate_forward_ref(none_ref), None)
+
+ def test_globals(self):
+ A = "str"
+ ref = ForwardRef('list[A]')
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, globals={'A': A}),
+ list[str],
+ )
+
+ def test_owner(self):
+ ref = ForwardRef("A")
+
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+
+ # We default to the globals of `owner`,
+ # so it no longer raises `NameError`
+ self.assertIs(
+ typing.evaluate_forward_ref(ref, owner=Loop), A
+ )
+
+ def test_inherited_owner(self):
+ # owner passed to evaluate_forward_ref
+ ref = ForwardRef("list['A']")
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, owner=Loop),
+ list[A],
+ )
+
+ # owner set on the ForwardRef
+ ref = ForwardRef("list['A']", owner=Loop)
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref),
+ list[A],
+ )
+
+ def test_partial_evaluation(self):
+ ref = ForwardRef("list[A]")
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF),
+ list[EqualToForwardRef('A')],
+ )
class CollectionsAbcTests(BaseTestCase):
@@ -8080,78 +8163,13 @@ class NamedTupleTests(BaseTestCase):
self.assertIs(type(a), Group)
self.assertEqual(a, (1, [2]))
- def test_namedtuple_keyword_usage(self):
- with self.assertWarnsRegex(
- DeprecationWarning,
- "Creating NamedTuple classes using keyword arguments is deprecated"
- ):
- LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int)
-
- nick = LocalEmployee('Nick', 25)
- self.assertIsInstance(nick, tuple)
- self.assertEqual(nick.name, 'Nick')
- self.assertEqual(LocalEmployee.__name__, 'LocalEmployee')
- self.assertEqual(LocalEmployee._fields, ('name', 'age'))
- self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int))
-
- with self.assertRaisesRegex(
- TypeError,
- "Either list of fields or keywords can be provided to NamedTuple, not both"
- ):
- NamedTuple('Name', [('x', int)], y=str)
-
- with self.assertRaisesRegex(
- TypeError,
- "Either list of fields or keywords can be provided to NamedTuple, not both"
- ):
- NamedTuple('Name', [], y=str)
-
- with self.assertRaisesRegex(
- TypeError,
- (
- r"Cannot pass `None` as the 'fields' parameter "
- r"and also specify fields using keyword arguments"
- )
- ):
- NamedTuple('Name', None, x=int)
-
- def test_namedtuple_special_keyword_names(self):
- with self.assertWarnsRegex(
- DeprecationWarning,
- "Creating NamedTuple classes using keyword arguments is deprecated"
- ):
- NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list)
-
- self.assertEqual(NT.__name__, 'NT')
- self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields'))
- a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)])
- self.assertEqual(a.cls, str)
- self.assertEqual(a.self, 42)
- self.assertEqual(a.typename, 'foo')
- self.assertEqual(a.fields, [('bar', tuple)])
-
def test_empty_namedtuple(self):
- expected_warning = re.escape(
- "Failing to pass a value for the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. `NT1 = NamedTuple('NT1', [])`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- NT1 = NamedTuple('NT1')
+ with self.assertRaisesRegex(TypeError, "missing.*required.*argument"):
+ BAD = NamedTuple('BAD')
- expected_warning = re.escape(
- "Passing `None` as the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. `NT2 = NamedTuple('NT2', [])`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- NT2 = NamedTuple('NT2', None)
-
- NT3 = NamedTuple('NT2', [])
+ NT1 = NamedTuple('NT1', {})
+ NT2 = NamedTuple('NT2', ())
+ NT3 = NamedTuple('NT3', [])
class CNT(NamedTuple):
pass # empty body
@@ -8166,16 +8184,18 @@ class NamedTupleTests(BaseTestCase):
def test_namedtuple_errors(self):
with self.assertRaises(TypeError):
NamedTuple.__new__()
+ with self.assertRaisesRegex(TypeError, "object is not iterable"):
+ NamedTuple('Name', None)
with self.assertRaisesRegex(
TypeError,
- "missing 1 required positional argument"
+ "missing 2 required positional arguments"
):
NamedTuple()
with self.assertRaisesRegex(
TypeError,
- "takes from 1 to 2 positional arguments but 3 were given"
+ "takes 2 positional arguments but 3 were given"
):
NamedTuple('Emp', [('name', str)], None)
@@ -8187,10 +8207,22 @@ class NamedTupleTests(BaseTestCase):
with self.assertRaisesRegex(
TypeError,
- "missing 1 required positional argument: 'typename'"
+ "got some positional-only arguments passed as keyword arguments"
):
NamedTuple(typename='Emp', name=str, id=int)
+ with self.assertRaisesRegex(
+ TypeError,
+ "got an unexpected keyword argument"
+ ):
+ NamedTuple('Name', [('x', int)], y=str)
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "got an unexpected keyword argument"
+ ):
+ NamedTuple('Name', [], y=str)
+
def test_copy_and_pickle(self):
global Emp # pickle wants to reference the class by name
Emp = NamedTuple('Emp', [('name', str), ('cool', int)])
@@ -8538,6 +8570,36 @@ class TypedDictTests(BaseTestCase):
self.assertEqual(Child.__required_keys__, frozenset(['a']))
self.assertEqual(Child.__optional_keys__, frozenset())
+ def test_inheritance_pep563(self):
+ def _make_td(future, class_name, annos, base, extra_names=None):
+ lines = []
+ if future:
+ lines.append('from __future__ import annotations')
+ lines.append('from typing import TypedDict')
+ lines.append(f'class {class_name}({base}):')
+ for name, anno in annos.items():
+ lines.append(f' {name}: {anno}')
+ code = '\n'.join(lines)
+ ns = run_code(code, extra_names)
+ return ns[class_name]
+
+ for base_future in (True, False):
+ for child_future in (True, False):
+ with self.subTest(base_future=base_future, child_future=child_future):
+ base = _make_td(
+ base_future, "Base", {"base": "int"}, "TypedDict"
+ )
+ self.assertIsNotNone(base.__annotate__)
+ child = _make_td(
+ child_future, "Child", {"child": "int"}, "Base", {"Base": base}
+ )
+ base_anno = ForwardRef("int", module="builtins") if base_future else int
+ child_anno = ForwardRef("int", module="builtins") if child_future else int
+ self.assertEqual(base.__annotations__, {'base': base_anno})
+ self.assertEqual(
+ child.__annotations__, {'child': child_anno, 'base': base_anno}
+ )
+
def test_required_notrequired_keys(self):
self.assertEqual(NontotalMovie.__required_keys__,
frozenset({"title"}))
@@ -8904,39 +8966,27 @@ class TypedDictTests(BaseTestCase):
self.assertEqual(CallTypedDict.__orig_bases__, (TypedDict,))
def test_zero_fields_typeddicts(self):
- T1 = TypedDict("T1", {})
+ T1a = TypedDict("T1a", {})
+ T1b = TypedDict("T1b", [])
+ T1c = TypedDict("T1c", ())
class T2(TypedDict): pass
class T3[tvar](TypedDict): pass
S = TypeVar("S")
class T4(TypedDict, Generic[S]): pass
- expected_warning = re.escape(
- "Failing to pass a value for the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a TypedDict class with 0 fields "
- "using the functional syntax, "
- "pass an empty dictionary, e.g. `T5 = TypedDict('T5', {})`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- T5 = TypedDict('T5')
-
- expected_warning = re.escape(
- "Passing `None` as the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a TypedDict class with 0 fields "
- "using the functional syntax, "
- "pass an empty dictionary, e.g. `T6 = TypedDict('T6', {})`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- T6 = TypedDict('T6', None)
-
- for klass in T1, T2, T3, T4, T5, T6:
+ for klass in T1a, T1b, T1c, T2, T3, T4:
with self.subTest(klass=klass.__name__):
self.assertEqual(klass.__annotations__, {})
self.assertEqual(klass.__required_keys__, set())
self.assertEqual(klass.__optional_keys__, set())
self.assertIsInstance(klass(), dict)
+ def test_errors(self):
+ with self.assertRaisesRegex(TypeError, "missing 1 required.*argument"):
+ TypedDict('TD')
+ with self.assertRaisesRegex(TypeError, "object is not iterable"):
+ TypedDict('TD', None)
+
def test_readonly_inheritance(self):
class Base1(TypedDict):
a: ReadOnly[int]
@@ -10731,6 +10781,9 @@ class UnionGenericAliasTests(BaseTestCase):
with self.assertWarns(DeprecationWarning):
self.assertNotEqual(int, typing._UnionGenericAlias)
+ def test_hashable(self):
+ self.assertEqual(hash(typing._UnionGenericAlias), hash(Union))
+
def load_tests(loader, tests, pattern):
import doctest
diff --git a/Lib/test/test_unittest/test_case.py b/Lib/test/test_unittest/test_case.py
index a04af55f3fc..d66cab146af 100644
--- a/Lib/test/test_unittest/test_case.py
+++ b/Lib/test/test_unittest/test_case.py
@@ -1989,7 +1989,7 @@ test case
pass
self.assertIsNone(value)
- def testAssertStartswith(self):
+ def testAssertStartsWith(self):
self.assertStartsWith('ababahalamaha', 'ababa')
self.assertStartsWith('ababahalamaha', ('x', 'ababa', 'y'))
self.assertStartsWith(UserString('ababahalamaha'), 'ababa')
@@ -2034,7 +2034,7 @@ test case
self.assertStartsWith('ababahalamaha', 'amaha', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertNotStartswith(self):
+ def testAssertNotStartsWith(self):
self.assertNotStartsWith('ababahalamaha', 'amaha')
self.assertNotStartsWith('ababahalamaha', ('x', 'amaha', 'y'))
self.assertNotStartsWith(UserString('ababahalamaha'), 'amaha')
@@ -2079,7 +2079,7 @@ test case
self.assertNotStartsWith('ababahalamaha', 'ababa', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertEndswith(self):
+ def testAssertEndsWith(self):
self.assertEndsWith('ababahalamaha', 'amaha')
self.assertEndsWith('ababahalamaha', ('x', 'amaha', 'y'))
self.assertEndsWith(UserString('ababahalamaha'), 'amaha')
@@ -2124,7 +2124,7 @@ test case
self.assertEndsWith('ababahalamaha', 'ababa', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertNotEndswith(self):
+ def testAssertNotEndsWith(self):
self.assertNotEndsWith('ababahalamaha', 'ababa')
self.assertNotEndsWith('ababahalamaha', ('x', 'ababa', 'y'))
self.assertNotEndsWith(UserString('ababahalamaha'), 'ababa')
diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py
index 9ac4c52449c..3f44e617303 100644
--- a/Lib/test/test_unittest/test_result.py
+++ b/Lib/test/test_unittest/test_result.py
@@ -1282,14 +1282,22 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 1)
+ self.assertEqual(len(result.errors), 2)
description = 'tearDownModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[1]
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
def testBufferSetUpModule_DoModuleCleanups(self):
with captured_stdout() as stdout:
result = unittest.TestResult()
@@ -1313,22 +1321,34 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\nset up module\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 2)
+ self.assertEqual(len(result.errors), 3)
description = 'setUpModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ZeroDivisionError: division by zero', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ValueError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn('\nStdout:\nset up module\n', formatted_exc)
+
test_case, formatted_exc = result.errors[1]
self.assertIn(expected_out, formatted_exc)
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ZeroDivisionError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[2]
+ self.assertIn(expected_out, formatted_exc)
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ZeroDivisionError', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
def testBufferTearDownModule_DoModuleCleanups(self):
with captured_stdout() as stdout:
result = unittest.TestResult()
@@ -1355,21 +1375,32 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\ntear down module\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 2)
+ self.assertEqual(len(result.errors), 3)
description = 'tearDownModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ZeroDivisionError: division by zero', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ValueError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn('\nStdout:\ntear down module\n', formatted_exc)
+
test_case, formatted_exc = result.errors[1]
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ZeroDivisionError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[2]
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ZeroDivisionError', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_unittest/test_runner.py b/Lib/test/test_unittest/test_runner.py
index 4d3cfd60b8d..a47e2ebb59d 100644
--- a/Lib/test/test_unittest/test_runner.py
+++ b/Lib/test/test_unittest/test_runner.py
@@ -13,6 +13,7 @@ from test.test_unittest.support import (
LoggingResult,
ResultWithNoStartTestRunStopTestRun,
)
+from test.support.testcase import ExceptionIsLikeMixin
def resultFactory(*_):
@@ -604,7 +605,7 @@ class TestClassCleanup(unittest.TestCase):
@support.force_not_colorized_test_class
-class TestModuleCleanUp(unittest.TestCase):
+class TestModuleCleanUp(ExceptionIsLikeMixin, unittest.TestCase):
def test_add_and_do_ModuleCleanup(self):
module_cleanups = []
@@ -646,11 +647,50 @@ class TestModuleCleanUp(unittest.TestCase):
[(module_cleanup_good, (1, 2, 3),
dict(four='hello', five='goodbye')),
(module_cleanup_bad, (), {})])
- with self.assertRaises(CustomError) as e:
+ with self.assertRaises(Exception) as e:
unittest.case.doModuleCleanups()
- self.assertEqual(str(e.exception), 'CleanUpExc')
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed',
+ [CustomError('CleanUpExc')]))
self.assertEqual(unittest.case._module_cleanups, [])
+ def test_doModuleCleanup_with_multiple_errors_in_addModuleCleanup(self):
+ def module_cleanup_bad1():
+ raise TypeError('CleanUpExc1')
+
+ def module_cleanup_bad2():
+ raise ValueError('CleanUpExc2')
+
+ class Module:
+ unittest.addModuleCleanup(module_cleanup_bad1)
+ unittest.addModuleCleanup(module_cleanup_bad2)
+ with self.assertRaises(ExceptionGroup) as e:
+ unittest.case.doModuleCleanups()
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ]))
+
+ def test_doModuleCleanup_with_exception_group_in_addModuleCleanup(self):
+ def module_cleanup_bad():
+ raise ExceptionGroup('CleanUpExc', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ])
+
+ class Module:
+ unittest.addModuleCleanup(module_cleanup_bad)
+ with self.assertRaises(ExceptionGroup) as e:
+ unittest.case.doModuleCleanups()
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed', [
+ ExceptionGroup('CleanUpExc', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ]),
+ ]))
+
def test_addModuleCleanup_arg_errors(self):
cleanups = []
def cleanup(*args, **kwargs):
@@ -871,9 +911,11 @@ class TestModuleCleanUp(unittest.TestCase):
ordering = []
blowUp = True
suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest)
- with self.assertRaises(CustomError) as cm:
+ with self.assertRaises(Exception) as cm:
suite.debug()
- self.assertEqual(str(cm.exception), 'CleanUpExc')
+ self.assertExceptionIsLike(cm.exception,
+ ExceptionGroup('module cleanup failed',
+ [CustomError('CleanUpExc')]))
self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test',
'tearDownClass', 'tearDownModule', 'cleanup_exc'])
self.assertEqual(unittest.case._module_cleanups, [])
diff --git a/Lib/test/test_unittest/testmock/testhelpers.py b/Lib/test/test_unittest/testmock/testhelpers.py
index d1e48bde982..0e82c723ec3 100644
--- a/Lib/test/test_unittest/testmock/testhelpers.py
+++ b/Lib/test/test_unittest/testmock/testhelpers.py
@@ -1050,6 +1050,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithPostInit()),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithPostInit)
self.assertIsInstance(mock.a, int)
self.assertIsInstance(mock.b, int)
@@ -1072,6 +1073,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithDefault(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithDefault)
self.assertIsInstance(mock.a, int)
self.assertIsInstance(mock.b, int)
@@ -1087,6 +1089,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithMethod(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithMethod)
self.assertIsInstance(mock.a, int)
mock.b.assert_not_called()
@@ -1102,11 +1105,29 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithNonFields(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithNonFields)
with self.assertRaisesRegex(AttributeError, msg):
mock.a
with self.assertRaisesRegex(AttributeError, msg):
mock.b
+ def test_dataclass_special_attrs(self):
+ @dataclass
+ class Description:
+ name: str
+
+ for mock in [
+ create_autospec(Description, instance=True),
+ create_autospec(Description(1)),
+ ]:
+ with self.subTest(mock=mock):
+ self.assertIsInstance(mock, Description)
+ self.assertIs(mock.__class__, Description)
+ self.assertIsInstance(mock.__dataclass_fields__, MagicMock)
+ self.assertIsInstance(mock.__dataclass_params__, MagicMock)
+ self.assertIsInstance(mock.__match_args__, MagicMock)
+ self.assertIsInstance(mock.__hash__, MagicMock)
+
class TestCallList(unittest.TestCase):
def test_args_list_contains_call_list(self):
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index 839326f6436..d4db5e60af7 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -202,6 +202,15 @@ class UnparseTestCase(ASTTestCase):
self.check_ast_roundtrip('f" something { my_dict["key"] } something else "')
self.check_ast_roundtrip('f"{f"{f"{f"{f"{f"{1+1}"}"}"}"}"}"')
+ def test_tstrings(self):
+ self.check_ast_roundtrip("t'foo'")
+ self.check_ast_roundtrip("t'foo {bar}'")
+ self.check_ast_roundtrip("t'foo {bar!s:.2f}'")
+ self.check_ast_roundtrip("t'foo {bar}' f'{bar}'")
+ self.check_ast_roundtrip("f'{bar}' t'foo {bar}'")
+ self.check_ast_roundtrip("t'foo {bar}' fr'\\hello {bar}'")
+ self.check_ast_roundtrip("t'foo {bar}' u'bar'")
+
def test_strings(self):
self.check_ast_roundtrip("u'foo'")
self.check_ast_roundtrip("r'foo'")
@@ -808,6 +817,15 @@ class CosmeticTestCase(ASTTestCase):
self.check_ast_roundtrip("def f[T: int = int, **P = int, *Ts = *int]():\n pass")
self.check_ast_roundtrip("class C[T: int = int, **P = int, *Ts = *int]():\n pass")
+ def test_tstr(self):
+ self.check_ast_roundtrip("t'{a + b}'")
+ self.check_ast_roundtrip("t'{a + b:x}'")
+ self.check_ast_roundtrip("t'{a + b!s}'")
+ self.check_ast_roundtrip("t'{ {a}}'")
+ self.check_ast_roundtrip("t'{ {a}=}'")
+ self.check_ast_roundtrip("t'{{a}}'")
+ self.check_ast_roundtrip("t''")
+
class ManualASTCreationTestCase(unittest.TestCase):
"""Test that AST nodes created without a type_params field unparse correctly."""
@@ -918,7 +936,7 @@ class DirectoryTestCase(ASTTestCase):
run_always_files = {"test_grammar.py", "test_syntax.py", "test_compile.py",
"test_ast.py", "test_asdl_parser.py", "test_fstring.py",
"test_patma.py", "test_type_alias.py", "test_type_params.py",
- "test_tokenize.py"}
+ "test_tokenize.py", "test_tstring.py"}
_files_to_test = None
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 90de828cc71..1d889ae7cf4 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -109,7 +109,7 @@ class urlopen_FileTests(unittest.TestCase):
finally:
f.close()
self.pathname = os_helper.TESTFN
- self.quoted_pathname = urllib.parse.quote(self.pathname)
+ self.quoted_pathname = urllib.parse.quote(os.fsencode(self.pathname))
self.returned_obj = urllib.request.urlopen("file:%s" % self.quoted_pathname)
def tearDown(self):
@@ -1551,7 +1551,8 @@ class Pathname_Tests(unittest.TestCase):
urllib.request.url2pathname(url, require_scheme=True),
expected_path)
- error_subtests = [
+ def test_url2pathname_require_scheme_errors(self):
+ subtests = [
'',
':',
'foo',
@@ -1561,13 +1562,21 @@ class Pathname_Tests(unittest.TestCase):
'data:file:foo',
'data:file://foo',
]
- for url in error_subtests:
+ for url in subtests:
with self.subTest(url=url):
self.assertRaises(
urllib.error.URLError,
urllib.request.url2pathname,
url, require_scheme=True)
+ @unittest.skipIf(support.is_emscripten, "Fixed by https://github.com/emscripten-core/emscripten/pull/24593")
+ def test_url2pathname_resolve_host(self):
+ fn = urllib.request.url2pathname
+ sep = os.path.sep
+ self.assertEqual(fn('//127.0.0.1/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+
@unittest.skipUnless(sys.platform == 'win32',
'test specific to Windows pathnames.')
def test_url2pathname_win(self):
@@ -1598,6 +1607,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertEqual(fn('//server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('////server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('/////server/path/to/file'), '\\\\server\\path\\to\\file')
+ self.assertEqual(fn('//127.0.0.1/path/to/file'), '\\\\127.0.0.1\\path\\to\\file')
# Localhost paths
self.assertEqual(fn('//localhost/C:/path/to/file'), 'C:\\path\\to\\file')
self.assertEqual(fn('//localhost/C|/path/to/file'), 'C:\\path\\to\\file')
@@ -1622,8 +1632,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertRaises(urllib.error.URLError, fn, '//:80/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//:/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//c:80/foo/bar')
- self.assertEqual(fn('//127.0.0.1/foo/bar'), '/foo/bar')
- self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), '/foo/bar')
+ self.assertRaises(urllib.error.URLError, fn, '//127.0.0.1/foo/bar')
@unittest.skipUnless(os_helper.FS_NONASCII, 'need os_helper.FS_NONASCII')
def test_url2pathname_nonascii(self):
diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py
index aabc360289a..b2bde5a9b1d 100644
--- a/Lib/test/test_urlparse.py
+++ b/Lib/test/test_urlparse.py
@@ -2,6 +2,7 @@ import sys
import unicodedata
import unittest
import urllib.parse
+from test import support
RFC1808_BASE = "http://a/b/c/d;p?q#f"
RFC2396_BASE = "http://a/b/c/d;p?q"
@@ -156,27 +157,25 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(result3.hostname, result.hostname)
self.assertEqual(result3.port, result.port)
- def test_qsl(self):
- for orig, expect in parse_qsl_test_cases:
- result = urllib.parse.parse_qsl(orig, keep_blank_values=True)
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- expect_without_blanks = [v for v in expect if len(v[1])]
- result = urllib.parse.parse_qsl(orig, keep_blank_values=False)
- self.assertEqual(result, expect_without_blanks,
- "Error parsing %r" % orig)
-
- def test_qs(self):
- for orig, expect in parse_qs_test_cases:
- result = urllib.parse.parse_qs(orig, keep_blank_values=True)
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- expect_without_blanks = {v: expect[v]
- for v in expect if len(expect[v][0])}
- result = urllib.parse.parse_qs(orig, keep_blank_values=False)
- self.assertEqual(result, expect_without_blanks,
- "Error parsing %r" % orig)
-
- def test_roundtrips(self):
- str_cases = [
+ @support.subTests('orig,expect', parse_qsl_test_cases)
+ def test_qsl(self, orig, expect):
+ result = urllib.parse.parse_qsl(orig, keep_blank_values=True)
+ self.assertEqual(result, expect)
+ expect_without_blanks = [v for v in expect if len(v[1])]
+ result = urllib.parse.parse_qsl(orig, keep_blank_values=False)
+ self.assertEqual(result, expect_without_blanks)
+
+ @support.subTests('orig,expect', parse_qs_test_cases)
+ def test_qs(self, orig, expect):
+ result = urllib.parse.parse_qs(orig, keep_blank_values=True)
+ self.assertEqual(result, expect)
+ expect_without_blanks = {v: expect[v]
+ for v in expect if len(expect[v][0])}
+ result = urllib.parse.parse_qs(orig, keep_blank_values=False)
+ self.assertEqual(result, expect_without_blanks)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,parsed,split', [
('path/to/file',
('', '', 'path/to/file', '', '', ''),
('', '', 'path/to/file', '', '')),
@@ -263,23 +262,21 @@ class UrlParseTestCase(unittest.TestCase):
('sch_me:path/to/file',
('', '', 'sch_me:path/to/file', '', '', ''),
('', '', 'sch_me:path/to/file', '', '')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- tuple(x.encode('ascii') for x in t[1]),
- tuple(x.encode('ascii') for x in t[2]))
- bytes_cases = [_encode(x) for x in str_cases]
- str_cases += [
('schème:path/to/file',
('', '', 'schème:path/to/file', '', '', ''),
('', '', 'schème:path/to/file', '', '')),
- ]
- for url, parsed, split in str_cases + bytes_cases:
- with self.subTest(url):
- self.checkRoundtrips(url, parsed, split)
-
- def test_roundtrips_normalization(self):
- str_cases = [
+ ])
+ def test_roundtrips(self, bytes, url, parsed, split):
+ if bytes:
+ if not url.isascii():
+ self.skipTest('non-ASCII bytes')
+ url = str_encode(url)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ self.checkRoundtrips(url, parsed, split)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,url2,parsed,split', [
('///path/to/file',
'/path/to/file',
('', '', '/path/to/file', '', '', ''),
@@ -300,22 +297,18 @@ class UrlParseTestCase(unittest.TestCase):
'https:///tmp/junk.txt',
('https', '', '/tmp/junk.txt', '', '', ''),
('https', '', '/tmp/junk.txt', '', '')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- t[1].encode('ascii'),
- tuple(x.encode('ascii') for x in t[2]),
- tuple(x.encode('ascii') for x in t[3]))
- bytes_cases = [_encode(x) for x in str_cases]
- for url, url2, parsed, split in str_cases + bytes_cases:
- with self.subTest(url):
- self.checkRoundtrips(url, parsed, split, url2)
-
- def test_http_roundtrips(self):
- # urllib.parse.urlsplit treats 'http:' as an optimized special case,
- # so we test both 'http:' and 'https:' in all the following.
- # Three cheers for white box knowledge!
- str_cases = [
+ ])
+ def test_roundtrips_normalization(self, bytes, url, url2, parsed, split):
+ if bytes:
+ url = str_encode(url)
+ url2 = str_encode(url2)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ self.checkRoundtrips(url, parsed, split, url2)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('scheme', ('http', 'https'))
+ @support.subTests('url,parsed,split', [
('://www.python.org',
('www.python.org', '', '', '', ''),
('www.python.org', '', '', '')),
@@ -331,23 +324,20 @@ class UrlParseTestCase(unittest.TestCase):
('://a/b/c/d;p?q#f',
('a', '/b/c/d', 'p', 'q', 'f'),
('a', '/b/c/d;p', 'q', 'f')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- tuple(x.encode('ascii') for x in t[1]),
- tuple(x.encode('ascii') for x in t[2]))
- bytes_cases = [_encode(x) for x in str_cases]
- str_schemes = ('http', 'https')
- bytes_schemes = (b'http', b'https')
- str_tests = str_schemes, str_cases
- bytes_tests = bytes_schemes, bytes_cases
- for schemes, test_cases in (str_tests, bytes_tests):
- for scheme in schemes:
- for url, parsed, split in test_cases:
- url = scheme + url
- parsed = (scheme,) + parsed
- split = (scheme,) + split
- self.checkRoundtrips(url, parsed, split)
+ ])
+ def test_http_roundtrips(self, bytes, scheme, url, parsed, split):
+ # urllib.parse.urlsplit treats 'http:' as an optimized special case,
+ # so we test both 'http:' and 'https:' in all the following.
+ # Three cheers for white box knowledge!
+ if bytes:
+ scheme = str_encode(scheme)
+ url = str_encode(url)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ url = scheme + url
+ parsed = (scheme,) + parsed
+ split = (scheme,) + split
+ self.checkRoundtrips(url, parsed, split)
def checkJoin(self, base, relurl, expected, *, relroundtrip=True):
with self.subTest(base=base, relurl=relurl):
@@ -363,12 +353,13 @@ class UrlParseTestCase(unittest.TestCase):
relurlb = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb))
self.assertEqual(urllib.parse.urljoin(baseb, relurlb), expectedb)
- def test_unparse_parse(self):
- str_cases = ['Python', './Python','x-newscheme://foo.com/stuff','x://y','x:/y','x:/','/',]
- bytes_cases = [x.encode('ascii') for x in str_cases]
- for u in str_cases + bytes_cases:
- self.assertEqual(urllib.parse.urlunsplit(urllib.parse.urlsplit(u)), u)
- self.assertEqual(urllib.parse.urlunparse(urllib.parse.urlparse(u)), u)
+ @support.subTests('bytes', (False, True))
+ @support.subTests('u', ['Python', './Python','x-newscheme://foo.com/stuff','x://y','x:/y','x:/','/',])
+ def test_unparse_parse(self, bytes, u):
+ if bytes:
+ u = str_encode(u)
+ self.assertEqual(urllib.parse.urlunsplit(urllib.parse.urlsplit(u)), u)
+ self.assertEqual(urllib.parse.urlunparse(urllib.parse.urlparse(u)), u)
def test_RFC1808(self):
# "normal" cases from RFC 1808:
@@ -695,8 +686,8 @@ class UrlParseTestCase(unittest.TestCase):
self.checkJoin('///b/c', '///w', '///w')
self.checkJoin('///b/c', 'w', '///b/w')
- def test_RFC2732(self):
- str_cases = [
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,hostname,port', [
('http://Test.python.org:5432/foo/', 'test.python.org', 5432),
('http://12.34.56.78:5432/foo/', '12.34.56.78', 5432),
('http://[::1]:5432/foo/', '::1', 5432),
@@ -727,26 +718,28 @@ class UrlParseTestCase(unittest.TestCase):
('http://[::12.34.56.78]:/foo/', '::12.34.56.78', None),
('http://[::ffff:12.34.56.78]:/foo/',
'::ffff:12.34.56.78', None),
- ]
- def _encode(t):
- return t[0].encode('ascii'), t[1].encode('ascii'), t[2]
- bytes_cases = [_encode(x) for x in str_cases]
- for url, hostname, port in str_cases + bytes_cases:
- urlparsed = urllib.parse.urlparse(url)
- self.assertEqual((urlparsed.hostname, urlparsed.port) , (hostname, port))
-
- str_cases = [
+ ])
+ def test_RFC2732(self, bytes, url, hostname, port):
+ if bytes:
+ url = str_encode(url)
+ hostname = str_encode(hostname)
+ urlparsed = urllib.parse.urlparse(url)
+ self.assertEqual((urlparsed.hostname, urlparsed.port), (hostname, port))
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('invalid_url', [
'http://::12.34.56.78]/',
'http://[::1/foo/',
'ftp://[::1/foo/bad]/bad',
'http://[::1/foo/bad]/bad',
- 'http://[::ffff:12.34.56.78']
- bytes_cases = [x.encode('ascii') for x in str_cases]
- for invalid_url in str_cases + bytes_cases:
- self.assertRaises(ValueError, urllib.parse.urlparse, invalid_url)
-
- def test_urldefrag(self):
- str_cases = [
+ 'http://[::ffff:12.34.56.78'])
+ def test_RFC2732_invalid(self, bytes, invalid_url):
+ if bytes:
+ invalid_url = str_encode(invalid_url)
+ self.assertRaises(ValueError, urllib.parse.urlparse, invalid_url)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,defrag,frag', [
('http://python.org#frag', 'http://python.org', 'frag'),
('http://python.org', 'http://python.org', ''),
('http://python.org/#frag', 'http://python.org/', 'frag'),
@@ -770,18 +763,18 @@ class UrlParseTestCase(unittest.TestCase):
('http:?q#f', 'http:?q', 'f'),
('//a/b/c;p?q#f', '//a/b/c;p?q', 'f'),
('://a/b/c;p?q#f', '://a/b/c;p?q', 'f'),
- ]
- def _encode(t):
- return type(t)(x.encode('ascii') for x in t)
- bytes_cases = [_encode(x) for x in str_cases]
- for url, defrag, frag in str_cases + bytes_cases:
- with self.subTest(url):
- result = urllib.parse.urldefrag(url)
- hash = '#' if isinstance(url, str) else b'#'
- self.assertEqual(result.geturl(), url.rstrip(hash))
- self.assertEqual(result, (defrag, frag))
- self.assertEqual(result.url, defrag)
- self.assertEqual(result.fragment, frag)
+ ])
+ def test_urldefrag(self, bytes, url, defrag, frag):
+ if bytes:
+ url = str_encode(url)
+ defrag = str_encode(defrag)
+ frag = str_encode(frag)
+ result = urllib.parse.urldefrag(url)
+ hash = '#' if isinstance(url, str) else b'#'
+ self.assertEqual(result.geturl(), url.rstrip(hash))
+ self.assertEqual(result, (defrag, frag))
+ self.assertEqual(result.url, defrag)
+ self.assertEqual(result.fragment, frag)
def test_urlsplit_scoped_IPv6(self):
p = urllib.parse.urlsplit('http://[FE80::822a:a8ff:fe49:470c%tESt]:1234')
@@ -981,42 +974,35 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(p.scheme, "https")
self.assertEqual(p.geturl(), "https://www.python.org/")
- def test_attributes_bad_port(self):
+ @support.subTests('bytes', (False, True))
+ @support.subTests('parse', (urllib.parse.urlsplit, urllib.parse.urlparse))
+ @support.subTests('port', ("foo", "1.5", "-1", "0x10", "-0", "1_1", " 1", "1 ", "६"))
+ def test_attributes_bad_port(self, bytes, parse, port):
"""Check handling of invalid ports."""
- for bytes in (False, True):
- for parse in (urllib.parse.urlsplit, urllib.parse.urlparse):
- for port in ("foo", "1.5", "-1", "0x10", "-0", "1_1", " 1", "1 ", "६"):
- with self.subTest(bytes=bytes, parse=parse, port=port):
- netloc = "www.example.net:" + port
- url = "http://" + netloc + "/"
- if bytes:
- if netloc.isascii() and port.isascii():
- netloc = netloc.encode("ascii")
- url = url.encode("ascii")
- else:
- continue
- p = parse(url)
- self.assertEqual(p.netloc, netloc)
- with self.assertRaises(ValueError):
- p.port
+ netloc = "www.example.net:" + port
+ url = "http://" + netloc + "/"
+ if bytes:
+ if not (netloc.isascii() and port.isascii()):
+ self.skipTest('non-ASCII bytes')
+ netloc = str_encode(netloc)
+ url = str_encode(url)
+ p = parse(url)
+ self.assertEqual(p.netloc, netloc)
+ with self.assertRaises(ValueError):
+ p.port
- def test_attributes_bad_scheme(self):
+ @support.subTests('bytes', (False, True))
+ @support.subTests('parse', (urllib.parse.urlsplit, urllib.parse.urlparse))
+ @support.subTests('scheme', (".", "+", "-", "0", "http&", "६http"))
+ def test_attributes_bad_scheme(self, bytes, parse, scheme):
"""Check handling of invalid schemes."""
- for bytes in (False, True):
- for parse in (urllib.parse.urlsplit, urllib.parse.urlparse):
- for scheme in (".", "+", "-", "0", "http&", "६http"):
- with self.subTest(bytes=bytes, parse=parse, scheme=scheme):
- url = scheme + "://www.example.net"
- if bytes:
- if url.isascii():
- url = url.encode("ascii")
- else:
- continue
- p = parse(url)
- if bytes:
- self.assertEqual(p.scheme, b"")
- else:
- self.assertEqual(p.scheme, "")
+ url = scheme + "://www.example.net"
+ if bytes:
+ if not url.isascii():
+ self.skipTest('non-ASCII bytes')
+ url = url.encode("ascii")
+ p = parse(url)
+ self.assertEqual(p.scheme, b"" if bytes else "")
def test_attributes_without_netloc(self):
# This example is straight from RFC 3261. It looks like it
@@ -1128,24 +1114,21 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query"),
(b'x-newscheme', b'foo.com', b'/stuff', b'', b'query', b''))
- def test_default_scheme(self):
+ @support.subTests('func', (urllib.parse.urlparse, urllib.parse.urlsplit))
+ def test_default_scheme(self, func):
# Exercise the scheme parameter of urlparse() and urlsplit()
- for func in (urllib.parse.urlparse, urllib.parse.urlsplit):
- with self.subTest(function=func):
- result = func("http://example.net/", "ftp")
- self.assertEqual(result.scheme, "http")
- result = func(b"http://example.net/", b"ftp")
- self.assertEqual(result.scheme, b"http")
- self.assertEqual(func("path", "ftp").scheme, "ftp")
- self.assertEqual(func("path", scheme="ftp").scheme, "ftp")
- self.assertEqual(func(b"path", scheme=b"ftp").scheme, b"ftp")
- self.assertEqual(func("path").scheme, "")
- self.assertEqual(func(b"path").scheme, b"")
- self.assertEqual(func(b"path", "").scheme, b"")
-
- def test_parse_fragments(self):
- # Exercise the allow_fragments parameter of urlparse() and urlsplit()
- tests = (
+ result = func("http://example.net/", "ftp")
+ self.assertEqual(result.scheme, "http")
+ result = func(b"http://example.net/", b"ftp")
+ self.assertEqual(result.scheme, b"http")
+ self.assertEqual(func("path", "ftp").scheme, "ftp")
+ self.assertEqual(func("path", scheme="ftp").scheme, "ftp")
+ self.assertEqual(func(b"path", scheme=b"ftp").scheme, b"ftp")
+ self.assertEqual(func("path").scheme, "")
+ self.assertEqual(func(b"path").scheme, b"")
+ self.assertEqual(func(b"path", "").scheme, b"")
+
+ @support.subTests('url,attr,expected_frag', (
("http:#frag", "path", "frag"),
("//example.net#frag", "path", "frag"),
("index.html#frag", "path", "frag"),
@@ -1156,24 +1139,24 @@ class UrlParseTestCase(unittest.TestCase):
("//abc#@frag", "path", "@frag"),
("//abc:80#@frag", "path", "@frag"),
("//abc#@frag:80", "path", "@frag:80"),
- )
- for url, attr, expected_frag in tests:
- for func in (urllib.parse.urlparse, urllib.parse.urlsplit):
- if attr == "params" and func is urllib.parse.urlsplit:
- attr = "path"
- with self.subTest(url=url, function=func):
- result = func(url, allow_fragments=False)
- self.assertEqual(result.fragment, "")
- self.assertEndsWith(getattr(result, attr),
- "#" + expected_frag)
- self.assertEqual(func(url, "", False).fragment, "")
-
- result = func(url, allow_fragments=True)
- self.assertEqual(result.fragment, expected_frag)
- self.assertNotEndsWith(getattr(result, attr), expected_frag)
- self.assertEqual(func(url, "", True).fragment,
- expected_frag)
- self.assertEqual(func(url).fragment, expected_frag)
+ ))
+ @support.subTests('func', (urllib.parse.urlparse, urllib.parse.urlsplit))
+ def test_parse_fragments(self, url, attr, expected_frag, func):
+ # Exercise the allow_fragments parameter of urlparse() and urlsplit()
+ if attr == "params" and func is urllib.parse.urlsplit:
+ attr = "path"
+ result = func(url, allow_fragments=False)
+ self.assertEqual(result.fragment, "")
+ self.assertEndsWith(getattr(result, attr),
+ "#" + expected_frag)
+ self.assertEqual(func(url, "", False).fragment, "")
+
+ result = func(url, allow_fragments=True)
+ self.assertEqual(result.fragment, expected_frag)
+ self.assertNotEndsWith(getattr(result, attr), expected_frag)
+ self.assertEqual(func(url, "", True).fragment,
+ expected_frag)
+ self.assertEqual(func(url).fragment, expected_frag)
def test_mixed_types_rejected(self):
# Several functions that process either strings or ASCII encoded bytes
@@ -1199,7 +1182,14 @@ class UrlParseTestCase(unittest.TestCase):
with self.assertRaisesRegex(TypeError, "Cannot mix str"):
urllib.parse.urljoin(b"http://python.org", "http://python.org")
- def _check_result_type(self, str_type):
+ @support.subTests('result_type', [
+ urllib.parse.DefragResult,
+ urllib.parse.SplitResult,
+ urllib.parse.ParseResult,
+ ])
+ def test_result_pairs(self, result_type):
+ # Check encoding and decoding between result pairs
+ str_type = result_type
num_args = len(str_type._fields)
bytes_type = str_type._encoded_counterpart
self.assertIs(bytes_type._decoded_counterpart, str_type)
@@ -1224,16 +1214,6 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(str_result.encode(encoding, errors), bytes_args)
self.assertEqual(str_result.encode(encoding, errors), bytes_result)
- def test_result_pairs(self):
- # Check encoding and decoding between result pairs
- result_types = [
- urllib.parse.DefragResult,
- urllib.parse.SplitResult,
- urllib.parse.ParseResult,
- ]
- for result_type in result_types:
- self._check_result_type(result_type)
-
def test_parse_qs_encoding(self):
result = urllib.parse.parse_qs("key=\u0141%E9", encoding="latin-1")
self.assertEqual(result, {'key': ['\u0141\xE9']})
@@ -1265,8 +1245,7 @@ class UrlParseTestCase(unittest.TestCase):
urllib.parse.parse_qsl('&'.join(['a=a']*11), max_num_fields=10)
urllib.parse.parse_qsl('&'.join(['a=a']*10), max_num_fields=10)
- def test_parse_qs_separator(self):
- parse_qs_semicolon_cases = [
+ @support.subTests('orig,expect', [
(";", {}),
(";;", {}),
(";a=b", {'a': ['b']}),
@@ -1277,17 +1256,14 @@ class UrlParseTestCase(unittest.TestCase):
(b";a=b", {b'a': [b'b']}),
(b"a=a+b;b=b+c", {b'a': [b'a b'], b'b': [b'b c']}),
(b"a=1;a=2", {b'a': [b'1', b'2']}),
- ]
- for orig, expect in parse_qs_semicolon_cases:
- with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
- result = urllib.parse.parse_qs(orig, separator=';')
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- result_bytes = urllib.parse.parse_qs(orig, separator=b';')
- self.assertEqual(result_bytes, expect, "Error parsing %r" % orig)
-
-
- def test_parse_qsl_separator(self):
- parse_qsl_semicolon_cases = [
+ ])
+ def test_parse_qs_separator(self, orig, expect):
+ result = urllib.parse.parse_qs(orig, separator=';')
+ self.assertEqual(result, expect)
+ result_bytes = urllib.parse.parse_qs(orig, separator=b';')
+ self.assertEqual(result_bytes, expect)
+
+ @support.subTests('orig,expect', [
(";", []),
(";;", []),
(";a=b", [('a', 'b')]),
@@ -1298,13 +1274,12 @@ class UrlParseTestCase(unittest.TestCase):
(b";a=b", [(b'a', b'b')]),
(b"a=a+b;b=b+c", [(b'a', b'a b'), (b'b', b'b c')]),
(b"a=1;a=2", [(b'a', b'1'), (b'a', b'2')]),
- ]
- for orig, expect in parse_qsl_semicolon_cases:
- with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
- result = urllib.parse.parse_qsl(orig, separator=';')
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- result_bytes = urllib.parse.parse_qsl(orig, separator=b';')
- self.assertEqual(result_bytes, expect, "Error parsing %r" % orig)
+ ])
+ def test_parse_qsl_separator(self, orig, expect):
+ result = urllib.parse.parse_qsl(orig, separator=';')
+ self.assertEqual(result, expect)
+ result_bytes = urllib.parse.parse_qsl(orig, separator=b';')
+ self.assertEqual(result_bytes, expect)
def test_parse_qsl_bytes(self):
self.assertEqual(urllib.parse.parse_qsl(b'a=b'), [(b'a', b'b')])
@@ -1695,11 +1670,12 @@ class Utility_Tests(unittest.TestCase):
self.assertRaises(UnicodeError, urllib.parse._to_bytes,
'http://www.python.org/medi\u00e6val')
- def test_unwrap(self):
- for wrapped_url in ('<URL:scheme://host/path>', '<scheme://host/path>',
- 'URL:scheme://host/path', 'scheme://host/path'):
- url = urllib.parse.unwrap(wrapped_url)
- self.assertEqual(url, 'scheme://host/path')
+ @support.subTests('wrapped_url',
+ ('<URL:scheme://host/path>', '<scheme://host/path>',
+ 'URL:scheme://host/path', 'scheme://host/path'))
+ def test_unwrap(self, wrapped_url):
+ url = urllib.parse.unwrap(wrapped_url)
+ self.assertEqual(url, 'scheme://host/path')
class DeprecationTest(unittest.TestCase):
@@ -1780,5 +1756,11 @@ class DeprecationTest(unittest.TestCase):
'urllib.parse.to_bytes() is deprecated as of 3.8')
+def str_encode(s):
+ return s.encode('ascii')
+
+def tuple_encode(t):
+ return tuple(str_encode(x) for x in t)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py
index ace84ef564d..75de9ea252d 100644
--- a/Lib/test/test_userdict.py
+++ b/Lib/test/test_userdict.py
@@ -166,7 +166,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
def test_missing(self):
# Make sure UserDict doesn't have a __missing__ method
- self.assertEqual(hasattr(collections.UserDict, "__missing__"), False)
+ self.assertNotHasAttr(collections.UserDict, "__missing__")
# Test several cases:
# (D) subclass defines __missing__ method returning a value
# (E) subclass defines __missing__ method raising RuntimeError
diff --git a/Lib/test/test_uuid.py b/Lib/test/test_uuid.py
index 958be5408ce..7ddacf07a2c 100755
--- a/Lib/test/test_uuid.py
+++ b/Lib/test/test_uuid.py
@@ -14,6 +14,7 @@ from unittest import mock
from test import support
from test.support import import_helper
+from test.support.script_helper import assert_python_ok
py_uuid = import_helper.import_fresh_module('uuid', blocked=['_uuid'])
c_uuid = import_helper.import_fresh_module('uuid', fresh=['_uuid'])
@@ -1217,10 +1218,37 @@ class BaseTestUUID:
class TestUUIDWithoutExtModule(BaseTestUUID, unittest.TestCase):
uuid = py_uuid
+
@unittest.skipUnless(c_uuid, 'requires the C _uuid module')
class TestUUIDWithExtModule(BaseTestUUID, unittest.TestCase):
uuid = c_uuid
+ def check_has_stable_libuuid_extractable_node(self):
+ if not self.uuid._has_stable_extractable_node:
+ self.skipTest("libuuid cannot deduce MAC address")
+
+ @unittest.skipUnless(os.name == 'posix', 'POSIX only')
+ def test_unix_getnode_from_libuuid(self):
+ self.check_has_stable_libuuid_extractable_node()
+ script = 'import uuid; print(uuid._unix_getnode())'
+ _, n_a, _ = assert_python_ok('-c', script)
+ _, n_b, _ = assert_python_ok('-c', script)
+ n_a, n_b = n_a.decode().strip(), n_b.decode().strip()
+ self.assertTrue(n_a.isdigit())
+ self.assertTrue(n_b.isdigit())
+ self.assertEqual(n_a, n_b)
+
+ @unittest.skipUnless(os.name == 'nt', 'Windows only')
+ def test_windows_getnode_from_libuuid(self):
+ self.check_has_stable_libuuid_extractable_node()
+ script = 'import uuid; print(uuid._windll_getnode())'
+ _, n_a, _ = assert_python_ok('-c', script)
+ _, n_b, _ = assert_python_ok('-c', script)
+ n_a, n_b = n_a.decode().strip(), n_b.decode().strip()
+ self.assertTrue(n_a.isdigit())
+ self.assertTrue(n_b.isdigit())
+ self.assertEqual(n_a, n_b)
+
class BaseTestInternals:
_uuid = py_uuid
diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py
index adc86a49b06..d62f3fba2d1 100644
--- a/Lib/test/test_venv.py
+++ b/Lib/test/test_venv.py
@@ -774,7 +774,7 @@ class BasicTest(BaseTest):
with open(script_path, 'rb') as script:
for i, line in enumerate(script, 1):
error_message = f"CR LF found in line {i}"
- self.assertFalse(line.endswith(b'\r\n'), error_message)
+ self.assertNotEndsWith(line, b'\r\n', error_message)
@requireVenvCreate
def test_scm_ignore_files_git(self):
@@ -978,7 +978,7 @@ class EnsurePipTest(BaseTest):
self.assertEqual(err, "")
out = out.decode("latin-1") # Force to text, prevent decoding errors
expected_version = "pip {}".format(ensurepip.version())
- self.assertEqual(out[:len(expected_version)], expected_version)
+ self.assertStartsWith(out, expected_version)
env_dir = os.fsencode(self.env_dir).decode("latin-1")
self.assertIn(env_dir, out)
@@ -1008,7 +1008,7 @@ class EnsurePipTest(BaseTest):
err, flags=re.MULTILINE)
# Ignore warning about missing optional module:
try:
- import ssl
+ import ssl # noqa: F401
except ImportError:
err = re.sub(
"^WARNING: Disabling truststore since ssl support is missing$",
diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py
index 03126cebe03..5c3b1250ceb 100644
--- a/Lib/test/test_warnings/__init__.py
+++ b/Lib/test/test_warnings/__init__.py
@@ -102,7 +102,7 @@ class PublicAPITests(BaseTest):
"""
def test_module_all_attribute(self):
- self.assertTrue(hasattr(self.module, '__all__'))
+ self.assertHasAttr(self.module, '__all__')
target_api = ["warn", "warn_explicit", "showwarning",
"formatwarning", "filterwarnings", "simplefilter",
"resetwarnings", "catch_warnings", "deprecated"]
@@ -735,7 +735,7 @@ class CWarnTests(WarnTests, unittest.TestCase):
# test.import_helper.import_fresh_module utility function
def test_accelerated(self):
self.assertIsNot(original_warnings, self.module)
- self.assertFalse(hasattr(self.module.warn, '__code__'))
+ self.assertNotHasAttr(self.module.warn, '__code__')
class PyWarnTests(WarnTests, unittest.TestCase):
module = py_warnings
@@ -744,7 +744,7 @@ class PyWarnTests(WarnTests, unittest.TestCase):
# test.import_helper.import_fresh_module utility function
def test_pure_python(self):
self.assertIsNot(original_warnings, self.module)
- self.assertTrue(hasattr(self.module.warn, '__code__'))
+ self.assertHasAttr(self.module.warn, '__code__')
class WCmdLineTests(BaseTest):
@@ -1528,12 +1528,12 @@ a=A()
# (_warnings will try to import it)
code = "f = open(%a)" % __file__
rc, out, err = assert_python_ok("-Wd", "-c", code)
- self.assertTrue(err.startswith(expected), ascii(err))
+ self.assertStartsWith(err, expected)
# import the warnings module
code = "import warnings; f = open(%a)" % __file__
rc, out, err = assert_python_ok("-Wd", "-c", code)
- self.assertTrue(err.startswith(expected), ascii(err))
+ self.assertStartsWith(err, expected)
class AsyncTests(BaseTest):
@@ -2018,10 +2018,70 @@ class DeprecatedTests(PyPublicAPITests):
self.assertFalse(inspect.iscoroutinefunction(Cls.sync))
self.assertTrue(inspect.iscoroutinefunction(Cls.coro))
+ def test_inspect_class_signature(self):
+ class Cls1: # no __init__ or __new__
+ pass
+
+ class Cls2: # __new__ only
+ def __new__(cls, x, y):
+ return super().__new__(cls)
+
+ class Cls3: # __init__ only
+ def __init__(self, x, y):
+ pass
+
+ class Cls4: # __new__ and __init__
+ def __new__(cls, x, y):
+ return super().__new__(cls)
+
+ def __init__(self, x, y):
+ pass
+
+ class Cls5(Cls1): # inherits no __init__ or __new__
+ pass
+
+ class Cls6(Cls2): # inherits __new__ only
+ pass
+
+ class Cls7(Cls3): # inherits __init__ only
+ pass
+
+ class Cls8(Cls4): # inherits __new__ and __init__
+ pass
+
+ # The `@deprecated` decorator will update the class in-place.
+ # Test the child classes first.
+ for cls in reversed((Cls1, Cls2, Cls3, Cls4, Cls5, Cls6, Cls7, Cls8)):
+ with self.subTest(f'class {cls.__name__} signature'):
+ try:
+ original_signature = inspect.signature(cls)
+ except ValueError:
+ original_signature = None
+ try:
+ original_new_signature = inspect.signature(cls.__new__)
+ except ValueError:
+ original_new_signature = None
+
+ deprecated_cls = deprecated("depr")(cls)
+
+ try:
+ deprecated_signature = inspect.signature(deprecated_cls)
+ except ValueError:
+ deprecated_signature = None
+ self.assertEqual(original_signature, deprecated_signature)
+
+ try:
+ deprecated_new_signature = inspect.signature(deprecated_cls.__new__)
+ except ValueError:
+ deprecated_new_signature = None
+ self.assertEqual(original_new_signature, deprecated_new_signature)
+
+
def setUpModule():
py_warnings.onceregistry.clear()
c_warnings.onceregistry.clear()
+
tearDownModule = setUpModule
if __name__ == "__main__":
diff --git a/Lib/test/test_wave.py b/Lib/test/test_wave.py
index 5e771c8de96..6c3362857fc 100644
--- a/Lib/test/test_wave.py
+++ b/Lib/test/test_wave.py
@@ -136,32 +136,6 @@ class MiscTestCase(unittest.TestCase):
not_exported = {'WAVE_FORMAT_PCM', 'WAVE_FORMAT_EXTENSIBLE', 'KSDATAFORMAT_SUBTYPE_PCM'}
support.check__all__(self, wave, not_exported=not_exported)
- def test_read_deprecations(self):
- filename = support.findfile('pluck-pcm8.wav', subdir='audiodata')
- with wave.open(filename) as reader:
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- reader.getmark('mark')
- with self.assertWarns(DeprecationWarning):
- self.assertIsNone(reader.getmarkers())
-
- def test_write_deprecations(self):
- with io.BytesIO(b'') as tmpfile:
- with wave.open(tmpfile, 'wb') as writer:
- writer.setnchannels(1)
- writer.setsampwidth(1)
- writer.setframerate(1)
- writer.setcomptype('NONE', 'not compressed')
-
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- writer.setmark(0, 0, 'mark')
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- writer.getmark('mark')
- with self.assertWarns(DeprecationWarning):
- self.assertIsNone(writer.getmarkers())
-
class WaveLowLevelTest(unittest.TestCase):
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index 4faad6629fe..4c7c900eb56 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -432,7 +432,7 @@ class ReferencesTestCase(TestBase):
self.assertEqual(proxy.foo, 2,
"proxy does not reflect attribute modification")
del o.foo
- self.assertFalse(hasattr(proxy, 'foo'),
+ self.assertNotHasAttr(proxy, 'foo',
"proxy does not reflect attribute removal")
proxy.foo = 1
@@ -442,7 +442,7 @@ class ReferencesTestCase(TestBase):
self.assertEqual(o.foo, 2,
"object does not reflect attribute modification via proxy")
del proxy.foo
- self.assertFalse(hasattr(o, 'foo'),
+ self.assertNotHasAttr(o, 'foo',
"object does not reflect attribute removal via proxy")
def test_proxy_deletion(self):
@@ -1108,7 +1108,7 @@ class SubclassableWeakrefTestCase(TestBase):
self.assertEqual(r.slot1, "abc")
self.assertEqual(r.slot2, "def")
self.assertEqual(r.meth(), "abcdef")
- self.assertFalse(hasattr(r, "__dict__"))
+ self.assertNotHasAttr(r, "__dict__")
def test_subclass_refs_with_cycle(self):
"""Confirm https://bugs.python.org/issue3100 is fixed."""
diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py
index 76e8e5c8ab7..c1e4f9c8366 100644
--- a/Lib/test/test_weakset.py
+++ b/Lib/test/test_weakset.py
@@ -466,7 +466,7 @@ class TestWeakSet(unittest.TestCase):
self.assertIsNot(dup, s)
self.assertIs(dup.x, s.x)
self.assertIs(dup.z, s.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
dup = copy.deepcopy(s)
self.assertIsInstance(dup, cls)
@@ -476,7 +476,7 @@ class TestWeakSet(unittest.TestCase):
self.assertIsNot(dup.x, s.x)
self.assertEqual(dup.z, s.z)
self.assertIsNot(dup.z, s.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
if __name__ == "__main__":
diff --git a/Lib/test/test_webbrowser.py b/Lib/test/test_webbrowser.py
index 4c3ea1cd8df..6b577ae100e 100644
--- a/Lib/test/test_webbrowser.py
+++ b/Lib/test/test_webbrowser.py
@@ -6,7 +6,6 @@ import subprocess
import sys
import unittest
import webbrowser
-from functools import partial
from test import support
from test.support import import_helper
from test.support import is_apple_mobile
diff --git a/Lib/test/test_winconsoleio.py b/Lib/test/test_winconsoleio.py
index d9076e77c15..1bae884ed9a 100644
--- a/Lib/test/test_winconsoleio.py
+++ b/Lib/test/test_winconsoleio.py
@@ -17,9 +17,9 @@ ConIO = io._WindowsConsoleIO
class WindowsConsoleIOTests(unittest.TestCase):
def test_abc(self):
- self.assertTrue(issubclass(ConIO, io.RawIOBase))
- self.assertFalse(issubclass(ConIO, io.BufferedIOBase))
- self.assertFalse(issubclass(ConIO, io.TextIOBase))
+ self.assertIsSubclass(ConIO, io.RawIOBase)
+ self.assertNotIsSubclass(ConIO, io.BufferedIOBase)
+ self.assertNotIsSubclass(ConIO, io.TextIOBase)
def test_open_fd(self):
self.assertRaisesRegex(ValueError,
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py
index fd7abd1782e..f16611b29a2 100644
--- a/Lib/test/test_with.py
+++ b/Lib/test/test_with.py
@@ -679,7 +679,7 @@ class AssignmentTargetTestCase(unittest.TestCase):
class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
- self.assertEqual(hasattr(blah, "foo"), True)
+ self.assertHasAttr(blah, "foo")
def testMultipleComplexTargets(self):
class C:
diff --git a/Lib/test/test_wmi.py b/Lib/test/test_wmi.py
index ac7c9cb3a5a..90eb40439d4 100644
--- a/Lib/test/test_wmi.py
+++ b/Lib/test/test_wmi.py
@@ -70,8 +70,8 @@ class WmiTests(unittest.TestCase):
def test_wmi_query_multiple_rows(self):
# Multiple instances should have an extra null separator
r = wmi_exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000")
- self.assertFalse(r.startswith("\0"), r)
- self.assertFalse(r.endswith("\0"), r)
+ self.assertNotStartsWith(r, "\0")
+ self.assertNotEndsWith(r, "\0")
it = iter(r.split("\0"))
try:
while True:
diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py
index b047f7b06f8..e04a4d2c221 100644
--- a/Lib/test/test_wsgiref.py
+++ b/Lib/test/test_wsgiref.py
@@ -149,9 +149,9 @@ class IntegrationTests(TestCase):
start_response("200 OK", ('Content-Type','text/plain'))
return ["Hello, world!"]
out, err = run_amock(validator(bad_app))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(
err.splitlines()[-2],
"AssertionError: Headers (('Content-Type', 'text/plain')) must"
@@ -174,9 +174,9 @@ class IntegrationTests(TestCase):
for status, exc_message in tests:
with self.subTest(status=status):
out, err = run_amock(create_bad_app(status))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(err.splitlines()[-2], exc_message)
def test_wsgi_input(self):
@@ -185,9 +185,9 @@ class IntegrationTests(TestCase):
s("200 OK", [("Content-Type", "text/plain; charset=utf-8")])
return [b"data"]
out, err = run_amock(validator(bad_app))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(
err.splitlines()[-2], "AssertionError"
)
@@ -200,7 +200,7 @@ class IntegrationTests(TestCase):
])
return [b"data"]
out, err = run_amock(validator(app))
- self.assertTrue(err.endswith('"GET / HTTP/1.0" 200 4\n'))
+ self.assertEndsWith(err, '"GET / HTTP/1.0" 200 4\n')
ver = sys.version.split()[0].encode('ascii')
py = python_implementation().encode('ascii')
pyver = py + b"/" + ver
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 5fe9d688410..38be2cd437f 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -225,8 +225,7 @@ class ElementTreeTest(unittest.TestCase):
self.assertTrue(ET.iselement(element), msg="not an element")
direlem = dir(element)
for attr in 'tag', 'attrib', 'text', 'tail':
- self.assertTrue(hasattr(element, attr),
- msg='no %s member' % attr)
+ self.assertHasAttr(element, attr)
self.assertIn(attr, direlem,
msg='no %s visible by dir' % attr)
@@ -251,7 +250,7 @@ class ElementTreeTest(unittest.TestCase):
# Make sure all standard element methods exist.
def check_method(method):
- self.assertTrue(hasattr(method, '__call__'),
+ self.assertHasAttr(method, '__call__',
msg="%s not callable" % method)
check_method(element.append)
@@ -2960,6 +2959,50 @@ class BadElementTest(ElementTestCase, unittest.TestCase):
del b
gc_collect()
+ def test_deepcopy_clear(self):
+ # Prevent crashes when __deepcopy__() clears the children list.
+ # See https://github.com/python/cpython/issues/133009.
+ class X(ET.Element):
+ def __deepcopy__(self, memo):
+ root.clear()
+ return self
+
+ root = ET.Element('a')
+ evil = X('x')
+ root.extend([evil, ET.Element('y')])
+ if is_python_implementation():
+ # Mutating a list over which we iterate raises an error.
+ self.assertRaises(RuntimeError, copy.deepcopy, root)
+ else:
+ c = copy.deepcopy(root)
+ # In the C implementation, we can still copy the evil element.
+ self.assertListEqual(list(c), [evil])
+
+ def test_deepcopy_grow(self):
+ # Prevent crashes when __deepcopy__() mutates the children list.
+ # See https://github.com/python/cpython/issues/133009.
+ a = ET.Element('a')
+ b = ET.Element('b')
+ c = ET.Element('c')
+
+ class X(ET.Element):
+ def __deepcopy__(self, memo):
+ root.append(a)
+ root.append(b)
+ return self
+
+ root = ET.Element('top')
+ evil1, evil2 = X('1'), X('2')
+ root.extend([evil1, c, evil2])
+ children = list(copy.deepcopy(root))
+ # mock deep copies
+ self.assertIs(children[0], evil1)
+ self.assertIs(children[2], evil2)
+ # true deep copies
+ self.assertEqual(children[1].tag, c.tag)
+ self.assertEqual([c.tag for c in children[3:]],
+ [a.tag, b.tag, a.tag, b.tag])
+
class MutationDeleteElementPath(str):
def __new__(cls, elem, *args):
diff --git a/Lib/test/test_xxlimited.py b/Lib/test/test_xxlimited.py
index 6dbfb3f4393..b52e78bc4fb 100644
--- a/Lib/test/test_xxlimited.py
+++ b/Lib/test/test_xxlimited.py
@@ -31,7 +31,7 @@ class CommonTests:
self.assertEqual(self.module.foo(1, 2), 3)
def test_str(self):
- self.assertTrue(issubclass(self.module.Str, str))
+ self.assertIsSubclass(self.module.Str, str)
self.assertIsNot(self.module.Str, str)
custom_string = self.module.Str("abcd")
diff --git a/Lib/test/test_zipapp.py b/Lib/test/test_zipapp.py
index d4766c59a10..8fb0a68deba 100644
--- a/Lib/test/test_zipapp.py
+++ b/Lib/test/test_zipapp.py
@@ -259,7 +259,7 @@ class ZipAppTest(unittest.TestCase):
(source / '__main__.py').touch()
target = io.BytesIO()
zipapp.create_archive(str(source), target, interpreter='python')
- self.assertTrue(target.getvalue().startswith(b'#!python\n'))
+ self.assertStartsWith(target.getvalue(), b'#!python\n')
def test_read_shebang(self):
# Test that we can read the shebang line correctly.
@@ -300,7 +300,7 @@ class ZipAppTest(unittest.TestCase):
zipapp.create_archive(str(source), str(target), interpreter='python')
new_target = io.BytesIO()
zipapp.create_archive(str(target), new_target, interpreter='python2.7')
- self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n'))
+ self.assertStartsWith(new_target.getvalue(), b'#!python2.7\n')
def test_read_from_pathlike_obj(self):
# Test that we can copy an archive using a path-like object
@@ -326,7 +326,7 @@ class ZipAppTest(unittest.TestCase):
new_target = io.BytesIO()
temp_archive.seek(0)
zipapp.create_archive(temp_archive, new_target, interpreter='python2.7')
- self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n'))
+ self.assertStartsWith(new_target.getvalue(), b'#!python2.7\n')
def test_remove_shebang(self):
# Test that we can remove the shebang from a file.
diff --git a/Lib/test/test_zipfile/__main__.py b/Lib/test/test_zipfile/__main__.py
index e25ac946edf..90da74ade38 100644
--- a/Lib/test/test_zipfile/__main__.py
+++ b/Lib/test/test_zipfile/__main__.py
@@ -1,6 +1,6 @@
import unittest
-from . import load_tests # noqa: F401
+from . import load_tests
if __name__ == "__main__":
diff --git a/Lib/test/test_zipfile/_path/_test_params.py b/Lib/test/test_zipfile/_path/_test_params.py
index bc95b4ebf4a..00a9eaf2f99 100644
--- a/Lib/test/test_zipfile/_path/_test_params.py
+++ b/Lib/test/test_zipfile/_path/_test_params.py
@@ -1,5 +1,5 @@
-import types
import functools
+import types
from ._itertools import always_iterable
diff --git a/Lib/test/test_zipfile/_path/test_complexity.py b/Lib/test/test_zipfile/_path/test_complexity.py
index b505dd7c376..7c108fc6ab8 100644
--- a/Lib/test/test_zipfile/_path/test_complexity.py
+++ b/Lib/test/test_zipfile/_path/test_complexity.py
@@ -8,10 +8,8 @@ import zipfile
from ._functools import compose
from ._itertools import consume
-
from ._support import import_or_skip
-
big_o = import_or_skip('big_o')
pytest = import_or_skip('pytest')
diff --git a/Lib/test/test_zipfile/_path/test_path.py b/Lib/test/test_zipfile/_path/test_path.py
index 0afabc0c668..696134023a5 100644
--- a/Lib/test/test_zipfile/_path/test_path.py
+++ b/Lib/test/test_zipfile/_path/test_path.py
@@ -1,6 +1,6 @@
+import contextlib
import io
import itertools
-import contextlib
import pathlib
import pickle
import stat
@@ -9,12 +9,11 @@ import unittest
import zipfile
import zipfile._path
-from test.support.os_helper import temp_dir, FakePath
+from test.support.os_helper import FakePath, temp_dir
from ._functools import compose
from ._itertools import Counter
-
-from ._test_params import parameterize, Invoked
+from ._test_params import Invoked, parameterize
class jaraco:
@@ -193,10 +192,10 @@ class TestPath(unittest.TestCase):
"""EncodingWarning must blame the read_text and open calls."""
assert sys.flags.warn_default_encoding
root = zipfile.Path(alpharep)
- with self.assertWarns(EncodingWarning) as wc:
+ with self.assertWarns(EncodingWarning) as wc: # noqa: F821 (astral-sh/ruff#13296)
root.joinpath("a.txt").read_text()
assert __file__ == wc.filename
- with self.assertWarns(EncodingWarning) as wc:
+ with self.assertWarns(EncodingWarning) as wc: # noqa: F821 (astral-sh/ruff#13296)
root.joinpath("a.txt").open("r").close()
assert __file__ == wc.filename
@@ -365,6 +364,17 @@ class TestPath(unittest.TestCase):
assert root.name == 'alpharep.zip' == root.filename.name
@pass_alpharep
+ def test_root_on_disk(self, alpharep):
+ """
+ The name/stem of the root should match the zipfile on disk.
+
+ This condition must hold across platforms.
+ """
+ root = zipfile.Path(self.zipfile_ondisk(alpharep))
+ assert root.name == 'alpharep.zip' == root.filename.name
+ assert root.stem == 'alpharep' == root.filename.stem
+
+ @pass_alpharep
def test_suffix(self, alpharep):
"""
The suffix of the root should be the suffix of the zipfile.
diff --git a/Lib/test/test_zipfile/_path/write-alpharep.py b/Lib/test/test_zipfile/_path/write-alpharep.py
index 48c09b53717..7418391abad 100644
--- a/Lib/test/test_zipfile/_path/write-alpharep.py
+++ b/Lib/test/test_zipfile/_path/write-alpharep.py
@@ -1,4 +1,3 @@
from . import test_path
-
__name__ == '__main__' and test_path.build_alpharep_fixture().extractall('alpharep')
diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py
index 7c8a82d821a..ada96813709 100644
--- a/Lib/test/test_zipfile/test_core.py
+++ b/Lib/test/test_zipfile/test_core.py
@@ -23,11 +23,13 @@ from test import archiver_tests
from test.support import script_helper, os_helper
from test.support import (
findfile, requires_zlib, requires_bz2, requires_lzma,
- captured_stdout, captured_stderr, requires_subprocess,
+ requires_zstd, captured_stdout, captured_stderr, requires_subprocess,
+ cpython_only
)
from test.support.os_helper import (
TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath
)
+from test.support.import_helper import ensure_lazy_imports
TESTFN2 = TESTFN + "2"
@@ -49,6 +51,13 @@ def get_files(test):
yield f
test.assertFalse(f.closed)
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("zipfile", {"typing"})
+
+
class AbstractTestsWithSourceFile:
@classmethod
def setUpClass(cls):
@@ -693,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractTestZip64InSmallFiles:
# These tests test the ZIP64 functionality without using large files,
@@ -1270,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractWriterTests:
@@ -1339,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase):
class LzmaWriterTests(AbstractWriterTests, unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdWriterTests(AbstractWriterTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class PyZipFileTests(unittest.TestCase):
def assertCompiledIn(self, name, namelist):
@@ -1971,6 +1991,25 @@ class OtherTests(unittest.TestCase):
self.assertFalse(zipfile.is_zipfile(fp))
fp.seek(0, 0)
self.assertFalse(zipfile.is_zipfile(fp))
+ # - passing non-zipfile with ZIP header elements
+ # data created using pyPNG like so:
+ # d = [(ord('P'), ord('K'), 5, 6), (ord('P'), ord('K'), 6, 6)]
+ # w = png.Writer(1,2,alpha=True,compression=0)
+ # f = open('onepix.png', 'wb')
+ # w.write(f, d)
+ # w.close()
+ data = (b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
+ b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I"
+ b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07"
+ b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82")
+ # - passing a filename
+ with open(TESTFN, "wb") as fp:
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(TESTFN))
+ # - passing a file-like object
+ fp = io.BytesIO()
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(fp))
def test_damaged_zipfile(self):
"""Check that zipfiles with missing bytes at the end raise BadZipFile."""
@@ -2669,6 +2708,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
b'\x00>\x00\x00\x00\x00\x00')
+@requires_zstd()
+class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
+ zip_with_bad_crc = (
+ b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00'
+ b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00'
+ b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00'
+ b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK'
+ b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00'
+ b'\x00\x00')
class DecryptionTests(unittest.TestCase):
"""Check that ZIP decryption works. Since the library does not
@@ -2896,6 +2946,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
# Provide the tell() method but not seek()
class Tellable:
@@ -3144,7 +3198,7 @@ class TestWithDirectory(unittest.TestCase):
with zipfile.ZipFile(TESTFN, "w") as zipf:
zipf.write(dirpath)
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("/x/"))
+ self.assertEndsWith(zinfo.filename, "/x/")
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
zipf.write(dirpath, "y")
zinfo = zipf.filelist[1]
@@ -3152,7 +3206,7 @@ class TestWithDirectory(unittest.TestCase):
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
with zipfile.ZipFile(TESTFN, "r") as zipf:
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("/x/"))
+ self.assertEndsWith(zinfo.filename, "/x/")
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
zinfo = zipf.filelist[1]
self.assertTrue(zinfo.filename, "y/")
@@ -3172,7 +3226,7 @@ class TestWithDirectory(unittest.TestCase):
self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10)
with zipfile.ZipFile(TESTFN, "r") as zipf:
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("x/"))
+ self.assertEndsWith(zinfo.filename, "x/")
self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10)
target = os.path.join(TESTFN2, "target")
os.mkdir(target)
@@ -3607,7 +3661,7 @@ class EncodedMetadataTests(unittest.TestCase):
except OSError:
pass
except UnicodeEncodeError:
- self.skipTest(f'cannot encode file name {fn!r}')
+ self.skipTest(f'cannot encode file name {fn!a}')
zipfile.main(["--metadata-encoding=shift_jis", "-e", TESTFN, TESTFN2])
listing = os.listdir(TESTFN2)
diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py
index 1f288c8b45d..b5b4acf5f85 100644
--- a/Lib/test/test_zipimport.py
+++ b/Lib/test/test_zipimport.py
@@ -835,11 +835,11 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
s = io.StringIO()
print_tb(tb, 1, s)
- self.assertTrue(s.getvalue().endswith(
+ self.assertEndsWith(s.getvalue(),
' def do_raise(): raise TypeError\n'
'' if support.has_no_debug_ranges() else
' ^^^^^^^^^^^^^^^\n'
- ))
+ )
else:
raise AssertionError("This ought to be impossible")
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
index 4d97fe56f3a..c57ab51eca1 100644
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -119,6 +119,114 @@ class ChecksumTestCase(unittest.TestCase):
self.assertEqual(binascii.crc32(b'spam'), zlib.crc32(b'spam'))
+class ChecksumCombineMixin:
+ """Mixin class for testing checksum combination."""
+
+ N = 1000
+ default_iv: int
+
+ def parse_iv(self, iv):
+ """Parse an IV value.
+
+ - The default IV is returned if *iv* is None.
+ - A random IV is returned if *iv* is -1.
+ - Otherwise, *iv* is returned as is.
+ """
+ if iv is None:
+ return self.default_iv
+ if iv == -1:
+ return random.randint(1, 0x80000000)
+ return iv
+
+ def checksum(self, data, init=None):
+ """Compute the checksum of data with a given initial value.
+
+ The *init* value is parsed by ``parse_iv``.
+ """
+ iv = self.parse_iv(init)
+ return self._checksum(data, iv)
+
+ def _checksum(self, data, init):
+ raise NotImplementedError
+
+ def combine(self, a, b, blen):
+ """Combine two checksums together."""
+ raise NotImplementedError
+
+ def get_random_data(self, data_len, *, iv=None):
+ """Get a triplet (data, iv, checksum)."""
+ data = random.randbytes(data_len)
+ init = self.parse_iv(iv)
+ checksum = self.checksum(data, init)
+ return data, init, checksum
+
+ def test_combine_empty(self):
+ for _ in range(self.N):
+ a, iv, checksum = self.get_random_data(32, iv=-1)
+ res = self.combine(iv, self.checksum(a), len(a))
+ self.assertEqual(res, checksum)
+
+ def test_combine_no_iv(self):
+ for _ in range(self.N):
+ a, _, chk_a = self.get_random_data(32)
+ b, _, chk_b = self.get_random_data(64)
+ res = self.combine(chk_a, chk_b, len(b))
+ self.assertEqual(res, self.checksum(a + b))
+
+ def test_combine_no_iv_invalid_length(self):
+ a, _, chk_a = self.get_random_data(32)
+ b, _, chk_b = self.get_random_data(64)
+ checksum = self.checksum(a + b)
+ for invalid_len in [1, len(a), 48, len(b) + 1, 191]:
+ invalid_res = self.combine(chk_a, chk_b, invalid_len)
+ self.assertNotEqual(invalid_res, checksum)
+
+ self.assertRaises(TypeError, self.combine, 0, 0, "len")
+
+ def test_combine_with_iv(self):
+ for _ in range(self.N):
+ a, iv_a, chk_a_with_iv = self.get_random_data(32, iv=-1)
+ chk_a_no_iv = self.checksum(a)
+ b, iv_b, chk_b_with_iv = self.get_random_data(64, iv=-1)
+ chk_b_no_iv = self.checksum(b)
+
+ # We can represent c = COMBINE(CHK(a, iv_a), CHK(b, iv_b)) as:
+ #
+ # c = CHK(CHK(b'', iv_a) + CHK(a) + CHK(b'', iv_b) + CHK(b))
+ # = COMBINE(
+ # COMBINE(CHK(b'', iv_a), CHK(a)),
+ # COMBINE(CHK(b'', iv_b), CHK(b)),
+ # )
+ # = COMBINE(COMBINE(iv_a, CHK(a)), COMBINE(iv_b, CHK(b)))
+ tmp0 = self.combine(iv_a, chk_a_no_iv, len(a))
+ tmp1 = self.combine(iv_b, chk_b_no_iv, len(b))
+ expected = self.combine(tmp0, tmp1, len(b))
+ checksum = self.combine(chk_a_with_iv, chk_b_with_iv, len(b))
+ self.assertEqual(checksum, expected)
+
+
+class CRC32CombineTestCase(ChecksumCombineMixin, unittest.TestCase):
+
+ default_iv = 0
+
+ def _checksum(self, data, init):
+ return zlib.crc32(data, init)
+
+ def combine(self, a, b, blen):
+ return zlib.crc32_combine(a, b, blen)
+
+
+class Adler32CombineTestCase(ChecksumCombineMixin, unittest.TestCase):
+
+ default_iv = 1
+
+ def _checksum(self, data, init):
+ return zlib.adler32(data, init)
+
+ def combine(self, a, b, blen):
+ return zlib.adler32_combine(a, b, blen)
+
+
# Issue #10276 - check that inputs >=4 GiB are handled correctly.
class ChecksumBigBufferTestCase(unittest.TestCase):
diff --git a/Lib/test/test_zoneinfo/test_zoneinfo.py b/Lib/test/test_zoneinfo/test_zoneinfo.py
index b0dbd768cab..f313e394f49 100644
--- a/Lib/test/test_zoneinfo/test_zoneinfo.py
+++ b/Lib/test/test_zoneinfo/test_zoneinfo.py
@@ -237,7 +237,6 @@ class ZoneInfoTest(TzPathUserMixin, ZoneInfoTestBase):
"../zoneinfo/America/Los_Angeles", # Traverses above TZPATH
"America/../America/Los_Angeles", # Not normalized
"America/./Los_Angeles",
- "",
]
for bad_key in bad_keys:
@@ -1916,8 +1915,8 @@ class ExtensionBuiltTest(unittest.TestCase):
def test_cache_location(self):
# The pure Python version stores caches on attributes, but the C
# extension stores them in C globals (at least for now)
- self.assertFalse(hasattr(c_zoneinfo.ZoneInfo, "_weak_cache"))
- self.assertTrue(hasattr(py_zoneinfo.ZoneInfo, "_weak_cache"))
+ self.assertNotHasAttr(c_zoneinfo.ZoneInfo, "_weak_cache")
+ self.assertHasAttr(py_zoneinfo.ZoneInfo, "_weak_cache")
def test_gc_tracked(self):
import gc
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
new file mode 100644
index 00000000000..d4c28aed38e
--- /dev/null
+++ b/Lib/test/test_zstd.py
@@ -0,0 +1,2794 @@
+import array
+import gc
+import io
+import pathlib
+import random
+import re
+import os
+import unittest
+import tempfile
+import threading
+
+from test.support.import_helper import import_module
+from test.support import threading_helper
+from test.support import _1M
+
+_zstd = import_module("_zstd")
+zstd = import_module("compression.zstd")
+
+from compression.zstd import (
+ open,
+ compress,
+ decompress,
+ ZstdCompressor,
+ ZstdDecompressor,
+ ZstdDict,
+ ZstdError,
+ zstd_version,
+ zstd_version_info,
+ COMPRESSION_LEVEL_DEFAULT,
+ get_frame_info,
+ get_frame_size,
+ finalize_dict,
+ train_dict,
+ CompressionParameter,
+ DecompressionParameter,
+ Strategy,
+ ZstdFile,
+)
+
+_1K = 1024
+_130_1K = 130 * _1K
+DICT_SIZE1 = 3*_1K
+
+DAT_130K_D = None
+DAT_130K_C = None
+
+DECOMPRESSED_DAT = None
+COMPRESSED_DAT = None
+
+DECOMPRESSED_100_PLUS_32KB = None
+COMPRESSED_100_PLUS_32KB = None
+
+SKIPPABLE_FRAME = None
+
+THIS_FILE_BYTES = None
+THIS_FILE_STR = None
+COMPRESSED_THIS_FILE = None
+
+COMPRESSED_BOGUS = None
+
+SAMPLES = None
+
+TRAINED_DICT = None
+
+SUPPORT_MULTITHREADING = False
+
+C_INT_MIN = -(2**31)
+C_INT_MAX = (2**31) - 1
+
+
+def setUpModule():
+ global SUPPORT_MULTITHREADING
+ SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0)
+ # uncompressed size 130KB, more than a zstd block.
+ # with a frame epilogue, 4 bytes checksum.
+ global DAT_130K_D
+ DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)])
+
+ global DAT_130K_C
+ DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1})
+
+ global DECOMPRESSED_DAT
+ DECOMPRESSED_DAT = b'abcdefg123456' * 1000
+
+ global COMPRESSED_DAT
+ COMPRESSED_DAT = compress(DECOMPRESSED_DAT)
+
+ global DECOMPRESSED_100_PLUS_32KB
+ DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K)
+
+ global COMPRESSED_100_PLUS_32KB
+ COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB)
+
+ global SKIPPABLE_FRAME
+ SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \
+ (32*_1K).to_bytes(4, byteorder='little') + \
+ b'a' * (32*_1K)
+
+ global THIS_FILE_BYTES, THIS_FILE_STR
+ with io.open(os.path.abspath(__file__), 'rb') as f:
+ THIS_FILE_BYTES = f.read()
+ THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES)
+ THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8')
+
+ global COMPRESSED_THIS_FILE
+ COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES)
+
+ global COMPRESSED_BOGUS
+ COMPRESSED_BOGUS = DECOMPRESSED_DAT
+
+ # dict data
+ words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue',
+ b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive',
+ b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird']
+ lst = []
+ for i in range(300):
+ sample = [b'%s = %d' % (random.choice(words), random.randrange(100))
+ for j in range(20)]
+ sample = b'\n'.join(sample)
+
+ lst.append(sample)
+ global SAMPLES
+ SAMPLES = lst
+ assert len(SAMPLES) > 10
+
+ global TRAINED_DICT
+ TRAINED_DICT = train_dict(SAMPLES, 3*_1K)
+ assert len(TRAINED_DICT.dict_content) <= 3*_1K
+
+
+class FunctionsTestCase(unittest.TestCase):
+
+ def test_version(self):
+ s = ".".join((str(i) for i in zstd_version_info))
+ self.assertEqual(s, zstd_version)
+
+ def test_compressionLevel_values(self):
+ min, max = CompressionParameter.compression_level.bounds()
+ self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int)
+ self.assertIs(type(min), int)
+ self.assertIs(type(max), int)
+ self.assertLess(min, max)
+
+ def test_roundtrip_default(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ dat1 = compress(raw_dat)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_roundtrip_level(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+
+ for level in range(max(-20, level_min), level_max + 1):
+ dat1 = compress(raw_dat, level)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_get_frame_info(self):
+ # no dict
+ info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20])
+ self.assertEqual(info.decompressed_size, 32 * _1K + 100)
+ self.assertEqual(info.dictionary_id, 0)
+
+ # use dict
+ dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT)
+ info = get_frame_info(dat)
+ self.assertEqual(info.decompressed_size, 345)
+ self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id)
+
+ with self.assertRaisesRegex(ZstdError, "not less than the frame header"):
+ get_frame_info(b"aaaaaaaaaaaaaa")
+
+ def test_get_frame_size(self):
+ size = get_frame_size(COMPRESSED_100_PLUS_32KB)
+ self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB))
+
+ with self.assertRaisesRegex(ZstdError, "not less than this complete frame"):
+ get_frame_size(b"aaaaaaaaaaaaaa")
+
+ def test_decompress_2x130_1K(self):
+ decompressed_size = get_frame_info(DAT_130K_C).decompressed_size
+ self.assertEqual(decompressed_size, _130_1K)
+
+ dat = decompress(DAT_130K_C + DAT_130K_C)
+ self.assertEqual(len(dat), 2 * _130_1K)
+
+
+class CompressorTestCase(unittest.TestCase):
+
+ def test_simple_compress_bad_args(self):
+ # ZstdCompressor
+ self.assertRaises(TypeError, ZstdCompressor, [])
+ self.assertRaises(TypeError, ZstdCompressor, level=3.14)
+ self.assertRaises(TypeError, ZstdCompressor, level="abc")
+ self.assertRaises(TypeError, ZstdCompressor, options=b"abc")
+
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})
+
+ # valid range for compression level is [-(1<<17), 22]
+ msg = r'illegal compression level {}; the valid range is \[-?\d+, -?\d+\]'
+ with self.assertRaisesRegex(ValueError, msg.format(C_INT_MAX)):
+ ZstdCompressor(C_INT_MAX)
+ with self.assertRaisesRegex(ValueError, msg.format(C_INT_MIN)):
+ ZstdCompressor(C_INT_MIN)
+ msg = r'illegal compression level; the valid range is \[-?\d+, -?\d+\]'
+ with self.assertRaisesRegex(ValueError, msg):
+ ZstdCompressor(level=-(2**1000))
+ with self.assertRaisesRegex(ValueError, msg):
+ ZstdCompressor(level=2**1000)
+
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.window_log: 100})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={3333: 100})
+
+ # Method bad arguments
+ zc = ZstdCompressor()
+ self.assertRaises(TypeError, zc.compress)
+ self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar")
+ self.assertRaises(TypeError, zc.compress, "str")
+ self.assertRaises((TypeError, ValueError), zc.flush, b"foo")
+ self.assertRaises(TypeError, zc.flush, b"blah", 1)
+
+ self.assertRaises(ValueError, zc.compress, b'', -1)
+ self.assertRaises(ValueError, zc.compress, b'', 3)
+ self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0
+ self.assertRaises(ValueError, zc.flush, 3)
+
+ zc.compress(b'')
+ zc.compress(b'', zc.CONTINUE)
+ zc.compress(b'', zc.FLUSH_BLOCK)
+ zc.compress(b'', zc.FLUSH_FRAME)
+ empty = zc.flush()
+ zc.flush(zc.FLUSH_BLOCK)
+ zc.flush(zc.FLUSH_FRAME)
+
+ def test_compress_parameters(self):
+ d = {CompressionParameter.compression_level : 10,
+
+ CompressionParameter.window_log : 12,
+ CompressionParameter.hash_log : 10,
+ CompressionParameter.chain_log : 12,
+ CompressionParameter.search_log : 12,
+ CompressionParameter.min_match : 4,
+ CompressionParameter.target_length : 12,
+ CompressionParameter.strategy : Strategy.lazy,
+
+ CompressionParameter.enable_long_distance_matching : 1,
+ CompressionParameter.ldm_hash_log : 12,
+ CompressionParameter.ldm_min_match : 11,
+ CompressionParameter.ldm_bucket_size_log : 5,
+ CompressionParameter.ldm_hash_rate_log : 12,
+
+ CompressionParameter.content_size_flag : 1,
+ CompressionParameter.checksum_flag : 1,
+ CompressionParameter.dict_id_flag : 0,
+
+ CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0,
+ }
+ ZstdCompressor(options=d)
+
+ d1 = d.copy()
+ # larger than signed int
+ d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options=d1)
+ # smaller than signed int
+ d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options=d1)
+
+ # out of bounds compression level
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+ with self.assertRaises(ValueError):
+ compress(b'', level_max+1)
+ with self.assertRaises(ValueError):
+ compress(b'', level_min-1)
+ with self.assertRaises(ValueError):
+ compress(b'', 2**1000)
+ with self.assertRaises(ValueError):
+ compress(b'', -(2**1000))
+ with self.assertRaises(ValueError):
+ compress(b'', options={
+ CompressionParameter.compression_level: level_max+1})
+ with self.assertRaises(ValueError):
+ compress(b'', options={
+ CompressionParameter.compression_level: level_min-1})
+
+ # zstd lib doesn't support MT compression
+ if not SUPPORT_MULTITHREADING:
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.nb_workers:4})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.job_size:4})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.overlap_log:4})
+
+ # out of bounds error msg
+ option = {CompressionParameter.window_log:100}
+ with self.assertRaisesRegex(
+ ValueError,
+ "compression parameter 'window_log' received an illegal value 100; "
+ r'the valid range is \[-?\d+, -?\d+\]',
+ ):
+ compress(b'', options=option)
+
+ def test_unknown_compression_parameter(self):
+ KEY = 100001234
+ option = {CompressionParameter.compression_level: 10,
+ KEY: 200000000}
+ pattern = rf"invalid compression parameter 'unknown parameter \(key {KEY}\)'"
+ with self.assertRaisesRegex(ValueError, pattern):
+ ZstdCompressor(options=option)
+
+ @unittest.skipIf(not SUPPORT_MULTITHREADING,
+ "zstd build doesn't support multi-threaded compression")
+ def test_zstd_multithread_compress(self):
+ size = 40*_1M
+ b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES))
+
+ options = {CompressionParameter.compression_level : 4,
+ CompressionParameter.nb_workers : 2}
+
+ # compress()
+ dat1 = compress(b, options=options)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, b)
+
+ # ZstdCompressor
+ c = ZstdCompressor(options=options)
+ dat1 = c.compress(b, c.CONTINUE)
+ dat2 = c.compress(b, c.FLUSH_BLOCK)
+ dat3 = c.compress(b, c.FLUSH_FRAME)
+ dat4 = decompress(dat1+dat2+dat3)
+ self.assertEqual(dat4, b * 3)
+
+ # ZstdFile
+ with ZstdFile(io.BytesIO(), 'w', options=options) as f:
+ f.write(b)
+
+ def test_compress_flushblock(self):
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK)
+ self.assertEqual(c.last_mode, c.FLUSH_BLOCK)
+ dat2 = c.flush()
+ pattern = "Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(dat1)
+
+ dat3 = decompress(dat1 + dat2)
+
+ self.assertEqual(dat3, THIS_FILE_BYTES)
+
+ def test_compress_flushframe(self):
+ # test compress & decompress
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat1)
+ self.assertEqual(nt.decompressed_size, None) # no content size
+
+ dat2 = decompress(dat1)
+
+ self.assertEqual(dat2, THIS_FILE_BYTES)
+
+ # single .FLUSH_FRAME mode has content size
+ c = ZstdCompressor()
+ dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat)
+ self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES))
+
+ def test_compress_empty(self):
+ # output empty content frame
+ self.assertNotEqual(compress(b''), b'')
+
+ c = ZstdCompressor()
+ self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
+
+ def test_set_pledged_input_size(self):
+ DAT = DECOMPRESSED_100_PLUS_32KB
+ CHUNK_SIZE = len(DAT) // 3
+
+ # wrong value
+ c = ZstdCompressor()
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(-300)
+ # overflow
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64)
+ # ZSTD_CONTENTSIZE_ERROR is invalid
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64-2)
+ # ZSTD_CONTENTSIZE_UNKNOWN should use None
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64-1)
+
+ # check valid values are settable
+ c.set_pledged_input_size(2**63)
+ c.set_pledged_input_size(2**64-3)
+
+ # check that zero means empty frame
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(0)
+ c.compress(b'')
+ dat = c.flush()
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, 0)
+
+
+ # wrong mode
+ c = ZstdCompressor(level=1)
+ c.compress(b'123456')
+ self.assertEqual(c.last_mode, c.CONTINUE)
+ with self.assertRaisesRegex(ValueError,
+ r'last_mode == FLUSH_FRAME'):
+ c.set_pledged_input_size(300)
+
+ # None value
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(None)
+ dat = c.compress(DAT) + c.flush()
+
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, None)
+
+ # correct value
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT))
+
+ chunks = []
+ posi = 0
+ while posi < len(DAT):
+ dat = c.compress(DAT[posi:posi+CHUNK_SIZE])
+ posi += CHUNK_SIZE
+ chunks.append(dat)
+
+ dat = c.flush()
+ chunks.append(dat)
+ chunks = b''.join(chunks)
+
+ ret = get_frame_info(chunks)
+ self.assertEqual(ret.decompressed_size, len(DAT))
+ self.assertEqual(decompress(chunks), DAT)
+
+ c.set_pledged_input_size(len(DAT)) # the second frame
+ dat = c.compress(DAT) + c.flush()
+
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, len(DAT))
+ self.assertEqual(decompress(dat), DAT)
+
+ # not enough data
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT)+1)
+
+ for start in range(0, len(DAT), CHUNK_SIZE):
+ end = min(start+CHUNK_SIZE, len(DAT))
+ _dat = c.compress(DAT[start:end])
+
+ with self.assertRaises(ZstdError):
+ c.flush()
+
+ # too much data
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT))
+
+ for start in range(0, len(DAT), CHUNK_SIZE):
+ end = min(start+CHUNK_SIZE, len(DAT))
+ _dat = c.compress(DAT[start:end])
+
+ with self.assertRaises(ZstdError):
+ c.compress(b'extra', ZstdCompressor.FLUSH_FRAME)
+
+ # content size not set if content_size_flag == 0
+ c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0})
+ c.set_pledged_input_size(10)
+ dat1 = c.compress(b"hello")
+ dat2 = c.compress(b"world")
+ dat3 = c.flush()
+ frame_data = get_frame_info(dat1 + dat2 + dat3)
+ self.assertIsNone(frame_data.decompressed_size)
+
+
+class DecompressorTestCase(unittest.TestCase):
+
+ def test_simple_decompress_bad_args(self):
+ # ZstdDecompressor
+ self.assertRaises(TypeError, ZstdDecompressor, ())
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc')
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4})
+
+ self.assertRaises(TypeError, ZstdDecompressor, options=123)
+ self.assertRaises(TypeError, ZstdDecompressor, options='abc')
+ self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')
+
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={C_INT_MAX: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={C_INT_MIN: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={0: C_INT_MAX})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={2**1000: 100})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={-(2**1000): 100})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={0: -(2**1000)})
+
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={DecompressionParameter.window_log_max: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={3333: 100})
+
+ empty = compress(b'')
+ lzd = ZstdDecompressor()
+ self.assertRaises(TypeError, lzd.decompress)
+ self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar")
+ self.assertRaises(TypeError, lzd.decompress, "str")
+ lzd.decompress(empty)
+
+ def test_decompress_parameters(self):
+ d = {DecompressionParameter.window_log_max : 15}
+ ZstdDecompressor(options=d)
+
+ d1 = d.copy()
+ # larger than signed int
+ d1[DecompressionParameter.window_log_max] = 2**1000
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(None, d1)
+ # smaller than signed int
+ d1[DecompressionParameter.window_log_max] = -(2**1000)
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(None, d1)
+
+ d1[DecompressionParameter.window_log_max] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(None, d1)
+ d1[DecompressionParameter.window_log_max] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(None, d1)
+
+ # out of bounds error msg
+ options = {DecompressionParameter.window_log_max:100}
+ with self.assertRaisesRegex(
+ ValueError,
+ "decompression parameter 'window_log_max' received an illegal value 100; "
+ r'the valid range is \[-?\d+, -?\d+\]',
+ ):
+ decompress(b'', options=options)
+
+ # out of bounds deecompression parameter
+ options[DecompressionParameter.window_log_max] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = 2**1000
+ with self.assertRaises(OverflowError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = -(2**1000)
+ with self.assertRaises(OverflowError):
+ decompress(b'', options=options)
+
+ def test_unknown_decompression_parameter(self):
+ KEY = 100001234
+ options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
+ KEY: 200000000}
+ pattern = rf"invalid decompression parameter 'unknown parameter \(key {KEY}\)'"
+ with self.assertRaisesRegex(ValueError, pattern):
+ ZstdDecompressor(options=options)
+
+ def test_decompress_epilogue_flags(self):
+ # DAT_130K_C has a 4 bytes checksum at frame epilogue
+
+ # full unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'')
+
+ # full limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'', 0)
+
+ # [:-4] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-4] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-3] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-3] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-1] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-1] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ def test_decompressor_arg(self):
+ zd = ZstdDict(b'12345678', is_raw=True)
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(zstd_dict={})
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(options=zd)
+
+ ZstdDecompressor()
+ ZstdDecompressor(zd, {})
+ ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25})
+
+ def test_decompressor_1(self):
+ # empty
+ d = ZstdDecompressor()
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+
+ # 130_1K full
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K full, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K, without 4 bytes checksum
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+
+ # above, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # full, unused_data
+ TRAIL = b'89234893abcd'
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C + TRAIL, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, TRAIL)
+
+ def test_decompressor_chunks_read_300(self):
+ TRAIL = b'89234893abcd'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(300)
+ if not dat:
+ break
+ else:
+ raise Exception('should not get here')
+
+ ret = d.decompress(dat)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+ def test_decompressor_chunks_read_3(self):
+ TRAIL = b'89234893'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(3)
+ if not dat:
+ break
+ else:
+ dat = b''
+
+ ret = d.decompress(dat, 1)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+
+ def test_decompress_empty(self):
+ with self.assertRaises(ZstdError):
+ decompress(b'')
+
+ d = ZstdDecompressor()
+ self.assertEqual(d.decompress(b''), b'')
+ self.assertFalse(d.eof)
+
+ def test_decompress_empty_content_frame(self):
+ DAT = compress(b'')
+ # decompress
+ self.assertGreaterEqual(len(DAT), 4)
+ self.assertEqual(decompress(DAT), b'')
+
+ with self.assertRaises(ZstdError):
+ decompress(DAT[:-1])
+
+ # ZstdDecompressor
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT)
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT[:-1])
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+class DecompressorFlagsTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ options = {CompressionParameter.checksum_flag:1}
+ c = ZstdCompressor(options=options)
+
+ cls.DECOMPRESSED_42 = b'a'*42
+ cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME)
+
+ cls.DECOMPRESSED_60 = b'a'*60
+ cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME)
+
+ cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60
+ cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60
+
+ cls._130_1K = 130*_1K
+
+ c = ZstdCompressor()
+ cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush()
+ cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush()
+ cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60
+
+ cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|'
+
+ def test_function_decompress(self):
+
+ self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K)
+
+ # 1 frame
+ self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42)
+
+ pattern = r"Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:1])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-1])
+
+ # 2 frames
+ self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42_60[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.UNKNOWN_FRAME_42_60[:-1])
+
+ # 130_1K
+ self.assertEqual(decompress(DAT_130K_C), DAT_130K_D)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.FRAME_42 + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa')
+
+ # doesn't match checksum
+ checksum = DAT_130K_C[-4:]
+ if checksum[0] == 255:
+ wrong_checksum = bytes([254]) + checksum[1:]
+ else:
+ wrong_checksum = bytes([checksum[0]+1]) + checksum[1:]
+
+ dat = DAT_130K_C[:-4] + wrong_checksum
+
+ with self.assertRaisesRegex(ZstdError, "doesn't match checksum"):
+ decompress(dat)
+
+ def test_function_skippable(self):
+ self.assertEqual(decompress(SKIPPABLE_FRAME), b'')
+ self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'')
+
+ # 1 frame + 2 skippable
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ # unknown size
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_60)
+
+ # 2 frames + 1 skippable
+ self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_42_60)
+
+ # incomplete
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:1])
+
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:-1])
+
+ with self.assertRaises(ZstdError):
+ decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME)
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ def test_decompressor_1(self):
+ # empty 1
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # empty 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # 1 frame
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42 + self.TRAIL)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # 1 frame, 32_1K
+ temp = compress(b'a'*(32*_1K))
+ d = ZstdDecompressor()
+ dat = d.decompress(temp, 32*_1K)
+
+ self.assertEqual(dat, b'a'*(32*_1K))
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, 32_1K+100, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes
+
+ self.assertEqual(len(dat), 100)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ dat = d.decompress(b'') # 32_1K
+
+ self.assertEqual(len(dat), 32*_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # incomplete 1
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_60[:1])
+
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 3
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-1])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ # incomplete 4
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4], 60)
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # Unknown frame descriptor
+ d = ZstdDecompressor()
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ d.decompress(b'aaaaaaaaa')
+
+ def test_decompressor_skippable(self):
+ # 1 skippable
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, max_length=0
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME, 0)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1])
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1], 0)
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+
+
+class ZstdDictTestCase(unittest.TestCase):
+
+ def test_is_raw(self):
+ # must be passed as a keyword argument
+ with self.assertRaises(TypeError):
+ ZstdDict(bytes(8), True)
+
+ # content < 8
+ b = b'1234567'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # content == 8
+ b = b'12345678'
+ zd = ZstdDict(b, is_raw=True)
+ self.assertEqual(zd.dict_id, 0)
+
+ temp = compress(b'aaa12345678', level=3, zstd_dict=zd)
+ self.assertEqual(b'aaa12345678', decompress(temp, zd))
+
+ # is_raw == False
+ b = b'12345678abcd'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # read only attributes
+ with self.assertRaises(AttributeError):
+ zd.dict_content = b
+
+ with self.assertRaises(AttributeError):
+ zd.dict_id = 10000
+
+ # ZstdDict arguments
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+ self.assertNotEqual(zd.dict_id, 0)
+
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True)
+ self.assertNotEqual(zd.dict_id, 0) # note this assertion
+
+ with self.assertRaises(TypeError):
+ ZstdDict("12345678abcdef", is_raw=True)
+ with self.assertRaises(TypeError):
+ ZstdDict(TRAINED_DICT)
+
+ # invalid parameter
+ with self.assertRaises(TypeError):
+ ZstdDict(desk333=345)
+
+ def test_invalid_dict(self):
+ DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little')
+ dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz'
+
+ # corrupted
+ zd = ZstdDict(dict_content, is_raw=False)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?content\.$'):
+ ZstdCompressor(zstd_dict=zd.as_digested_dict)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?content\.$'):
+ ZstdDecompressor(zd)
+
+ # wrong type
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=[zd, 1])
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 1.0))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd,))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 3))
+ with self.assertRaises(OverflowError):
+ ZstdCompressor(zstd_dict=(zd, 2**1000))
+ with self.assertRaises(OverflowError):
+ ZstdCompressor(zstd_dict=(zd, -2**1000))
+
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor(zstd_dict=[zd, 1])
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor(zstd_dict=(zd, 1.0))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd,))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, 3))
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor((zd, 2**1000))
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor((zd, -2**1000))
+
+ def test_train_dict(self):
+ TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
+ ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+
+ self.assertNotEqual(TRAINED_DICT.dict_id, 0)
+ self.assertGreater(len(TRAINED_DICT.dict_content), 0)
+ self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1)
+ self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT)))
+
+ # compress/decompress
+ c = ZstdCompressor(zstd_dict=TRAINED_DICT)
+ for sample in SAMPLES:
+ dat1 = compress(sample, zstd_dict=TRAINED_DICT)
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ def test_finalize_dict(self):
+ DICT_SIZE2 = 200*_1K
+ C_LEVEL = 6
+
+ try:
+ dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL)
+ except NotImplementedError:
+ # < v1.4.5 at compile-time, >= v.1.4.5 at run-time
+ return
+
+ self.assertNotEqual(dic2.dict_id, 0)
+ self.assertGreater(len(dic2.dict_content), 0)
+ self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2)
+
+ # compress/decompress
+ c = ZstdCompressor(C_LEVEL, zstd_dict=dic2)
+ for sample in SAMPLES:
+ dat1 = compress(sample, C_LEVEL, zstd_dict=dic2)
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ # dict mismatch
+ self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id)
+
+ dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT)
+ with self.assertRaises(ZstdError):
+ decompress(dat1, dic2)
+
+ def test_train_dict_arguments(self):
+ with self.assertRaises(ValueError):
+ train_dict([], 100*_1K)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, -100)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, 0)
+
+ def test_finalize_dict_arguments(self):
+ with self.assertRaises(TypeError):
+ finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, [], 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, -100, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, 0, 2)
+
+ def test_train_dict_c(self):
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd.train_dict({}, (), 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(bytearray(), (), 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', 99, 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', [], 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', (), 100.1)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', (99.1,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'abc', (4, -1), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'abc', (2,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (99,), 100)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (2**1000,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (-2**1000,), 100)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (), 0)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (), -1)
+
+ with self.assertRaises(ZstdError):
+ _zstd.train_dict(b'', (), 1)
+
+ def test_finalize_dict_c(self):
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(1, 2, 3, 4, 5)
+
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict({}, b'', (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
+
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
+
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
+
+ with self.assertRaises(ZstdError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
+
+ def test_train_buffer_protocol_samples(self):
+ def _nbytes(dat):
+ if isinstance(dat, (bytes, bytearray)):
+ return len(dat)
+ return memoryview(dat).nbytes
+
+ # prepare samples
+ chunk_lst = []
+ wrong_size_lst = []
+ correct_size_lst = []
+ for _ in range(300):
+ arr = array.array('Q', [random.randint(0, 20) for i in range(20)])
+ chunk_lst.append(arr)
+ correct_size_lst.append(_nbytes(arr))
+ wrong_size_lst.append(len(arr))
+ concatenation = b''.join(chunk_lst)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd.train_dict(concatenation, tuple(wrong_size_lst), 100*_1K)
+
+ # correct size list
+ _zstd.train_dict(concatenation, tuple(correct_size_lst), 3*_1K)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(wrong_size_lst), 300*_1K, 5)
+
+ # correct size list
+ _zstd.finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(correct_size_lst), 300*_1K, 5)
+
+ def test_as_prefix(self):
+ # V1
+ V1 = THIS_FILE_BYTES
+ zd = ZstdDict(V1, is_raw=True)
+
+ # V2
+ mid = len(V1) // 2
+ V2 = V1[:mid] + \
+ (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \
+ V1[mid+1:]
+
+ # compress
+ dat = compress(V2, zstd_dict=zd.as_prefix)
+ self.assertEqual(get_frame_info(dat).dictionary_id, 0)
+
+ # decompress
+ self.assertEqual(decompress(dat, zd.as_prefix), V2)
+
+ # use wrong prefix
+ zd2 = ZstdDict(SAMPLES[0], is_raw=True)
+ try:
+ decompressed = decompress(dat, zd2.as_prefix)
+ except ZstdError: # expected
+ pass
+ else:
+ self.assertNotEqual(decompressed, V2)
+
+ # read only attribute
+ with self.assertRaises(AttributeError):
+ zd.as_prefix = b'1234'
+
+ def test_as_digested_dict(self):
+ zd = TRAINED_DICT
+
+ # test .as_digested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict)
+ self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_digested_dict = b'1234'
+
+ # test .as_undigested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict)
+ self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_undigested_dict = b'1234'
+
+ def test_advanced_compression_parameters(self):
+ options = {CompressionParameter.compression_level: 6,
+ CompressionParameter.window_log: 20,
+ CompressionParameter.enable_long_distance_matching: 1}
+
+ # automatically select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ # explicitly select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ def test_len(self):
+ self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content))
+ self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT))
+
+class FileTestCase(unittest.TestCase):
+ def setUp(self):
+ self.DECOMPRESSED_42 = b'a'*42
+ self.FRAME_42 = compress(self.DECOMPRESSED_42)
+
+ def test_init(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "x") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "a") as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "w", level=12) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ def test_init_with_PathLike_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2)
+
+ os.remove(filename)
+
+ def test_init_with_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename) as f:
+ pass
+ with ZstdFile(filename, "w") as f:
+ pass
+ with ZstdFile(filename, "a") as f:
+ pass
+
+ os.remove(filename)
+
+ def test_init_mode(self):
+ bi = io.BytesIO()
+
+ with ZstdFile(bi, "r"):
+ pass
+ with ZstdFile(bi, "rb"):
+ pass
+ with ZstdFile(bi, "w"):
+ pass
+ with ZstdFile(bi, "wb"):
+ pass
+ with ZstdFile(bi, "a"):
+ pass
+ with ZstdFile(bi, "ab"):
+ pass
+
+ def test_init_with_x_mode(self):
+ with tempfile.NamedTemporaryFile() as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb"):
+ with ZstdFile(filename, mode):
+ pass
+ with self.assertRaises(FileExistsError):
+ with ZstdFile(filename, mode):
+ pass
+ os.remove(filename)
+
+ def test_init_bad_mode(self):
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x"))
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")
+
+ with self.assertRaisesRegex(TypeError,
+ r"not be a CompressionParameter"):
+ ZstdFile(io.BytesIO(), 'rb',
+ options={CompressionParameter.compression_level:5})
+ with self.assertRaisesRegex(TypeError,
+ r"not be a DecompressionParameter"):
+ ZstdFile(io.BytesIO(), 'wb',
+ options={DecompressionParameter.window_log_max:21})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12)
+
+ def test_init_bad_check(self):
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(), "w", level='asd')
+ # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(), "w", options={999:9999})
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)
+
+ with self.assertRaises(OverflowError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:2**31})
+
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={444:333})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456')
+
+ def test_init_close_fp(self):
+ # get a temp file name
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ tmp_f.write(DAT_130K_C)
+ filename = tmp_f.name
+
+ with self.assertRaises(TypeError):
+ ZstdFile(filename, options={'a':'b'})
+
+ # for PyPy
+ gc.collect()
+
+ os.remove(filename)
+
+ def test_close(self):
+ with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src:
+ f = ZstdFile(src)
+ f.close()
+ # ZstdFile.close() should not close the underlying file object.
+ self.assertFalse(src.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+ self.assertFalse(src.closed)
+
+ # Test with a real file on disk, opened directly by ZstdFile.
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ fp = f._fp
+ f.close()
+ # Here, ZstdFile.close() *should* close the underlying file object.
+ self.assertTrue(fp.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+
+ os.remove(filename)
+
+ def test_closed(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.closed)
+ f.read()
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ def test_fileno(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertRaises(io.UnsupportedOperation, f.fileno)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.fileno(), f._fp.fileno())
+ self.assertIsInstance(f.fileno(), int)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ os.remove(filename)
+
+ # 3, no .fileno() method
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'fileno'):
+ f.fileno()
+
+ def test_name(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ with self.assertRaises(AttributeError):
+ f.name
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.name, f._fp.name)
+ self.assertIsInstance(f.name, str)
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ os.remove(filename)
+
+ # 3, no .filename property
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'name'):
+ f.name
+
+ def test_seekable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.seekable())
+ f.read()
+ self.assertTrue(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ src = io.BytesIO(COMPRESSED_100_PLUS_32KB)
+ src.seekable = lambda: False
+ f = ZstdFile(src)
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ def test_readable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.readable())
+ f.read()
+ self.assertTrue(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ def test_writable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.writable())
+ f.read()
+ self.assertFalse(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertTrue(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ def test_read_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertEqual(f.read(0), b"")
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:20}) as f:
+ self.assertEqual(f.read(0), b"")
+
+ # empty file
+ with ZstdFile(io.BytesIO(b'')) as f:
+ self.assertEqual(f.read(0), b"")
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ with ZstdFile(io.BytesIO(b'')) as f:
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ def test_read_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ chunks = []
+ while True:
+ result = f.read(10)
+ if not result:
+ break
+ self.assertLessEqual(len(result), 10)
+ chunks.append(result)
+ self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_read_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT)
+
+ def test_read_incomplete(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f:
+ self.assertRaises(EOFError, f.read)
+
+ # Trailing data isn't a valid compressed stream
+ with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ def test_read_truncated(self):
+ # Drop stream epilogue: 4 bytes checksum
+ truncated = DAT_130K_C[:-4]
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ # this is an important test, make sure it doesn't raise EOFError.
+ self.assertEqual(f.read(130*_1K), DAT_130K_D)
+ with self.assertRaises(EOFError):
+ f.read(1)
+
+ # Incomplete header
+ for i in range(1, 20):
+ with ZstdFile(io.BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
+ def test_read_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_DAT))
+ f.close()
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertRaises(TypeError, f.read, float())
+
+ def test_read_bad_data(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ def test_read_exception(self):
+ class C:
+ def read(self, size=-1):
+ raise OSError
+ with ZstdFile(C()) as f:
+ with self.assertRaises(OSError):
+ f.read(10)
+
+ def test_read1(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DAT_130K_D)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read1(0), b"")
+
+ def test_read1_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ blocks = []
+ while True:
+ result = f.read1(10)
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(TypeError, f.read1, None)
+
+ def test_readinto(self):
+ arr = array.array("I", range(100))
+ self.assertEqual(len(arr), 100)
+ self.assertEqual(len(arr) * arr.itemsize, 400)
+ ba = bytearray(300)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ # 0 length output buffer
+ self.assertEqual(f.readinto(ba[0:0]), 0)
+
+ # use correct length for buffer protocol object
+ self.assertEqual(f.readinto(arr), 400)
+ self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400])
+
+ # normal readinto
+ self.assertEqual(f.readinto(ba), 300)
+ self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700])
+
+ def test_peek(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek()
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek(10)
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+
+ def test_peek_bad_args(self):
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.peek)
+
+ def test_iterator(self):
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ compressed = compress(THIS_FILE_BYTES)
+
+ # iter
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(list(iter(f)), lines)
+
+ # readline
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ for line in lines:
+ self.assertEqual(f.readline(), line)
+ self.assertEqual(f.readline(), b'')
+ self.assertEqual(f.readline(), b'')
+
+ # readlines
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(f.readlines(), lines)
+
+ def test_decompress_limited(self):
+ _ZSTD_DStreamInSize = 128*_1K + 3
+
+ bomb = compress(b'\0' * int(2e6), level=10)
+ self.assertLess(len(bomb), _ZSTD_DStreamInSize)
+
+ decomp = ZstdFile(io.BytesIO(bomb))
+ self.assertEqual(decomp.read(1), b'\0')
+
+ # BufferedReader uses 128 KiB buffer in __init__.py
+ max_decomp = 128*_1K
+ self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp,
+ "Excessive amount of data was decompressed")
+
+ def test_write(self):
+ raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor()
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", level=12) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(12)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1})
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ options = {CompressionParameter.compression_level:-5,
+ CompressionParameter.checksum_flag:1}
+ with ZstdFile(dst, "w",
+ options=options) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options=options)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_empty_frame(self):
+ # .FLUSH_FRAME generates an empty content frame
+ c = ZstdCompressor()
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+
+ # don't generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ pass
+ self.assertEqual(bo.getvalue(), b'')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_FRAME)
+ self.assertEqual(bo.getvalue(), b'')
+
+ # if .write(b''), generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'')
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ # has an empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ def test_write_empty_block(self):
+ # If no internal data, .FLUSH_BLOCK return b''.
+ c = ZstdCompressor()
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK),
+ b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+
+ # mode = .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ f.flush(f.FLUSH_BLOCK)
+ fp_pos = f._fp.tell()
+ self.assertNotEqual(fp_pos, 0)
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), fp_pos)
+
+ # mode != .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+ f.write(b'')
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+
+ def test_write_101(self):
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ for start in range(0, len(THIS_FILE_BYTES), 101):
+ f.write(THIS_FILE_BYTES[start:start+101])
+
+ comp = ZstdCompressor()
+ expected = comp.compress(THIS_FILE_BYTES) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_append(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ part1 = THIS_FILE_BYTES[:_1K]
+ part2 = THIS_FILE_BYTES[_1K:1536]
+ part3 = THIS_FILE_BYTES[1536:]
+ expected = b"".join(comp(x) for x in (part1, part2, part3))
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(part1)
+ with ZstdFile(dst, "a") as f:
+ f.write(part2)
+ with ZstdFile(dst, "a") as f:
+ f.write(part3)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_bad_args(self):
+ f = ZstdFile(io.BytesIO(), "w")
+ f.close()
+ self.assertRaises(ValueError, f.write, b"foo")
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f:
+ self.assertRaises(ValueError, f.write, b"bar")
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(TypeError, f.write, None)
+ self.assertRaises(TypeError, f.write, "text")
+ self.assertRaises(TypeError, f.write, 789)
+
+ def test_writelines(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.writelines(lines)
+ expected = comp(THIS_FILE_BYTES)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_seek_forward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(555)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:])
+
+ def test_seek_forward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:])
+
+ def test_seek_forward_relative_to_current(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(100)
+ f.seek(1236, 1)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:])
+
+ def test_seek_forward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-555, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:])
+
+ def test_seek_backward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(1001)
+ f.seek(211)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:])
+
+ def test_seek_backward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333)
+ f.seek(737)
+ self.assertEqual(f.read(),
+ DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_backward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-150, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:])
+
+ def test_seek_past_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001)
+ self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB))
+ self.assertEqual(f.read(), b"")
+
+ def test_seek_past_start(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-88)
+ self.assertEqual(f.tell(), 0)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(ValueError, f.seek, 0, 3)
+ # io.BufferedReader raises TypeError instead of ValueError
+ self.assertRaises((TypeError, ValueError), f.seek, 9, ())
+ self.assertRaises(TypeError, f.seek, None)
+ self.assertRaises(TypeError, f.seek, b"derp")
+
+ def test_seek_not_seekable(self):
+ class C(io.BytesIO):
+ def seekable(self):
+ return False
+ obj = C(COMPRESSED_100_PLUS_32KB)
+ with ZstdFile(obj, 'r') as f:
+ d = f.read(1)
+ self.assertFalse(f.seekable())
+ with self.assertRaisesRegex(io.UnsupportedOperation,
+ 'File or stream is not seekable'):
+ f.seek(0)
+ d += f.read()
+ self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_tell(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ pos = 0
+ while True:
+ self.assertEqual(f.tell(), pos)
+ result = f.read(random.randint(171, 189))
+ if not result:
+ break
+ pos += len(result)
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+ with ZstdFile(io.BytesIO(), "w") as f:
+ for pos in range(0, len(DAT_130K_D), 143):
+ self.assertEqual(f.tell(), pos)
+ f.write(DAT_130K_D[pos:pos+143])
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+
+ def test_tell_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.tell)
+
+ def test_file_dict(self):
+ # default
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_file_prefix(self):
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_UnsupportedOperation(self):
+ # 1
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.write(b'1234')
+
+ # 2
+ class T:
+ def read(self, size):
+ return b'a' * size
+
+ with self.assertRaises(TypeError): # on creation
+ with ZstdFile(T(), 'w') as f:
+ pass
+
+ # 3
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.read(100)
+ with self.assertRaises(io.UnsupportedOperation):
+ f.seek(100)
+ self.assertEqual(f.closed, True)
+ with self.assertRaises(ValueError):
+ f.readable()
+ with self.assertRaises(ValueError):
+ f.tell()
+ with self.assertRaises(ValueError):
+ f.read(100)
+
+ def test_read_readinto_readinto1(self):
+ lst = []
+ with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f:
+ while True:
+ method = random.randint(0, 2)
+ size = random.randint(0, 300)
+
+ if method == 0:
+ dat = f.read(size)
+ if not dat and size:
+ break
+ lst.append(dat)
+ elif method == 1:
+ ba = bytearray(size)
+ read_size = f.readinto(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ elif method == 2:
+ ba = bytearray(size)
+ read_size = f.readinto1(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5)
+
+ def test_zstdfile_flush(self):
+ # closed
+ f = ZstdFile(io.BytesIO(), 'w')
+ f.close()
+ with self.assertRaises(ValueError):
+ f.flush()
+
+ # read
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ # does nothing for read-only stream
+ f.flush()
+
+ # write
+ DAT = b'abcd'
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertEqual(bi.tell(), 0) # not enough for a block
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertGreater(bi.tell(), 0) # flushed
+
+ # write, no .flush() method
+ class C:
+ def write(self, b):
+ return len(b)
+ with ZstdFile(C(), 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+
+ def test_zstdfile_flush_mode(self):
+ self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK)
+ self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME)
+ with self.assertRaises(AttributeError):
+ ZstdFile.CONTINUE
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ # flush block
+ self.assertEqual(f.write(b'123'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_BLOCK))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush())
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ # flush frame
+ self.assertEqual(f.write(b'456'), 3)
+ self.assertIsNone(f.flush(mode=f.FLUSH_FRAME))
+ # flush frame
+ self.assertEqual(f.write(b'789'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ self.assertEqual(decompress(bo.getvalue()), b'123456789')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'):
+ f.flush(ZstdCompressor.CONTINUE)
+ with self.assertRaises(ValueError):
+ f.flush(-1)
+ with self.assertRaises(ValueError):
+ f.flush(123456)
+ with self.assertRaises(TypeError):
+ f.flush(node=ZstdCompressor.CONTINUE)
+ with self.assertRaises((TypeError, ValueError)):
+ f.flush('FLUSH_FRAME')
+ with self.assertRaises(TypeError):
+ f.flush(b'456', f.FLUSH_BLOCK)
+
+ def test_zstdfile_truncate(self):
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.truncate(200)
+
+ def test_zstdfile_iter_issue45475(self):
+ lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))]
+ self.assertGreater(len(lines), 0)
+
+ def test_append_new_file(self):
+ with tempfile.NamedTemporaryFile(delete=True) as tmp_f:
+ filename = tmp_f.name
+
+ with ZstdFile(filename, 'a') as f:
+ pass
+ self.assertTrue(os.path.isfile(filename))
+
+ os.remove(filename)
+
+class OpenTestCase(unittest.TestCase):
+
+ def test_binary_modes(self):
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with io.BytesIO() as bio:
+ with open(bio, "wb") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+ with open(bio, "ab") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2)
+
+ def test_text_modes(self):
+ # empty input
+ with self.assertRaises(EOFError):
+ with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader:
+ for _ in reader:
+ pass
+
+ # read
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+ with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f:
+ self.assertEqual(f.read(), uncompressed)
+
+ with io.BytesIO() as bio:
+ # write
+ with open(bio, "wt", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ # append
+ with open(bio, "at", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2)
+
+ def test_bad_params(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ with self.assertRaises(ValueError):
+ open(TESTFN, "")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rbt")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", newline="\n")
+
+ os.remove(TESTFN)
+
+ def test_option(self):
+ options = {DecompressionParameter.window_log_max:25}
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ options = {CompressionParameter.compression_level:12}
+ with io.BytesIO() as bio:
+ with open(bio, "wb", options=options) as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_encoding(self):
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-16-le") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-16-le")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-16-le") as f:
+ self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed)
+
+ def test_encoding_error_handler(self):
+ with io.BytesIO(compress(b"foo\xffbar")) as bio:
+ with open(bio, "rt", encoding="ascii", errors="ignore") as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ text = THIS_FILE_STR.replace(os.linesep, "\n")
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-8", newline="\n") as f:
+ f.write(text)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-8", newline="\r") as f:
+ self.assertEqual(f.readlines(), [text])
+
+ def test_x_mode(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb", "xt"):
+ os.remove(TESTFN)
+
+ if mode == "xt":
+ encoding = "utf-8"
+ else:
+ encoding = None
+ with open(TESTFN, mode, encoding=encoding):
+ pass
+ with self.assertRaises(FileExistsError):
+ with open(TESTFN, mode):
+ pass
+
+ os.remove(TESTFN)
+
+ def test_open_dict(self):
+ # default
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # invalid dictionary
+ bi = io.BytesIO()
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict={1:2, 2:3})
+
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict=b'1234567890')
+
+ def test_open_prefix(self):
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_buffer_protocol(self):
+ # don't use len() for buffer protocol objects
+ arr = array.array("i", range(1000))
+ LENGTH = len(arr) * arr.itemsize
+
+ with open(io.BytesIO(), "wb") as f:
+ self.assertEqual(f.write(arr), LENGTH)
+ self.assertEqual(f.tell(), LENGTH)
+
+class FreeThreadingMethodTests(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_locking(self):
+ input = b'a'* (16*_1K)
+ num_threads = 8
+
+ comp = ZstdCompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ parts.append(res)
+ rest1 = comp.flush()
+ expected = b''.join(parts) + rest1
+
+ comp = ZstdCompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ rest2 = comp.flush()
+ self.assertEqual(rest1, rest2)
+ actual = b''.join(output) + rest2
+ self.assertEqual(expected, actual)
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_locking(self):
+ input = compress(b'a'* (16*_1K))
+ num_threads = 8
+ # to ensure we decompress over multiple calls, set maxsize
+ window_size = _1K * 16//num_threads
+
+ decomp = ZstdDecompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = decomp.decompress(input, window_size)
+ if res:
+ parts.append(res)
+ expected = b''.join(parts)
+
+ comp = ZstdDecompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, window_size)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ actual = b''.join(output)
+ self.assertEqual(expected, actual)
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ level = threading.get_ident() % 4
+ # sync threads to increase chance of contention on
+ # capsule storing dictionary levels
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ # sync threads to increase chance of contention on
+ # decompression dictionary
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()