aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_ast_unparse.py12
-rw-r--r--Lib/_colorize.py255
-rw-r--r--Lib/_compat_pickle.py3
-rw-r--r--Lib/_opcode_metadata.py46
-rw-r--r--Lib/_py_warnings.py2
-rw-r--r--Lib/_pydatetime.py53
-rw-r--r--Lib/_pydecimal.py8
-rw-r--r--Lib/_pyio.py23
-rw-r--r--Lib/_pyrepl/_module_completer.py24
-rw-r--r--Lib/_pyrepl/base_eventqueue.py18
-rw-r--r--Lib/_pyrepl/commands.py10
-rw-r--r--Lib/_pyrepl/main.py11
-rw-r--r--Lib/_pyrepl/reader.py9
-rw-r--r--Lib/_pyrepl/readline.py6
-rw-r--r--Lib/_pyrepl/simple_interact.py16
-rw-r--r--Lib/_pyrepl/unix_console.py21
-rw-r--r--Lib/_pyrepl/utils.py49
-rw-r--r--Lib/_pyrepl/windows_console.py67
-rw-r--r--Lib/_strptime.py247
-rw-r--r--Lib/_threading_local.py122
-rw-r--r--Lib/annotationlib.py522
-rw-r--r--Lib/argparse.py128
-rw-r--r--Lib/ast.py61
-rw-r--r--Lib/asyncio/__main__.py40
-rw-r--r--Lib/asyncio/base_events.py2
-rw-r--r--Lib/asyncio/base_subprocess.py7
-rw-r--r--Lib/asyncio/futures.py17
-rw-r--r--Lib/asyncio/graph.py6
-rw-r--r--Lib/asyncio/selector_events.py2
-rw-r--r--Lib/asyncio/taskgroups.py7
-rw-r--r--Lib/asyncio/tasks.py50
-rw-r--r--Lib/asyncio/tools.py260
-rw-r--r--Lib/calendar.py8
-rw-r--r--Lib/cmd.py2
-rw-r--r--Lib/code.py4
-rw-r--r--Lib/compileall.py4
-rw-r--r--Lib/compression/bz2.py (renamed from Lib/compression/bz2/__init__.py)0
-rw-r--r--Lib/compression/gzip.py (renamed from Lib/compression/gzip/__init__.py)0
-rw-r--r--Lib/compression/lzma.py (renamed from Lib/compression/lzma/__init__.py)0
-rw-r--r--Lib/compression/zlib.py (renamed from Lib/compression/zlib/__init__.py)0
-rw-r--r--Lib/compression/zstd/__init__.py242
-rw-r--r--Lib/compression/zstd/_zstdfile.py345
-rw-r--r--Lib/concurrent/futures/_base.py27
-rw-r--r--Lib/concurrent/futures/interpreter.py212
-rw-r--r--Lib/concurrent/futures/process.py5
-rw-r--r--Lib/concurrent/interpreters/__init__.py (renamed from Lib/test/support/interpreters/__init__.py)50
-rw-r--r--Lib/concurrent/interpreters/_crossinterp.py (renamed from Lib/test/support/interpreters/_crossinterp.py)2
-rw-r--r--Lib/concurrent/interpreters/_queues.py (renamed from Lib/test/support/interpreters/queues.py)145
-rw-r--r--Lib/configparser.py11
-rw-r--r--Lib/ctypes/__init__.py6
-rw-r--r--Lib/ctypes/_layout.py21
-rw-r--r--Lib/curses/__init__.py17
-rw-r--r--Lib/dataclasses.py53
-rw-r--r--Lib/dbm/dumb.py32
-rw-r--r--Lib/dbm/sqlite3.py4
-rw-r--r--Lib/difflib.py54
-rw-r--r--Lib/dis.py2
-rw-r--r--Lib/doctest.py115
-rw-r--r--Lib/email/_header_value_parser.py6
-rw-r--r--Lib/email/feedparser.py2
-rw-r--r--Lib/email/header.py17
-rw-r--r--Lib/email/message.py2
-rw-r--r--Lib/email/utils.py10
-rw-r--r--Lib/encodings/aliases.py2
-rw-r--r--Lib/encodings/palmos.py2
-rw-r--r--Lib/ensurepip/__init__.py2
-rw-r--r--Lib/fnmatch.py4
-rw-r--r--Lib/fractions.py16
-rw-r--r--Lib/functools.py3
-rw-r--r--Lib/genericpath.py11
-rw-r--r--Lib/getpass.py69
-rw-r--r--Lib/glob.py32
-rw-r--r--Lib/gzip.py4
-rw-r--r--Lib/hashlib.py12
-rw-r--r--Lib/heapq.py51
-rw-r--r--Lib/html/parser.py67
-rw-r--r--Lib/http/client.py9
-rw-r--r--Lib/http/server.py394
-rw-r--r--Lib/idlelib/NEWS2x.txt2
-rw-r--r--Lib/idlelib/News3.txt7
-rw-r--r--Lib/idlelib/configdialog.py2
-rw-r--r--Lib/idlelib/debugger.py2
-rw-r--r--Lib/idlelib/editor.py2
-rw-r--r--Lib/idlelib/idle_test/htest.py2
-rwxr-xr-xLib/idlelib/pyshell.py2
-rw-r--r--Lib/inspect.py6
-rw-r--r--Lib/ipaddress.py17
-rw-r--r--Lib/json/encoder.py5
-rw-r--r--Lib/json/tool.py43
-rw-r--r--Lib/linecache.py67
-rw-r--r--Lib/locale.py5
-rw-r--r--Lib/logging/__init__.py33
-rw-r--r--Lib/logging/config.py2
-rw-r--r--Lib/mimetypes.py4
-rw-r--r--Lib/multiprocessing/connection.py2
-rw-r--r--Lib/multiprocessing/context.py2
-rw-r--r--Lib/multiprocessing/forkserver.py4
-rw-r--r--Lib/multiprocessing/sharedctypes.py7
-rw-r--r--Lib/multiprocessing/util.py79
-rw-r--r--Lib/netrc.py33
-rw-r--r--Lib/ntpath.py41
-rw-r--r--Lib/os.py3
-rw-r--r--Lib/pathlib/__init__.py7
-rw-r--r--Lib/pathlib/_os.py28
-rw-r--r--Lib/pathlib/types.py51
-rw-r--r--Lib/pdb.py394
-rw-r--r--Lib/pickle.py4
-rw-r--r--Lib/pickletools.py4
-rw-r--r--Lib/platform.py133
-rw-r--r--Lib/posixpath.py57
-rw-r--r--Lib/pprint.py79
-rw-r--r--Lib/py_compile.py2
-rw-r--r--Lib/pydoc.py6
-rw-r--r--Lib/pydoc_data/topics.py109
-rw-r--r--Lib/random.py2
-rw-r--r--Lib/re/__init__.py2
-rw-r--r--Lib/re/_parser.py3
-rw-r--r--Lib/reprlib.py19
-rw-r--r--Lib/shelve.py5
-rw-r--r--Lib/shutil.py19
-rw-r--r--Lib/site.py8
-rw-r--r--Lib/socketserver.py2
-rw-r--r--Lib/sqlite3/__main__.py73
-rw-r--r--Lib/sqlite3/_completer.py42
-rw-r--r--Lib/ssl.py2
-rw-r--r--Lib/string/__init__.py47
-rw-r--r--Lib/subprocess.py14
-rw-r--r--Lib/sysconfig/__init__.py15
-rw-r--r--Lib/tarfile.py230
-rw-r--r--Lib/tempfile.py5
-rw-r--r--Lib/test/.ruff.toml10
-rw-r--r--Lib/test/_code_definitions.py160
-rw-r--r--Lib/test/_test_embed_structseq.py2
-rw-r--r--Lib/test/_test_gc_fast_cycles.py48
-rw-r--r--Lib/test/_test_multiprocessing.py84
-rw-r--r--Lib/test/audit-tests.py28
-rw-r--r--Lib/test/datetimetester.py35
-rw-r--r--Lib/test/libregrtest/main.py23
-rw-r--r--Lib/test/libregrtest/setup.py11
-rw-r--r--Lib/test/libregrtest/single.py2
-rw-r--r--Lib/test/libregrtest/tsan.py2
-rw-r--r--Lib/test/libregrtest/utils.py42
-rw-r--r--Lib/test/lock_tests.py43
-rw-r--r--Lib/test/mapping_tests.py4
-rw-r--r--Lib/test/mp_preload_flush.py15
-rw-r--r--Lib/test/pickletester.py38
-rw-r--r--Lib/test/pythoninfo.py24
-rwxr-xr-xLib/test/re_tests.py2
-rw-r--r--Lib/test/subprocessdata/fd_status.py4
-rw-r--r--Lib/test/support/__init__.py121
-rw-r--r--Lib/test/support/channels.py (renamed from Lib/test/support/interpreters/channels.py)93
-rw-r--r--Lib/test/support/hashlib_helper.py221
-rw-r--r--Lib/test/support/import_helper.py2
-rw-r--r--Lib/test/support/strace_helper.py7
-rw-r--r--Lib/test/support/warnings_helper.py3
-rw-r--r--Lib/test/test__interpchannels.py8
-rw-r--r--Lib/test/test__interpreters.py37
-rw-r--r--Lib/test/test__osx_support.py4
-rw-r--r--Lib/test/test_abstract_numbers.py30
-rw-r--r--Lib/test/test_annotationlib.py381
-rw-r--r--Lib/test/test_argparse.py167
-rw-r--r--Lib/test/test_asdl_parser.py8
-rw-r--r--Lib/test/test_ast/test_ast.py450
-rw-r--r--Lib/test/test_asyncgen.py9
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py37
-rw-r--r--Lib/test/test_asyncio/test_futures.py58
-rw-r--r--Lib/test/test_asyncio/test_locks.py2
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py16
-rw-r--r--Lib/test/test_asyncio/test_ssl.py10
-rw-r--r--Lib/test/test_asyncio/test_tasks.py73
-rw-r--r--Lib/test/test_asyncio/test_tools.py1706
-rw-r--r--Lib/test/test_audit.py10
-rw-r--r--Lib/test/test_base64.py10
-rw-r--r--Lib/test/test_baseexception.py8
-rw-r--r--Lib/test/test_binascii.py6
-rw-r--r--Lib/test/test_binop.py2
-rw-r--r--Lib/test/test_buffer.py4
-rw-r--r--Lib/test/test_bufio.py2
-rw-r--r--Lib/test/test_build_details.py16
-rw-r--r--Lib/test/test_builtin.py8
-rw-r--r--Lib/test/test_bytes.py10
-rw-r--r--Lib/test/test_bz2.py4
-rw-r--r--Lib/test/test_calendar.py5
-rw-r--r--Lib/test/test_call.py4
-rw-r--r--Lib/test/test_capi/test_bytearray.py6
-rw-r--r--Lib/test/test_capi/test_bytes.py8
-rw-r--r--Lib/test/test_capi/test_config.py3
-rw-r--r--Lib/test/test_capi/test_float.py6
-rw-r--r--Lib/test/test_capi/test_import.py2
-rw-r--r--Lib/test/test_capi/test_misc.py6
-rw-r--r--Lib/test/test_capi/test_object.py11
-rw-r--r--Lib/test/test_capi/test_opt.py547
-rw-r--r--Lib/test/test_capi/test_sys.py64
-rw-r--r--Lib/test/test_capi/test_type.py10
-rw-r--r--Lib/test/test_capi/test_unicode.py21
-rw-r--r--Lib/test/test_class.py1
-rw-r--r--Lib/test/test_clinic.py51
-rw-r--r--Lib/test/test_cmd.py32
-rw-r--r--Lib/test/test_cmd_line.py52
-rw-r--r--Lib/test/test_cmd_line_script.py11
-rw-r--r--Lib/test/test_code.py389
-rw-r--r--Lib/test/test_code_module.py2
-rw-r--r--Lib/test/test_codeccallbacks.py38
-rw-r--r--Lib/test/test_codecs.py63
-rw-r--r--Lib/test/test_codeop.py2
-rw-r--r--Lib/test/test_collections.py2
-rw-r--r--Lib/test/test_compileall.py2
-rw-r--r--Lib/test/test_compiler_assemble.py2
-rw-r--r--Lib/test/test_concurrent_futures/test_future.py57
-rw-r--r--Lib/test/test_concurrent_futures/test_init.py4
-rw-r--r--Lib/test/test_concurrent_futures/test_interpreter_pool.py265
-rw-r--r--Lib/test/test_concurrent_futures/test_shutdown.py58
-rw-r--r--Lib/test/test_configparser.py12
-rw-r--r--Lib/test/test_contextlib.py8
-rw-r--r--Lib/test/test_contextlib_async.py8
-rw-r--r--Lib/test/test_copy.py5
-rw-r--r--Lib/test/test_coroutines.py2
-rw-r--r--Lib/test/test_cprofile.py19
-rw-r--r--Lib/test/test_crossinterp.py711
-rw-r--r--Lib/test/test_csv.py55
-rw-r--r--Lib/test/test_ctypes/_support.py1
-rw-r--r--Lib/test/test_ctypes/test_aligned_structures.py1
-rw-r--r--Lib/test/test_ctypes/test_bitfields.py5
-rw-r--r--Lib/test/test_ctypes/test_byteswap.py3
-rw-r--r--Lib/test/test_ctypes/test_generated_structs.py13
-rw-r--r--Lib/test/test_ctypes/test_incomplete.py10
-rw-r--r--Lib/test/test_ctypes/test_parameters.py4
-rw-r--r--Lib/test/test_ctypes/test_pep3118.py3
-rw-r--r--Lib/test/test_ctypes/test_structunion.py18
-rw-r--r--Lib/test/test_ctypes/test_structures.py31
-rw-r--r--Lib/test/test_ctypes/test_unaligned_structures.py2
-rw-r--r--Lib/test/test_curses.py42
-rw-r--r--Lib/test/test_dataclasses/__init__.py78
-rw-r--r--Lib/test/test_dbm.py63
-rw-r--r--Lib/test/test_dbm_gnu.py27
-rw-r--r--Lib/test/test_dbm_sqlite3.py4
-rw-r--r--Lib/test/test_decimal.py3
-rw-r--r--Lib/test/test_deque.py2
-rw-r--r--Lib/test/test_descr.py49
-rw-r--r--Lib/test/test_dict.py61
-rw-r--r--Lib/test/test_difflib.py6
-rw-r--r--Lib/test/test_difflib_expect.html48
-rw-r--r--Lib/test/test_dis.py217
-rw-r--r--Lib/test/test_doctest/sample_doctest_errors.py46
-rw-r--r--Lib/test/test_doctest/test_doctest.py447
-rw-r--r--Lib/test/test_doctest/test_doctest_errors.txt14
-rw-r--r--Lib/test/test_doctest/test_doctest_skip.txt2
-rw-r--r--Lib/test/test_doctest/test_doctest_skip2.txt6
-rw-r--r--Lib/test/test_dynamicclassattribute.py4
-rw-r--r--Lib/test/test_email/test__header_value_parser.py78
-rw-r--r--Lib/test/test_email/test_email.py30
-rw-r--r--Lib/test/test_email/test_utils.py10
-rw-r--r--Lib/test/test_embed.py15
-rw-r--r--Lib/test/test_enum.py27
-rw-r--r--Lib/test/test_errno.py6
-rw-r--r--Lib/test/test_exception_group.py14
-rw-r--r--Lib/test/test_exceptions.py23
-rw-r--r--Lib/test/test_external_inspection.py664
-rw-r--r--Lib/test/test_faulthandler.py23
-rw-r--r--Lib/test/test_fcntl.py59
-rw-r--r--Lib/test/test_fileinput.py2
-rw-r--r--Lib/test/test_fileio.py12
-rw-r--r--Lib/test/test_float.py2
-rw-r--r--Lib/test/test_fnmatch.py98
-rw-r--r--Lib/test/test_format.py10
-rw-r--r--Lib/test/test_fractions.py216
-rw-r--r--Lib/test/test_free_threading/test_dict.py16
-rw-r--r--Lib/test/test_free_threading/test_functools.py75
-rw-r--r--Lib/test/test_free_threading/test_generators.py51
-rw-r--r--Lib/test/test_free_threading/test_heapq.py267
-rw-r--r--Lib/test/test_free_threading/test_io.py109
-rw-r--r--Lib/test/test_free_threading/test_itertools.py95
-rw-r--r--Lib/test/test_free_threading/test_itertools_batched.py38
-rw-r--r--Lib/test/test_free_threading/test_itertools_combinatoric.py51
-rw-r--r--Lib/test/test_fstring.py10
-rw-r--r--Lib/test/test_functools.py28
-rw-r--r--Lib/test/test_future_stmt/test_future.py5
-rw-r--r--Lib/test/test_gc.py66
-rw-r--r--Lib/test/test_generated_cases.py433
-rw-r--r--Lib/test/test_genericalias.py18
-rw-r--r--Lib/test/test_genericpath.py6
-rw-r--r--Lib/test/test_getpass.py39
-rw-r--r--Lib/test/test_gettext.py7
-rw-r--r--Lib/test/test_glob.py80
-rw-r--r--Lib/test/test_grammar.py14
-rw-r--r--Lib/test/test_gzip.py9
-rw-r--r--Lib/test/test_hashlib.py253
-rw-r--r--Lib/test/test_heapq.py197
-rw-r--r--Lib/test/test_hmac.py153
-rw-r--r--Lib/test/test_htmlparser.py213
-rw-r--r--Lib/test/test_http_cookiejar.py187
-rw-r--r--Lib/test/test_httpservers.py738
-rw-r--r--Lib/test/test_idle.py2
-rw-r--r--Lib/test/test_import/__init__.py6
-rw-r--r--Lib/test/test_importlib/import_/test_relative_imports.py15
-rw-r--r--Lib/test/test_importlib/test_locks.py1
-rw-r--r--Lib/test/test_importlib/test_threaded_import.py15
-rw-r--r--Lib/test/test_inspect/test_inspect.py49
-rw-r--r--Lib/test/test_int.py2
-rw-r--r--Lib/test/test_interpreters/test_api.py884
-rw-r--r--Lib/test/test_interpreters/test_channels.py56
-rw-r--r--Lib/test/test_interpreters/test_lifecycle.py4
-rw-r--r--Lib/test/test_interpreters/test_queues.py271
-rw-r--r--Lib/test/test_interpreters/test_stress.py32
-rw-r--r--Lib/test/test_interpreters/utils.py3
-rw-r--r--Lib/test/test_io.py103
-rw-r--r--Lib/test/test_ioctl.py16
-rw-r--r--Lib/test/test_ipaddress.py54
-rw-r--r--Lib/test/test_isinstance.py2
-rw-r--r--Lib/test/test_iter.py2
-rw-r--r--Lib/test/test_json/test_dump.py8
-rw-r--r--Lib/test/test_json/test_fail.py2
-rw-r--r--Lib/test/test_json/test_recursion.py3
-rw-r--r--Lib/test/test_json/test_tool.py87
-rw-r--r--Lib/test/test_launcher.py8
-rw-r--r--Lib/test/test_linecache.py37
-rw-r--r--Lib/test/test_list.py15
-rw-r--r--Lib/test/test_listcomps.py2
-rw-r--r--Lib/test/test_locale.py13
-rw-r--r--Lib/test/test_logging.py27
-rw-r--r--Lib/test/test_lzma.py4
-rw-r--r--Lib/test/test_math.py43
-rw-r--r--Lib/test/test_memoryio.py4
-rw-r--r--Lib/test/test_memoryview.py20
-rw-r--r--Lib/test/test_mimetypes.py8
-rw-r--r--Lib/test/test_minidom.py182
-rw-r--r--Lib/test/test_monitoring.py15
-rw-r--r--Lib/test/test_netrc.py13
-rw-r--r--Lib/test/test_ntpath.py360
-rw-r--r--Lib/test/test_opcache.py14
-rw-r--r--Lib/test/test_optparse.py11
-rw-r--r--Lib/test/test_ordered_dict.py8
-rw-r--r--Lib/test/test_os.py28
-rw-r--r--Lib/test/test_pathlib/support/lexical_path.py11
-rw-r--r--Lib/test/test_pathlib/support/local_path.py10
-rw-r--r--Lib/test/test_pathlib/support/zip_path.py45
-rw-r--r--Lib/test/test_pathlib/test_join_windows.py17
-rw-r--r--Lib/test/test_pathlib/test_pathlib.py36
-rw-r--r--Lib/test/test_pdb.py90
-rw-r--r--Lib/test/test_peepholer.py126
-rw-r--r--Lib/test/test_peg_generator/test_c_parser.py4
-rw-r--r--Lib/test/test_peg_generator/test_pegen.py6
-rw-r--r--Lib/test/test_perf_profiler.py9
-rw-r--r--Lib/test/test_pickle.py14
-rw-r--r--Lib/test/test_platform.py93
-rw-r--r--Lib/test/test_positional_only_arg.py12
-rw-r--r--Lib/test/test_posix.py57
-rw-r--r--Lib/test/test_posixpath.py388
-rw-r--r--Lib/test/test_pprint.py355
-rw-r--r--Lib/test/test_property.py4
-rw-r--r--Lib/test/test_pstats.py7
-rw-r--r--Lib/test/test_pty.py1
-rw-r--r--Lib/test/test_pulldom.py4
-rw-r--r--Lib/test/test_pyclbr.py2
-rw-r--r--Lib/test/test_pydoc/test_pydoc.py4
-rw-r--r--Lib/test/test_pyrepl/support.py3
-rw-r--r--Lib/test/test_pyrepl/test_eventqueue.py78
-rw-r--r--Lib/test/test_pyrepl/test_interact.py2
-rw-r--r--Lib/test/test_pyrepl/test_pyrepl.py212
-rw-r--r--Lib/test/test_pyrepl/test_reader.py90
-rw-r--r--Lib/test/test_pyrepl/test_unix_console.py13
-rw-r--r--Lib/test/test_pyrepl/test_windows_console.py227
-rw-r--r--Lib/test/test_queue.py20
-rw-r--r--Lib/test/test_random.py340
-rw-r--r--Lib/test/test_re.py19
-rw-r--r--Lib/test/test_readline.py8
-rw-r--r--Lib/test/test_regrtest.py65
-rw-r--r--Lib/test/test_remote_pdb.py457
-rw-r--r--Lib/test/test_repl.py6
-rw-r--r--Lib/test/test_reprlib.py97
-rw-r--r--Lib/test/test_rlcompleter.py27
-rw-r--r--Lib/test/test_runpy.py2
-rw-r--r--Lib/test/test_scope.py2
-rw-r--r--Lib/test/test_script_helper.py3
-rw-r--r--Lib/test/test_set.py6
-rw-r--r--Lib/test/test_shlex.py2
-rw-r--r--Lib/test/test_shutil.py12
-rw-r--r--Lib/test/test_site.py22
-rw-r--r--Lib/test/test_socket.py31
-rw-r--r--Lib/test/test_source_encoding.py3
-rw-r--r--Lib/test/test_sqlite3/test_cli.py160
-rw-r--r--Lib/test/test_sqlite3/test_dbapi.py14
-rw-r--r--Lib/test/test_sqlite3/test_factory.py15
-rw-r--r--Lib/test/test_sqlite3/test_hooks.py22
-rw-r--r--Lib/test/test_sqlite3/test_userfunctions.py55
-rw-r--r--Lib/test/test_ssl.py54
-rw-r--r--Lib/test/test_stable_abi_ctypes.py4
-rw-r--r--Lib/test/test_stat.py6
-rw-r--r--Lib/test/test_statistics.py25
-rw-r--r--Lib/test/test_str.py8
-rw-r--r--Lib/test/test_strftime.py16
-rw-r--r--Lib/test/test_string/_support.py1
-rw-r--r--Lib/test/test_string/test_string.py8
-rw-r--r--Lib/test/test_string/test_templatelib.py7
-rw-r--r--Lib/test/test_strptime.py37
-rw-r--r--Lib/test/test_strtod.py2
-rw-r--r--Lib/test/test_structseq.py4
-rw-r--r--Lib/test/test_subprocess.py22
-rw-r--r--Lib/test/test_super.py10
-rw-r--r--Lib/test/test_support.py5
-rw-r--r--Lib/test/test_syntax.py109
-rw-r--r--Lib/test/test_sys.py192
-rw-r--r--Lib/test/test_sysconfig.py20
-rw-r--r--Lib/test/test_tarfile.py405
-rw-r--r--Lib/test/test_tempfile.py16
-rw-r--r--Lib/test/test_termios.py4
-rw-r--r--Lib/test/test_threadedtempfile.py4
-rw-r--r--Lib/test/test_threading.py57
-rw-r--r--Lib/test/test_time.py12
-rw-r--r--Lib/test/test_timeit.py4
-rw-r--r--Lib/test/test_tkinter/support.py2
-rw-r--r--Lib/test/test_tkinter/test_misc.py24
-rw-r--r--Lib/test/test_tkinter/widget_tests.py2
-rw-r--r--Lib/test/test_tokenize.py78
-rw-r--r--Lib/test/test_tools/i18n_data/docstrings.py2
-rw-r--r--Lib/test/test_traceback.py143
-rw-r--r--Lib/test/test_tstring.py1
-rw-r--r--Lib/test/test_ttk/test_widgets.py6
-rw-r--r--Lib/test/test_type_annotations.py41
-rw-r--r--Lib/test/test_type_comments.py2
-rw-r--r--Lib/test/test_types.py40
-rw-r--r--Lib/test/test_typing.py279
-rw-r--r--Lib/test/test_unittest/test_case.py8
-rw-r--r--Lib/test/test_unittest/test_result.py37
-rw-r--r--Lib/test/test_unittest/test_runner.py52
-rw-r--r--Lib/test/test_unittest/testmock/testhelpers.py21
-rw-r--r--Lib/test/test_unparse.py9
-rw-r--r--Lib/test/test_urllib.py19
-rw-r--r--Lib/test/test_urlparse.py398
-rw-r--r--Lib/test/test_userdict.py2
-rwxr-xr-xLib/test/test_uuid.py28
-rw-r--r--Lib/test/test_venv.py6
-rw-r--r--Lib/test/test_warnings/__init__.py10
-rw-r--r--Lib/test/test_wave.py26
-rw-r--r--Lib/test/test_weakref.py6
-rw-r--r--Lib/test/test_weakset.py4
-rw-r--r--Lib/test/test_webbrowser.py1
-rw-r--r--Lib/test/test_winconsoleio.py6
-rw-r--r--Lib/test/test_with.py2
-rw-r--r--Lib/test/test_wmi.py4
-rw-r--r--Lib/test/test_wsgiref.py14
-rw-r--r--Lib/test/test_xml_etree.py49
-rw-r--r--Lib/test/test_xxlimited.py2
-rw-r--r--Lib/test/test_zipapp.py6
-rw-r--r--Lib/test/test_zipfile/__main__.py2
-rw-r--r--Lib/test/test_zipfile/_path/_test_params.py2
-rw-r--r--Lib/test/test_zipfile/_path/test_complexity.py2
-rw-r--r--Lib/test/test_zipfile/_path/test_path.py22
-rw-r--r--Lib/test/test_zipfile/_path/write-alpharep.py1
-rw-r--r--Lib/test/test_zipfile/test_core.py64
-rw-r--r--Lib/test/test_zipimport.py4
-rw-r--r--Lib/test/test_zlib.py108
-rw-r--r--Lib/test/test_zoneinfo/test_zoneinfo.py5
-rw-r--r--Lib/test/test_zstd.py2794
-rw-r--r--Lib/textwrap.py4
-rw-r--r--Lib/threading.py19
-rw-r--r--Lib/tokenize.py20
-rw-r--r--Lib/tomllib/_parser.py10
-rw-r--r--Lib/tomllib/_re.py6
-rw-r--r--Lib/tomllib/mypy.ini2
-rw-r--r--Lib/trace.py2
-rw-r--r--Lib/traceback.py118
-rw-r--r--Lib/types.py2
-rw-r--r--Lib/typing.py144
-rw-r--r--Lib/unittest/case.py4
-rw-r--r--Lib/unittest/main.py4
-rw-r--r--Lib/unittest/mock.py15
-rw-r--r--Lib/unittest/runner.py89
-rw-r--r--Lib/unittest/suite.py20
-rw-r--r--Lib/urllib/parse.py2
-rw-r--r--Lib/urllib/request.py19
-rw-r--r--Lib/uuid.py24
-rw-r--r--Lib/venv/__init__.py11
-rw-r--r--Lib/wave.py29
-rw-r--r--Lib/webbrowser.py4
-rw-r--r--Lib/wsgiref/handlers.py3
-rw-r--r--Lib/xmlrpc/server.py2
-rw-r--r--Lib/zipapp.py2
-rw-r--r--Lib/zipfile/__init__.py80
-rw-r--r--Lib/zipfile/_path/__init__.py38
-rw-r--r--Lib/zipfile/_path/_functools.py20
-rw-r--r--Lib/zipfile/_path/glob.py5
-rw-r--r--Lib/zoneinfo/_common.py8
-rw-r--r--Lib/zoneinfo/_tzpath.py5
-rw-r--r--Lib/zoneinfo/_zoneinfo.py6
485 files changed, 23618 insertions, 6417 deletions
diff --git a/Lib/_ast_unparse.py b/Lib/_ast_unparse.py
index 0b669edb2ff..c25066eb107 100644
--- a/Lib/_ast_unparse.py
+++ b/Lib/_ast_unparse.py
@@ -627,6 +627,9 @@ class Unparser(NodeVisitor):
self._ftstring_helper(fstring_parts)
def _tstring_helper(self, node):
+ if not node.values:
+ self._write_ftstring([], "t")
+ return
last_idx = 0
for i, value in enumerate(node.values):
# This can happen if we have an implicit concat of a t-string
@@ -679,9 +682,12 @@ class Unparser(NodeVisitor):
unparser.set_precedence(_Precedence.TEST.next(), inner)
return unparser.visit(inner)
- def _write_interpolation(self, node):
+ def _write_interpolation(self, node, is_interpolation=False):
with self.delimit("{", "}"):
- expr = self._unparse_interpolation_value(node.value)
+ if is_interpolation:
+ expr = node.str
+ else:
+ expr = self._unparse_interpolation_value(node.value)
if expr.startswith("{"):
# Separate pair of opening brackets as "{ {"
self.write(" ")
@@ -696,7 +702,7 @@ class Unparser(NodeVisitor):
self._write_interpolation(node)
def visit_Interpolation(self, node):
- self._write_interpolation(node)
+ self._write_interpolation(node, is_interpolation=True)
def visit_Name(self, node):
self.write(node.id)
diff --git a/Lib/_colorize.py b/Lib/_colorize.py
index 54895488e74..4a310a40235 100644
--- a/Lib/_colorize.py
+++ b/Lib/_colorize.py
@@ -1,28 +1,17 @@
-from __future__ import annotations
import io
import os
import sys
+from collections.abc import Callable, Iterator, Mapping
+from dataclasses import dataclass, field, Field
+
COLORIZE = True
+
# types
if False:
- from typing import IO, Literal
-
- type ColorTag = Literal[
- "PROMPT",
- "KEYWORD",
- "BUILTIN",
- "COMMENT",
- "STRING",
- "NUMBER",
- "OP",
- "DEFINITION",
- "SOFT_KEYWORD",
- "RESET",
- ]
-
- theme: dict[ColorTag, str]
+ from typing import IO, Self, ClassVar
+ _theme: Theme
class ANSIColors:
@@ -86,6 +75,186 @@ for attr, code in ANSIColors.__dict__.items():
setattr(NoColors, attr, "")
+#
+# Experimental theming support (see gh-133346)
+#
+
+# - Create a theme by copying an existing `Theme` with one or more sections
+# replaced, using `default_theme.copy_with()`;
+# - create a theme section by copying an existing `ThemeSection` with one or
+# more colors replaced, using for example `default_theme.syntax.copy_with()`;
+# - create a theme from scratch by instantiating a `Theme` data class with
+# the required sections (which are also dataclass instances).
+#
+# Then call `_colorize.set_theme(your_theme)` to set it.
+#
+# Put your theme configuration in $PYTHONSTARTUP for the interactive shell,
+# or sitecustomize.py in your virtual environment or Python installation for
+# other uses. Your applications can call `_colorize.set_theme()` too.
+#
+# Note that thanks to the dataclasses providing default values for all fields,
+# creating a new theme or theme section from scratch is possible without
+# specifying all keys.
+#
+# For example, here's a theme that makes punctuation and operators less prominent:
+#
+# try:
+# from _colorize import set_theme, default_theme, Syntax, ANSIColors
+# except ImportError:
+# pass
+# else:
+# theme_with_dim_operators = default_theme.copy_with(
+# syntax=Syntax(op=ANSIColors.INTENSE_BLACK),
+# )
+# set_theme(theme_with_dim_operators)
+# del set_theme, default_theme, Syntax, ANSIColors, theme_with_dim_operators
+#
+# Guarding the import ensures that your .pythonstartup file will still work in
+# Python 3.13 and older. Deleting the variables ensures they don't remain in your
+# interactive shell's global scope.
+
+class ThemeSection(Mapping[str, str]):
+ """A mixin/base class for theme sections.
+
+ It enables dictionary access to a section, as well as implements convenience
+ methods.
+ """
+
+ # The two types below are just that: types to inform the type checker that the
+ # mixin will work in context of those fields existing
+ __dataclass_fields__: ClassVar[dict[str, Field[str]]]
+ _name_to_value: Callable[[str], str]
+
+ def __post_init__(self) -> None:
+ name_to_value = {}
+ for color_name in self.__dataclass_fields__:
+ name_to_value[color_name] = getattr(self, color_name)
+ super().__setattr__('_name_to_value', name_to_value.__getitem__)
+
+ def copy_with(self, **kwargs: str) -> Self:
+ color_state: dict[str, str] = {}
+ for color_name in self.__dataclass_fields__:
+ color_state[color_name] = getattr(self, color_name)
+ color_state.update(kwargs)
+ return type(self)(**color_state)
+
+ @classmethod
+ def no_colors(cls) -> Self:
+ color_state: dict[str, str] = {}
+ for color_name in cls.__dataclass_fields__:
+ color_state[color_name] = ""
+ return cls(**color_state)
+
+ def __getitem__(self, key: str) -> str:
+ return self._name_to_value(key)
+
+ def __len__(self) -> int:
+ return len(self.__dataclass_fields__)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self.__dataclass_fields__)
+
+
+@dataclass(frozen=True)
+class Argparse(ThemeSection):
+ usage: str = ANSIColors.BOLD_BLUE
+ prog: str = ANSIColors.BOLD_MAGENTA
+ prog_extra: str = ANSIColors.MAGENTA
+ heading: str = ANSIColors.BOLD_BLUE
+ summary_long_option: str = ANSIColors.CYAN
+ summary_short_option: str = ANSIColors.GREEN
+ summary_label: str = ANSIColors.YELLOW
+ summary_action: str = ANSIColors.GREEN
+ long_option: str = ANSIColors.BOLD_CYAN
+ short_option: str = ANSIColors.BOLD_GREEN
+ label: str = ANSIColors.BOLD_YELLOW
+ action: str = ANSIColors.BOLD_GREEN
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Syntax(ThemeSection):
+ prompt: str = ANSIColors.BOLD_MAGENTA
+ keyword: str = ANSIColors.BOLD_BLUE
+ builtin: str = ANSIColors.CYAN
+ comment: str = ANSIColors.RED
+ string: str = ANSIColors.GREEN
+ number: str = ANSIColors.YELLOW
+ op: str = ANSIColors.RESET
+ definition: str = ANSIColors.BOLD
+ soft_keyword: str = ANSIColors.BOLD_BLUE
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Traceback(ThemeSection):
+ type: str = ANSIColors.BOLD_MAGENTA
+ message: str = ANSIColors.MAGENTA
+ filename: str = ANSIColors.MAGENTA
+ line_no: str = ANSIColors.MAGENTA
+ frame: str = ANSIColors.MAGENTA
+ error_highlight: str = ANSIColors.BOLD_RED
+ error_range: str = ANSIColors.RED
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Unittest(ThemeSection):
+ passed: str = ANSIColors.GREEN
+ warn: str = ANSIColors.YELLOW
+ fail: str = ANSIColors.RED
+ fail_info: str = ANSIColors.BOLD_RED
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Theme:
+ """A suite of themes for all sections of Python.
+
+ When adding a new one, remember to also modify `copy_with` and `no_colors`
+ below.
+ """
+ argparse: Argparse = field(default_factory=Argparse)
+ syntax: Syntax = field(default_factory=Syntax)
+ traceback: Traceback = field(default_factory=Traceback)
+ unittest: Unittest = field(default_factory=Unittest)
+
+ def copy_with(
+ self,
+ *,
+ argparse: Argparse | None = None,
+ syntax: Syntax | None = None,
+ traceback: Traceback | None = None,
+ unittest: Unittest | None = None,
+ ) -> Self:
+ """Return a new Theme based on this instance with some sections replaced.
+
+ Themes are immutable to protect against accidental modifications that
+ could lead to invalid terminal states.
+ """
+ return type(self)(
+ argparse=argparse or self.argparse,
+ syntax=syntax or self.syntax,
+ traceback=traceback or self.traceback,
+ unittest=unittest or self.unittest,
+ )
+
+ @classmethod
+ def no_colors(cls) -> Self:
+ """Return a new Theme where colors in all sections are empty strings.
+
+ This allows writing user code as if colors are always used. The color
+ fields will be ANSI color code strings when colorization is desired
+ and possible, and empty strings otherwise.
+ """
+ return cls(
+ argparse=Argparse.no_colors(),
+ syntax=Syntax.no_colors(),
+ traceback=Traceback.no_colors(),
+ unittest=Unittest.no_colors(),
+ )
+
+
def get_colors(
colorize: bool = False, *, file: IO[str] | IO[bytes] | None = None
) -> ANSIColors:
@@ -138,26 +307,40 @@ def can_colorize(*, file: IO[str] | IO[bytes] | None = None) -> bool:
return hasattr(file, "isatty") and file.isatty()
-def set_theme(t: dict[ColorTag, str] | None = None) -> None:
- global theme
+default_theme = Theme()
+theme_no_color = default_theme.no_colors()
+
+
+def get_theme(
+ *,
+ tty_file: IO[str] | IO[bytes] | None = None,
+ force_color: bool = False,
+ force_no_color: bool = False,
+) -> Theme:
+ """Returns the currently set theme, potentially in a zero-color variant.
+
+ In cases where colorizing is not possible (see `can_colorize`), the returned
+ theme contains all empty strings in all color definitions.
+ See `Theme.no_colors()` for more information.
+
+ It is recommended not to cache the result of this function for extended
+ periods of time because the user might influence theme selection by
+ the interactive shell, a debugger, or application-specific code. The
+ environment (including environment variable state and console configuration
+ on Windows) can also change in the course of the application life cycle.
+ """
+ if force_color or (not force_no_color and can_colorize(file=tty_file)):
+ return _theme
+ return theme_no_color
+
+
+def set_theme(t: Theme) -> None:
+ global _theme
- if t:
- theme = t
- return
+ if not isinstance(t, Theme):
+ raise ValueError(f"Expected Theme object, found {t}")
- colors = get_colors()
- theme = {
- "PROMPT": colors.BOLD_MAGENTA,
- "KEYWORD": colors.BOLD_BLUE,
- "BUILTIN": colors.CYAN,
- "COMMENT": colors.RED,
- "STRING": colors.GREEN,
- "NUMBER": colors.YELLOW,
- "OP": colors.RESET,
- "DEFINITION": colors.BOLD,
- "SOFT_KEYWORD": colors.BOLD_BLUE,
- "RESET": colors.RESET,
- }
+ _theme = t
-set_theme()
+set_theme(default_theme)
diff --git a/Lib/_compat_pickle.py b/Lib/_compat_pickle.py
index 439f8c02f4b..a9813264324 100644
--- a/Lib/_compat_pickle.py
+++ b/Lib/_compat_pickle.py
@@ -175,7 +175,6 @@ IMPORT_MAPPING.update({
'SimpleDialog': 'tkinter.simpledialog',
'DocXMLRPCServer': 'xmlrpc.server',
'SimpleHTTPServer': 'http.server',
- 'CGIHTTPServer': 'http.server',
# For compatibility with broken pickles saved in old Python 3 versions
'UserDict': 'collections',
'UserList': 'collections',
@@ -217,8 +216,6 @@ REVERSE_NAME_MAPPING.update({
('DocXMLRPCServer', 'DocCGIXMLRPCRequestHandler'),
('http.server', 'SimpleHTTPRequestHandler'):
('SimpleHTTPServer', 'SimpleHTTPRequestHandler'),
- ('http.server', 'CGIHTTPRequestHandler'):
- ('CGIHTTPServer', 'CGIHTTPRequestHandler'),
('_socket', 'socket'): ('socket', '_socketobject'),
})
diff --git a/Lib/_opcode_metadata.py b/Lib/_opcode_metadata.py
index b9304ec3c03..f168d169a32 100644
--- a/Lib/_opcode_metadata.py
+++ b/Lib/_opcode_metadata.py
@@ -6,10 +6,6 @@ _specializations = {
"RESUME": [
"RESUME_CHECK",
],
- "LOAD_CONST": [
- "LOAD_CONST_MORTAL",
- "LOAD_CONST_IMMORTAL",
- ],
"TO_BOOL": [
"TO_BOOL_ALWAYS_TRUE",
"TO_BOOL_BOOL",
@@ -186,28 +182,26 @@ _specialized_opmap = {
'LOAD_ATTR_PROPERTY': 187,
'LOAD_ATTR_SLOT': 188,
'LOAD_ATTR_WITH_HINT': 189,
- 'LOAD_CONST_IMMORTAL': 190,
- 'LOAD_CONST_MORTAL': 191,
- 'LOAD_GLOBAL_BUILTIN': 192,
- 'LOAD_GLOBAL_MODULE': 193,
- 'LOAD_SUPER_ATTR_ATTR': 194,
- 'LOAD_SUPER_ATTR_METHOD': 195,
- 'RESUME_CHECK': 196,
- 'SEND_GEN': 197,
- 'STORE_ATTR_INSTANCE_VALUE': 198,
- 'STORE_ATTR_SLOT': 199,
- 'STORE_ATTR_WITH_HINT': 200,
- 'STORE_SUBSCR_DICT': 201,
- 'STORE_SUBSCR_LIST_INT': 202,
- 'TO_BOOL_ALWAYS_TRUE': 203,
- 'TO_BOOL_BOOL': 204,
- 'TO_BOOL_INT': 205,
- 'TO_BOOL_LIST': 206,
- 'TO_BOOL_NONE': 207,
- 'TO_BOOL_STR': 208,
- 'UNPACK_SEQUENCE_LIST': 209,
- 'UNPACK_SEQUENCE_TUPLE': 210,
- 'UNPACK_SEQUENCE_TWO_TUPLE': 211,
+ 'LOAD_GLOBAL_BUILTIN': 190,
+ 'LOAD_GLOBAL_MODULE': 191,
+ 'LOAD_SUPER_ATTR_ATTR': 192,
+ 'LOAD_SUPER_ATTR_METHOD': 193,
+ 'RESUME_CHECK': 194,
+ 'SEND_GEN': 195,
+ 'STORE_ATTR_INSTANCE_VALUE': 196,
+ 'STORE_ATTR_SLOT': 197,
+ 'STORE_ATTR_WITH_HINT': 198,
+ 'STORE_SUBSCR_DICT': 199,
+ 'STORE_SUBSCR_LIST_INT': 200,
+ 'TO_BOOL_ALWAYS_TRUE': 201,
+ 'TO_BOOL_BOOL': 202,
+ 'TO_BOOL_INT': 203,
+ 'TO_BOOL_LIST': 204,
+ 'TO_BOOL_NONE': 205,
+ 'TO_BOOL_STR': 206,
+ 'UNPACK_SEQUENCE_LIST': 207,
+ 'UNPACK_SEQUENCE_TUPLE': 208,
+ 'UNPACK_SEQUENCE_TWO_TUPLE': 209,
}
opmap = {
diff --git a/Lib/_py_warnings.py b/Lib/_py_warnings.py
index 3cdc6ffe198..cbaa9445862 100644
--- a/Lib/_py_warnings.py
+++ b/Lib/_py_warnings.py
@@ -371,7 +371,7 @@ def _setoption(arg):
if message:
message = re.escape(message)
if module:
- module = re.escape(module) + r'\Z'
+ module = re.escape(module) + r'\z'
if lineno:
try:
lineno = int(lineno)
diff --git a/Lib/_pydatetime.py b/Lib/_pydatetime.py
index e3db1b52b11..bc35823f701 100644
--- a/Lib/_pydatetime.py
+++ b/Lib/_pydatetime.py
@@ -467,6 +467,7 @@ def _parse_isoformat_time(tstr):
hour, minute, second, microsecond = time_comps
became_next_day = False
error_from_components = False
+ error_from_tz = None
if (hour == 24):
if all(time_comp == 0 for time_comp in time_comps[1:]):
hour = 0
@@ -500,14 +501,22 @@ def _parse_isoformat_time(tstr):
else:
tzsign = -1 if tstr[tz_pos - 1] == '-' else 1
- td = timedelta(hours=tz_comps[0], minutes=tz_comps[1],
- seconds=tz_comps[2], microseconds=tz_comps[3])
-
- tzi = timezone(tzsign * td)
+ try:
+ # This function is intended to validate datetimes, but because
+ # we restrict time zones to ±24h, it serves here as well.
+ _check_time_fields(hour=tz_comps[0], minute=tz_comps[1],
+ second=tz_comps[2], microsecond=tz_comps[3],
+ fold=0)
+ except ValueError as e:
+ error_from_tz = e
+ else:
+ td = timedelta(hours=tz_comps[0], minutes=tz_comps[1],
+ seconds=tz_comps[2], microseconds=tz_comps[3])
+ tzi = timezone(tzsign * td)
time_comps.append(tzi)
- return time_comps, became_next_day, error_from_components
+ return time_comps, became_next_day, error_from_components, error_from_tz
# tuple[int, int, int] -> tuple[int, int, int] version of date.fromisocalendar
def _isoweek_to_gregorian(year, week, day):
@@ -1127,8 +1136,8 @@ class date:
This is 'YYYY-MM-DD'.
References:
- - http://www.w3.org/TR/NOTE-datetime
- - http://www.cl.cam.ac.uk/~mgk25/iso-time.html
+ - https://www.w3.org/TR/NOTE-datetime
+ - https://www.cl.cam.ac.uk/~mgk25/iso-time.html
"""
return "%04d-%02d-%02d" % (self._year, self._month, self._day)
@@ -1262,7 +1271,7 @@ class date:
The first week is 1; Monday is 1 ... Sunday is 7.
ISO calendar algorithm taken from
- http://www.phys.uu.nl/~vgent/calendar/isocalendar.htm
+ https://www.phys.uu.nl/~vgent/calendar/isocalendar.htm
(used with permission)
"""
year = self._year
@@ -1633,9 +1642,21 @@ class time:
time_string = time_string.removeprefix('T')
try:
- return cls(*_parse_isoformat_time(time_string)[0])
- except Exception:
- raise ValueError(f'Invalid isoformat string: {time_string!r}')
+ time_components, _, error_from_components, error_from_tz = (
+ _parse_isoformat_time(time_string)
+ )
+ except ValueError:
+ raise ValueError(
+ f'Invalid isoformat string: {time_string!r}') from None
+ else:
+ if error_from_tz:
+ raise error_from_tz
+ if error_from_components:
+ raise ValueError(
+ "Minute, second, and microsecond must be 0 when hour is 24"
+ )
+
+ return cls(*time_components)
def strftime(self, format):
"""Format using strftime(). The date part of the timestamp passed
@@ -1947,11 +1968,16 @@ class datetime(date):
if tstr:
try:
- time_components, became_next_day, error_from_components = _parse_isoformat_time(tstr)
+ (time_components,
+ became_next_day,
+ error_from_components,
+ error_from_tz) = _parse_isoformat_time(tstr)
except ValueError:
raise ValueError(
f'Invalid isoformat string: {date_string!r}') from None
else:
+ if error_from_tz:
+ raise error_from_tz
if error_from_components:
raise ValueError("minute, second, and microsecond must be 0 when hour is 24")
@@ -2089,7 +2115,6 @@ class datetime(date):
else:
ts = (self - _EPOCH) // timedelta(seconds=1)
localtm = _time.localtime(ts)
- local = datetime(*localtm[:6])
# Extract TZ data
gmtoff = localtm.tm_gmtoff
zone = localtm.tm_zone
@@ -2139,7 +2164,7 @@ class datetime(date):
By default, the fractional part is omitted if self.microsecond == 0.
If self.tzinfo is not None, the UTC offset is also attached, giving
- giving a full format of 'YYYY-MM-DD HH:MM:SS.mmmmmm+HH:MM'.
+ a full format of 'YYYY-MM-DD HH:MM:SS.mmmmmm+HH:MM'.
Optional argument sep specifies the separator between date and
time, default 'T'.
diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py
index 4b09207eca6..781b38ec26b 100644
--- a/Lib/_pydecimal.py
+++ b/Lib/_pydecimal.py
@@ -6096,7 +6096,7 @@ _parser = re.compile(r""" # A numeric string consists of:
(?P<diag>\d*) # with (possibly empty) diagnostic info.
)
# \s*
- \Z
+ \z
""", re.VERBOSE | re.IGNORECASE).match
_all_zeros = re.compile('0*$').match
@@ -6120,11 +6120,11 @@ _parse_format_specifier_regex = re.compile(r"""\A
(?P<no_neg_0>z)?
(?P<alt>\#)?
(?P<zeropad>0)?
-(?P<minimumwidth>(?!0)\d+)?
+(?P<minimumwidth>\d+)?
(?P<thousands_sep>[,_])?
-(?:\.(?P<precision>0|(?!0)\d+))?
+(?:\.(?P<precision>\d+))?
(?P<type>[eEfFgGn%])?
-\Z
+\z
""", re.VERBOSE|re.DOTALL)
del re
diff --git a/Lib/_pyio.py b/Lib/_pyio.py
index a870de5b532..fb2a6d049ca 100644
--- a/Lib/_pyio.py
+++ b/Lib/_pyio.py
@@ -407,6 +407,9 @@ class IOBase(metaclass=abc.ABCMeta):
if closed:
return
+ if dealloc_warn := getattr(self, "_dealloc_warn", None):
+ dealloc_warn(self)
+
# If close() fails, the caller logs the exception with
# sys.unraisablehook. close() must be called at the end at __del__().
self.close()
@@ -645,8 +648,6 @@ class RawIOBase(IOBase):
self._unsupported("write")
io.RawIOBase.register(RawIOBase)
-from _io import FileIO
-RawIOBase.register(FileIO)
class BufferedIOBase(IOBase):
@@ -853,6 +854,10 @@ class _BufferedIOMixin(BufferedIOBase):
else:
return "<{}.{} name={!r}>".format(modname, clsname, name)
+ def _dealloc_warn(self, source):
+ if dealloc_warn := getattr(self.raw, "_dealloc_warn", None):
+ dealloc_warn(source)
+
### Lower-level APIs ###
def fileno(self):
@@ -1563,7 +1568,8 @@ class FileIO(RawIOBase):
if not isinstance(fd, int):
raise TypeError('expected integer from opener')
if fd < 0:
- raise OSError('Negative file descriptor')
+ # bpo-27066: Raise a ValueError for bad value.
+ raise ValueError(f'opener returned {fd}')
owned_fd = fd
if not noinherit_flag:
os.set_inheritable(fd, False)
@@ -1600,12 +1606,11 @@ class FileIO(RawIOBase):
raise
self._fd = fd
- def __del__(self):
+ def _dealloc_warn(self, source):
if self._fd >= 0 and self._closefd and not self.closed:
import warnings
- warnings.warn('unclosed file %r' % (self,), ResourceWarning,
+ warnings.warn(f'unclosed file {source!r}', ResourceWarning,
stacklevel=2, source=self)
- self.close()
def __getstate__(self):
raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
@@ -1780,7 +1785,7 @@ class FileIO(RawIOBase):
if not self.closed:
self._stat_atopen = None
try:
- if self._closefd:
+ if self._closefd and self._fd >= 0:
os.close(self._fd)
finally:
super().close()
@@ -2689,6 +2694,10 @@ class TextIOWrapper(TextIOBase):
def newlines(self):
return self._decoder.newlines if self._decoder else None
+ def _dealloc_warn(self, source):
+ if dealloc_warn := getattr(self.buffer, "_dealloc_warn", None):
+ dealloc_warn(source)
+
class StringIO(TextIOWrapper):
"""Text I/O implementation using an in-memory buffer.
diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py
index 347f05607c7..1e9462a4215 100644
--- a/Lib/_pyrepl/_module_completer.py
+++ b/Lib/_pyrepl/_module_completer.py
@@ -17,8 +17,8 @@ if TYPE_CHECKING:
def make_default_module_completer() -> ModuleCompleter:
- # Inside pyrepl, __package__ is set to '_pyrepl'
- return ModuleCompleter(namespace={'__package__': '_pyrepl'})
+ # Inside pyrepl, __package__ is set to None by default
+ return ModuleCompleter(namespace={'__package__': None})
class ModuleCompleter:
@@ -42,11 +42,11 @@ class ModuleCompleter:
self._global_cache: list[pkgutil.ModuleInfo] = []
self._curr_sys_path: list[str] = sys.path[:]
- def get_completions(self, line: str) -> list[str]:
+ def get_completions(self, line: str) -> list[str] | None:
"""Return the next possible import completions for 'line'."""
result = ImportParser(line).parse()
if not result:
- return []
+ return None
try:
return self.complete(*result)
except Exception:
@@ -81,8 +81,11 @@ class ModuleCompleter:
def _find_modules(self, path: str, prefix: str) -> list[str]:
if not path:
# Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
- return [name for _, name, _ in self.global_cache
- if name.startswith(prefix)]
+ builtin_modules = [name for name in sys.builtin_module_names
+ if self.is_suggestion_match(name, prefix)]
+ third_party_modules = [module.name for module in self.global_cache
+ if self.is_suggestion_match(module.name, prefix)]
+ return sorted(builtin_modules + third_party_modules)
if path.startswith('.'):
# Convert relative path to absolute path
@@ -97,7 +100,14 @@ class ModuleCompleter:
if mod_info.ispkg and mod_info.name == segment]
modules = self.iter_submodules(modules)
return [module.name for module in modules
- if module.name.startswith(prefix)]
+ if self.is_suggestion_match(module.name, prefix)]
+
+ def is_suggestion_match(self, module_name: str, prefix: str) -> bool:
+ if prefix:
+ return module_name.startswith(prefix)
+ # For consistency with attribute completion, which
+ # does not suggest private attributes unless requested.
+ return not module_name.startswith("_")
def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
"""Iterate over all submodules of the given parent modules."""
diff --git a/Lib/_pyrepl/base_eventqueue.py b/Lib/_pyrepl/base_eventqueue.py
index e018c4fc183..0589a0f437e 100644
--- a/Lib/_pyrepl/base_eventqueue.py
+++ b/Lib/_pyrepl/base_eventqueue.py
@@ -69,18 +69,14 @@ class BaseEventQueue:
trace('added event {event}', event=event)
self.events.append(event)
- def push(self, char: int | bytes | str) -> None:
+ def push(self, char: int | bytes) -> None:
"""
Processes a character by updating the buffer and handling special key mappings.
"""
+ assert isinstance(char, (int, bytes))
ord_char = char if isinstance(char, int) else ord(char)
- if ord_char > 255:
- assert isinstance(char, str)
- char = bytes(char.encode(self.encoding, "replace"))
- self.buf.extend(char)
- else:
- char = bytes(bytearray((ord_char,)))
- self.buf.append(ord_char)
+ char = ord_char.to_bytes()
+ self.buf.append(ord_char)
if char in self.keymap:
if self.keymap is self.compiled_keymap:
@@ -91,7 +87,7 @@ class BaseEventQueue:
if isinstance(k, dict):
self.keymap = k
else:
- self.insert(Event('key', k, self.flush_buf()))
+ self.insert(Event('key', k, bytes(self.flush_buf())))
self.keymap = self.compiled_keymap
elif self.buf and self.buf[0] == 27: # escape
@@ -100,7 +96,7 @@ class BaseEventQueue:
# the docstring in keymap.py
trace('unrecognized escape sequence, propagating...')
self.keymap = self.compiled_keymap
- self.insert(Event('key', '\033', bytearray(b'\033')))
+ self.insert(Event('key', '\033', b'\033'))
for _c in self.flush_buf()[1:]:
self.push(_c)
@@ -110,5 +106,5 @@ class BaseEventQueue:
except UnicodeError:
return
else:
- self.insert(Event('key', decoded, self.flush_buf()))
+ self.insert(Event('key', decoded, bytes(self.flush_buf())))
self.keymap = self.compiled_keymap
diff --git a/Lib/_pyrepl/commands.py b/Lib/_pyrepl/commands.py
index 2054a8e400f..50c824995d8 100644
--- a/Lib/_pyrepl/commands.py
+++ b/Lib/_pyrepl/commands.py
@@ -370,6 +370,13 @@ class self_insert(EditCommand):
r = self.reader
text = self.event * r.get_arg()
r.insert(text)
+ if r.paste_mode:
+ data = ""
+ ev = r.console.getpending()
+ data += ev.data
+ if data:
+ r.insert(data)
+ r.last_refresh_cache.invalidated = True
class insert_nl(EditCommand):
@@ -439,7 +446,7 @@ class help(Command):
import _sitebuiltins
with self.reader.suspend():
- self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment, call-arg]
+ self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment]
class invalid_key(Command):
@@ -484,7 +491,6 @@ class perform_bracketed_paste(Command):
data = ""
start = time.time()
while done not in data:
- self.reader.console.wait(100)
ev = self.reader.console.getpending()
data += ev.data
trace(
diff --git a/Lib/_pyrepl/main.py b/Lib/_pyrepl/main.py
index a6f824dcc4a..447eb1e551e 100644
--- a/Lib/_pyrepl/main.py
+++ b/Lib/_pyrepl/main.py
@@ -1,6 +1,7 @@
import errno
import os
import sys
+import types
CAN_USE_PYREPL: bool
@@ -29,12 +30,10 @@ def interactive_console(mainmodule=None, quiet=False, pythonstartup=False):
print(FAIL_REASON, file=sys.stderr)
return sys._baserepl()
- if mainmodule:
- namespace = mainmodule.__dict__
- else:
- import __main__
- namespace = __main__.__dict__
- namespace.pop("__pyrepl_interactive_console", None)
+ if not mainmodule:
+ mainmodule = types.ModuleType("__main__")
+
+ namespace = mainmodule.__dict__
# sys._baserepl() above does this internally, we do it here
startup_path = os.getenv("PYTHONSTARTUP")
diff --git a/Lib/_pyrepl/reader.py b/Lib/_pyrepl/reader.py
index 65c2230dfd6..0ebd9162eca 100644
--- a/Lib/_pyrepl/reader.py
+++ b/Lib/_pyrepl/reader.py
@@ -28,7 +28,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from . import commands, console, input
-from .utils import wlen, unbracket, disp_str, gen_colors
+from .utils import wlen, unbracket, disp_str, gen_colors, THEME
from .trace import trace
@@ -491,11 +491,8 @@ class Reader:
prompt = self.ps1
if self.can_colorize:
- prompt = (
- f"{_colorize.theme["PROMPT"]}"
- f"{prompt}"
- f"{_colorize.theme["RESET"]}"
- )
+ t = THEME()
+ prompt = f"{t.prompt}{prompt}{t.reset}"
return prompt
def push_input_trans(self, itrans: input.KeymapTranslator) -> None:
diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py
index 560a9db1921..9560ae779ab 100644
--- a/Lib/_pyrepl/readline.py
+++ b/Lib/_pyrepl/readline.py
@@ -134,7 +134,8 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
return "".join(b[p + 1 : self.pos])
def get_completions(self, stem: str) -> list[str]:
- if module_completions := self.get_module_completions():
+ module_completions = self.get_module_completions()
+ if module_completions is not None:
return module_completions
if len(stem) == 0 and self.more_lines is not None:
b = self.buffer
@@ -165,7 +166,7 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
result.sort()
return result
- def get_module_completions(self) -> list[str]:
+ def get_module_completions(self) -> list[str] | None:
line = self.get_line()
return self.config.module_completer.get_completions(line)
@@ -606,6 +607,7 @@ def _setup(namespace: Mapping[str, Any]) -> None:
# set up namespace in rlcompleter, which requires it to be a bona fide dict
if not isinstance(namespace, dict):
namespace = dict(namespace)
+ _wrapper.config.module_completer = ModuleCompleter(namespace)
_wrapper.config.readline_completer = RLCompleter(namespace).complete
# this is not really what readline.c does. Better than nothing I guess
diff --git a/Lib/_pyrepl/simple_interact.py b/Lib/_pyrepl/simple_interact.py
index e2274629b65..965b853c34b 100644
--- a/Lib/_pyrepl/simple_interact.py
+++ b/Lib/_pyrepl/simple_interact.py
@@ -31,6 +31,7 @@ import os
import sys
import code
import warnings
+import errno
from .readline import _get_reader, multiline_input, append_history_file
@@ -110,6 +111,10 @@ def run_multiline_interactive_console(
more_lines = functools.partial(_more_lines, console)
input_n = 0
+ _is_x_showrefcount_set = sys._xoptions.get("showrefcount")
+ _is_pydebug_build = hasattr(sys, "gettotalrefcount")
+ show_ref_count = _is_x_showrefcount_set and _is_pydebug_build
+
def maybe_run_command(statement: str) -> bool:
statement = statement.strip()
if statement in console.locals or statement not in REPL_COMMANDS:
@@ -149,6 +154,7 @@ def run_multiline_interactive_console(
append_history_file()
except (FileNotFoundError, PermissionError, OSError) as e:
warnings.warn(f"failed to open the history file for writing: {e}")
+
input_n += 1
except KeyboardInterrupt:
r = _get_reader()
@@ -162,3 +168,13 @@ def run_multiline_interactive_console(
except MemoryError:
console.write("\nMemoryError\n")
console.resetbuffer()
+ except SystemExit:
+ raise
+ except:
+ console.showtraceback()
+ console.resetbuffer()
+ if show_ref_count:
+ console.write(
+ f"[{sys.gettotalrefcount()} refs,"
+ f" {sys.getallocatedblocks()} blocks]\n"
+ )
diff --git a/Lib/_pyrepl/unix_console.py b/Lib/_pyrepl/unix_console.py
index 07b160d2324..d21cdd9b076 100644
--- a/Lib/_pyrepl/unix_console.py
+++ b/Lib/_pyrepl/unix_console.py
@@ -29,6 +29,7 @@ import signal
import struct
import termios
import time
+import types
import platform
from fcntl import ioctl
@@ -39,6 +40,12 @@ from .trace import trace
from .unix_eventqueue import EventQueue
from .utils import wlen
+# declare posix optional to allow None assignment on other platforms
+posix: types.ModuleType | None
+try:
+ import posix
+except ImportError:
+ posix = None
TYPE_CHECKING = False
@@ -197,6 +204,12 @@ class UnixConsole(Console):
self.event_queue = EventQueue(self.input_fd, self.encoding)
self.cursor_visible = 1
+ signal.signal(signal.SIGCONT, self._sigcont_handler)
+
+ def _sigcont_handler(self, signum, frame):
+ self.restore()
+ self.prepare()
+
def __read(self, n: int) -> bytes:
return os.read(self.input_fd, n)
@@ -550,11 +563,9 @@ class UnixConsole(Console):
@property
def input_hook(self):
- try:
- import posix
- except ImportError:
- return None
- if posix._is_inputhook_installed():
+ # avoid inline imports here so the repl doesn't get flooded
+ # with import logging from -X importtime=2
+ if posix is not None and posix._is_inputhook_installed():
return posix._inputhook
def __enable_bracketed_paste(self) -> None:
diff --git a/Lib/_pyrepl/utils.py b/Lib/_pyrepl/utils.py
index fe154aa59a0..e04fbdc6c8a 100644
--- a/Lib/_pyrepl/utils.py
+++ b/Lib/_pyrepl/utils.py
@@ -23,6 +23,11 @@ IDENTIFIERS_AFTER = {"def", "class"}
BUILTINS = {str(name) for name in dir(builtins) if not name.startswith('_')}
+def THEME(**kwargs):
+ # Not cached: the user can modify the theme inside the interactive session.
+ return _colorize.get_theme(**kwargs).syntax
+
+
class Span(NamedTuple):
"""Span indexing that's inclusive on both ends."""
@@ -36,15 +41,21 @@ class Span(NamedTuple):
@classmethod
def from_token(cls, token: TI, line_len: list[int]) -> Self:
+ end_offset = -1
+ if (token.type in {T.FSTRING_MIDDLE, T.TSTRING_MIDDLE}
+ and token.string.endswith(("{", "}"))):
+ # gh-134158: a visible trailing brace comes from a double brace in input
+ end_offset += 1
+
return cls(
line_len[token.start[0] - 1] + token.start[1],
- line_len[token.end[0] - 1] + token.end[1] - 1,
+ line_len[token.end[0] - 1] + token.end[1] + end_offset,
)
class ColorSpan(NamedTuple):
span: Span
- tag: _colorize.ColorTag
+ tag: str
@functools.cache
@@ -97,6 +108,8 @@ def gen_colors(buffer: str) -> Iterator[ColorSpan]:
for color in gen_colors_from_token_stream(gen, line_lengths):
yield color
last_emitted = color
+ except SyntaxError:
+ return
except tokenize.TokenError as te:
yield from recover_unterminated_string(
te, line_lengths, last_emitted, buffer
@@ -135,7 +148,7 @@ def recover_unterminated_string(
span = Span(start, end)
trace("yielding span {a} -> {b}", a=span.start, b=span.end)
- yield ColorSpan(span, "STRING")
+ yield ColorSpan(span, "string")
else:
trace(
"unhandled token error({buffer}) = {te}",
@@ -164,28 +177,28 @@ def gen_colors_from_token_stream(
| T.TSTRING_START | T.TSTRING_MIDDLE | T.TSTRING_END
):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "STRING")
+ yield ColorSpan(span, "string")
case T.COMMENT:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "COMMENT")
+ yield ColorSpan(span, "comment")
case T.NUMBER:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "NUMBER")
+ yield ColorSpan(span, "number")
case T.OP:
if token.string in "([{":
bracket_level += 1
elif token.string in ")]}":
bracket_level -= 1
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "OP")
+ yield ColorSpan(span, "op")
case T.NAME:
if is_def_name:
is_def_name = False
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "DEFINITION")
+ yield ColorSpan(span, "definition")
elif keyword.iskeyword(token.string):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "KEYWORD")
+ yield ColorSpan(span, "keyword")
if token.string in IDENTIFIERS_AFTER:
is_def_name = True
elif (
@@ -194,10 +207,10 @@ def gen_colors_from_token_stream(
and is_soft_keyword_used(prev_token, token, next_token)
):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "SOFT_KEYWORD")
+ yield ColorSpan(span, "soft_keyword")
elif token.string in BUILTINS:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "BUILTIN")
+ yield ColorSpan(span, "builtin")
keyword_first_sets_match = {"False", "None", "True", "await", "lambda", "not"}
@@ -249,7 +262,10 @@ def is_soft_keyword_used(*tokens: TI | None) -> bool:
def disp_str(
- buffer: str, colors: list[ColorSpan] | None = None, start_index: int = 0
+ buffer: str,
+ colors: list[ColorSpan] | None = None,
+ start_index: int = 0,
+ force_color: bool = False,
) -> tuple[CharBuffer, CharWidths]:
r"""Decompose the input buffer into a printable variant with applied colors.
@@ -290,15 +306,16 @@ def disp_str(
# move past irrelevant spans
colors.pop(0)
+ theme = THEME(force_color=force_color)
pre_color = ""
post_color = ""
if colors and colors[0].span.start < start_index:
# looks like we're continuing a previous color (e.g. a multiline str)
- pre_color = _colorize.theme[colors[0].tag]
+ pre_color = theme[colors[0].tag]
for i, c in enumerate(buffer, start_index):
if colors and colors[0].span.start == i: # new color starts now
- pre_color = _colorize.theme[colors[0].tag]
+ pre_color = theme[colors[0].tag]
if c == "\x1a": # CTRL-Z on Windows
chars.append(c)
@@ -315,7 +332,7 @@ def disp_str(
char_widths.append(str_width(c))
if colors and colors[0].span.end == i: # current color ends now
- post_color = _colorize.theme["RESET"]
+ post_color = theme.reset
colors.pop(0)
chars[-1] = pre_color + chars[-1] + post_color
@@ -325,7 +342,7 @@ def disp_str(
if colors and colors[0].span.start < i and colors[0].span.end > i:
# even though the current color should be continued, reset it for now.
# the next call to `disp_str()` will revive it.
- chars[-1] += _colorize.theme["RESET"]
+ chars[-1] += theme.reset
return chars, char_widths
diff --git a/Lib/_pyrepl/windows_console.py b/Lib/_pyrepl/windows_console.py
index 77985e59a93..c56dcd6d7dd 100644
--- a/Lib/_pyrepl/windows_console.py
+++ b/Lib/_pyrepl/windows_console.py
@@ -24,6 +24,7 @@ import os
import sys
import ctypes
+import types
from ctypes.wintypes import (
_COORD,
WORD,
@@ -58,6 +59,12 @@ except:
self.err = err
self.descr = descr
+# declare nt optional to allow None assignment on other platforms
+nt: types.ModuleType | None
+try:
+ import nt
+except ImportError:
+ nt = None
TYPE_CHECKING = False
@@ -121,9 +128,8 @@ class _error(Exception):
def _supports_vt():
try:
- import nt
return nt._supports_virtual_terminal()
- except (ImportError, AttributeError):
+ except AttributeError:
return False
class WindowsConsole(Console):
@@ -235,11 +241,9 @@ class WindowsConsole(Console):
@property
def input_hook(self):
- try:
- import nt
- except ImportError:
- return None
- if nt._is_inputhook_installed():
+ # avoid inline imports here so the repl doesn't get flooded
+ # with import logging from -X importtime=2
+ if nt is not None and nt._is_inputhook_installed():
return nt._inputhook
def __write_changed_line(
@@ -415,10 +419,7 @@ class WindowsConsole(Console):
return info.srWindow.Bottom # type: ignore[no-any-return]
- def _read_input(self, block: bool = True) -> INPUT_RECORD | None:
- if not block and not self.wait(timeout=0):
- return None
-
+ def _read_input(self) -> INPUT_RECORD | None:
rec = INPUT_RECORD()
read = DWORD()
if not ReadConsoleInput(InHandle, rec, 1, read):
@@ -427,14 +428,10 @@ class WindowsConsole(Console):
return rec
def _read_input_bulk(
- self, block: bool, n: int
+ self, n: int
) -> tuple[ctypes.Array[INPUT_RECORD], int]:
rec = (n * INPUT_RECORD)()
read = DWORD()
-
- if not block and not self.wait(timeout=0):
- return rec, 0
-
if not ReadConsoleInput(InHandle, rec, n, read):
raise WinError(GetLastError())
@@ -445,8 +442,11 @@ class WindowsConsole(Console):
and there is no event pending, otherwise waits for the
completion of an event."""
+ if not block and not self.wait(timeout=0):
+ return None
+
while self.event_queue.empty():
- rec = self._read_input(block)
+ rec = self._read_input()
if rec is None:
return None
@@ -464,7 +464,7 @@ class WindowsConsole(Console):
if key == "\r":
# Make enter unix-like
- return Event(evt="key", data="\n", raw=b"\n")
+ return Event(evt="key", data="\n")
elif key_event.wVirtualKeyCode == 8:
# Turn backspace directly into the command
key = "backspace"
@@ -476,24 +476,29 @@ class WindowsConsole(Console):
key = f"ctrl {key}"
elif key_event.dwControlKeyState & ALT_ACTIVE:
# queue the key, return the meta command
- self.event_queue.insert(Event(evt="key", data=key, raw=key))
+ self.event_queue.insert(Event(evt="key", data=key))
return Event(evt="key", data="\033") # keymap.py uses this for meta
- return Event(evt="key", data=key, raw=key)
+ return Event(evt="key", data=key)
if block:
continue
return None
elif self.__vt_support:
# If virtual terminal is enabled, scanning VT sequences
- self.event_queue.push(rec.Event.KeyEvent.uChar.UnicodeChar)
+ for char in raw_key.encode(self.event_queue.encoding, "replace"):
+ self.event_queue.push(char)
continue
if key_event.dwControlKeyState & ALT_ACTIVE:
- # queue the key, return the meta command
- self.event_queue.insert(Event(evt="key", data=key, raw=raw_key))
- return Event(evt="key", data="\033") # keymap.py uses this for meta
-
- return Event(evt="key", data=key, raw=raw_key)
+ # Do not swallow characters that have been entered via AltGr:
+ # Windows internally converts AltGr to CTRL+ALT, see
+ # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw
+ if not key_event.dwControlKeyState & CTRL_ACTIVE:
+ # queue the key, return the meta command
+ self.event_queue.insert(Event(evt="key", data=key))
+ return Event(evt="key", data="\033") # keymap.py uses this for meta
+
+ return Event(evt="key", data=key)
return self.event_queue.get()
def push_char(self, char: int | bytes) -> None:
@@ -542,12 +547,20 @@ class WindowsConsole(Console):
if e2:
e.data += e2.data
- recs, rec_count = self._read_input_bulk(False, 1024)
+ recs, rec_count = self._read_input_bulk(1024)
for i in range(rec_count):
rec = recs[i]
+ # In case of a legacy console, we do not only receive a keydown
+ # event, but also a keyup event - and for uppercase letters
+ # an additional SHIFT_PRESSED event.
if rec and rec.EventType == KEY_EVENT:
key_event = rec.Event.KeyEvent
+ if not key_event.bKeyDown:
+ continue
ch = key_event.uChar.UnicodeChar
+ if ch == "\x00":
+ # ignore SHIFT_PRESSED and special keys
+ continue
if ch == "\r":
ch += "\n"
e.data += ch
diff --git a/Lib/_strptime.py b/Lib/_strptime.py
index aa63933a49d..cdc55e8daaf 100644
--- a/Lib/_strptime.py
+++ b/Lib/_strptime.py
@@ -14,6 +14,7 @@ import os
import time
import locale
import calendar
+import re
from re import compile as re_compile
from re import sub as re_sub
from re import IGNORECASE
@@ -41,6 +42,29 @@ def _findall(haystack, needle):
yield i
i += len(needle)
+def _fixmonths(months):
+ yield from months
+ # The lower case of 'İ' ('\u0130') is 'i\u0307'.
+ # The re module only supports 1-to-1 character matching in
+ # case-insensitive mode.
+ for s in months:
+ if 'i\u0307' in s:
+ yield s.replace('i\u0307', '\u0130')
+
+lzh_TW_alt_digits = (
+ # 〇:一:二:三:四:五:六:七:八:九
+ '\u3007', '\u4e00', '\u4e8c', '\u4e09', '\u56db',
+ '\u4e94', '\u516d', '\u4e03', '\u516b', '\u4e5d',
+ # 十:十一:十二:十三:十四:十五:十六:十七:十八:十九
+ '\u5341', '\u5341\u4e00', '\u5341\u4e8c', '\u5341\u4e09', '\u5341\u56db',
+ '\u5341\u4e94', '\u5341\u516d', '\u5341\u4e03', '\u5341\u516b', '\u5341\u4e5d',
+ # 廿:廿一:廿二:廿三:廿四:廿五:廿六:廿七:廿八:廿九
+ '\u5eff', '\u5eff\u4e00', '\u5eff\u4e8c', '\u5eff\u4e09', '\u5eff\u56db',
+ '\u5eff\u4e94', '\u5eff\u516d', '\u5eff\u4e03', '\u5eff\u516b', '\u5eff\u4e5d',
+ # 卅:卅一
+ '\u5345', '\u5345\u4e00')
+
+
class LocaleTime(object):
"""Stores and handles locale-specific information related to time.
@@ -84,6 +108,7 @@ class LocaleTime(object):
self.__calc_weekday()
self.__calc_month()
self.__calc_am_pm()
+ self.__calc_alt_digits()
self.__calc_timezone()
self.__calc_date_time()
if _getlang() != self.lang:
@@ -119,9 +144,43 @@ class LocaleTime(object):
am_pm.append(time.strftime("%p", time_tuple).lower().strip())
self.am_pm = am_pm
+ def __calc_alt_digits(self):
+ # Set self.LC_alt_digits by using time.strftime().
+
+ # The magic data should contain all decimal digits.
+ time_tuple = time.struct_time((1998, 1, 27, 10, 43, 56, 1, 27, 0))
+ s = time.strftime("%x%X", time_tuple)
+ if s.isascii():
+ # Fast path -- all digits are ASCII.
+ self.LC_alt_digits = ()
+ return
+
+ digits = ''.join(sorted(set(re.findall(r'\d', s))))
+ if len(digits) == 10 and ord(digits[-1]) == ord(digits[0]) + 9:
+ # All 10 decimal digits from the same set.
+ if digits.isascii():
+ # All digits are ASCII.
+ self.LC_alt_digits = ()
+ return
+
+ self.LC_alt_digits = [a + b for a in digits for b in digits]
+ # Test whether the numbers contain leading zero.
+ time_tuple2 = time.struct_time((2000, 1, 1, 1, 1, 1, 5, 1, 0))
+ if self.LC_alt_digits[1] not in time.strftime("%x %X", time_tuple2):
+ self.LC_alt_digits[:10] = digits
+ return
+
+ # Either non-Gregorian calendar or non-decimal numbers.
+ if {'\u4e00', '\u4e03', '\u4e5d', '\u5341', '\u5eff'}.issubset(s):
+ # lzh_TW
+ self.LC_alt_digits = lzh_TW_alt_digits
+ return
+
+ self.LC_alt_digits = None
+
def __calc_date_time(self):
- # Set self.date_time, self.date, & self.time by using
- # time.strftime().
+ # Set self.LC_date_time, self.LC_date, self.LC_time and
+ # self.LC_time_ampm by using time.strftime().
# Use (1999,3,17,22,44,55,2,76,0) for magic date because the amount of
# overloaded numbers is minimized. The order in which searches for
@@ -129,26 +188,32 @@ class LocaleTime(object):
# possible ambiguity for what something represents.
time_tuple = time.struct_time((1999,3,17,22,44,55,2,76,0))
time_tuple2 = time.struct_time((1999,1,3,1,1,1,6,3,0))
- replacement_pairs = [
+ replacement_pairs = []
+
+ # Non-ASCII digits
+ if self.LC_alt_digits or self.LC_alt_digits is None:
+ for n, d in [(19, '%OC'), (99, '%Oy'), (22, '%OH'),
+ (44, '%OM'), (55, '%OS'), (17, '%Od'),
+ (3, '%Om'), (2, '%Ow'), (10, '%OI')]:
+ if self.LC_alt_digits is None:
+ s = chr(0x660 + n // 10) + chr(0x660 + n % 10)
+ replacement_pairs.append((s, d))
+ if n < 10:
+ replacement_pairs.append((s[1], d))
+ elif len(self.LC_alt_digits) > n:
+ replacement_pairs.append((self.LC_alt_digits[n], d))
+ else:
+ replacement_pairs.append((time.strftime(d, time_tuple), d))
+ replacement_pairs += [
('1999', '%Y'), ('99', '%y'), ('22', '%H'),
('44', '%M'), ('55', '%S'), ('76', '%j'),
('17', '%d'), ('03', '%m'), ('3', '%m'),
# '3' needed for when no leading zero.
('2', '%w'), ('10', '%I'),
- # Non-ASCII digits
- ('\u0661\u0669\u0669\u0669', '%Y'),
- ('\u0669\u0669', '%Oy'),
- ('\u0662\u0662', '%OH'),
- ('\u0664\u0664', '%OM'),
- ('\u0665\u0665', '%OS'),
- ('\u0661\u0667', '%Od'),
- ('\u0660\u0663', '%Om'),
- ('\u0663', '%Om'),
- ('\u0662', '%Ow'),
- ('\u0661\u0660', '%OI'),
]
+
date_time = []
- for directive in ('%c', '%x', '%X'):
+ for directive in ('%c', '%x', '%X', '%r'):
current_format = time.strftime(directive, time_tuple).lower()
current_format = current_format.replace('%', '%%')
# The month and the day of the week formats are treated specially
@@ -172,9 +237,10 @@ class LocaleTime(object):
if tz:
current_format = current_format.replace(tz, "%Z")
# Transform all non-ASCII digits to digits in range U+0660 to U+0669.
- current_format = re_sub(r'\d(?<![0-9])',
- lambda m: chr(0x0660 + int(m[0])),
- current_format)
+ if not current_format.isascii() and self.LC_alt_digits is None:
+ current_format = re_sub(r'\d(?<![0-9])',
+ lambda m: chr(0x0660 + int(m[0])),
+ current_format)
for old, new in replacement_pairs:
current_format = current_format.replace(old, new)
# If %W is used, then Sunday, 2005-01-03 will fall on week 0 since
@@ -189,6 +255,7 @@ class LocaleTime(object):
self.LC_date_time = date_time[0]
self.LC_date = date_time[1]
self.LC_time = date_time[2]
+ self.LC_time_ampm = date_time[3]
def __find_month_format(self, directive):
"""Find the month format appropriate for the current locale.
@@ -213,7 +280,7 @@ class LocaleTime(object):
full_indices &= indices
indices = set(_findall(datetime, self.a_month[m]))
if abbr_indices is None:
- abbr_indices = indices
+ abbr_indices = set(indices)
else:
abbr_indices &= indices
if not full_indices and not abbr_indices:
@@ -241,7 +308,7 @@ class LocaleTime(object):
if self.f_weekday[wd] != self.a_weekday[wd]:
indices = set(_findall(datetime, self.a_weekday[wd]))
if abbr_indices is None:
- abbr_indices = indices
+ abbr_indices = set(indices)
else:
abbr_indices &= indices
if not full_indices and not abbr_indices:
@@ -288,8 +355,10 @@ class TimeRE(dict):
# The " [1-9]" part of the regex is to make %c from ANSI C work
'd': r"(?P<d>3[0-1]|[1-2]\d|0[1-9]|[1-9]| [1-9])",
'f': r"(?P<f>[0-9]{1,6})",
- 'H': r"(?P<H>2[0-3]|[0-1]\d|\d)",
+ 'H': r"(?P<H>2[0-3]|[0-1]\d|\d| \d)",
+ 'k': r"(?P<H>2[0-3]|[0-1]\d|\d| \d)",
'I': r"(?P<I>1[0-2]|0[1-9]|[1-9]| [1-9])",
+ 'l': r"(?P<I>1[0-2]|0[1-9]|[1-9]| [1-9])",
'G': r"(?P<G>\d\d\d\d)",
'j': r"(?P<j>36[0-6]|3[0-5]\d|[1-2]\d\d|0[1-9]\d|00[1-9]|[1-9]\d|0[1-9]|[1-9])",
'm': r"(?P<m>1[0-2]|0[1-9]|[1-9])",
@@ -302,26 +371,59 @@ class TimeRE(dict):
# W is set below by using 'U'
'y': r"(?P<y>\d\d)",
'Y': r"(?P<Y>\d\d\d\d)",
- 'z': r"(?P<z>[+-]\d\d:?[0-5]\d(:?[0-5]\d(\.\d{1,6})?)?|(?-i:Z))",
+ 'z': r"(?P<z>([+-]\d\d:?[0-5]\d(:?[0-5]\d(\.\d{1,6})?)?)|(?-i:Z))?",
'A': self.__seqToRE(self.locale_time.f_weekday, 'A'),
'a': self.__seqToRE(self.locale_time.a_weekday, 'a'),
- 'B': self.__seqToRE(self.locale_time.f_month[1:], 'B'),
- 'b': self.__seqToRE(self.locale_time.a_month[1:], 'b'),
+ 'B': self.__seqToRE(_fixmonths(self.locale_time.f_month[1:]), 'B'),
+ 'b': self.__seqToRE(_fixmonths(self.locale_time.a_month[1:]), 'b'),
'p': self.__seqToRE(self.locale_time.am_pm, 'p'),
'Z': self.__seqToRE((tz for tz_names in self.locale_time.timezone
for tz in tz_names),
'Z'),
'%': '%'}
- for d in 'dmyHIMS':
- mapping['O' + d] = r'(?P<%s>\d\d|\d| \d)' % d
- mapping['Ow'] = r'(?P<w>\d)'
+ if self.locale_time.LC_alt_digits is None:
+ for d in 'dmyCHIMS':
+ mapping['O' + d] = r'(?P<%s>\d\d|\d| \d)' % d
+ mapping['Ow'] = r'(?P<w>\d)'
+ else:
+ mapping.update({
+ 'Od': self.__seqToRE(self.locale_time.LC_alt_digits[1:32], 'd',
+ '3[0-1]|[1-2][0-9]|0[1-9]|[1-9]'),
+ 'Om': self.__seqToRE(self.locale_time.LC_alt_digits[1:13], 'm',
+ '1[0-2]|0[1-9]|[1-9]'),
+ 'Ow': self.__seqToRE(self.locale_time.LC_alt_digits[:7], 'w',
+ '[0-6]'),
+ 'Oy': self.__seqToRE(self.locale_time.LC_alt_digits, 'y',
+ '[0-9][0-9]'),
+ 'OC': self.__seqToRE(self.locale_time.LC_alt_digits, 'C',
+ '[0-9][0-9]'),
+ 'OH': self.__seqToRE(self.locale_time.LC_alt_digits[:24], 'H',
+ '2[0-3]|[0-1][0-9]|[0-9]'),
+ 'OI': self.__seqToRE(self.locale_time.LC_alt_digits[1:13], 'I',
+ '1[0-2]|0[1-9]|[1-9]'),
+ 'OM': self.__seqToRE(self.locale_time.LC_alt_digits[:60], 'M',
+ '[0-5][0-9]|[0-9]'),
+ 'OS': self.__seqToRE(self.locale_time.LC_alt_digits[:62], 'S',
+ '6[0-1]|[0-5][0-9]|[0-9]'),
+ })
+ mapping.update({
+ 'e': mapping['d'],
+ 'Oe': mapping['Od'],
+ 'P': mapping['p'],
+ 'Op': mapping['p'],
+ 'W': mapping['U'].replace('U', 'W'),
+ })
mapping['W'] = mapping['U'].replace('U', 'W')
+
base.__init__(mapping)
+ base.__setitem__('T', self.pattern('%H:%M:%S'))
+ base.__setitem__('R', self.pattern('%H:%M'))
+ base.__setitem__('r', self.pattern(self.locale_time.LC_time_ampm))
base.__setitem__('X', self.pattern(self.locale_time.LC_time))
base.__setitem__('x', self.pattern(self.locale_time.LC_date))
base.__setitem__('c', self.pattern(self.locale_time.LC_date_time))
- def __seqToRE(self, to_convert, directive):
+ def __seqToRE(self, to_convert, directive, altregex=None):
"""Convert a list to a regex string for matching a directive.
Want possible matching values to be from longest to shortest. This
@@ -337,8 +439,9 @@ class TimeRE(dict):
else:
return ''
regex = '|'.join(re_escape(stuff) for stuff in to_convert)
- regex = '(?P<%s>%s' % (directive, regex)
- return '%s)' % regex
+ if altregex is not None:
+ regex += '|' + altregex
+ return '(?P<%s>%s)' % (directive, regex)
def pattern(self, format):
"""Return regex pattern for the format string.
@@ -365,7 +468,7 @@ class TimeRE(dict):
nonlocal day_of_month_in_format
day_of_month_in_format = True
return self[format_char]
- format = re_sub(r'%([OE]?\\?.?)', repl, format)
+ format = re_sub(r'%[-_0^#]*[0-9]*([OE]?\\?.?)', repl, format)
if day_of_month_in_format and not year_in_format:
import warnings
warnings.warn("""\
@@ -467,6 +570,15 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
# values
weekday = julian = None
found_dict = found.groupdict()
+ if locale_time.LC_alt_digits:
+ def parse_int(s):
+ try:
+ return locale_time.LC_alt_digits.index(s)
+ except ValueError:
+ return int(s)
+ else:
+ parse_int = int
+
for group_key in found_dict.keys():
# Directives not explicitly handled below:
# c, x, X
@@ -474,30 +586,34 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
# U, W
# worthless without day of the week
if group_key == 'y':
- year = int(found_dict['y'])
- # Open Group specification for strptime() states that a %y
- #value in the range of [00, 68] is in the century 2000, while
- #[69,99] is in the century 1900
- if year <= 68:
- year += 2000
+ year = parse_int(found_dict['y'])
+ if 'C' in found_dict:
+ century = parse_int(found_dict['C'])
+ year += century * 100
else:
- year += 1900
+ # Open Group specification for strptime() states that a %y
+ #value in the range of [00, 68] is in the century 2000, while
+ #[69,99] is in the century 1900
+ if year <= 68:
+ year += 2000
+ else:
+ year += 1900
elif group_key == 'Y':
year = int(found_dict['Y'])
elif group_key == 'G':
iso_year = int(found_dict['G'])
elif group_key == 'm':
- month = int(found_dict['m'])
+ month = parse_int(found_dict['m'])
elif group_key == 'B':
month = locale_time.f_month.index(found_dict['B'].lower())
elif group_key == 'b':
month = locale_time.a_month.index(found_dict['b'].lower())
elif group_key == 'd':
- day = int(found_dict['d'])
+ day = parse_int(found_dict['d'])
elif group_key == 'H':
- hour = int(found_dict['H'])
+ hour = parse_int(found_dict['H'])
elif group_key == 'I':
- hour = int(found_dict['I'])
+ hour = parse_int(found_dict['I'])
ampm = found_dict.get('p', '').lower()
# If there was no AM/PM indicator, we'll treat this like AM
if ampm in ('', locale_time.am_pm[0]):
@@ -513,9 +629,9 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
if hour != 12:
hour += 12
elif group_key == 'M':
- minute = int(found_dict['M'])
+ minute = parse_int(found_dict['M'])
elif group_key == 'S':
- second = int(found_dict['S'])
+ second = parse_int(found_dict['S'])
elif group_key == 'f':
s = found_dict['f']
# Pad to always return microseconds.
@@ -548,27 +664,28 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
iso_week = int(found_dict['V'])
elif group_key == 'z':
z = found_dict['z']
- if z == 'Z':
- gmtoff = 0
- else:
- if z[3] == ':':
- z = z[:3] + z[4:]
- if len(z) > 5:
- if z[5] != ':':
- msg = f"Inconsistent use of : in {found_dict['z']}"
- raise ValueError(msg)
- z = z[:5] + z[6:]
- hours = int(z[1:3])
- minutes = int(z[3:5])
- seconds = int(z[5:7] or 0)
- gmtoff = (hours * 60 * 60) + (minutes * 60) + seconds
- gmtoff_remainder = z[8:]
- # Pad to always return microseconds.
- gmtoff_remainder_padding = "0" * (6 - len(gmtoff_remainder))
- gmtoff_fraction = int(gmtoff_remainder + gmtoff_remainder_padding)
- if z.startswith("-"):
- gmtoff = -gmtoff
- gmtoff_fraction = -gmtoff_fraction
+ if z:
+ if z == 'Z':
+ gmtoff = 0
+ else:
+ if z[3] == ':':
+ z = z[:3] + z[4:]
+ if len(z) > 5:
+ if z[5] != ':':
+ msg = f"Inconsistent use of : in {found_dict['z']}"
+ raise ValueError(msg)
+ z = z[:5] + z[6:]
+ hours = int(z[1:3])
+ minutes = int(z[3:5])
+ seconds = int(z[5:7] or 0)
+ gmtoff = (hours * 60 * 60) + (minutes * 60) + seconds
+ gmtoff_remainder = z[8:]
+ # Pad to always return microseconds.
+ gmtoff_remainder_padding = "0" * (6 - len(gmtoff_remainder))
+ gmtoff_fraction = int(gmtoff_remainder + gmtoff_remainder_padding)
+ if z.startswith("-"):
+ gmtoff = -gmtoff
+ gmtoff_fraction = -gmtoff_fraction
elif group_key == 'Z':
# Since -1 is default value only need to worry about setting tz if
# it can be something other than -1.
diff --git a/Lib/_threading_local.py b/Lib/_threading_local.py
index b006d76c4e2..0b9e5d3bbf6 100644
--- a/Lib/_threading_local.py
+++ b/Lib/_threading_local.py
@@ -4,128 +4,6 @@
class. Depending on the version of Python you're using, there may be a
faster one available. You should always import the `local` class from
`threading`.)
-
-Thread-local objects support the management of thread-local data.
-If you have data that you want to be local to a thread, simply create
-a thread-local object and use its attributes:
-
- >>> mydata = local()
- >>> mydata.number = 42
- >>> mydata.number
- 42
-
-You can also access the local-object's dictionary:
-
- >>> mydata.__dict__
- {'number': 42}
- >>> mydata.__dict__.setdefault('widgets', [])
- []
- >>> mydata.widgets
- []
-
-What's important about thread-local objects is that their data are
-local to a thread. If we access the data in a different thread:
-
- >>> log = []
- >>> def f():
- ... items = sorted(mydata.__dict__.items())
- ... log.append(items)
- ... mydata.number = 11
- ... log.append(mydata.number)
-
- >>> import threading
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
- >>> log
- [[], 11]
-
-we get different data. Furthermore, changes made in the other thread
-don't affect data seen in this thread:
-
- >>> mydata.number
- 42
-
-Of course, values you get from a local object, including a __dict__
-attribute, are for whatever thread was current at the time the
-attribute was read. For that reason, you generally don't want to save
-these values across threads, as they apply only to the thread they
-came from.
-
-You can create custom local objects by subclassing the local class:
-
- >>> class MyLocal(local):
- ... number = 2
- ... def __init__(self, /, **kw):
- ... self.__dict__.update(kw)
- ... def squared(self):
- ... return self.number ** 2
-
-This can be useful to support default values, methods and
-initialization. Note that if you define an __init__ method, it will be
-called each time the local object is used in a separate thread. This
-is necessary to initialize each thread's dictionary.
-
-Now if we create a local object:
-
- >>> mydata = MyLocal(color='red')
-
-Now we have a default number:
-
- >>> mydata.number
- 2
-
-an initial color:
-
- >>> mydata.color
- 'red'
- >>> del mydata.color
-
-And a method that operates on the data:
-
- >>> mydata.squared()
- 4
-
-As before, we can access the data in a separate thread:
-
- >>> log = []
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
- >>> log
- [[('color', 'red')], 11]
-
-without affecting this thread's data:
-
- >>> mydata.number
- 2
- >>> mydata.color
- Traceback (most recent call last):
- ...
- AttributeError: 'MyLocal' object has no attribute 'color'
-
-Note that subclasses can define slots, but they are not thread
-local. They are shared across threads:
-
- >>> class MyLocal(local):
- ... __slots__ = 'number'
-
- >>> mydata = MyLocal()
- >>> mydata.number = 42
- >>> mydata.color = 'red'
-
-So, the separate thread:
-
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
-
-affects what we see:
-
- >>> mydata.number
- 11
-
->>> del mydata
"""
from weakref import ref
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 971f636f971..c83a1573ccd 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -12,7 +12,7 @@ __all__ = [
"ForwardRef",
"call_annotate_function",
"call_evaluate_function",
- "get_annotate_function",
+ "get_annotate_from_class_namespace",
"get_annotations",
"annotations_to_string",
"type_repr",
@@ -27,6 +27,9 @@ class Format(enum.IntEnum):
_sentinel = object()
+# Following `NAME_ERROR_MSG` in `ceval_macros.h`:
+_NAME_ERROR_MSG = "name '{name:.200}' is not defined"
+
# Slots shared by ForwardRef and _Stringifier. The __forward__ names must be
# preserved for compatibility with the old typing.ForwardRef class. The remaining
@@ -38,6 +41,7 @@ _SLOTS = (
"__weakref__",
"__arg__",
"__globals__",
+ "__extra_names__",
"__code__",
"__ast_node__",
"__cell__",
@@ -82,6 +86,7 @@ class ForwardRef:
# is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None
self.__cell__ = None
+ self.__extra_names__ = None
# These are initially None but serve as a cache and may be set to a non-None
# value later.
self.__code__ = None
@@ -90,11 +95,28 @@ class ForwardRef:
def __init_subclass__(cls, /, *args, **kwds):
raise TypeError("Cannot subclass ForwardRef")
- def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
+ def evaluate(
+ self,
+ *,
+ globals=None,
+ locals=None,
+ type_params=None,
+ owner=None,
+ format=Format.VALUE,
+ ):
"""Evaluate the forward reference and return the value.
If the forward reference cannot be evaluated, raise an exception.
"""
+ match format:
+ case Format.STRING:
+ return self.__forward_arg__
+ case Format.VALUE:
+ is_forwardref_format = False
+ case Format.FORWARDREF:
+ is_forwardref_format = True
+ case _:
+ raise NotImplementedError(format)
if self.__cell__ is not None:
try:
return self.__cell__.cell_contents
@@ -151,21 +173,42 @@ class ForwardRef:
if not self.__forward_is_class__ or param_name not in globals:
globals[param_name] = param
locals.pop(param_name, None)
+ if self.__extra_names__:
+ locals = {**locals, **self.__extra_names__}
arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg):
if arg in locals:
- value = locals[arg]
+ return locals[arg]
elif arg in globals:
- value = globals[arg]
+ return globals[arg]
elif hasattr(builtins, arg):
return getattr(builtins, arg)
+ elif is_forwardref_format:
+ return self
else:
- raise NameError(arg)
+ raise NameError(_NAME_ERROR_MSG.format(name=arg), name=arg)
else:
code = self.__forward_code__
- value = eval(code, globals=globals, locals=locals)
- return value
+ try:
+ return eval(code, globals=globals, locals=locals)
+ except Exception:
+ if not is_forwardref_format:
+ raise
+ new_locals = _StringifierDict(
+ {**builtins.__dict__, **locals},
+ globals=globals,
+ owner=owner,
+ is_class=self.__forward_is_class__,
+ format=format,
+ )
+ try:
+ result = eval(code, globals=globals, locals=new_locals)
+ except Exception:
+ return self
+ else:
+ new_locals.transmogrify()
+ return result
def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
import typing
@@ -231,6 +274,10 @@ class ForwardRef:
and self.__forward_is_class__ == other.__forward_is_class__
and self.__cell__ == other.__cell__
and self.__owner__ == other.__owner__
+ and (
+ (tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
+ (tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
+ )
)
def __hash__(self):
@@ -241,6 +288,7 @@ class ForwardRef:
self.__forward_is_class__,
self.__cell__,
self.__owner__,
+ tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
))
def __or__(self, other):
@@ -260,6 +308,9 @@ class ForwardRef:
return f"ForwardRef({self.__forward_arg__!r}{''.join(extra)})"
+_Template = type(t"")
+
+
class _Stringifier:
# Must match the slots on ForwardRef, so we can turn an instance of one into an
# instance of the other in place.
@@ -274,6 +325,7 @@ class _Stringifier:
cell=None,
*,
stringifier_dict,
+ extra_names=None,
):
# Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name).
@@ -285,6 +337,7 @@ class _Stringifier:
self.__code__ = None
self.__ast_node__ = node
self.__globals__ = globals
+ self.__extra_names__ = extra_names
self.__cell__ = cell
self.__owner__ = owner
self.__stringifier_dict__ = stringifier_dict
@@ -292,28 +345,65 @@ class _Stringifier:
def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str):
- return ast.Name(id=other.__ast_node__)
- return other.__ast_node__
- elif isinstance(other, slice):
+ return ast.Name(id=other.__ast_node__), other.__extra_names__
+ return other.__ast_node__, other.__extra_names__
+ elif type(other) is _Template:
+ return _template_to_ast(other), None
+ elif (
+ # In STRING format we don't bother with the create_unique_name() dance;
+ # it's better to emit the repr() of the object instead of an opaque name.
+ self.__stringifier_dict__.format == Format.STRING
+ or other is None
+ or type(other) in (str, int, float, bool, complex)
+ ):
+ return ast.Constant(value=other), None
+ elif type(other) is dict:
+ extra_names = {}
+ keys = []
+ values = []
+ for key, value in other.items():
+ new_key, new_extra_names = self.__convert_to_ast(key)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ keys.append(new_key)
+ new_value, new_extra_names = self.__convert_to_ast(value)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ values.append(new_value)
+ return ast.Dict(keys, values), extra_names
+ elif type(other) in (list, tuple, set):
+ extra_names = {}
+ elts = []
+ for elt in other:
+ new_elt, new_extra_names = self.__convert_to_ast(elt)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ elts.append(new_elt)
+ ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
+ return ast_class(elts), extra_names
+ else:
+ name = self.__stringifier_dict__.create_unique_name()
+ return ast.Name(id=name), {name: other}
+
+ def __convert_to_ast_getitem(self, other):
+ if isinstance(other, slice):
+ extra_names = {}
+
+ def conv(obj):
+ if obj is None:
+ return None
+ new_obj, new_extra_names = self.__convert_to_ast(obj)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ return new_obj
+
return ast.Slice(
- lower=(
- self.__convert_to_ast(other.start)
- if other.start is not None
- else None
- ),
- upper=(
- self.__convert_to_ast(other.stop)
- if other.stop is not None
- else None
- ),
- step=(
- self.__convert_to_ast(other.step)
- if other.step is not None
- else None
- ),
- )
+ lower=conv(other.start),
+ upper=conv(other.stop),
+ step=conv(other.step),
+ ), extra_names
else:
- return ast.Constant(value=other)
+ return self.__convert_to_ast(other)
def __get_ast(self):
node = self.__ast_node__
@@ -321,13 +411,19 @@ class _Stringifier:
return ast.Name(id=node)
return node
- def __make_new(self, node):
+ def __make_new(self, node, extra_names=None):
+ new_extra_names = {}
+ if self.__extra_names__ is not None:
+ new_extra_names.update(self.__extra_names__)
+ if extra_names is not None:
+ new_extra_names.update(extra_names)
stringifier = _Stringifier(
node,
self.__globals__,
self.__owner__,
self.__forward_is_class__,
stringifier_dict=self.__stringifier_dict__,
+ extra_names=new_extra_names or None,
)
self.__stringifier_dict__.stringifiers.append(stringifier)
return stringifier
@@ -343,27 +439,37 @@ class _Stringifier:
if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
- elts = [self.__convert_to_ast(elt) for elt in other]
+ extra_names = {}
+ elts = []
+ for elt in other:
+ new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ elts.append(new_elt)
other = ast.Tuple(elts)
else:
- other = self.__convert_to_ast(other)
+ other, extra_names = self.__convert_to_ast_getitem(other)
assert isinstance(other, ast.AST), repr(other)
- return self.__make_new(ast.Subscript(self.__get_ast(), other))
+ return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)
def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__get_ast(), attr))
def __call__(self, *args, **kwargs):
- return self.__make_new(
- ast.Call(
- self.__get_ast(),
- [self.__convert_to_ast(arg) for arg in args],
- [
- ast.keyword(key, self.__convert_to_ast(value))
- for key, value in kwargs.items()
- ],
- )
- )
+ extra_names = {}
+ ast_args = []
+ for arg in args:
+ new_arg, new_extra_names = self.__convert_to_ast(arg)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ ast_args.append(new_arg)
+ ast_kwargs = []
+ for key, value in kwargs.items():
+ new_value, new_extra_names = self.__convert_to_ast(value)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ ast_kwargs.append(ast.keyword(key, new_value))
+ return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)
def __iter__(self):
yield self.__make_new(ast.Starred(self.__get_ast()))
@@ -378,8 +484,9 @@ class _Stringifier:
def _make_binop(op: ast.AST):
def binop(self, other):
+ rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
- ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
+ ast.BinOp(self.__get_ast(), op, rhs), extra_names
)
return binop
@@ -402,8 +509,9 @@ class _Stringifier:
def _make_rbinop(op: ast.AST):
def rbinop(self, other):
+ new_other, extra_names = self.__convert_to_ast(other)
return self.__make_new(
- ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
+ ast.BinOp(new_other, op, self.__get_ast()), extra_names
)
return rbinop
@@ -426,12 +534,14 @@ class _Stringifier:
def _make_compare(op):
def compare(self, other):
+ rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
ast.Compare(
left=self.__get_ast(),
ops=[op],
- comparators=[self.__convert_to_ast(other)],
- )
+ comparators=[rhs],
+ ),
+ extra_names,
)
return compare
@@ -458,14 +568,42 @@ class _Stringifier:
del _make_unary_op
+def _template_to_ast(template):
+ values = []
+ for part in template:
+ match part:
+ case str():
+ values.append(ast.Constant(value=part))
+ # Interpolation, but we don't want to import the string module
+ case _:
+ interp = ast.Interpolation(
+ str=part.expression,
+ value=ast.parse(part.expression),
+ conversion=(
+ ord(part.conversion)
+ if part.conversion is not None
+ else -1
+ ),
+ format_spec=(
+ ast.Constant(value=part.format_spec)
+ if part.format_spec != ""
+ else None
+ ),
+ )
+ values.append(interp)
+ return ast.TemplateStr(values=values)
+
+
class _StringifierDict(dict):
- def __init__(self, namespace, globals=None, owner=None, is_class=False):
+ def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
super().__init__(namespace)
self.namespace = namespace
self.globals = globals
self.owner = owner
self.is_class = is_class
self.stringifiers = []
+ self.next_id = 1
+ self.format = format
def __missing__(self, key):
fwdref = _Stringifier(
@@ -478,6 +616,19 @@ class _StringifierDict(dict):
self.stringifiers.append(fwdref)
return fwdref
+ def transmogrify(self):
+ for obj in self.stringifiers:
+ obj.__class__ = ForwardRef
+ obj.__stringifier_dict__ = None # not needed for ForwardRef
+ if isinstance(obj.__ast_node__, str):
+ obj.__arg__ = obj.__ast_node__
+ obj.__ast_node__ = None
+
+ def create_unique_name(self):
+ name = f"__annotationlib_name_{self.next_id}__"
+ self.next_id += 1
+ return name
+
def call_evaluate_function(evaluate, format, *, owner=None):
"""Call an evaluate function. Evaluate functions are normally generated for
@@ -521,20 +672,11 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# possibly constants if the annotate function uses them directly). We then
# convert each of those into a string to get an approximation of the
# original source.
- globals = _StringifierDict({})
- if annotate.__closure__:
- freevars = annotate.__code__.co_freevars
- new_closure = []
- for i, cell in enumerate(annotate.__closure__):
- if i < len(freevars):
- name = freevars[i]
- else:
- name = "__cell__"
- fwdref = _Stringifier(name, stringifier_dict=globals)
- new_closure.append(types.CellType(fwdref))
- closure = tuple(new_closure)
- else:
- closure = None
+ globals = _StringifierDict({}, format=format)
+ is_class = isinstance(owner, type)
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=False
+ )
func = types.FunctionType(
annotate.__code__,
globals,
@@ -544,9 +686,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
)
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
if _is_evaluate:
- return annos if isinstance(annos, str) else repr(annos)
+ return _stringify_single(annos)
return {
- key: val if isinstance(val, str) else repr(val)
+ key: _stringify_single(val)
for key, val in annos.items()
}
elif format == Format.FORWARDREF:
@@ -569,33 +711,43 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# that returns a bool and an defined set of attributes.
namespace = {**annotate.__builtins__, **annotate.__globals__}
is_class = isinstance(owner, type)
- globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
- if annotate.__closure__:
- freevars = annotate.__code__.co_freevars
- new_closure = []
- for i, cell in enumerate(annotate.__closure__):
- try:
- cell.cell_contents
- except ValueError:
- if i < len(freevars):
- name = freevars[i]
- else:
- name = "__cell__"
- fwdref = _Stringifier(
- name,
- cell=cell,
- owner=owner,
- globals=annotate.__globals__,
- is_class=is_class,
- stringifier_dict=globals,
- )
- globals.stringifiers.append(fwdref)
- new_closure.append(types.CellType(fwdref))
- else:
- new_closure.append(cell)
- closure = tuple(new_closure)
+ globals = _StringifierDict(
+ namespace,
+ globals=annotate.__globals__,
+ owner=owner,
+ is_class=is_class,
+ format=format,
+ )
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=True
+ )
+ func = types.FunctionType(
+ annotate.__code__,
+ globals,
+ closure=closure,
+ argdefs=annotate.__defaults__,
+ kwdefaults=annotate.__kwdefaults__,
+ )
+ try:
+ result = func(Format.VALUE_WITH_FAKE_GLOBALS)
+ except Exception:
+ pass
else:
- closure = None
+ globals.transmogrify()
+ return result
+
+ # Try again, but do not provide any globals. This allows us to return
+ # a value in certain cases where an exception gets raised during evaluation.
+ globals = _StringifierDict(
+ {},
+ globals=annotate.__globals__,
+ owner=owner,
+ is_class=is_class,
+ format=format,
+ )
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=False
+ )
func = types.FunctionType(
annotate.__code__,
globals,
@@ -604,13 +756,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
kwdefaults=annotate.__kwdefaults__,
)
result = func(Format.VALUE_WITH_FAKE_GLOBALS)
- for obj in globals.stringifiers:
- obj.__class__ = ForwardRef
- obj.__stringifier_dict__ = None # not needed for ForwardRef
- if isinstance(obj.__ast_node__, str):
- obj.__arg__ = obj.__ast_node__
- obj.__ast_node__ = None
- return result
+ globals.transmogrify()
+ if _is_evaluate:
+ if isinstance(result, ForwardRef):
+ return result.evaluate(format=Format.FORWARDREF)
+ else:
+ return result
+ else:
+ return {
+ key: (
+ val.evaluate(format=Format.FORWARDREF)
+ if isinstance(val, ForwardRef)
+ else val
+ )
+ for key, val in result.items()
+ }
elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise
# NotImplementedError for this format.
@@ -619,20 +779,61 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
raise ValueError(f"Invalid format: {format!r}")
-def get_annotate_function(obj):
- """Get the __annotate__ function for an object.
+def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation):
+ if not annotate.__closure__:
+ return None
+ freevars = annotate.__code__.co_freevars
+ new_closure = []
+ for i, cell in enumerate(annotate.__closure__):
+ if i < len(freevars):
+ name = freevars[i]
+ else:
+ name = "__cell__"
+ new_cell = None
+ if allow_evaluation:
+ try:
+ cell.cell_contents
+ except ValueError:
+ pass
+ else:
+ new_cell = cell
+ if new_cell is None:
+ fwdref = _Stringifier(
+ name,
+ cell=cell,
+ owner=owner,
+ globals=annotate.__globals__,
+ is_class=is_class,
+ stringifier_dict=stringifier_dict,
+ )
+ stringifier_dict.stringifiers.append(fwdref)
+ new_cell = types.CellType(fwdref)
+ new_closure.append(new_cell)
+ return tuple(new_closure)
+
+
+def _stringify_single(anno):
+ if anno is ...:
+ return "..."
+ # We have to handle str specially to support PEP 563 stringified annotations.
+ elif isinstance(anno, str):
+ return anno
+ elif isinstance(anno, _Template):
+ return ast.unparse(_template_to_ast(anno))
+ else:
+ return repr(anno)
+
- obj may be a function, class, or module, or a user-defined type with
- an `__annotate__` attribute.
+def get_annotate_from_class_namespace(obj):
+ """Retrieve the annotate function from a class namespace dictionary.
- Returns the __annotate__ function or None.
+ Return None if the namespace does not contain an annotate function.
+ This is useful in metaclass ``__new__`` methods to retrieve the annotate function.
"""
- if isinstance(obj, dict):
- try:
- return obj["__annotate__"]
- except KeyError:
- return obj.get("__annotate_func__", None)
- return getattr(obj, "__annotate__", None)
+ try:
+ return obj["__annotate__"]
+ except KeyError:
+ return obj.get("__annotate_func__", None)
def get_annotations(
@@ -703,7 +904,7 @@ def get_annotations(
# For FORWARDREF, we use __annotations__ if it exists
try:
ann = _get_dunder_annotations(obj)
- except NameError:
+ except Exception:
pass
else:
if ann is not None:
@@ -724,7 +925,7 @@ def get_annotations(
# But if we didn't get it, we use __annotations__ instead.
ann = _get_dunder_annotations(obj)
if ann is not None:
- return annotations_to_string(ann)
+ return annotations_to_string(ann)
case Format.VALUE_WITH_FAKE_GLOBALS:
raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only")
case _:
@@ -741,48 +942,49 @@ def get_annotations(
if not eval_str:
return dict(ann)
- if isinstance(obj, type):
- # class
- obj_globals = None
- module_name = getattr(obj, "__module__", None)
- if module_name:
- module = sys.modules.get(module_name, None)
- if module:
- obj_globals = getattr(module, "__dict__", None)
- obj_locals = dict(vars(obj))
- unwrap = obj
- elif isinstance(obj, types.ModuleType):
- # module
- obj_globals = getattr(obj, "__dict__")
- obj_locals = None
- unwrap = None
- elif callable(obj):
- # this includes types.Function, types.BuiltinFunctionType,
- # types.BuiltinMethodType, functools.partial, functools.singledispatch,
- # "class funclike" from Lib/test/test_inspect... on and on it goes.
- obj_globals = getattr(obj, "__globals__", None)
- obj_locals = None
- unwrap = obj
- else:
- obj_globals = obj_locals = unwrap = None
-
- if unwrap is not None:
- while True:
- if hasattr(unwrap, "__wrapped__"):
- unwrap = unwrap.__wrapped__
- continue
- if functools := sys.modules.get("functools"):
- if isinstance(unwrap, functools.partial):
- unwrap = unwrap.func
+ if globals is None or locals is None:
+ if isinstance(obj, type):
+ # class
+ obj_globals = None
+ module_name = getattr(obj, "__module__", None)
+ if module_name:
+ module = sys.modules.get(module_name, None)
+ if module:
+ obj_globals = getattr(module, "__dict__", None)
+ obj_locals = dict(vars(obj))
+ unwrap = obj
+ elif isinstance(obj, types.ModuleType):
+ # module
+ obj_globals = getattr(obj, "__dict__")
+ obj_locals = None
+ unwrap = None
+ elif callable(obj):
+ # this includes types.Function, types.BuiltinFunctionType,
+ # types.BuiltinMethodType, functools.partial, functools.singledispatch,
+ # "class funclike" from Lib/test/test_inspect... on and on it goes.
+ obj_globals = getattr(obj, "__globals__", None)
+ obj_locals = None
+ unwrap = obj
+ else:
+ obj_globals = obj_locals = unwrap = None
+
+ if unwrap is not None:
+ while True:
+ if hasattr(unwrap, "__wrapped__"):
+ unwrap = unwrap.__wrapped__
continue
- break
- if hasattr(unwrap, "__globals__"):
- obj_globals = unwrap.__globals__
+ if functools := sys.modules.get("functools"):
+ if isinstance(unwrap, functools.partial):
+ unwrap = unwrap.func
+ continue
+ break
+ if hasattr(unwrap, "__globals__"):
+ obj_globals = unwrap.__globals__
- if globals is None:
- globals = obj_globals
- if locals is None:
- locals = obj_locals
+ if globals is None:
+ globals = obj_globals
+ if locals is None:
+ locals = obj_locals
# "Inject" type parameters into the local namespace
# (unless they are shadowed by assignments *in* the local namespace),
@@ -811,6 +1013,9 @@ def type_repr(value):
if value.__module__ == "builtins":
return value.__qualname__
return f"{value.__module__}.{value.__qualname__}"
+ elif isinstance(value, _Template):
+ tree = _template_to_ast(value)
+ return ast.unparse(tree)
if value is ...:
return "..."
return repr(value)
@@ -832,7 +1037,7 @@ def _get_and_call_annotate(obj, format):
May not return a fresh dictionary.
"""
- annotate = get_annotate_function(obj)
+ annotate = getattr(obj, "__annotate__", None)
if annotate is not None:
ann = call_annotate_function(annotate, format, owner=obj)
if not isinstance(ann, dict):
@@ -841,14 +1046,27 @@ def _get_and_call_annotate(obj, format):
return None
+_BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__
+
+
def _get_dunder_annotations(obj):
"""Return the annotations for an object, checking that it is a dictionary.
Does not return a fresh dictionary.
"""
- ann = getattr(obj, "__annotations__", None)
- if ann is None:
- return None
+ # This special case is needed to support types defined under
+ # from __future__ import annotations, where accessing the __annotations__
+ # attribute directly might return annotations for the wrong class.
+ if isinstance(obj, type):
+ try:
+ ann = _BASE_GET_ANNOTATIONS(obj)
+ except AttributeError:
+ # For static types, the descriptor raises AttributeError.
+ return None
+ else:
+ ann = getattr(obj, "__annotations__", None)
+ if ann is None:
+ return None
if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
diff --git a/Lib/argparse.py b/Lib/argparse.py
index c0dcd0bbff0..83258cf3e0f 100644
--- a/Lib/argparse.py
+++ b/Lib/argparse.py
@@ -167,7 +167,6 @@ class HelpFormatter(object):
indent_increment=2,
max_help_position=24,
width=None,
- prefix_chars='-',
color=False,
):
# default setting for width
@@ -176,16 +175,7 @@ class HelpFormatter(object):
width = shutil.get_terminal_size().columns
width -= 2
- from _colorize import ANSIColors, NoColors, can_colorize, decolor
-
- if color and can_colorize():
- self._ansi = ANSIColors()
- self._decolor = decolor
- else:
- self._ansi = NoColors
- self._decolor = lambda text: text
-
- self._prefix_chars = prefix_chars
+ self._set_color(color)
self._prog = prog
self._indent_increment = indent_increment
self._max_help_position = min(max_help_position,
@@ -202,9 +192,20 @@ class HelpFormatter(object):
self._whitespace_matcher = _re.compile(r'\s+', _re.ASCII)
self._long_break_matcher = _re.compile(r'\n\n\n+')
+ def _set_color(self, color):
+ from _colorize import can_colorize, decolor, get_theme
+
+ if color and can_colorize():
+ self._theme = get_theme(force_color=True).argparse
+ self._decolor = decolor
+ else:
+ self._theme = get_theme(force_no_color=True).argparse
+ self._decolor = lambda text: text
+
# ===============================
# Section and indentation methods
# ===============================
+
def _indent(self):
self._current_indent += self._indent_increment
self._level += 1
@@ -237,14 +238,12 @@ class HelpFormatter(object):
# add the heading if the section was non-empty
if self.heading is not SUPPRESS and self.heading is not None:
- bold_blue = self.formatter._ansi.BOLD_BLUE
- reset = self.formatter._ansi.RESET
-
current_indent = self.formatter._current_indent
heading_text = _('%(heading)s:') % dict(heading=self.heading)
+ t = self.formatter._theme
heading = (
f'{" " * current_indent}'
- f'{bold_blue}{heading_text}{reset}\n'
+ f'{t.heading}{heading_text}{t.reset}\n'
)
else:
heading = ''
@@ -258,6 +257,7 @@ class HelpFormatter(object):
# ========================
# Message building methods
# ========================
+
def start_section(self, heading):
self._indent()
section = self._Section(self, self._current_section, heading)
@@ -301,6 +301,7 @@ class HelpFormatter(object):
# =======================
# Help-formatting methods
# =======================
+
def format_help(self):
help = self._root_section.format_help()
if help:
@@ -314,10 +315,7 @@ class HelpFormatter(object):
if part and part is not SUPPRESS])
def _format_usage(self, usage, actions, groups, prefix):
- bold_blue = self._ansi.BOLD_BLUE
- bold_magenta = self._ansi.BOLD_MAGENTA
- magenta = self._ansi.MAGENTA
- reset = self._ansi.RESET
+ t = self._theme
if prefix is None:
prefix = _('usage: ')
@@ -325,15 +323,15 @@ class HelpFormatter(object):
# if usage is specified, use that
if usage is not None:
usage = (
- magenta
+ t.prog_extra
+ usage
- % {"prog": f"{bold_magenta}{self._prog}{reset}{magenta}"}
- + reset
+ % {"prog": f"{t.prog}{self._prog}{t.reset}{t.prog_extra}"}
+ + t.reset
)
# if no optionals or positionals are available, usage is just prog
elif usage is None and not actions:
- usage = f"{bold_magenta}{self._prog}{reset}"
+ usage = f"{t.prog}{self._prog}{t.reset}"
# if optionals and positionals are available, calculate usage
elif usage is None:
@@ -411,23 +409,16 @@ class HelpFormatter(object):
usage = '\n'.join(lines)
usage = usage.removeprefix(prog)
- usage = f"{bold_magenta}{prog}{reset}{usage}"
+ usage = f"{t.prog}{prog}{t.reset}{usage}"
# prefix with 'usage:'
- return f'{bold_blue}{prefix}{reset}{usage}\n\n'
+ return f'{t.usage}{prefix}{t.reset}{usage}\n\n'
def _format_actions_usage(self, actions, groups):
return ' '.join(self._get_actions_usage_parts(actions, groups))
def _is_long_option(self, string):
- return len(string) >= 2 and string[1] in self._prefix_chars
-
- def _is_short_option(self, string):
- return (
- not self._is_long_option(string)
- and len(string) >= 1
- and string[0] in self._prefix_chars
- )
+ return len(string) > 2
def _get_actions_usage_parts(self, actions, groups):
# find group indices and identify actions in groups
@@ -452,10 +443,7 @@ class HelpFormatter(object):
# collect all actions format strings
parts = []
- cyan = self._ansi.CYAN
- green = self._ansi.GREEN
- yellow = self._ansi.YELLOW
- reset = self._ansi.RESET
+ t = self._theme
for action in actions:
# suppressed arguments are marked with None
@@ -465,7 +453,11 @@ class HelpFormatter(object):
# produce all arg strings
elif not action.option_strings:
default = self._get_default_metavar_for_positional(action)
- part = green + self._format_args(action, default) + reset
+ part = (
+ t.summary_action
+ + self._format_args(action, default)
+ + t.reset
+ )
# if it's in a group, strip the outer []
if action in group_actions:
@@ -475,26 +467,26 @@ class HelpFormatter(object):
# produce the first way to invoke the option in brackets
else:
option_string = action.option_strings[0]
+ if self._is_long_option(option_string):
+ option_color = t.summary_long_option
+ else:
+ option_color = t.summary_short_option
# if the Optional doesn't take a value, format is:
# -s or --long
if action.nargs == 0:
part = action.format_usage()
- if self._is_long_option(part):
- part = f"{cyan}{part}{reset}"
- elif self._is_short_option(part):
- part = f"{green}{part}{reset}"
+ part = f"{option_color}{part}{t.reset}"
# if the Optional takes a value, format is:
# -s ARGS or --long ARGS
else:
default = self._get_default_metavar_for_optional(action)
args_string = self._format_args(action, default)
- if self._is_long_option(option_string):
- option_string = f"{cyan}{option_string}"
- elif self._is_short_option(option_string):
- option_string = f"{green}{option_string}"
- part = f"{option_string} {yellow}{args_string}{reset}"
+ part = (
+ f"{option_color}{option_string} "
+ f"{t.summary_label}{args_string}{t.reset}"
+ )
# make it look optional if it's not required or in a group
if not action.required and action not in group_actions:
@@ -590,17 +582,14 @@ class HelpFormatter(object):
return self._join_parts(parts)
def _format_action_invocation(self, action):
- bold_green = self._ansi.BOLD_GREEN
- bold_cyan = self._ansi.BOLD_CYAN
- bold_yellow = self._ansi.BOLD_YELLOW
- reset = self._ansi.RESET
+ t = self._theme
if not action.option_strings:
default = self._get_default_metavar_for_positional(action)
return (
- bold_green
+ t.action
+ ' '.join(self._metavar_formatter(action, default)(1))
- + reset
+ + t.reset
)
else:
@@ -609,11 +598,9 @@ class HelpFormatter(object):
parts = []
for s in strings:
if self._is_long_option(s):
- parts.append(f"{bold_cyan}{s}{reset}")
- elif self._is_short_option(s):
- parts.append(f"{bold_green}{s}{reset}")
+ parts.append(f"{t.long_option}{s}{t.reset}")
else:
- parts.append(s)
+ parts.append(f"{t.short_option}{s}{t.reset}")
return parts
# if the Optional doesn't take a value, format is:
@@ -628,7 +615,7 @@ class HelpFormatter(object):
default = self._get_default_metavar_for_optional(action)
option_strings = color_option_strings(action.option_strings)
args_string = (
- f"{bold_yellow}{self._format_args(action, default)}{reset}"
+ f"{t.label}{self._format_args(action, default)}{t.reset}"
)
return ', '.join(option_strings) + ' ' + args_string
@@ -1483,6 +1470,7 @@ class _ActionsContainer(object):
# ====================
# Registration methods
# ====================
+
def register(self, registry_name, value, object):
registry = self._registries.setdefault(registry_name, {})
registry[value] = object
@@ -1493,6 +1481,7 @@ class _ActionsContainer(object):
# ==================================
# Namespace default accessor methods
# ==================================
+
def set_defaults(self, **kwargs):
self._defaults.update(kwargs)
@@ -1512,6 +1501,7 @@ class _ActionsContainer(object):
# =======================
# Adding argument actions
# =======================
+
def add_argument(self, *args, **kwargs):
"""
add_argument(dest, ..., name=value, ...)
@@ -1544,7 +1534,7 @@ class _ActionsContainer(object):
action_name = kwargs.get('action')
action_class = self._pop_action_class(kwargs)
if not callable(action_class):
- raise ValueError('unknown action {action_class!r}')
+ raise ValueError(f'unknown action {action_class!r}')
action = action_class(**kwargs)
# raise an error if action for positional argument does not
@@ -1937,6 +1927,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# =======================
# Pretty __repr__ methods
# =======================
+
def _get_kwargs(self):
names = [
'prog',
@@ -1951,6 +1942,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# ==================================
# Optional/Positional adding methods
# ==================================
+
def add_subparsers(self, **kwargs):
if self._subparsers is not None:
raise ValueError('cannot have multiple subparser arguments')
@@ -2004,6 +1996,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# =====================================
# Command line argument parsing methods
# =====================================
+
def parse_args(self, args=None, namespace=None):
args, argv = self.parse_known_args(args, namespace)
if argv:
@@ -2598,6 +2591,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# ========================
# Value conversion methods
# ========================
+
def _get_values(self, action, arg_strings):
# optional argument produces a default when not present
if not arg_strings and action.nargs == OPTIONAL:
@@ -2697,6 +2691,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# =======================
# Help-formatting methods
# =======================
+
def format_usage(self):
formatter = self._get_formatter()
formatter.add_usage(self.usage, self._actions,
@@ -2727,20 +2722,14 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
return formatter.format_help()
def _get_formatter(self):
- if isinstance(self.formatter_class, type) and issubclass(
- self.formatter_class, HelpFormatter
- ):
- return self.formatter_class(
- prog=self.prog,
- prefix_chars=self.prefix_chars,
- color=self.color,
- )
- else:
- return self.formatter_class(prog=self.prog)
+ formatter = self.formatter_class(prog=self.prog)
+ formatter._set_color(self.color)
+ return formatter
# =====================
# Help-printing methods
# =====================
+
def print_usage(self, file=None):
if file is None:
file = _sys.stdout
@@ -2762,6 +2751,7 @@ class ArgumentParser(_AttributeHolder, _ActionsContainer):
# ===============
# Exiting methods
# ===============
+
def exit(self, status=0, message=None):
if message:
self._print_message(message, _sys.stderr)
diff --git a/Lib/ast.py b/Lib/ast.py
index aa788e6eb62..6d3daf64f5c 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -147,18 +147,22 @@ def dump(
if value is None and getattr(cls, name, ...) is None:
keywords = True
continue
- if (
- not show_empty
- and (value is None or value == [])
- # Special cases:
- # `Constant(value=None)` and `MatchSingleton(value=None)`
- and not isinstance(node, (Constant, MatchSingleton))
- ):
- args_buffer.append(repr(value))
- continue
- elif not keywords:
- args.extend(args_buffer)
- args_buffer = []
+ if not show_empty:
+ if value == []:
+ field_type = cls._field_types.get(name, object)
+ if getattr(field_type, '__origin__', ...) is list:
+ if not keywords:
+ args_buffer.append(repr(value))
+ continue
+ elif isinstance(value, Load):
+ field_type = cls._field_types.get(name, object)
+ if field_type is expr_context:
+ if not keywords:
+ args_buffer.append(repr(value))
+ continue
+ if not keywords:
+ args.extend(args_buffer)
+ args_buffer = []
value, simple = _format(value, level)
allsimple = allsimple and simple
if keywords:
@@ -626,11 +630,11 @@ def unparse(ast_obj):
return unparser.visit(ast_obj)
-def main():
+def main(args=None):
import argparse
import sys
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('infile', nargs='?', default='-',
help='the file to parse; defaults to stdin')
parser.add_argument('-m', '--mode', default='exec',
@@ -643,7 +647,16 @@ def main():
'column offsets')
parser.add_argument('-i', '--indent', type=int, default=3,
help='indentation of nodes (number of spaces)')
- args = parser.parse_args()
+ parser.add_argument('--feature-version',
+ type=str, default=None, metavar='VERSION',
+ help='Python version in the format 3.x '
+ '(for example, 3.10)')
+ parser.add_argument('-O', '--optimize',
+ type=int, default=-1, metavar='LEVEL',
+ help='optimization level for parser (default -1)')
+ parser.add_argument('--show-empty', default=False, action='store_true',
+ help='show empty lists and fields in dump output')
+ args = parser.parse_args(args)
if args.infile == '-':
name = '<stdin>'
@@ -652,8 +665,22 @@ def main():
name = args.infile
with open(args.infile, 'rb') as infile:
source = infile.read()
- tree = parse(source, name, args.mode, type_comments=args.no_type_comments)
- print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
+
+ # Process feature_version
+ feature_version = None
+ if args.feature_version:
+ try:
+ major, minor = map(int, args.feature_version.split('.', 1))
+ except ValueError:
+ parser.error('Invalid format for --feature-version; '
+ 'expected format 3.x (for example, 3.10)')
+
+ feature_version = (major, minor)
+
+ tree = parse(source, name, args.mode, type_comments=args.no_type_comments,
+ feature_version=feature_version, optimize=args.optimize)
+ print(dump(tree, include_attributes=args.include_attributes,
+ indent=args.indent, show_empty=args.show_empty))
if __name__ == '__main__':
main()
diff --git a/Lib/asyncio/__main__.py b/Lib/asyncio/__main__.py
index 69f5a30cfe5..21ca5c5f62a 100644
--- a/Lib/asyncio/__main__.py
+++ b/Lib/asyncio/__main__.py
@@ -1,5 +1,7 @@
+import argparse
import ast
import asyncio
+import asyncio.tools
import concurrent.futures
import contextvars
import inspect
@@ -10,7 +12,7 @@ import threading
import types
import warnings
-from _colorize import can_colorize, ANSIColors # type: ignore[import-not-found]
+from _colorize import get_theme
from _pyrepl.console import InteractiveColoredConsole
from . import futures
@@ -101,8 +103,9 @@ class REPLThread(threading.Thread):
exec(startup_code, console.locals)
ps1 = getattr(sys, "ps1", ">>> ")
- if can_colorize() and CAN_USE_PYREPL:
- ps1 = f"{ANSIColors.BOLD_MAGENTA}{ps1}{ANSIColors.RESET}"
+ if CAN_USE_PYREPL:
+ theme = get_theme().syntax
+ ps1 = f"{theme.prompt}{ps1}{theme.reset}"
console.write(f"{ps1}import asyncio\n")
if CAN_USE_PYREPL:
@@ -140,6 +143,37 @@ class REPLThread(threading.Thread):
if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ prog="python3 -m asyncio",
+ description="Interactive asyncio shell and CLI tools",
+ color=True,
+ )
+ subparsers = parser.add_subparsers(help="sub-commands", dest="command")
+ ps = subparsers.add_parser(
+ "ps", help="Display a table of all pending tasks in a process"
+ )
+ ps.add_argument("pid", type=int, help="Process ID to inspect")
+ pstree = subparsers.add_parser(
+ "pstree", help="Display a tree of all pending tasks in a process"
+ )
+ pstree.add_argument("pid", type=int, help="Process ID to inspect")
+ args = parser.parse_args()
+ match args.command:
+ case "ps":
+ asyncio.tools.display_awaited_by_tasks_table(args.pid)
+ sys.exit(0)
+ case "pstree":
+ asyncio.tools.display_awaited_by_tasks_tree(args.pid)
+ sys.exit(0)
+ case None:
+ pass # continue to the interactive shell
+ case _:
+ # shouldn't happen as an invalid command-line wouldn't parse
+ # but let's keep it for the next person adding a command
+ print(f"error: unhandled command {args.command}", file=sys.stderr)
+ parser.print_usage(file=sys.stderr)
+ sys.exit(1)
+
sys.audit("cpython.run_stdin")
if os.getenv('PYTHON_BASIC_REPL'):
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 29b872ce00e..04fb961e998 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -459,7 +459,7 @@ class BaseEventLoop(events.AbstractEventLoop):
return futures.Future(loop=self)
def create_task(self, coro, **kwargs):
- """Schedule a coroutine object.
+ """Schedule or begin executing a coroutine object.
Return a task object.
"""
diff --git a/Lib/asyncio/base_subprocess.py b/Lib/asyncio/base_subprocess.py
index 9c2ba679ce2..d40af422e61 100644
--- a/Lib/asyncio/base_subprocess.py
+++ b/Lib/asyncio/base_subprocess.py
@@ -104,7 +104,12 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
for proto in self._pipes.values():
if proto is None:
continue
- proto.pipe.close()
+ # See gh-114177
+ # skip closing the pipe if loop is already closed
+ # this can happen e.g. when loop is closed immediately after
+ # process is killed
+ if self._loop and not self._loop.is_closed():
+ proto.pipe.close()
if (self._proc is not None and
# has the child process finished?
diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py
index d1df6707302..6bd00a64478 100644
--- a/Lib/asyncio/futures.py
+++ b/Lib/asyncio/futures.py
@@ -351,22 +351,19 @@ def _set_concurrent_future_state(concurrent, source):
def _copy_future_state(source, dest):
"""Internal helper to copy state from another Future.
- The other Future may be a concurrent.futures.Future.
+ The other Future must be a concurrent.futures.Future.
"""
- assert source.done()
if dest.cancelled():
return
assert not dest.done()
- if source.cancelled():
+ done, cancelled, result, exception = source._get_snapshot()
+ assert done
+ if cancelled:
dest.cancel()
+ elif exception is not None:
+ dest.set_exception(_convert_future_exc(exception))
else:
- exception = source.exception()
- if exception is not None:
- dest.set_exception(_convert_future_exc(exception))
- else:
- result = source.result()
- dest.set_result(result)
-
+ dest.set_result(result)
def _chain_future(source, destination):
"""Chain two futures so that when one completes, so does the other.
diff --git a/Lib/asyncio/graph.py b/Lib/asyncio/graph.py
index d8df7c9919a..b5bfeb1630a 100644
--- a/Lib/asyncio/graph.py
+++ b/Lib/asyncio/graph.py
@@ -1,6 +1,7 @@
"""Introspection utils for tasks call graphs."""
import dataclasses
+import io
import sys
import types
@@ -16,9 +17,6 @@ __all__ = (
'FutureCallGraph',
)
-if False: # for type checkers
- from typing import TextIO
-
# Sadly, we can't re-use the traceback module's datastructures as those
# are tailored for error reporting, whereas we need to represent an
# async call graph.
@@ -270,7 +268,7 @@ def print_call_graph(
future: futures.Future | None = None,
/,
*,
- file: TextIO | None = None,
+ file: io.Writer[str] | None = None,
depth: int = 1,
limit: int | None = None,
) -> None:
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 22147451fa7..6ad84044adf 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -173,7 +173,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
# See https://bugs.python.org/issue27906 for more details.
- for _ in range(backlog):
+ for _ in range(backlog + 1):
try:
conn, addr = sock.accept()
if self._debug:
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index 1633478d1c8..00e8f6d5d1a 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -179,7 +179,7 @@ class TaskGroup:
exc = None
- def create_task(self, coro, *, name=None, context=None):
+ def create_task(self, coro, **kwargs):
"""Create a new task in this group and return it.
Similar to `asyncio.create_task`.
@@ -193,10 +193,7 @@ class TaskGroup:
if self._aborting:
coro.close()
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
- if context is None:
- task = self._loop.create_task(coro, name=name)
- else:
- task = self._loop.create_task(coro, name=name, context=context)
+ task = self._loop.create_task(coro, **kwargs)
futures.future_add_to_awaited_by(task, self._parent_task)
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 825e91f5594..fbd5c39a7c5 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -386,19 +386,13 @@ else:
Task = _CTask = _asyncio.Task
-def create_task(coro, *, name=None, context=None):
+def create_task(coro, **kwargs):
"""Schedule the execution of a coroutine object in a spawn task.
Return a Task object.
"""
loop = events.get_running_loop()
- if context is None:
- # Use legacy API if context is not needed
- task = loop.create_task(coro, name=name)
- else:
- task = loop.create_task(coro, name=name, context=context)
-
- return task
+ return loop.create_task(coro, **kwargs)
# wait() and as_completed() similar to those in PEP 3148.
@@ -914,6 +908,25 @@ def gather(*coros_or_futures, return_exceptions=False):
return outer
+def _log_on_exception(fut):
+ if fut.cancelled():
+ return
+
+ exc = fut.exception()
+ if exc is None:
+ return
+
+ context = {
+ 'message':
+ f'{exc.__class__.__name__} exception in shielded future',
+ 'exception': exc,
+ 'future': fut,
+ }
+ if fut._source_traceback:
+ context['source_traceback'] = fut._source_traceback
+ fut._loop.call_exception_handler(context)
+
+
def shield(arg):
"""Wait for a future, shielding it from cancellation.
@@ -959,14 +972,11 @@ def shield(arg):
else:
cur_task = None
- def _inner_done_callback(inner, cur_task=cur_task):
- if cur_task is not None:
- futures.future_discard_from_awaited_by(inner, cur_task)
+ def _clear_awaited_by_callback(inner):
+ futures.future_discard_from_awaited_by(inner, cur_task)
+ def _inner_done_callback(inner):
if outer.cancelled():
- if not inner.cancelled():
- # Mark inner's result as retrieved.
- inner.exception()
return
if inner.cancelled():
@@ -978,10 +988,16 @@ def shield(arg):
else:
outer.set_result(inner.result())
-
def _outer_done_callback(outer):
if not inner.done():
inner.remove_done_callback(_inner_done_callback)
+ # Keep only one callback to log on cancel
+ inner.remove_done_callback(_log_on_exception)
+ inner.add_done_callback(_log_on_exception)
+
+ if cur_task is not None:
+ inner.add_done_callback(_clear_awaited_by_callback)
+
inner.add_done_callback(_inner_done_callback)
outer.add_done_callback(_outer_done_callback)
@@ -1030,9 +1046,9 @@ def create_eager_task_factory(custom_task_constructor):
used. E.g. `loop.set_task_factory(asyncio.eager_task_factory)`.
"""
- def factory(loop, coro, *, name=None, context=None):
+ def factory(loop, coro, *, eager_start=True, **kwargs):
return custom_task_constructor(
- coro, loop=loop, name=name, context=context, eager_start=True)
+ coro, loop=loop, eager_start=eager_start, **kwargs)
return factory
diff --git a/Lib/asyncio/tools.py b/Lib/asyncio/tools.py
new file mode 100644
index 00000000000..2683f34cc71
--- /dev/null
+++ b/Lib/asyncio/tools.py
@@ -0,0 +1,260 @@
+"""Tools to analyze tasks running in asyncio programs."""
+
+from collections import defaultdict, namedtuple
+from itertools import count
+from enum import Enum
+import sys
+from _remote_debugging import RemoteUnwinder, FrameInfo
+
+class NodeType(Enum):
+ COROUTINE = 1
+ TASK = 2
+
+
+class CycleFoundException(Exception):
+ """Raised when there is a cycle when drawing the call tree."""
+ def __init__(
+ self,
+ cycles: list[list[int]],
+ id2name: dict[int, str],
+ ) -> None:
+ super().__init__(cycles, id2name)
+ self.cycles = cycles
+ self.id2name = id2name
+
+
+
+# ─── indexing helpers ───────────────────────────────────────────
+def _format_stack_entry(elem: str|FrameInfo) -> str:
+ if not isinstance(elem, str):
+ if elem.lineno == 0 and elem.filename == "":
+ return f"{elem.funcname}"
+ else:
+ return f"{elem.funcname} {elem.filename}:{elem.lineno}"
+ return elem
+
+
+def _index(result):
+ id2name, awaits, task_stacks = {}, [], {}
+ for awaited_info in result:
+ for task_info in awaited_info.awaited_by:
+ task_id = task_info.task_id
+ task_name = task_info.task_name
+ id2name[task_id] = task_name
+
+ # Store the internal coroutine stack for this task
+ if task_info.coroutine_stack:
+ for coro_info in task_info.coroutine_stack:
+ call_stack = coro_info.call_stack
+ internal_stack = [_format_stack_entry(frame) for frame in call_stack]
+ task_stacks[task_id] = internal_stack
+
+ # Add the awaited_by relationships (external dependencies)
+ if task_info.awaited_by:
+ for coro_info in task_info.awaited_by:
+ call_stack = coro_info.call_stack
+ parent_task_id = coro_info.task_name
+ stack = [_format_stack_entry(frame) for frame in call_stack]
+ awaits.append((parent_task_id, stack, task_id))
+ return id2name, awaits, task_stacks
+
+
+def _build_tree(id2name, awaits, task_stacks):
+ id2label = {(NodeType.TASK, tid): name for tid, name in id2name.items()}
+ children = defaultdict(list)
+ cor_nodes = defaultdict(dict) # Maps parent -> {frame_name: node_key}
+ next_cor_id = count(1)
+
+ def get_or_create_cor_node(parent, frame):
+ """Get existing coroutine node or create new one under parent"""
+ if frame in cor_nodes[parent]:
+ return cor_nodes[parent][frame]
+
+ node_key = (NodeType.COROUTINE, f"c{next(next_cor_id)}")
+ id2label[node_key] = frame
+ children[parent].append(node_key)
+ cor_nodes[parent][frame] = node_key
+ return node_key
+
+ # Build task dependency tree with coroutine frames
+ for parent_id, stack, child_id in awaits:
+ cur = (NodeType.TASK, parent_id)
+ for frame in reversed(stack):
+ cur = get_or_create_cor_node(cur, frame)
+
+ child_key = (NodeType.TASK, child_id)
+ if child_key not in children[cur]:
+ children[cur].append(child_key)
+
+ # Add coroutine stacks for leaf tasks
+ awaiting_tasks = {parent_id for parent_id, _, _ in awaits}
+ for task_id in id2name:
+ if task_id not in awaiting_tasks and task_id in task_stacks:
+ cur = (NodeType.TASK, task_id)
+ for frame in reversed(task_stacks[task_id]):
+ cur = get_or_create_cor_node(cur, frame)
+
+ return id2label, children
+
+
+def _roots(id2label, children):
+ all_children = {c for kids in children.values() for c in kids}
+ return [n for n in id2label if n not in all_children]
+
+# ─── detect cycles in the task-to-task graph ───────────────────────
+def _task_graph(awaits):
+ """Return {parent_task_id: {child_task_id, …}, …}."""
+ g = defaultdict(set)
+ for parent_id, _stack, child_id in awaits:
+ g[parent_id].add(child_id)
+ return g
+
+
+def _find_cycles(graph):
+ """
+ Depth-first search for back-edges.
+
+ Returns a list of cycles (each cycle is a list of task-ids) or an
+ empty list if the graph is acyclic.
+ """
+ WHITE, GREY, BLACK = 0, 1, 2
+ color = defaultdict(lambda: WHITE)
+ path, cycles = [], []
+
+ def dfs(v):
+ color[v] = GREY
+ path.append(v)
+ for w in graph.get(v, ()):
+ if color[w] == WHITE:
+ dfs(w)
+ elif color[w] == GREY: # back-edge → cycle!
+ i = path.index(w)
+ cycles.append(path[i:] + [w]) # make a copy
+ color[v] = BLACK
+ path.pop()
+
+ for v in list(graph):
+ if color[v] == WHITE:
+ dfs(v)
+ return cycles
+
+
+# ─── PRINT TREE FUNCTION ───────────────────────────────────────
+def get_all_awaited_by(pid):
+ unwinder = RemoteUnwinder(pid)
+ return unwinder.get_all_awaited_by()
+
+
+def build_async_tree(result, task_emoji="(T)", cor_emoji=""):
+ """
+ Build a list of strings for pretty-print an async call tree.
+
+ The call tree is produced by `get_all_async_stacks()`, prefixing tasks
+ with `task_emoji` and coroutine frames with `cor_emoji`.
+ """
+ id2name, awaits, task_stacks = _index(result)
+ g = _task_graph(awaits)
+ cycles = _find_cycles(g)
+ if cycles:
+ raise CycleFoundException(cycles, id2name)
+ labels, children = _build_tree(id2name, awaits, task_stacks)
+
+ def pretty(node):
+ flag = task_emoji if node[0] == NodeType.TASK else cor_emoji
+ return f"{flag} {labels[node]}"
+
+ def render(node, prefix="", last=True, buf=None):
+ if buf is None:
+ buf = []
+ buf.append(f"{prefix}{'└── ' if last else '├── '}{pretty(node)}")
+ new_pref = prefix + (" " if last else "│ ")
+ kids = children.get(node, [])
+ for i, kid in enumerate(kids):
+ render(kid, new_pref, i == len(kids) - 1, buf)
+ return buf
+
+ return [render(root) for root in _roots(labels, children)]
+
+
+def build_task_table(result):
+ id2name, _, _ = _index(result)
+ table = []
+
+ for awaited_info in result:
+ thread_id = awaited_info.thread_id
+ for task_info in awaited_info.awaited_by:
+ # Get task info
+ task_id = task_info.task_id
+ task_name = task_info.task_name
+
+ # Build coroutine stack string
+ frames = [frame for coro in task_info.coroutine_stack
+ for frame in coro.call_stack]
+ coro_stack = " -> ".join(_format_stack_entry(x).split(" ")[0]
+ for x in frames)
+
+ # Handle tasks with no awaiters
+ if not task_info.awaited_by:
+ table.append([thread_id, hex(task_id), task_name, coro_stack,
+ "", "", "0x0"])
+ continue
+
+ # Handle tasks with awaiters
+ for coro_info in task_info.awaited_by:
+ parent_id = coro_info.task_name
+ awaiter_frames = [_format_stack_entry(x).split(" ")[0]
+ for x in coro_info.call_stack]
+ awaiter_chain = " -> ".join(awaiter_frames)
+ awaiter_name = id2name.get(parent_id, "Unknown")
+ parent_id_str = (hex(parent_id) if isinstance(parent_id, int)
+ else str(parent_id))
+
+ table.append([thread_id, hex(task_id), task_name, coro_stack,
+ awaiter_chain, awaiter_name, parent_id_str])
+
+ return table
+
+def _print_cycle_exception(exception: CycleFoundException):
+ print("ERROR: await-graph contains cycles - cannot print a tree!", file=sys.stderr)
+ print("", file=sys.stderr)
+ for c in exception.cycles:
+ inames = " → ".join(exception.id2name.get(tid, hex(tid)) for tid in c)
+ print(f"cycle: {inames}", file=sys.stderr)
+
+
+def _get_awaited_by_tasks(pid: int) -> list:
+ try:
+ return get_all_awaited_by(pid)
+ except RuntimeError as e:
+ while e.__context__ is not None:
+ e = e.__context__
+ print(f"Error retrieving tasks: {e}")
+ sys.exit(1)
+
+
+def display_awaited_by_tasks_table(pid: int) -> None:
+ """Build and print a table of all pending tasks under `pid`."""
+
+ tasks = _get_awaited_by_tasks(pid)
+ table = build_task_table(tasks)
+ # Print the table in a simple tabular format
+ print(
+ f"{'tid':<10} {'task id':<20} {'task name':<20} {'coroutine stack':<50} {'awaiter chain':<50} {'awaiter name':<15} {'awaiter id':<15}"
+ )
+ print("-" * 180)
+ for row in table:
+ print(f"{row[0]:<10} {row[1]:<20} {row[2]:<20} {row[3]:<50} {row[4]:<50} {row[5]:<15} {row[6]:<15}")
+
+
+def display_awaited_by_tasks_tree(pid: int) -> None:
+ """Build and print a tree of all pending tasks under `pid`."""
+
+ tasks = _get_awaited_by_tasks(pid)
+ try:
+ result = build_async_tree(tasks)
+ except CycleFoundException as e:
+ _print_cycle_exception(e)
+ sys.exit(1)
+
+ for tree in result:
+ print("\n".join(tree))
diff --git a/Lib/calendar.py b/Lib/calendar.py
index 01a76ff8e78..3be1b50500e 100644
--- a/Lib/calendar.py
+++ b/Lib/calendar.py
@@ -565,7 +565,7 @@ class HTMLCalendar(Calendar):
Return a formatted year as a complete HTML page.
"""
if encoding is None:
- encoding = sys.getdefaultencoding()
+ encoding = 'utf-8'
v = []
a = v.append
a('<?xml version="1.0" encoding="%s"?>\n' % encoding)
@@ -810,7 +810,7 @@ def timegm(tuple):
def main(args=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
textgroup = parser.add_argument_group('text only arguments')
htmlgroup = parser.add_argument_group('html only arguments')
textgroup.add_argument(
@@ -846,7 +846,7 @@ def main(args=None):
parser.add_argument(
"-e", "--encoding",
default=None,
- help="encoding to use for output"
+ help="encoding to use for output (default utf-8)"
)
parser.add_argument(
"-t", "--type",
@@ -890,7 +890,7 @@ def main(args=None):
cal.setfirstweekday(options.first_weekday)
encoding = options.encoding
if encoding is None:
- encoding = sys.getdefaultencoding()
+ encoding = 'utf-8'
optdict = dict(encoding=encoding, css=options.css)
write = sys.stdout.buffer.write
if options.year is None:
diff --git a/Lib/cmd.py b/Lib/cmd.py
index 438b88aa104..51495fb3216 100644
--- a/Lib/cmd.py
+++ b/Lib/cmd.py
@@ -273,7 +273,7 @@ class Cmd:
endidx = readline.get_endidx() - stripped
if begidx>0:
cmd, args, foo = self.parseline(line)
- if cmd == '':
+ if not cmd:
compfunc = self.completedefault
else:
try:
diff --git a/Lib/code.py b/Lib/code.py
index 41331dfd071..f7e275d8801 100644
--- a/Lib/code.py
+++ b/Lib/code.py
@@ -224,7 +224,7 @@ class InteractiveConsole(InteractiveInterpreter):
sys.ps1 = ">>> "
delete_ps1_after = True
try:
- _ps2 = sys.ps2
+ sys.ps2
delete_ps2_after = False
except AttributeError:
sys.ps2 = "... "
@@ -385,7 +385,7 @@ def interact(banner=None, readfunc=None, local=None, exitmsg=None, local_exit=Fa
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('-q', action='store_true',
help="don't print version and copyright messages")
args = parser.parse_args()
diff --git a/Lib/compileall.py b/Lib/compileall.py
index 47e2446356e..67fe370451e 100644
--- a/Lib/compileall.py
+++ b/Lib/compileall.py
@@ -317,7 +317,9 @@ def main():
import argparse
parser = argparse.ArgumentParser(
- description='Utilities to support installing Python libraries.')
+ description='Utilities to support installing Python libraries.',
+ color=True,
+ )
parser.add_argument('-l', action='store_const', const=0,
default=None, dest='maxlevels',
help="don't recurse into subdirectories")
diff --git a/Lib/compression/bz2/__init__.py b/Lib/compression/bz2.py
index 16815d6cd20..16815d6cd20 100644
--- a/Lib/compression/bz2/__init__.py
+++ b/Lib/compression/bz2.py
diff --git a/Lib/compression/gzip/__init__.py b/Lib/compression/gzip.py
index 552f48f948a..552f48f948a 100644
--- a/Lib/compression/gzip/__init__.py
+++ b/Lib/compression/gzip.py
diff --git a/Lib/compression/lzma/__init__.py b/Lib/compression/lzma.py
index b4bc7ccb1db..b4bc7ccb1db 100644
--- a/Lib/compression/lzma/__init__.py
+++ b/Lib/compression/lzma.py
diff --git a/Lib/compression/zlib/__init__.py b/Lib/compression/zlib.py
index 3aa7e2db90e..3aa7e2db90e 100644
--- a/Lib/compression/zlib/__init__.py
+++ b/Lib/compression/zlib.py
diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py
new file mode 100644
index 00000000000..84b25914b0a
--- /dev/null
+++ b/Lib/compression/zstd/__init__.py
@@ -0,0 +1,242 @@
+"""Python bindings to the Zstandard (zstd) compression library (RFC-8878)."""
+
+__all__ = (
+ # compression.zstd
+ 'COMPRESSION_LEVEL_DEFAULT',
+ 'compress',
+ 'CompressionParameter',
+ 'decompress',
+ 'DecompressionParameter',
+ 'finalize_dict',
+ 'get_frame_info',
+ 'Strategy',
+ 'train_dict',
+
+ # compression.zstd._zstdfile
+ 'open',
+ 'ZstdFile',
+
+ # _zstd
+ 'get_frame_size',
+ 'zstd_version',
+ 'zstd_version_info',
+ 'ZstdCompressor',
+ 'ZstdDecompressor',
+ 'ZstdDict',
+ 'ZstdError',
+)
+
+import _zstd
+import enum
+from _zstd import (ZstdCompressor, ZstdDecompressor, ZstdDict, ZstdError,
+ get_frame_size, zstd_version)
+from compression.zstd._zstdfile import ZstdFile, open, _nbytes
+
+# zstd_version_number is (MAJOR * 100 * 100 + MINOR * 100 + RELEASE)
+zstd_version_info = (*divmod(_zstd.zstd_version_number // 100, 100),
+ _zstd.zstd_version_number % 100)
+"""Version number of the runtime zstd library as a tuple of integers."""
+
+COMPRESSION_LEVEL_DEFAULT = _zstd.ZSTD_CLEVEL_DEFAULT
+"""The default compression level for Zstandard, currently '3'."""
+
+
+class FrameInfo:
+ """Information about a Zstandard frame."""
+
+ __slots__ = 'decompressed_size', 'dictionary_id'
+
+ def __init__(self, decompressed_size, dictionary_id):
+ super().__setattr__('decompressed_size', decompressed_size)
+ super().__setattr__('dictionary_id', dictionary_id)
+
+ def __repr__(self):
+ return (f'FrameInfo(decompressed_size={self.decompressed_size}, '
+ f'dictionary_id={self.dictionary_id})')
+
+ def __setattr__(self, name, _):
+ raise AttributeError(f"can't set attribute {name!r}")
+
+
+def get_frame_info(frame_buffer):
+ """Get Zstandard frame information from a frame header.
+
+ *frame_buffer* is a bytes-like object. It should start from the beginning
+ of a frame, and needs to include at least the frame header (6 to 18 bytes).
+
+ The returned FrameInfo object has two attributes.
+ 'decompressed_size' is the size in bytes of the data in the frame when
+ decompressed, or None when the decompressed size is unknown.
+ 'dictionary_id' is an int in the range (0, 2**32). The special value 0
+ means that the dictionary ID was not recorded in the frame header,
+ the frame may or may not need a dictionary to be decoded,
+ and the ID of such a dictionary is not specified.
+ """
+ return FrameInfo(*_zstd.get_frame_info(frame_buffer))
+
+
+def train_dict(samples, dict_size):
+ """Return a ZstdDict representing a trained Zstandard dictionary.
+
+ *samples* is an iterable of samples, where a sample is a bytes-like
+ object representing a file.
+
+ *dict_size* is the dictionary's maximum size, in bytes.
+ """
+ if not isinstance(dict_size, int):
+ ds_cls = type(dict_size).__qualname__
+ raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.')
+
+ samples = tuple(samples)
+ chunks = b''.join(samples)
+ chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+ if not chunks:
+ raise ValueError("samples contained no data; can't train dictionary.")
+ dict_content = _zstd.train_dict(chunks, chunk_sizes, dict_size)
+ return ZstdDict(dict_content)
+
+
+def finalize_dict(zstd_dict, /, samples, dict_size, level):
+ """Return a ZstdDict representing a finalized Zstandard dictionary.
+
+ Given a custom content as a basis for dictionary, and a set of samples,
+ finalize *zstd_dict* by adding headers and statistics according to the
+ Zstandard dictionary format.
+
+ You may compose an effective dictionary content by hand, which is used as
+ basis dictionary, and use some samples to finalize a dictionary. The basis
+ dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict.
+
+ *samples* is an iterable of samples, where a sample is a bytes-like object
+ representing a file.
+ *dict_size* is the dictionary's maximum size, in bytes.
+ *level* is the expected compression level. The statistics for each
+ compression level differ, so tuning the dictionary to the compression level
+ can provide improvements.
+ """
+
+ if not isinstance(zstd_dict, ZstdDict):
+ raise TypeError('zstd_dict argument should be a ZstdDict object.')
+ if not isinstance(dict_size, int):
+ raise TypeError('dict_size argument should be an int object.')
+ if not isinstance(level, int):
+ raise TypeError('level argument should be an int object.')
+
+ samples = tuple(samples)
+ chunks = b''.join(samples)
+ chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+ if not chunks:
+ raise ValueError("The samples are empty content, can't finalize the "
+ "dictionary.")
+ dict_content = _zstd.finalize_dict(zstd_dict.dict_content, chunks,
+ chunk_sizes, dict_size, level)
+ return ZstdDict(dict_content)
+
+
+def compress(data, level=None, options=None, zstd_dict=None):
+ """Return Zstandard compressed *data* as bytes.
+
+ *level* is an int specifying the compression level to use, defaulting to
+ COMPRESSION_LEVEL_DEFAULT ('3').
+ *options* is a dict object that contains advanced compression
+ parameters. See CompressionParameter for more on options.
+ *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+ the function train_dict for how to train a ZstdDict on sample data.
+
+ For incremental compression, use a ZstdCompressor instead.
+ """
+ comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict)
+ return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME)
+
+
+def decompress(data, zstd_dict=None, options=None):
+ """Decompress one or more frames of Zstandard compressed *data*.
+
+ *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+ the function train_dict for how to train a ZstdDict on sample data.
+ *options* is a dict object that contains advanced compression
+ parameters. See DecompressionParameter for more on options.
+
+ For incremental decompression, use a ZstdDecompressor instead.
+ """
+ results = []
+ while True:
+ decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict)
+ results.append(decomp.decompress(data))
+ if not decomp.eof:
+ raise ZstdError('Compressed data ended before the '
+ 'end-of-stream marker was reached')
+ data = decomp.unused_data
+ if not data:
+ break
+ return b''.join(results)
+
+
+class CompressionParameter(enum.IntEnum):
+ """Compression parameters."""
+
+ compression_level = _zstd.ZSTD_c_compressionLevel
+ window_log = _zstd.ZSTD_c_windowLog
+ hash_log = _zstd.ZSTD_c_hashLog
+ chain_log = _zstd.ZSTD_c_chainLog
+ search_log = _zstd.ZSTD_c_searchLog
+ min_match = _zstd.ZSTD_c_minMatch
+ target_length = _zstd.ZSTD_c_targetLength
+ strategy = _zstd.ZSTD_c_strategy
+
+ enable_long_distance_matching = _zstd.ZSTD_c_enableLongDistanceMatching
+ ldm_hash_log = _zstd.ZSTD_c_ldmHashLog
+ ldm_min_match = _zstd.ZSTD_c_ldmMinMatch
+ ldm_bucket_size_log = _zstd.ZSTD_c_ldmBucketSizeLog
+ ldm_hash_rate_log = _zstd.ZSTD_c_ldmHashRateLog
+
+ content_size_flag = _zstd.ZSTD_c_contentSizeFlag
+ checksum_flag = _zstd.ZSTD_c_checksumFlag
+ dict_id_flag = _zstd.ZSTD_c_dictIDFlag
+
+ nb_workers = _zstd.ZSTD_c_nbWorkers
+ job_size = _zstd.ZSTD_c_jobSize
+ overlap_log = _zstd.ZSTD_c_overlapLog
+
+ def bounds(self):
+ """Return the (lower, upper) int bounds of a compression parameter.
+
+ Both the lower and upper bounds are inclusive.
+ """
+ return _zstd.get_param_bounds(self.value, is_compress=True)
+
+
+class DecompressionParameter(enum.IntEnum):
+ """Decompression parameters."""
+
+ window_log_max = _zstd.ZSTD_d_windowLogMax
+
+ def bounds(self):
+ """Return the (lower, upper) int bounds of a decompression parameter.
+
+ Both the lower and upper bounds are inclusive.
+ """
+ return _zstd.get_param_bounds(self.value, is_compress=False)
+
+
+class Strategy(enum.IntEnum):
+ """Compression strategies, listed from fastest to strongest.
+
+ Note that new strategies might be added in the future.
+ Only the order (from fast to strong) is guaranteed,
+ the numeric value might change.
+ """
+
+ fast = _zstd.ZSTD_fast
+ dfast = _zstd.ZSTD_dfast
+ greedy = _zstd.ZSTD_greedy
+ lazy = _zstd.ZSTD_lazy
+ lazy2 = _zstd.ZSTD_lazy2
+ btlazy2 = _zstd.ZSTD_btlazy2
+ btopt = _zstd.ZSTD_btopt
+ btultra = _zstd.ZSTD_btultra
+ btultra2 = _zstd.ZSTD_btultra2
+
+
+# Check validity of the CompressionParameter & DecompressionParameter types
+_zstd.set_parameter_types(CompressionParameter, DecompressionParameter)
diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py
new file mode 100644
index 00000000000..d709f5efc65
--- /dev/null
+++ b/Lib/compression/zstd/_zstdfile.py
@@ -0,0 +1,345 @@
+import io
+from os import PathLike
+from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
+from compression._common import _streams
+
+__all__ = ('ZstdFile', 'open')
+
+_MODE_CLOSED = 0
+_MODE_READ = 1
+_MODE_WRITE = 2
+
+
+def _nbytes(dat, /):
+ if isinstance(dat, (bytes, bytearray)):
+ return len(dat)
+ with memoryview(dat) as mv:
+ return mv.nbytes
+
+
+class ZstdFile(_streams.BaseStream):
+ """A file-like object providing transparent Zstandard (de)compression.
+
+ A ZstdFile can act as a wrapper for an existing file object, or refer
+ directly to a named file on disk.
+
+ ZstdFile provides a *binary* file interface. Data is read and returned as
+ bytes, and may only be written to objects that support the Buffer Protocol.
+ """
+
+ FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK
+ FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME
+
+ def __init__(self, file, /, mode='r', *,
+ level=None, options=None, zstd_dict=None):
+ """Open a Zstandard compressed file in binary mode.
+
+ *file* can be either an file-like object, or a file name to open.
+
+ *mode* can be 'r' for reading (default), 'w' for (over)writing, 'x' for
+ creating exclusively, or 'a' for appending. These can equivalently be
+ given as 'rb', 'wb', 'xb' and 'ab' respectively.
+
+ *level* is an optional int specifying the compression level to use,
+ or COMPRESSION_LEVEL_DEFAULT if not given.
+
+ *options* is an optional dict for advanced compression parameters.
+ See CompressionParameter and DecompressionParameter for the possible
+ options.
+
+ *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard
+ dictionary. See train_dict() to train ZstdDict on sample data.
+ """
+ self._fp = None
+ self._close_fp = False
+ self._mode = _MODE_CLOSED
+ self._buffer = None
+
+ if not isinstance(mode, str):
+ raise ValueError('mode must be a str')
+ if options is not None and not isinstance(options, dict):
+ raise TypeError('options must be a dict or None')
+ mode = mode.removesuffix('b') # handle rb, wb, xb, ab
+ if mode == 'r':
+ if level is not None:
+ raise TypeError('level is illegal in read mode')
+ self._mode = _MODE_READ
+ elif mode in {'w', 'a', 'x'}:
+ if level is not None and not isinstance(level, int):
+ raise TypeError('level must be int or None')
+ self._mode = _MODE_WRITE
+ self._compressor = ZstdCompressor(level=level, options=options,
+ zstd_dict=zstd_dict)
+ self._pos = 0
+ else:
+ raise ValueError(f'Invalid mode: {mode!r}')
+
+ if isinstance(file, (str, bytes, PathLike)):
+ self._fp = io.open(file, f'{mode}b')
+ self._close_fp = True
+ elif ((mode == 'r' and hasattr(file, 'read'))
+ or (mode != 'r' and hasattr(file, 'write'))):
+ self._fp = file
+ else:
+ raise TypeError('file must be a file-like object '
+ 'or a str, bytes, or PathLike object')
+
+ if self._mode == _MODE_READ:
+ raw = _streams.DecompressReader(
+ self._fp,
+ ZstdDecompressor,
+ zstd_dict=zstd_dict,
+ options=options,
+ )
+ self._buffer = io.BufferedReader(raw)
+
+ def close(self):
+ """Flush and close the file.
+
+ May be called multiple times. Once the file has been closed,
+ any other operation on it will raise ValueError.
+ """
+ if self._fp is None:
+ return
+ try:
+ if self._mode == _MODE_READ:
+ if getattr(self, '_buffer', None):
+ self._buffer.close()
+ self._buffer = None
+ elif self._mode == _MODE_WRITE:
+ self.flush(self.FLUSH_FRAME)
+ self._compressor = None
+ finally:
+ self._mode = _MODE_CLOSED
+ try:
+ if self._close_fp:
+ self._fp.close()
+ finally:
+ self._fp = None
+ self._close_fp = False
+
+ def write(self, data, /):
+ """Write a bytes-like object *data* to the file.
+
+ Returns the number of uncompressed bytes written, which is
+ always the length of data in bytes. Note that due to buffering,
+ the file on disk may not reflect the data written until .flush()
+ or .close() is called.
+ """
+ self._check_can_write()
+
+ length = _nbytes(data)
+
+ compressed = self._compressor.compress(data)
+ self._fp.write(compressed)
+ self._pos += length
+ return length
+
+ def flush(self, mode=FLUSH_BLOCK):
+ """Flush remaining data to the underlying stream.
+
+ The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this
+ method will reduce compression ratio, use it only when necessary.
+
+ If the program is interrupted afterwards, all data can be recovered.
+ To ensure saving to disk, also need to use os.fsync(fd).
+
+ This method does nothing in reading mode.
+ """
+ if self._mode == _MODE_READ:
+ return
+ self._check_not_closed()
+ if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
+ raise ValueError('Invalid mode argument, expected either '
+ 'ZstdFile.FLUSH_FRAME or '
+ 'ZstdFile.FLUSH_BLOCK')
+ if self._compressor.last_mode == mode:
+ return
+ # Flush zstd block/frame, and write.
+ data = self._compressor.flush(mode)
+ self._fp.write(data)
+ if hasattr(self._fp, 'flush'):
+ self._fp.flush()
+
+ def read(self, size=-1):
+ """Read up to size uncompressed bytes from the file.
+
+ If size is negative or omitted, read until EOF is reached.
+ Returns b'' if the file is already at EOF.
+ """
+ if size is None:
+ size = -1
+ self._check_can_read()
+ return self._buffer.read(size)
+
+ def read1(self, size=-1):
+ """Read up to size uncompressed bytes, while trying to avoid
+ making multiple reads from the underlying stream. Reads up to a
+ buffer's worth of data if size is negative.
+
+ Returns b'' if the file is at EOF.
+ """
+ self._check_can_read()
+ if size < 0:
+ # Note this should *not* be io.DEFAULT_BUFFER_SIZE.
+ # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing
+ # a full block is read.
+ size = ZSTD_DStreamOutSize
+ return self._buffer.read1(size)
+
+ def readinto(self, b):
+ """Read bytes into b.
+
+ Returns the number of bytes read (0 for EOF).
+ """
+ self._check_can_read()
+ return self._buffer.readinto(b)
+
+ def readinto1(self, b):
+ """Read bytes into b, while trying to avoid making multiple reads
+ from the underlying stream.
+
+ Returns the number of bytes read (0 for EOF).
+ """
+ self._check_can_read()
+ return self._buffer.readinto1(b)
+
+ def readline(self, size=-1):
+ """Read a line of uncompressed bytes from the file.
+
+ The terminating newline (if present) is retained. If size is
+ non-negative, no more than size bytes will be read (in which
+ case the line may be incomplete). Returns b'' if already at EOF.
+ """
+ self._check_can_read()
+ return self._buffer.readline(size)
+
+ def seek(self, offset, whence=io.SEEK_SET):
+ """Change the file position.
+
+ The new position is specified by offset, relative to the
+ position indicated by whence. Possible values for whence are:
+
+ 0: start of stream (default): offset must not be negative
+ 1: current stream position
+ 2: end of stream; offset must not be positive
+
+ Returns the new file position.
+
+ Note that seeking is emulated, so depending on the arguments,
+ this operation may be extremely slow.
+ """
+ self._check_can_read()
+
+ # BufferedReader.seek() checks seekable
+ return self._buffer.seek(offset, whence)
+
+ def peek(self, size=-1):
+ """Return buffered data without advancing the file position.
+
+ Always returns at least one byte of data, unless at EOF.
+ The exact number of bytes returned is unspecified.
+ """
+ # Relies on the undocumented fact that BufferedReader.peek() always
+ # returns at least one byte (except at EOF)
+ self._check_can_read()
+ return self._buffer.peek(size)
+
+ def __next__(self):
+ if ret := self._buffer.readline():
+ return ret
+ raise StopIteration
+
+ def tell(self):
+ """Return the current file position."""
+ self._check_not_closed()
+ if self._mode == _MODE_READ:
+ return self._buffer.tell()
+ elif self._mode == _MODE_WRITE:
+ return self._pos
+
+ def fileno(self):
+ """Return the file descriptor for the underlying file."""
+ self._check_not_closed()
+ return self._fp.fileno()
+
+ @property
+ def name(self):
+ self._check_not_closed()
+ return self._fp.name
+
+ @property
+ def mode(self):
+ return 'wb' if self._mode == _MODE_WRITE else 'rb'
+
+ @property
+ def closed(self):
+ """True if this file is closed."""
+ return self._mode == _MODE_CLOSED
+
+ def seekable(self):
+ """Return whether the file supports seeking."""
+ return self.readable() and self._buffer.seekable()
+
+ def readable(self):
+ """Return whether the file was opened for reading."""
+ self._check_not_closed()
+ return self._mode == _MODE_READ
+
+ def writable(self):
+ """Return whether the file was opened for writing."""
+ self._check_not_closed()
+ return self._mode == _MODE_WRITE
+
+
+def open(file, /, mode='rb', *, level=None, options=None, zstd_dict=None,
+ encoding=None, errors=None, newline=None):
+ """Open a Zstandard compressed file in binary or text mode.
+
+ file can be either a file name (given as a str, bytes, or PathLike object),
+ in which case the named file is opened, or it can be an existing file object
+ to read from or write to.
+
+ The mode parameter can be 'r', 'rb' (default), 'w', 'wb', 'x', 'xb', 'a',
+ 'ab' for binary mode, or 'rt', 'wt', 'xt', 'at' for text mode.
+
+ The level, options, and zstd_dict parameters specify the settings the same
+ as ZstdFile.
+
+ When using read mode (decompression), the options parameter is a dict
+ representing advanced decompression options. The level parameter is not
+ supported in this case. When using write mode (compression), only one of
+ level, an int representing the compression level, or options, a dict
+ representing advanced compression options, may be passed. In both modes,
+ zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary.
+
+ For binary mode, this function is equivalent to the ZstdFile constructor:
+ ZstdFile(filename, mode, ...). In this case, the encoding, errors and
+ newline parameters must not be provided.
+
+ For text mode, an ZstdFile object is created, and wrapped in an
+ io.TextIOWrapper instance with the specified encoding, error handling
+ behavior, and line ending(s).
+ """
+
+ text_mode = 't' in mode
+ mode = mode.replace('t', '')
+
+ if text_mode:
+ if 'b' in mode:
+ raise ValueError(f'Invalid mode: {mode!r}')
+ else:
+ if encoding is not None:
+ raise ValueError('Argument "encoding" not supported in binary mode')
+ if errors is not None:
+ raise ValueError('Argument "errors" not supported in binary mode')
+ if newline is not None:
+ raise ValueError('Argument "newline" not supported in binary mode')
+
+ binary_file = ZstdFile(file, mode, level=level, options=options,
+ zstd_dict=zstd_dict)
+
+ if text_mode:
+ return io.TextIOWrapper(binary_file, encoding, errors, newline)
+ else:
+ return binary_file
diff --git a/Lib/concurrent/futures/_base.py b/Lib/concurrent/futures/_base.py
index d98b1ebdd58..f506ce68aea 100644
--- a/Lib/concurrent/futures/_base.py
+++ b/Lib/concurrent/futures/_base.py
@@ -558,6 +558,33 @@ class Future(object):
self._condition.notify_all()
self._invoke_callbacks()
+ def _get_snapshot(self):
+ """Get a snapshot of the future's current state.
+
+ This method atomically retrieves the state in one lock acquisition,
+ which is significantly faster than multiple method calls.
+
+ Returns:
+ Tuple of (done, cancelled, result, exception)
+ - done: True if the future is done (cancelled or finished)
+ - cancelled: True if the future was cancelled
+ - result: The result if available and not cancelled
+ - exception: The exception if available and not cancelled
+ """
+ # Fast path: check if already finished without lock
+ if self._state == FINISHED:
+ return True, False, self._result, self._exception
+
+ # Need lock for other states since they can change
+ with self._condition:
+ # We have to check the state again after acquiring the lock
+ # because it may have changed in the meantime.
+ if self._state == FINISHED:
+ return True, False, self._result, self._exception
+ if self._state in {CANCELLED, CANCELLED_AND_NOTIFIED}:
+ return True, True, None, None
+ return False, False, None, None
+
__class_getitem__ = classmethod(types.GenericAlias)
class Executor(object):
diff --git a/Lib/concurrent/futures/interpreter.py b/Lib/concurrent/futures/interpreter.py
index d17688dc9d7..cbb60ce80c1 100644
--- a/Lib/concurrent/futures/interpreter.py
+++ b/Lib/concurrent/futures/interpreter.py
@@ -1,69 +1,39 @@
"""Implements InterpreterPoolExecutor."""
-import contextlib
-import pickle
+from concurrent import interpreters
+import sys
import textwrap
from . import thread as _thread
-import _interpreters
-import _interpqueues
+import traceback
-class ExecutionFailed(_interpreters.InterpreterError):
- """An unhandled exception happened during execution."""
-
- def __init__(self, excinfo):
- msg = excinfo.formatted
- if not msg:
- if excinfo.type and excinfo.msg:
- msg = f'{excinfo.type.__name__}: {excinfo.msg}'
- else:
- msg = excinfo.type.__name__ or excinfo.msg
- super().__init__(msg)
- self.excinfo = excinfo
-
- def __str__(self):
+def do_call(results, func, args, kwargs):
+ try:
+ return func(*args, **kwargs)
+ except BaseException as exc:
+ # Send the captured exception out on the results queue,
+ # but still leave it unhandled for the interpreter to handle.
try:
- formatted = self.excinfo.errdisplay
- except Exception:
- return super().__str__()
- else:
- return textwrap.dedent(f"""
-{super().__str__()}
-
-Uncaught in the interpreter:
-
-{formatted}
- """.strip())
-
-
-UNBOUND = 2 # error; this should not happen.
+ results.put(exc)
+ except interpreters.NotShareableError:
+ # The exception is not shareable.
+ print('exception is not shareable:', file=sys.stderr)
+ traceback.print_exception(exc)
+ results.put(None)
+ raise # re-raise
class WorkerContext(_thread.WorkerContext):
@classmethod
- def prepare(cls, initializer, initargs, shared):
+ def prepare(cls, initializer, initargs):
def resolve_task(fn, args, kwargs):
if isinstance(fn, str):
# XXX Circle back to this later.
raise TypeError('scripts not supported')
- if args or kwargs:
- raise ValueError(f'a script does not take args or kwargs, got {args!r} and {kwargs!r}')
- data = textwrap.dedent(fn)
- kind = 'script'
- # Make sure the script compiles.
- # Ideally we wouldn't throw away the resulting code
- # object. However, there isn't much to be done until
- # code objects are shareable and/or we do a better job
- # of supporting code objects in _interpreters.exec().
- compile(data, '<string>', 'exec')
else:
- # Functions defined in the __main__ module can't be pickled,
- # so they can't be used here. In the future, we could possibly
- # borrow from multiprocessing to work around this.
- data = pickle.dumps((fn, args, kwargs))
- kind = 'function'
- return (data, kind)
+ task = (fn, args, kwargs)
+ return task
if initializer is not None:
try:
@@ -75,73 +45,24 @@ class WorkerContext(_thread.WorkerContext):
else:
initdata = None
def create_context():
- return cls(initdata, shared)
+ return cls(initdata)
return create_context, resolve_task
- @classmethod
- @contextlib.contextmanager
- def _capture_exc(cls, resultsid):
- try:
- yield
- except BaseException as exc:
- # Send the captured exception out on the results queue,
- # but still leave it unhandled for the interpreter to handle.
- err = pickle.dumps(exc)
- _interpqueues.put(resultsid, (None, err), 1, UNBOUND)
- raise # re-raise
-
- @classmethod
- def _send_script_result(cls, resultsid):
- _interpqueues.put(resultsid, (None, None), 0, UNBOUND)
-
- @classmethod
- def _call(cls, func, args, kwargs, resultsid):
- with cls._capture_exc(resultsid):
- res = func(*args or (), **kwargs or {})
- # Send the result back.
- try:
- _interpqueues.put(resultsid, (res, None), 0, UNBOUND)
- except _interpreters.NotShareableError:
- res = pickle.dumps(res)
- _interpqueues.put(resultsid, (res, None), 1, UNBOUND)
-
- @classmethod
- def _call_pickled(cls, pickled, resultsid):
- with cls._capture_exc(resultsid):
- fn, args, kwargs = pickle.loads(pickled)
- cls._call(fn, args, kwargs, resultsid)
-
- def __init__(self, initdata, shared=None):
+ def __init__(self, initdata):
self.initdata = initdata
- self.shared = dict(shared) if shared else None
- self.interpid = None
- self.resultsid = None
+ self.interp = None
+ self.results = None
def __del__(self):
- if self.interpid is not None:
+ if self.interp is not None:
self.finalize()
- def _exec(self, script):
- assert self.interpid is not None
- excinfo = _interpreters.exec(self.interpid, script, restrict=True)
- if excinfo is not None:
- raise ExecutionFailed(excinfo)
-
def initialize(self):
- assert self.interpid is None, self.interpid
- self.interpid = _interpreters.create(reqrefs=True)
+ assert self.interp is None, self.interp
+ self.interp = interpreters.create()
try:
- _interpreters.incref(self.interpid)
-
maxsize = 0
- fmt = 0
- self.resultsid = _interpqueues.create(maxsize, fmt, UNBOUND)
-
- self._exec(f'from {__name__} import WorkerContext')
-
- if self.shared:
- _interpreters.set___main___attrs(
- self.interpid, self.shared, restrict=True)
+ self.results = interpreters.create_queue(maxsize)
if self.initdata:
self.run(self.initdata)
@@ -150,64 +71,25 @@ class WorkerContext(_thread.WorkerContext):
raise # re-raise
def finalize(self):
- interpid = self.interpid
- resultsid = self.resultsid
- self.resultsid = None
- self.interpid = None
- if resultsid is not None:
- try:
- _interpqueues.destroy(resultsid)
- except _interpqueues.QueueNotFoundError:
- pass
- if interpid is not None:
- try:
- _interpreters.decref(interpid)
- except _interpreters.InterpreterNotFoundError:
- pass
+ interp = self.interp
+ results = self.results
+ self.results = None
+ self.interp = None
+ if results is not None:
+ del results
+ if interp is not None:
+ interp.close()
def run(self, task):
- data, kind = task
- if kind == 'script':
- raise NotImplementedError('script kind disabled')
- script = f"""
-with WorkerContext._capture_exc({self.resultsid}):
-{textwrap.indent(data, ' ')}
-WorkerContext._send_script_result({self.resultsid})"""
- elif kind == 'function':
- script = f'WorkerContext._call_pickled({data!r}, {self.resultsid})'
- else:
- raise NotImplementedError(kind)
-
try:
- self._exec(script)
- except ExecutionFailed as exc:
- exc_wrapper = exc
- else:
- exc_wrapper = None
-
- # Return the result, or raise the exception.
- while True:
- try:
- obj = _interpqueues.get(self.resultsid)
- except _interpqueues.QueueNotFoundError:
+ return self.interp.call(do_call, self.results, *task)
+ except interpreters.ExecutionFailed as wrapper:
+ # Wait for the exception data to show up.
+ exc = self.results.get()
+ if exc is None:
+ # The exception must have been not shareable.
raise # re-raise
- except _interpqueues.QueueError:
- continue
- except ModuleNotFoundError:
- # interpreters.queues doesn't exist, which means
- # QueueEmpty doesn't. Act as though it does.
- continue
- else:
- break
- (res, excdata), pickled, unboundop = obj
- assert unboundop is None, unboundop
- if excdata is not None:
- assert res is None, res
- assert pickled
- assert exc_wrapper is not None
- exc = pickle.loads(excdata)
- raise exc from exc_wrapper
- return pickle.loads(res) if pickled else res
+ raise exc from wrapper
class BrokenInterpreterPool(_thread.BrokenThreadPool):
@@ -221,11 +103,11 @@ class InterpreterPoolExecutor(_thread.ThreadPoolExecutor):
BROKEN = BrokenInterpreterPool
@classmethod
- def prepare_context(cls, initializer, initargs, shared):
- return WorkerContext.prepare(initializer, initargs, shared)
+ def prepare_context(cls, initializer, initargs):
+ return WorkerContext.prepare(initializer, initargs)
def __init__(self, max_workers=None, thread_name_prefix='',
- initializer=None, initargs=(), shared=None):
+ initializer=None, initargs=()):
"""Initializes a new InterpreterPoolExecutor instance.
Args:
@@ -235,8 +117,6 @@ class InterpreterPoolExecutor(_thread.ThreadPoolExecutor):
initializer: A callable or script used to initialize
each worker interpreter.
initargs: A tuple of arguments to pass to the initializer.
- shared: A mapping of shareabled objects to be inserted into
- each worker interpreter.
"""
super().__init__(max_workers, thread_name_prefix,
- initializer, initargs, shared=shared)
+ initializer, initargs)
diff --git a/Lib/concurrent/futures/process.py b/Lib/concurrent/futures/process.py
index 76b7b2abe83..a14650bf5fa 100644
--- a/Lib/concurrent/futures/process.py
+++ b/Lib/concurrent/futures/process.py
@@ -755,6 +755,11 @@ class ProcessPoolExecutor(_base.Executor):
self._executor_manager_thread_wakeup
def _adjust_process_count(self):
+ # gh-132969: avoid error when state is reset and executor is still running,
+ # which will happen when shutdown(wait=False) is called.
+ if self._processes is None:
+ return
+
# if there's an idle process, we don't need to spawn a new one.
if self._idle_worker_semaphore.acquire(blocking=False):
return
diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/concurrent/interpreters/__init__.py
index e067f259364..0fd661249a2 100644
--- a/Lib/test/support/interpreters/__init__.py
+++ b/Lib/concurrent/interpreters/__init__.py
@@ -9,6 +9,10 @@ from _interpreters import (
InterpreterError, InterpreterNotFoundError, NotShareableError,
is_shareable,
)
+from ._queues import (
+ create as create_queue,
+ Queue, QueueEmpty, QueueFull,
+)
__all__ = [
@@ -20,21 +24,6 @@ __all__ = [
]
-_queuemod = None
-
-def __getattr__(name):
- if name in ('Queue', 'QueueEmpty', 'QueueFull', 'create_queue'):
- global create_queue, Queue, QueueEmpty, QueueFull
- ns = globals()
- from .queues import (
- create as create_queue,
- Queue, QueueEmpty, QueueFull,
- )
- return ns[name]
- else:
- raise AttributeError(name)
-
-
_EXEC_FAILURE_STR = """
{superstr}
@@ -226,33 +215,32 @@ class Interpreter:
if excinfo is not None:
raise ExecutionFailed(excinfo)
- def call(self, callable, /):
- """Call the object in the interpreter with given args/kwargs.
+ def _call(self, callable, args, kwargs):
+ res, excinfo = _interpreters.call(self._id, callable, args, kwargs, restrict=True)
+ if excinfo is not None:
+ raise ExecutionFailed(excinfo)
+ return res
- Only functions that take no arguments and have no closure
- are supported.
+ def call(self, callable, /, *args, **kwargs):
+ """Call the object in the interpreter with given args/kwargs.
- The return value is discarded.
+ Nearly all callables, args, kwargs, and return values are
+ supported. All "shareable" objects are supported, as are
+ "stateless" functions (meaning non-closures that do not use
+ any globals). This method will fall back to pickle.
If the callable raises an exception then the error display
- (including full traceback) is send back between the interpreters
+ (including full traceback) is sent back between the interpreters
and an ExecutionFailed exception is raised, much like what
happens with Interpreter.exec().
"""
- # XXX Support args and kwargs.
- # XXX Support arbitrary callables.
- # XXX Support returning the return value (e.g. via pickle).
- excinfo = _interpreters.call(self._id, callable, restrict=True)
- if excinfo is not None:
- raise ExecutionFailed(excinfo)
+ return self._call(callable, args, kwargs)
- def call_in_thread(self, callable, /):
+ def call_in_thread(self, callable, /, *args, **kwargs):
"""Return a new thread that calls the object in the interpreter.
The return value and any raised exception are discarded.
"""
- def task():
- self.call(callable)
- t = threading.Thread(target=task)
+ t = threading.Thread(target=self._call, args=(callable, args, kwargs))
t.start()
return t
diff --git a/Lib/test/support/interpreters/_crossinterp.py b/Lib/concurrent/interpreters/_crossinterp.py
index 544e197ba4c..f47eb693ac8 100644
--- a/Lib/test/support/interpreters/_crossinterp.py
+++ b/Lib/concurrent/interpreters/_crossinterp.py
@@ -61,7 +61,7 @@ class UnboundItem:
def __repr__(self):
return f'{self._MODULE}.{self._NAME}'
-# return f'interpreters.queues.UNBOUND'
+# return f'interpreters._queues.UNBOUND'
UNBOUND = object.__new__(UnboundItem)
diff --git a/Lib/test/support/interpreters/queues.py b/Lib/concurrent/interpreters/_queues.py
index deb8e8613af..99987f2f692 100644
--- a/Lib/test/support/interpreters/queues.py
+++ b/Lib/concurrent/interpreters/_queues.py
@@ -1,6 +1,5 @@
"""Cross-interpreter Queues High Level Module."""
-import pickle
import queue
import time
import weakref
@@ -63,29 +62,34 @@ def _resolve_unbound(flag):
return resolved
-def create(maxsize=0, *, syncobj=False, unbounditems=UNBOUND):
+def create(maxsize=0, *, unbounditems=UNBOUND):
"""Return a new cross-interpreter queue.
The queue may be used to pass data safely between interpreters.
- "syncobj" sets the default for Queue.put()
- and Queue.put_nowait().
-
- "unbounditems" likewise sets the default. See Queue.put() for
+ "unbounditems" sets the default for Queue.put(); see that method for
supported values. The default value is UNBOUND, which replaces
the unbound item.
"""
- fmt = _SHARED_ONLY if syncobj else _PICKLED
unbound = _serialize_unbound(unbounditems)
unboundop, = unbound
- qid = _queues.create(maxsize, fmt, unboundop)
- return Queue(qid, _fmt=fmt, _unbound=unbound)
+ qid = _queues.create(maxsize, unboundop, -1)
+ self = Queue(qid)
+ self._set_unbound(unboundop, unbounditems)
+ return self
def list_all():
"""Return a list of all open queues."""
- return [Queue(qid, _fmt=fmt, _unbound=(unboundop,))
- for qid, fmt, unboundop in _queues.list_all()]
+ queues = []
+ for qid, unboundop, _ in _queues.list_all():
+ self = Queue(qid)
+ if not hasattr(self, '_unbound'):
+ self._set_unbound(unboundop)
+ else:
+ assert self._unbound[0] == unboundop
+ queues.append(self)
+ return queues
_known_queues = weakref.WeakValueDictionary()
@@ -93,28 +97,17 @@ _known_queues = weakref.WeakValueDictionary()
class Queue:
"""A cross-interpreter queue."""
- def __new__(cls, id, /, *, _fmt=None, _unbound=None):
+ def __new__(cls, id, /):
# There is only one instance for any given ID.
if isinstance(id, int):
id = int(id)
else:
raise TypeError(f'id must be an int, got {id!r}')
- if _fmt is None:
- if _unbound is None:
- _fmt, op = _queues.get_queue_defaults(id)
- _unbound = (op,)
- else:
- _fmt, _ = _queues.get_queue_defaults(id)
- elif _unbound is None:
- _, op = _queues.get_queue_defaults(id)
- _unbound = (op,)
try:
self = _known_queues[id]
except KeyError:
self = super().__new__(cls)
self._id = id
- self._fmt = _fmt
- self._unbound = _unbound
_known_queues[id] = self
_queues.bind(id)
return self
@@ -143,11 +136,28 @@ class Queue:
def __getstate__(self):
return None
+ def _set_unbound(self, op, items=None):
+ assert not hasattr(self, '_unbound')
+ if items is None:
+ items = _resolve_unbound(op)
+ unbound = (op, items)
+ self._unbound = unbound
+ return unbound
+
@property
def id(self):
return self._id
@property
+ def unbounditems(self):
+ try:
+ _, items = self._unbound
+ except AttributeError:
+ op, _ = _queues.get_queue_defaults(self._id)
+ _, items = self._set_unbound(op)
+ return items
+
+ @property
def maxsize(self):
try:
return self._maxsize
@@ -165,77 +175,56 @@ class Queue:
return _queues.get_count(self._id)
def put(self, obj, timeout=None, *,
- syncobj=None,
- unbound=None,
+ unbounditems=None,
_delay=10 / 1000, # 10 milliseconds
):
"""Add the object to the queue.
This blocks while the queue is full.
- If "syncobj" is None (the default) then it uses the
- queue's default, set with create_queue().
-
- If "syncobj" is false then all objects are supported,
- at the expense of worse performance.
-
- If "syncobj" is true then the object must be "shareable".
- Examples of "shareable" objects include the builtin singletons,
- str, and memoryview. One benefit is that such objects are
- passed through the queue efficiently.
-
- The key difference, though, is conceptual: the corresponding
- object returned from Queue.get() will be strictly equivalent
- to the given obj. In other words, the two objects will be
- effectively indistinguishable from each other, even if the
- object is mutable. The received object may actually be the
- same object, or a copy (immutable values only), or a proxy.
- Regardless, the received object should be treated as though
- the original has been shared directly, whether or not it
- actually is. That's a slightly different and stronger promise
- than just (initial) equality, which is all "syncobj=False"
- can promise.
-
- "unbound" controls the behavior of Queue.get() for the given
+ For most objects, the object received through Queue.get() will
+ be a new one, equivalent to the original and not sharing any
+ actual underlying data. The notable exceptions include
+ cross-interpreter types (like Queue) and memoryview, where the
+ underlying data is actually shared. Furthermore, some types
+ can be sent through a queue more efficiently than others. This
+ group includes various immutable types like int, str, bytes, and
+ tuple (if the items are likewise efficiently shareable). See interpreters.is_shareable().
+
+ "unbounditems" controls the behavior of Queue.get() for the given
object if the current interpreter (calling put()) is later
destroyed.
- If "unbound" is None (the default) then it uses the
+ If "unbounditems" is None (the default) then it uses the
queue's default, set with create_queue(),
which is usually UNBOUND.
- If "unbound" is UNBOUND_ERROR then get() will raise an
+ If "unbounditems" is UNBOUND_ERROR then get() will raise an
ItemInterpreterDestroyed exception if the original interpreter
has been destroyed. This does not otherwise affect the queue;
the next call to put() will work like normal, returning the next
item in the queue.
- If "unbound" is UNBOUND_REMOVE then the item will be removed
+ If "unbounditems" is UNBOUND_REMOVE then the item will be removed
from the queue as soon as the original interpreter is destroyed.
Be aware that this will introduce an imbalance between put()
and get() calls.
- If "unbound" is UNBOUND then it is returned by get() in place
+ If "unbounditems" is UNBOUND then it is returned by get() in place
of the unbound item.
"""
- if syncobj is None:
- fmt = self._fmt
- else:
- fmt = _SHARED_ONLY if syncobj else _PICKLED
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
if timeout is not None:
timeout = int(timeout)
if timeout < 0:
raise ValueError(f'timeout value must be non-negative')
end = time.time() + timeout
- if fmt is _PICKLED:
- obj = pickle.dumps(obj)
while True:
try:
- _queues.put(self._id, obj, fmt, unboundop)
+ _queues.put(self._id, obj, unboundop)
except QueueFull as exc:
if timeout is not None and time.time() >= end:
raise # re-raise
@@ -243,18 +232,12 @@ class Queue:
else:
break
- def put_nowait(self, obj, *, syncobj=None, unbound=None):
- if syncobj is None:
- fmt = self._fmt
+ def put_nowait(self, obj, *, unbounditems=None):
+ if unbounditems is None:
+ unboundop = -1
else:
- fmt = _SHARED_ONLY if syncobj else _PICKLED
- if unbound is None:
- unboundop, = self._unbound
- else:
- unboundop, = _serialize_unbound(unbound)
- if fmt is _PICKLED:
- obj = pickle.dumps(obj)
- _queues.put(self._id, obj, fmt, unboundop)
+ unboundop, = _serialize_unbound(unbounditems)
+ _queues.put(self._id, obj, unboundop)
def get(self, timeout=None, *,
_delay=10 / 1000, # 10 milliseconds
@@ -265,7 +248,7 @@ class Queue:
If the next item's original interpreter has been destroyed
then the "next object" is determined by the value of the
- "unbound" argument to put().
+ "unbounditems" argument to put().
"""
if timeout is not None:
timeout = int(timeout)
@@ -274,7 +257,7 @@ class Queue:
end = time.time() + timeout
while True:
try:
- obj, fmt, unboundop = _queues.get(self._id)
+ obj, unboundop = _queues.get(self._id)
except QueueEmpty as exc:
if timeout is not None and time.time() >= end:
raise # re-raise
@@ -284,10 +267,6 @@ class Queue:
if unboundop is not None:
assert obj is None, repr(obj)
return _resolve_unbound(unboundop)
- if fmt == _PICKLED:
- obj = pickle.loads(obj)
- else:
- assert fmt == _SHARED_ONLY
return obj
def get_nowait(self):
@@ -297,16 +276,12 @@ class Queue:
is the same as get().
"""
try:
- obj, fmt, unboundop = _queues.get(self._id)
+ obj, unboundop = _queues.get(self._id)
except QueueEmpty as exc:
raise # re-raise
if unboundop is not None:
assert obj is None, repr(obj)
return _resolve_unbound(unboundop)
- if fmt == _PICKLED:
- obj = pickle.loads(obj)
- else:
- assert fmt == _SHARED_ONLY
return obj
diff --git a/Lib/configparser.py b/Lib/configparser.py
index 239fda60a02..18af1eadaad 100644
--- a/Lib/configparser.py
+++ b/Lib/configparser.py
@@ -1218,11 +1218,14 @@ class RawConfigParser(MutableMapping):
def _validate_key_contents(self, key):
"""Raises an InvalidWriteError for any keys containing
- delimiters or that match the section header pattern"""
+ delimiters or that begins with the section header pattern"""
if re.match(self.SECTCRE, key):
- raise InvalidWriteError("Cannot write keys matching section pattern")
- if any(delim in key for delim in self._delimiters):
- raise InvalidWriteError("Cannot write key that contains delimiters")
+ raise InvalidWriteError(
+ f"Cannot write key {key}; begins with section pattern")
+ for delim in self._delimiters:
+ if delim in key:
+ raise InvalidWriteError(
+ f"Cannot write key {key}; contains delimiter {delim}")
def _validate_value_types(self, *, section="", option="", value=""):
"""Raises a TypeError for illegal non-string values.
diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py
index 823a3692fd1..d6d07a13f75 100644
--- a/Lib/ctypes/__init__.py
+++ b/Lib/ctypes/__init__.py
@@ -379,12 +379,6 @@ def create_unicode_buffer(init, size=None):
return buf
raise TypeError(init)
-
-def SetPointerType(pointer, cls):
- import warnings
- warnings._deprecated("ctypes.SetPointerType", remove=(3, 15))
- pointer.set_type(cls)
-
def ARRAY(typ, len):
return typ * len
diff --git a/Lib/ctypes/_layout.py b/Lib/ctypes/_layout.py
index 0719e72cfed..2048ccb6a1c 100644
--- a/Lib/ctypes/_layout.py
+++ b/Lib/ctypes/_layout.py
@@ -5,6 +5,7 @@ may change at any time.
"""
import sys
+import warnings
from _ctypes import CField, buffer_info
import ctypes
@@ -66,9 +67,26 @@ def get_layout(cls, input_fields, is_struct, base):
# For clarity, variables that count bits have `bit` in their names.
+ pack = getattr(cls, '_pack_', None)
+
layout = getattr(cls, '_layout_', None)
if layout is None:
- if sys.platform == 'win32' or getattr(cls, '_pack_', None):
+ if sys.platform == 'win32':
+ gcc_layout = False
+ elif pack:
+ if is_struct:
+ base_type_name = 'Structure'
+ else:
+ base_type_name = 'Union'
+ warnings._deprecated(
+ '_pack_ without _layout_',
+ f"Due to '_pack_', the '{cls.__name__}' {base_type_name} will "
+ + "use memory layout compatible with MSVC (Windows). "
+ + "If this is intended, set _layout_ to 'ms'. "
+ + "The implicit default is deprecated and slated to become "
+ + "an error in Python {remove}.",
+ remove=(3, 19),
+ )
gcc_layout = False
else:
gcc_layout = True
@@ -95,7 +113,6 @@ def get_layout(cls, input_fields, is_struct, base):
else:
big_endian = sys.byteorder == 'big'
- pack = getattr(cls, '_pack_', None)
if pack is not None:
try:
pack = int(pack)
diff --git a/Lib/curses/__init__.py b/Lib/curses/__init__.py
index 6165fe6c987..605d5fcbec5 100644
--- a/Lib/curses/__init__.py
+++ b/Lib/curses/__init__.py
@@ -30,9 +30,8 @@ def initscr():
fd=_sys.__stdout__.fileno())
stdscr = _curses.initscr()
for key, value in _curses.__dict__.items():
- if key[0:4] == 'ACS_' or key in ('LINES', 'COLS'):
+ if key.startswith('ACS_') or key in ('LINES', 'COLS'):
setattr(curses, key, value)
-
return stdscr
# This is a similar wrapper for start_color(), which adds the COLORS and
@@ -41,12 +40,9 @@ def initscr():
def start_color():
import _curses, curses
- retval = _curses.start_color()
- if hasattr(_curses, 'COLORS'):
- curses.COLORS = _curses.COLORS
- if hasattr(_curses, 'COLOR_PAIRS'):
- curses.COLOR_PAIRS = _curses.COLOR_PAIRS
- return retval
+ _curses.start_color()
+ curses.COLORS = _curses.COLORS
+ curses.COLOR_PAIRS = _curses.COLOR_PAIRS
# Import Python has_key() implementation if _curses doesn't contain has_key()
@@ -85,10 +81,11 @@ def wrapper(func, /, *args, **kwds):
# Start color, too. Harmless if the terminal doesn't have
# color; user can test with has_color() later on. The try/catch
# works around a minor bit of over-conscientiousness in the curses
- # module -- the error return from C start_color() is ignorable.
+ # module -- the error return from C start_color() is ignorable,
+ # unless they are raised by the interpreter due to other issues.
try:
start_color()
- except:
+ except _curses.error:
pass
return func(stdscr, *args, **kwds)
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 0f7dc9ae6b8..86d29df0639 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
property,
})
+# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
+# without importing `typing` module.
+_ANY_MARKER = object()
+
class InitVar:
__slots__ = ('type', )
@@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
for item in fields:
if isinstance(item, str):
name = item
- tp = 'typing.Any'
+ tp = _ANY_MARKER
elif len(item) == 2:
name, tp, = item
elif len(item) == 3:
@@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
seen.add(name)
annotations[name] = tp
+ # We initially block the VALUE format, because inside dataclass() we'll
+ # call get_annotations(), which will try the VALUE format first. If we don't
+ # block, that means we'd always end up eagerly importing typing here, which
+ # is what we're trying to avoid.
+ value_blocked = True
+
+ def annotate_method(format):
+ def get_any():
+ match format:
+ case annotationlib.Format.STRING:
+ return 'typing.Any'
+ case annotationlib.Format.FORWARDREF:
+ typing = sys.modules.get("typing")
+ if typing is None:
+ return annotationlib.ForwardRef("Any", module="typing")
+ else:
+ return typing.Any
+ case annotationlib.Format.VALUE:
+ if value_blocked:
+ raise NotImplementedError
+ from typing import Any
+ return Any
+ case _:
+ raise NotImplementedError
+ annos = {
+ ann: get_any() if t is _ANY_MARKER else t
+ for ann, t in annotations.items()
+ }
+ if format == annotationlib.Format.STRING:
+ return annotationlib.annotations_to_string(annos)
+ else:
+ return annos
+
# Update 'ns' with the user-supplied namespace plus our calculated values.
def exec_body_callback(ns):
ns.update(namespace)
ns.update(defaults)
- ns['__annotations__'] = annotations
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
+ # For now, set annotations including the _ANY_MARKER.
+ cls.__annotate__ = annotate_method
# For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created.
@@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
cls.__module__ = module
# Apply the normal provided decorator.
- return decorator(cls, init=init, repr=repr, eq=eq, order=order,
- unsafe_hash=unsafe_hash, frozen=frozen,
- match_args=match_args, kw_only=kw_only, slots=slots,
- weakref_slot=weakref_slot)
+ cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
+ unsafe_hash=unsafe_hash, frozen=frozen,
+ match_args=match_args, kw_only=kw_only, slots=slots,
+ weakref_slot=weakref_slot)
+ # Now that the class is ready, allow the VALUE format.
+ value_blocked = False
+ return cls
def replace(obj, /, **changes):
diff --git a/Lib/dbm/dumb.py b/Lib/dbm/dumb.py
index def120ffc37..1bc239a84ff 100644
--- a/Lib/dbm/dumb.py
+++ b/Lib/dbm/dumb.py
@@ -9,7 +9,7 @@ XXX TO DO:
- seems to contain a bug when updating...
- reclaim free space (currently, space once occupied by deleted or expanded
-items is never reused)
+items is not reused exept if .reorganize() is called)
- support concurrent access (currently, if two processes take turns making
updates, they can mess up the index)
@@ -17,8 +17,6 @@ updates, they can mess up the index)
- support efficient access to large databases (currently, the whole index
is read when the database is opened, and some updates rewrite the whole index)
-- support opening for read-only (flag = 'm')
-
"""
import ast as _ast
@@ -289,6 +287,34 @@ class _Database(collections.abc.MutableMapping):
def __exit__(self, *args):
self.close()
+ def reorganize(self):
+ if self._readonly:
+ raise error('The database is opened for reading only')
+ self._verify_open()
+ # Ensure all changes are committed before reorganizing.
+ self._commit()
+ # Open file in r+ to allow changing in-place.
+ with _io.open(self._datfile, 'rb+') as f:
+ reorganize_pos = 0
+
+ # Iterate over existing keys, sorted by starting byte.
+ for key in sorted(self._index, key = lambda k: self._index[k][0]):
+ pos, siz = self._index[key]
+ f.seek(pos)
+ val = f.read(siz)
+
+ f.seek(reorganize_pos)
+ f.write(val)
+ self._index[key] = (reorganize_pos, siz)
+
+ blocks_occupied = (siz + _BLOCKSIZE - 1) // _BLOCKSIZE
+ reorganize_pos += blocks_occupied * _BLOCKSIZE
+
+ f.truncate(reorganize_pos)
+ # Commit changes to index, which were not in-place.
+ self._commit()
+
+
def open(file, flag='c', mode=0o666):
"""Open the database file, filename, and return corresponding object.
diff --git a/Lib/dbm/sqlite3.py b/Lib/dbm/sqlite3.py
index 7e0ae2a29e3..b296a1bcd1b 100644
--- a/Lib/dbm/sqlite3.py
+++ b/Lib/dbm/sqlite3.py
@@ -15,6 +15,7 @@ LOOKUP_KEY = "SELECT value FROM Dict WHERE key = CAST(? AS BLOB)"
STORE_KV = "REPLACE INTO Dict (key, value) VALUES (CAST(? AS BLOB), CAST(? AS BLOB))"
DELETE_KEY = "DELETE FROM Dict WHERE key = CAST(? AS BLOB)"
ITER_KEYS = "SELECT key FROM Dict"
+REORGANIZE = "VACUUM"
class error(OSError):
@@ -122,6 +123,9 @@ class _Database(MutableMapping):
def __exit__(self, *args):
self.close()
+ def reorganize(self):
+ self._execute(REORGANIZE)
+
def open(filename, /, flag="r", mode=0o666):
"""Open a dbm.sqlite3 database and return the dbm object.
diff --git a/Lib/difflib.py b/Lib/difflib.py
index f1f4e62514a..487936dbf47 100644
--- a/Lib/difflib.py
+++ b/Lib/difflib.py
@@ -78,8 +78,8 @@ class SequenceMatcher:
sequences. As a rule of thumb, a .ratio() value over 0.6 means the
sequences are close matches:
- >>> print(round(s.ratio(), 3))
- 0.866
+ >>> print(round(s.ratio(), 2))
+ 0.87
>>>
If you're only interested in where the sequences match,
@@ -1615,16 +1615,13 @@ def _mdiff(fromlines, tolines, context=None, linejunk=None,
_file_template = """
-<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
- "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
-
-<html>
-
+<!DOCTYPE html>
+<html lang="en">
<head>
- <meta http-equiv="Content-Type"
- content="text/html; charset=%(charset)s" />
- <title></title>
- <style type="text/css">%(styles)s
+ <meta charset="%(charset)s">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ <title>Diff comparison</title>
+ <style>%(styles)s
</style>
</head>
@@ -1636,13 +1633,36 @@ _file_template = """
_styles = """
:root {color-scheme: light dark}
- table.diff {font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace; border:medium}
- .diff_header {background-color:#e0e0e0}
- td.diff_header {text-align:right}
- .diff_next {background-color:#c0c0c0}
+ table.diff {
+ font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace;
+ border: medium;
+ }
+ .diff_header {
+ background-color: #e0e0e0;
+ font-weight: bold;
+ }
+ td.diff_header {
+ text-align: right;
+ padding: 0 8px;
+ }
+ .diff_next {
+ background-color: #c0c0c0;
+ padding: 4px 0;
+ }
.diff_add {background-color:palegreen}
.diff_chg {background-color:#ffff77}
.diff_sub {background-color:#ffaaaa}
+ table.diff[summary="Legends"] {
+ margin-top: 20px;
+ border: 1px solid #ccc;
+ }
+ table.diff[summary="Legends"] th {
+ background-color: #e0e0e0;
+ padding: 4px 8px;
+ }
+ table.diff[summary="Legends"] td {
+ padding: 4px 8px;
+ }
@media (prefers-color-scheme: dark) {
.diff_header {background-color:#666}
@@ -1650,6 +1670,8 @@ _styles = """
.diff_add {background-color:darkgreen}
.diff_chg {background-color:#847415}
.diff_sub {background-color:darkred}
+ table.diff[summary="Legends"] {border-color:#555}
+ table.diff[summary="Legends"] th{background-color:#666}
}"""
_table_template = """
@@ -1692,7 +1714,7 @@ class HtmlDiff(object):
make_table -- generates HTML for a single side by side table
make_file -- generates complete HTML file with a single side by side table
- See tools/scripts/diff.py for an example usage of this class.
+ See Doc/includes/diff.py for an example usage of this class.
"""
_file_template = _file_template
diff --git a/Lib/dis.py b/Lib/dis.py
index cb6d077a391..d6d2c1386dd 100644
--- a/Lib/dis.py
+++ b/Lib/dis.py
@@ -1131,7 +1131,7 @@ class Bytecode:
def main(args=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('-C', '--show-caches', action='store_true',
help='show inline caches')
parser.add_argument('-O', '--show-offsets', action='store_true',
diff --git a/Lib/doctest.py b/Lib/doctest.py
index e02e73ed722..c8c95ecbb27 100644
--- a/Lib/doctest.py
+++ b/Lib/doctest.py
@@ -101,6 +101,7 @@ import pdb
import re
import sys
import traceback
+import types
import unittest
from io import StringIO, IncrementalNewlineDecoder
from collections import namedtuple
@@ -385,7 +386,7 @@ class _OutputRedirectingPdb(pdb.Pdb):
self.__out = out
self.__debugger_used = False
# do not play signal games in the pdb
- pdb.Pdb.__init__(self, stdout=out, nosigint=True)
+ super().__init__(stdout=out, nosigint=True)
# still use input() to get user input
self.use_rawinput = 1
@@ -1278,6 +1279,11 @@ class DocTestRunner:
# Reporting methods
#/////////////////////////////////////////////////////////////////
+ def report_skip(self, out, test, example):
+ """
+ Report that the given example was skipped.
+ """
+
def report_start(self, out, test, example):
"""
Report that the test runner is about to process the given
@@ -1375,6 +1381,8 @@ class DocTestRunner:
# If 'SKIP' is set, then skip this example.
if self.optionflags & SKIP:
+ if not quiet:
+ self.report_skip(out, test, example)
skips += 1
continue
@@ -1395,11 +1403,11 @@ class DocTestRunner:
exec(compile(example.source, filename, "single",
compileflags, True), test.globs)
self.debugger.set_continue() # ==== Example Finished ====
- exception = None
+ exc_info = None
except KeyboardInterrupt:
raise
- except:
- exception = sys.exc_info()
+ except BaseException as exc:
+ exc_info = type(exc), exc, exc.__traceback__.tb_next
self.debugger.set_continue() # ==== Example Finished ====
got = self._fakeout.getvalue() # the actual output
@@ -1408,21 +1416,21 @@ class DocTestRunner:
# If the example executed without raising any exceptions,
# verify its output.
- if exception is None:
+ if exc_info is None:
if check(example.want, got, self.optionflags):
outcome = SUCCESS
# The example raised an exception: check if it was expected.
else:
- formatted_ex = traceback.format_exception_only(*exception[:2])
- if issubclass(exception[0], SyntaxError):
+ formatted_ex = traceback.format_exception_only(*exc_info[:2])
+ if issubclass(exc_info[0], SyntaxError):
# SyntaxError / IndentationError is special:
# we don't care about the carets / suggestions / etc
# We only care about the error message and notes.
# They start with `SyntaxError:` (or any other class name)
exception_line_prefixes = (
- f"{exception[0].__qualname__}:",
- f"{exception[0].__module__}.{exception[0].__qualname__}:",
+ f"{exc_info[0].__qualname__}:",
+ f"{exc_info[0].__module__}.{exc_info[0].__qualname__}:",
)
exc_msg_index = next(
index
@@ -1433,7 +1441,7 @@ class DocTestRunner:
exc_msg = "".join(formatted_ex)
if not quiet:
- got += _exception_traceback(exception)
+ got += _exception_traceback(exc_info)
# If `example.exc_msg` is None, then we weren't expecting
# an exception.
@@ -1462,7 +1470,7 @@ class DocTestRunner:
elif outcome is BOOM:
if not quiet:
self.report_unexpected_exception(out, test, example,
- exception)
+ exc_info)
failures += 1
else:
assert False, ("unknown outcome", outcome)
@@ -2272,12 +2280,63 @@ def set_unittest_reportflags(flags):
return old
+class _DocTestCaseRunner(DocTestRunner):
+
+ def __init__(self, *args, test_case, test_result, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._test_case = test_case
+ self._test_result = test_result
+ self._examplenum = 0
+
+ def _subTest(self):
+ subtest = unittest.case._SubTest(self._test_case, str(self._examplenum), {})
+ self._examplenum += 1
+ return subtest
+
+ def report_skip(self, out, test, example):
+ unittest.case._addSkip(self._test_result, self._subTest(), '')
+
+ def report_success(self, out, test, example, got):
+ self._test_result.addSubTest(self._test_case, self._subTest(), None)
+
+ def report_unexpected_exception(self, out, test, example, exc_info):
+ tb = self._add_traceback(exc_info[2], test, example)
+ exc_info = (*exc_info[:2], tb)
+ self._test_result.addSubTest(self._test_case, self._subTest(), exc_info)
+
+ def report_failure(self, out, test, example, got):
+ msg = ('Failed example:\n' + _indent(example.source) +
+ self._checker.output_difference(example, got, self.optionflags).rstrip('\n'))
+ exc = self._test_case.failureException(msg)
+ tb = self._add_traceback(None, test, example)
+ exc_info = (type(exc), exc, tb)
+ self._test_result.addSubTest(self._test_case, self._subTest(), exc_info)
+
+ def _add_traceback(self, traceback, test, example):
+ if test.lineno is None or example.lineno is None:
+ lineno = None
+ else:
+ lineno = test.lineno + example.lineno + 1
+ return types.SimpleNamespace(
+ tb_frame = types.SimpleNamespace(
+ f_globals=test.globs,
+ f_code=types.SimpleNamespace(
+ co_filename=test.filename,
+ co_name=test.name,
+ ),
+ ),
+ tb_next = traceback,
+ tb_lasti = -1,
+ tb_lineno = lineno,
+ )
+
+
class DocTestCase(unittest.TestCase):
def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
checker=None):
- unittest.TestCase.__init__(self)
+ super().__init__()
self._dt_optionflags = optionflags
self._dt_checker = checker
self._dt_test = test
@@ -2301,30 +2360,28 @@ class DocTestCase(unittest.TestCase):
test.globs.clear()
test.globs.update(self._dt_globs)
+ def run(self, result=None):
+ self._test_result = result
+ return super().run(result)
+
def runTest(self):
test = self._dt_test
- old = sys.stdout
- new = StringIO()
optionflags = self._dt_optionflags
+ result = self._test_result
if not (optionflags & REPORTING_FLAGS):
# The option flags don't include any reporting flags,
# so add the default reporting flags
optionflags |= _unittest_reportflags
+ if getattr(result, 'failfast', False):
+ optionflags |= FAIL_FAST
- runner = DocTestRunner(optionflags=optionflags,
- checker=self._dt_checker, verbose=False)
-
- try:
- runner.DIVIDER = "-"*70
- results = runner.run(test, out=new.write, clear_globs=False)
- if results.skipped == results.attempted:
- raise unittest.SkipTest("all examples were skipped")
- finally:
- sys.stdout = old
-
- if results.failed:
- raise self.failureException(self.format_failure(new.getvalue()))
+ runner = _DocTestCaseRunner(optionflags=optionflags,
+ checker=self._dt_checker, verbose=False,
+ test_case=self, test_result=result)
+ results = runner.run(test, clear_globs=False)
+ if results.skipped == results.attempted:
+ raise unittest.SkipTest("all examples were skipped")
def format_failure(self, err):
test = self._dt_test
@@ -2439,7 +2496,7 @@ class DocTestCase(unittest.TestCase):
class SkipDocTestCase(DocTestCase):
def __init__(self, module):
self.module = module
- DocTestCase.__init__(self, None)
+ super().__init__(None)
def setUp(self):
self.skipTest("DocTestSuite will not work with -O2 and above")
@@ -2870,7 +2927,7 @@ __test__ = {"_TestClass": _TestClass,
def _test():
import argparse
- parser = argparse.ArgumentParser(description="doctest runner")
+ parser = argparse.ArgumentParser(description="doctest runner", color=True)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
help='print very verbose output for all tests')
parser.add_argument('-o', '--option', action='append',
diff --git a/Lib/email/_header_value_parser.py b/Lib/email/_header_value_parser.py
index 9a51b943733..91243378dc0 100644
--- a/Lib/email/_header_value_parser.py
+++ b/Lib/email/_header_value_parser.py
@@ -1020,6 +1020,8 @@ def _get_ptext_to_endchars(value, endchars):
a flag that is True iff there were any quoted printables decoded.
"""
+ if not value:
+ return '', '', False
fragment, *remainder = _wsp_splitter(value, 1)
vchars = []
escape = False
@@ -1573,7 +1575,7 @@ def get_dtext(value):
def _check_for_early_dl_end(value, domain_literal):
if value:
return False
- domain_literal.append(errors.InvalidHeaderDefect(
+ domain_literal.defects.append(errors.InvalidHeaderDefect(
"end of input inside domain-literal"))
domain_literal.append(ValueTerminal(']', 'domain-literal-end'))
return True
@@ -1592,9 +1594,9 @@ def get_domain_literal(value):
raise errors.HeaderParseError("expected '[' at start of domain-literal "
"but found '{}'".format(value))
value = value[1:]
+ domain_literal.append(ValueTerminal('[', 'domain-literal-start'))
if _check_for_early_dl_end(value, domain_literal):
return domain_literal, value
- domain_literal.append(ValueTerminal('[', 'domain-literal-start'))
if value[0] in WSP:
token, value = get_fws(value)
domain_literal.append(token)
diff --git a/Lib/email/feedparser.py b/Lib/email/feedparser.py
index b2bc4afc1cc..9d80a5822af 100644
--- a/Lib/email/feedparser.py
+++ b/Lib/email/feedparser.py
@@ -30,7 +30,7 @@ from io import StringIO
NLCRE = re.compile(r'\r\n|\r|\n')
NLCRE_bol = re.compile(r'(\r\n|\r|\n)')
-NLCRE_eol = re.compile(r'(\r\n|\r|\n)\Z')
+NLCRE_eol = re.compile(r'(\r\n|\r|\n)\z')
NLCRE_crack = re.compile(r'(\r\n|\r|\n)')
# RFC 2822 $3.6.8 Optional fields. ftext is %d33-57 / %d59-126, Any character
# except controls, SP, and ":".
diff --git a/Lib/email/header.py b/Lib/email/header.py
index 113a81f4131..220a84a7454 100644
--- a/Lib/email/header.py
+++ b/Lib/email/header.py
@@ -59,16 +59,22 @@ _max_append = email.quoprimime._max_append
def decode_header(header):
"""Decode a message header value without converting charset.
- Returns a list of (string, charset) pairs containing each of the decoded
- parts of the header. Charset is None for non-encoded parts of the header,
- otherwise a lower-case string containing the name of the character set
- specified in the encoded string.
+ For historical reasons, this function may return either:
+
+ 1. A list of length 1 containing a pair (str, None).
+ 2. A list of (bytes, charset) pairs containing each of the decoded
+ parts of the header. Charset is None for non-encoded parts of the header,
+ otherwise a lower-case string containing the name of the character set
+ specified in the encoded string.
header may be a string that may or may not contain RFC2047 encoded words,
or it may be a Header object.
An email.errors.HeaderParseError may be raised when certain decoding error
occurs (e.g. a base64 decoding exception).
+
+ This function exists for backwards compatibility only. For new code, we
+ recommend using email.headerregistry.HeaderRegistry instead.
"""
# If it is a Header object, we can just return the encoded chunks.
if hasattr(header, '_chunks'):
@@ -161,6 +167,9 @@ def make_header(decoded_seq, maxlinelen=None, header_name=None,
This function takes one of those sequence of pairs and returns a Header
instance. Optional maxlinelen, header_name, and continuation_ws are as in
the Header constructor.
+
+ This function exists for backwards compatibility only, and is not
+ recommended for use in new code.
"""
h = Header(maxlinelen=maxlinelen, header_name=header_name,
continuation_ws=continuation_ws)
diff --git a/Lib/email/message.py b/Lib/email/message.py
index 87fcab68868..41fcc2b9778 100644
--- a/Lib/email/message.py
+++ b/Lib/email/message.py
@@ -564,7 +564,7 @@ class Message:
msg.add_header('content-disposition', 'attachment', filename='bud.gif')
msg.add_header('content-disposition', 'attachment',
- filename=('utf-8', '', Fußballer.ppt'))
+ filename=('utf-8', '', 'Fußballer.ppt'))
msg.add_header('content-disposition', 'attachment',
filename='Fußballer.ppt'))
"""
diff --git a/Lib/email/utils.py b/Lib/email/utils.py
index 7eab74dc0db..3de1f0d24a1 100644
--- a/Lib/email/utils.py
+++ b/Lib/email/utils.py
@@ -417,8 +417,14 @@ def decode_params(params):
for name, continuations in rfc2231_params.items():
value = []
extended = False
- # Sort by number
- continuations.sort()
+ # Sort by number, treating None as 0 if there is no 0,
+ # and ignore it if there is already a 0.
+ has_zero = any(x[0] == 0 for x in continuations)
+ if has_zero:
+ continuations = [x for x in continuations if x[0] is not None]
+ else:
+ continuations = [(x[0] or 0, x[1], x[2]) for x in continuations]
+ continuations.sort(key=lambda x: x[0])
# And now append all values in numerical order, converting
# %-encodings for the encoded segments. If any of the
# continuation names ends in a *, then the entire string, after
diff --git a/Lib/encodings/aliases.py b/Lib/encodings/aliases.py
index a94bb270671..4ecb6b6e297 100644
--- a/Lib/encodings/aliases.py
+++ b/Lib/encodings/aliases.py
@@ -405,6 +405,8 @@ aliases = {
'iso_8859_8' : 'iso8859_8',
'iso_8859_8_1988' : 'iso8859_8',
'iso_ir_138' : 'iso8859_8',
+ 'iso_8859_8_i' : 'iso8859_8',
+ 'iso_8859_8_e' : 'iso8859_8',
# iso8859_9 codec
'csisolatin5' : 'iso8859_9',
diff --git a/Lib/encodings/palmos.py b/Lib/encodings/palmos.py
index c506d654523..df164ca5b95 100644
--- a/Lib/encodings/palmos.py
+++ b/Lib/encodings/palmos.py
@@ -201,7 +201,7 @@ decoding_table = (
'\u02dc' # 0x98 -> SMALL TILDE
'\u2122' # 0x99 -> TRADE MARK SIGN
'\u0161' # 0x9A -> LATIN SMALL LETTER S WITH CARON
- '\x9b' # 0x9B -> <control>
+ '\u203a' # 0x9B -> SINGLE RIGHT-POINTING ANGLE QUOTATION MARK
'\u0153' # 0x9C -> LATIN SMALL LIGATURE OE
'\x9d' # 0x9D -> <control>
'\x9e' # 0x9E -> <control>
diff --git a/Lib/ensurepip/__init__.py b/Lib/ensurepip/__init__.py
index 6fc9f39b24c..aa641e94a8b 100644
--- a/Lib/ensurepip/__init__.py
+++ b/Lib/ensurepip/__init__.py
@@ -205,7 +205,7 @@ def _uninstall_helper(*, verbosity=0):
def _main(argv=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(
"--version",
action="version",
diff --git a/Lib/fnmatch.py b/Lib/fnmatch.py
index 1dee8330f5d..10e1c936688 100644
--- a/Lib/fnmatch.py
+++ b/Lib/fnmatch.py
@@ -185,7 +185,7 @@ def _translate(pat, star, question_mark):
def _join_translated_parts(parts, star_indices):
if not star_indices:
- return fr'(?s:{"".join(parts)})\Z'
+ return fr'(?s:{"".join(parts)})\z'
iter_star_indices = iter(star_indices)
j = next(iter_star_indices)
buffer = parts[:j] # fixed pieces at the start
@@ -206,4 +206,4 @@ def _join_translated_parts(parts, star_indices):
append('.*')
extend(parts[i:])
res = ''.join(buffer)
- return fr'(?s:{res})\Z'
+ return fr'(?s:{res})\z'
diff --git a/Lib/fractions.py b/Lib/fractions.py
index fa722589fb4..cb05ae7c200 100644
--- a/Lib/fractions.py
+++ b/Lib/fractions.py
@@ -64,7 +64,7 @@ _RATIONAL_FORMAT = re.compile(r"""
(?:\.(?P<decimal>\d*|\d+(_\d+)*))? # an optional fractional part
(?:E(?P<exp>[-+]?\d+(_\d+)*))? # and optional exponent
)
- \s*\Z # and optional whitespace to finish
+ \s*\z # and optional whitespace to finish
""", re.VERBOSE | re.IGNORECASE)
@@ -168,9 +168,9 @@ _FLOAT_FORMAT_SPECIFICATION_MATCHER = re.compile(r"""
# A '0' that's *not* followed by another digit is parsed as a minimum width
# rather than a zeropad flag.
(?P<zeropad>0(?=[0-9]))?
- (?P<minimumwidth>0|[1-9][0-9]*)?
+ (?P<minimumwidth>[0-9]+)?
(?P<thousands_sep>[,_])?
- (?:\.(?P<precision>0|[1-9][0-9]*))?
+ (?:\.(?P<precision>[0-9]+))?
(?P<presentation_type>[eEfFgG%])
""", re.DOTALL | re.VERBOSE).fullmatch
@@ -238,11 +238,6 @@ class Fraction(numbers.Rational):
self._denominator = 1
return self
- elif isinstance(numerator, numbers.Rational):
- self._numerator = numerator.numerator
- self._denominator = numerator.denominator
- return self
-
elif (isinstance(numerator, float) or
(not isinstance(numerator, type) and
hasattr(numerator, 'as_integer_ratio'))):
@@ -278,6 +273,11 @@ class Fraction(numbers.Rational):
if m.group('sign') == '-':
numerator = -numerator
+ elif isinstance(numerator, numbers.Rational):
+ self._numerator = numerator.numerator
+ self._denominator = numerator.denominator
+ return self
+
else:
raise TypeError("argument should be a string or a Rational "
"instance or have the as_integer_ratio() method")
diff --git a/Lib/functools.py b/Lib/functools.py
index 714070c6ac9..7f0eac3f650 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -323,6 +323,9 @@ def _partial_new(cls, func, /, *args, **keywords):
"or a descriptor")
if args and args[-1] is Placeholder:
raise TypeError("trailing Placeholders are not allowed")
+ for value in keywords.values():
+ if value is Placeholder:
+ raise TypeError("Placeholder cannot be passed as a keyword argument")
if isinstance(func, base_cls):
pto_phcount = func._phcount
tot_args = func.args
diff --git a/Lib/genericpath.py b/Lib/genericpath.py
index ba7b0a13c7f..9363f564aab 100644
--- a/Lib/genericpath.py
+++ b/Lib/genericpath.py
@@ -8,7 +8,7 @@ import stat
__all__ = ['commonprefix', 'exists', 'getatime', 'getctime', 'getmtime',
'getsize', 'isdevdrive', 'isdir', 'isfile', 'isjunction', 'islink',
- 'lexists', 'samefile', 'sameopenfile', 'samestat']
+ 'lexists', 'samefile', 'sameopenfile', 'samestat', 'ALLOW_MISSING']
# Does a path exist?
@@ -189,3 +189,12 @@ def _check_arg_types(funcname, *args):
f'os.PathLike object, not {s.__class__.__name__!r}') from None
if hasstr and hasbytes:
raise TypeError("Can't mix strings and bytes in path components") from None
+
+# A singleton with a true boolean value.
+@object.__new__
+class ALLOW_MISSING:
+ """Special value for use in realpath()."""
+ def __repr__(self):
+ return 'os.path.ALLOW_MISSING'
+ def __reduce__(self):
+ return self.__class__.__name__
diff --git a/Lib/getpass.py b/Lib/getpass.py
index bd0097ced94..1dd40e25e09 100644
--- a/Lib/getpass.py
+++ b/Lib/getpass.py
@@ -1,6 +1,7 @@
"""Utilities to get a password and/or the current user name.
-getpass(prompt[, stream]) - Prompt for a password, with echo turned off.
+getpass(prompt[, stream[, echo_char]]) - Prompt for a password, with echo
+turned off and optional keyboard feedback.
getuser() - Get the user name from the environment or password database.
GetPassWarning - This UserWarning is issued when getpass() cannot prevent
@@ -25,13 +26,15 @@ __all__ = ["getpass","getuser","GetPassWarning"]
class GetPassWarning(UserWarning): pass
-def unix_getpass(prompt='Password: ', stream=None):
+def unix_getpass(prompt='Password: ', stream=None, *, echo_char=None):
"""Prompt for a password, with echo turned off.
Args:
prompt: Written on stream to ask for the input. Default: 'Password: '
stream: A writable file object to display the prompt. Defaults to
the tty. If no tty is available defaults to sys.stderr.
+ echo_char: A string used to mask input (e.g., '*'). If None, input is
+ hidden.
Returns:
The seKr3t input.
Raises:
@@ -40,6 +43,8 @@ def unix_getpass(prompt='Password: ', stream=None):
Always restores terminal settings before returning.
"""
+ _check_echo_char(echo_char)
+
passwd = None
with contextlib.ExitStack() as stack:
try:
@@ -68,12 +73,16 @@ def unix_getpass(prompt='Password: ', stream=None):
old = termios.tcgetattr(fd) # a copy to save
new = old[:]
new[3] &= ~termios.ECHO # 3 == 'lflags'
+ if echo_char:
+ new[3] &= ~termios.ICANON
tcsetattr_flags = termios.TCSAFLUSH
if hasattr(termios, 'TCSASOFT'):
tcsetattr_flags |= termios.TCSASOFT
try:
termios.tcsetattr(fd, tcsetattr_flags, new)
- passwd = _raw_input(prompt, stream, input=input)
+ passwd = _raw_input(prompt, stream, input=input,
+ echo_char=echo_char)
+
finally:
termios.tcsetattr(fd, tcsetattr_flags, old)
stream.flush() # issue7208
@@ -93,10 +102,11 @@ def unix_getpass(prompt='Password: ', stream=None):
return passwd
-def win_getpass(prompt='Password: ', stream=None):
+def win_getpass(prompt='Password: ', stream=None, *, echo_char=None):
"""Prompt for password with echo off, using Windows getwch()."""
if sys.stdin is not sys.__stdin__:
return fallback_getpass(prompt, stream)
+ _check_echo_char(echo_char)
for c in prompt:
msvcrt.putwch(c)
@@ -108,25 +118,39 @@ def win_getpass(prompt='Password: ', stream=None):
if c == '\003':
raise KeyboardInterrupt
if c == '\b':
+ if echo_char and pw:
+ msvcrt.putwch('\b')
+ msvcrt.putwch(' ')
+ msvcrt.putwch('\b')
pw = pw[:-1]
else:
pw = pw + c
+ if echo_char:
+ msvcrt.putwch(echo_char)
msvcrt.putwch('\r')
msvcrt.putwch('\n')
return pw
-def fallback_getpass(prompt='Password: ', stream=None):
+def fallback_getpass(prompt='Password: ', stream=None, *, echo_char=None):
+ _check_echo_char(echo_char)
import warnings
warnings.warn("Can not control echo on the terminal.", GetPassWarning,
stacklevel=2)
if not stream:
stream = sys.stderr
print("Warning: Password input may be echoed.", file=stream)
- return _raw_input(prompt, stream)
+ return _raw_input(prompt, stream, echo_char=echo_char)
+
+def _check_echo_char(echo_char):
+ # ASCII excluding control characters
+ if echo_char and not (echo_char.isprintable() and echo_char.isascii()):
+ raise ValueError("'echo_char' must be a printable ASCII string, "
+ f"got: {echo_char!r}")
-def _raw_input(prompt="", stream=None, input=None):
+
+def _raw_input(prompt="", stream=None, input=None, echo_char=None):
# This doesn't save the string in the GNU readline history.
if not stream:
stream = sys.stderr
@@ -143,6 +167,8 @@ def _raw_input(prompt="", stream=None, input=None):
stream.write(prompt)
stream.flush()
# NOTE: The Python C API calls flockfile() (and unlock) during readline.
+ if echo_char:
+ return _readline_with_echo_char(stream, input, echo_char)
line = input.readline()
if not line:
raise EOFError
@@ -151,6 +177,35 @@ def _raw_input(prompt="", stream=None, input=None):
return line
+def _readline_with_echo_char(stream, input, echo_char):
+ passwd = ""
+ eof_pressed = False
+ while True:
+ char = input.read(1)
+ if char == '\n' or char == '\r':
+ break
+ elif char == '\x03':
+ raise KeyboardInterrupt
+ elif char == '\x7f' or char == '\b':
+ if passwd:
+ stream.write("\b \b")
+ stream.flush()
+ passwd = passwd[:-1]
+ elif char == '\x04':
+ if eof_pressed:
+ break
+ else:
+ eof_pressed = True
+ elif char == '\x00':
+ continue
+ else:
+ passwd += char
+ stream.write(echo_char)
+ stream.flush()
+ eof_pressed = False
+ return passwd
+
+
def getuser():
"""Get the username from the environment or password database.
diff --git a/Lib/glob.py b/Lib/glob.py
index 8879eff8041..1e48fe43167 100644
--- a/Lib/glob.py
+++ b/Lib/glob.py
@@ -316,7 +316,7 @@ def translate(pat, *, recursive=False, include_hidden=False, seps=None):
if idx < last_part_idx:
results.append(any_sep)
res = ''.join(results)
- return fr'(?s:{res})\Z'
+ return fr'(?s:{res})\z'
@functools.lru_cache(maxsize=512)
@@ -358,6 +358,12 @@ class _GlobberBase:
"""
raise NotImplementedError
+ @staticmethod
+ def stringify_path(path):
+ """Converts the path to a string object
+ """
+ raise NotImplementedError
+
# High-level methods
def compile(self, pat, altsep=None):
@@ -466,8 +472,9 @@ class _GlobberBase:
select_next = self.selector(parts)
def select_recursive(path, exists=False):
- match_pos = len(str(path))
- if match is None or match(str(path), match_pos):
+ path_str = self.stringify_path(path)
+ match_pos = len(path_str)
+ if match is None or match(path_str, match_pos):
yield from select_next(path, exists)
stack = [path]
while stack:
@@ -489,7 +496,7 @@ class _GlobberBase:
pass
if is_dir or not dir_only:
- entry_path_str = str(entry_path)
+ entry_path_str = self.stringify_path(entry_path)
if dir_only:
entry_path = self.concat_path(entry_path, self.sep)
if match is None or match(entry_path_str, match_pos):
@@ -529,19 +536,6 @@ class _StringGlobber(_GlobberBase):
entries = list(scandir_it)
return ((entry, entry.name, entry.path) for entry in entries)
-
-class _PathGlobber(_GlobberBase):
- """Provides shell-style pattern matching and globbing for pathlib paths.
- """
-
@staticmethod
- def lexists(path):
- return path.info.exists(follow_symlinks=False)
-
- @staticmethod
- def scandir(path):
- return ((child.info, child.name, child) for child in path.iterdir())
-
- @staticmethod
- def concat_path(path, text):
- return path.with_segments(str(path) + text)
+ def stringify_path(path):
+ return path # Already a string.
diff --git a/Lib/gzip.py b/Lib/gzip.py
index b7375b25473..c00f51858de 100644
--- a/Lib/gzip.py
+++ b/Lib/gzip.py
@@ -667,7 +667,9 @@ def main():
from argparse import ArgumentParser
parser = ArgumentParser(description=
"A simple command line interface for the gzip module: act like gzip, "
- "but do not delete the input file.")
+ "but do not delete the input file.",
+ color=True,
+ )
group = parser.add_mutually_exclusive_group()
group.add_argument('--fast', action='store_true', help='compress faster')
group.add_argument('--best', action='store_true', help='compress better')
diff --git a/Lib/hashlib.py b/Lib/hashlib.py
index abacac22ea0..0e9bd98aa1f 100644
--- a/Lib/hashlib.py
+++ b/Lib/hashlib.py
@@ -141,29 +141,29 @@ def __get_openssl_constructor(name):
return __get_builtin_constructor(name)
-def __py_new(name, data=b'', **kwargs):
+def __py_new(name, *args, **kwargs):
"""new(name, data=b'', **kwargs) - Return a new hashing object using the
named algorithm; optionally initialized with data (which must be
a bytes-like object).
"""
- return __get_builtin_constructor(name)(data, **kwargs)
+ return __get_builtin_constructor(name)(*args, **kwargs)
-def __hash_new(name, data=b'', **kwargs):
+def __hash_new(name, *args, **kwargs):
"""new(name, data=b'') - Return a new hashing object using the named algorithm;
optionally initialized with data (which must be a bytes-like object).
"""
if name in __block_openssl_constructor:
# Prefer our builtin blake2 implementation.
- return __get_builtin_constructor(name)(data, **kwargs)
+ return __get_builtin_constructor(name)(*args, **kwargs)
try:
- return _hashlib.new(name, data, **kwargs)
+ return _hashlib.new(name, *args, **kwargs)
except ValueError:
# If the _hashlib module (OpenSSL) doesn't support the named
# hash, try using our builtin implementations.
# This allows for SHA224/256 and SHA384/512 support even though
# the OpenSSL library prior to 0.9.8 doesn't provide them.
- return __get_builtin_constructor(name)(data)
+ return __get_builtin_constructor(name)(*args, **kwargs)
try:
diff --git a/Lib/heapq.py b/Lib/heapq.py
index 9649da251f2..6ceb211f1ca 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -178,7 +178,7 @@ def heapify(x):
for i in reversed(range(n//2)):
_siftup(x, i)
-def _heappop_max(heap):
+def heappop_max(heap):
"""Maxheap version of a heappop."""
lastelt = heap.pop() # raises appropriate IndexError if heap is empty
if heap:
@@ -188,19 +188,32 @@ def _heappop_max(heap):
return returnitem
return lastelt
-def _heapreplace_max(heap, item):
+def heapreplace_max(heap, item):
"""Maxheap version of a heappop followed by a heappush."""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
heap[0] = item
_siftup_max(heap, 0)
return returnitem
-def _heapify_max(x):
+def heappush_max(heap, item):
+ """Maxheap version of a heappush."""
+ heap.append(item)
+ _siftdown_max(heap, 0, len(heap)-1)
+
+def heappushpop_max(heap, item):
+ """Maxheap fast version of a heappush followed by a heappop."""
+ if heap and item < heap[0]:
+ item, heap[0] = heap[0], item
+ _siftup_max(heap, 0)
+ return item
+
+def heapify_max(x):
"""Transform list into a maxheap, in-place, in O(len(x)) time."""
n = len(x)
for i in reversed(range(n//2)):
_siftup_max(x, i)
+
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
# is the index of a leaf with a possibly out-of-order value. Restore the
# heap invariant.
@@ -335,9 +348,9 @@ def merge(*iterables, key=None, reverse=False):
h_append = h.append
if reverse:
- _heapify = _heapify_max
- _heappop = _heappop_max
- _heapreplace = _heapreplace_max
+ _heapify = heapify_max
+ _heappop = heappop_max
+ _heapreplace = heapreplace_max
direction = -1
else:
_heapify = heapify
@@ -490,10 +503,10 @@ def nsmallest(n, iterable, key=None):
result = [(elem, i) for i, elem in zip(range(n), it)]
if not result:
return result
- _heapify_max(result)
+ heapify_max(result)
top = result[0][0]
order = n
- _heapreplace = _heapreplace_max
+ _heapreplace = heapreplace_max
for elem in it:
if elem < top:
_heapreplace(result, (elem, order))
@@ -507,10 +520,10 @@ def nsmallest(n, iterable, key=None):
result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
if not result:
return result
- _heapify_max(result)
+ heapify_max(result)
top = result[0][0]
order = n
- _heapreplace = _heapreplace_max
+ _heapreplace = heapreplace_max
for elem in it:
k = key(elem)
if k < top:
@@ -583,19 +596,13 @@ try:
from _heapq import *
except ImportError:
pass
-try:
- from _heapq import _heapreplace_max
-except ImportError:
- pass
-try:
- from _heapq import _heapify_max
-except ImportError:
- pass
-try:
- from _heapq import _heappop_max
-except ImportError:
- pass
+# For backwards compatibility
+_heappop_max = heappop_max
+_heapreplace_max = heapreplace_max
+_heappush_max = heappush_max
+_heappushpop_max = heappushpop_max
+_heapify_max = heapify_max
if __name__ == "__main__":
diff --git a/Lib/html/parser.py b/Lib/html/parser.py
index 13c95c34e50..ba416e7fa6e 100644
--- a/Lib/html/parser.py
+++ b/Lib/html/parser.py
@@ -12,6 +12,7 @@ import re
import _markupbase
from html import unescape
+from html.entities import html5 as html5_entities
__all__ = ['HTMLParser']
@@ -23,8 +24,10 @@ incomplete = re.compile('&[a-zA-Z#]')
entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]')
charref = re.compile('&#(?:[0-9]+|[xX][0-9a-fA-F]+)[^0-9a-fA-F]')
+attr_charref = re.compile(r'&(#[0-9]+|#[xX][0-9a-fA-F]+|[a-zA-Z][a-zA-Z0-9]*)[;=]?')
starttagopen = re.compile('<[a-zA-Z]')
+endtagopen = re.compile('</[a-zA-Z]')
piclose = re.compile('>')
commentclose = re.compile(r'--\s*>')
# Note:
@@ -57,6 +60,22 @@ endendtag = re.compile('>')
# </ and the tag name, so maybe this should be fixed
endtagfind = re.compile(r'</\s*([a-zA-Z][-.a-zA-Z0-9:_]*)\s*>')
+# Character reference processing logic specific to attribute values
+# See: https://html.spec.whatwg.org/multipage/parsing.html#named-character-reference-state
+def _replace_attr_charref(match):
+ ref = match.group(0)
+ # Numeric / hex char refs must always be unescaped
+ if ref.startswith('&#'):
+ return unescape(ref)
+ # Named character / entity references must only be unescaped
+ # if they are an exact match, and they are not followed by an equals sign
+ if not ref.endswith('=') and ref[1:] in html5_entities:
+ return unescape(ref)
+ # Otherwise do not unescape
+ return ref
+
+def _unescape_attrvalue(s):
+ return attr_charref.sub(_replace_attr_charref, s)
class HTMLParser(_markupbase.ParserBase):
@@ -177,7 +196,7 @@ class HTMLParser(_markupbase.ParserBase):
k = self.parse_pi(i)
elif startswith("<!", i):
k = self.parse_html_declaration(i)
- elif (i + 1) < n:
+ elif (i + 1) < n or end:
self.handle_data("<")
k = i + 1
else:
@@ -185,17 +204,35 @@ class HTMLParser(_markupbase.ParserBase):
if k < 0:
if not end:
break
- k = rawdata.find('>', i + 1)
- if k < 0:
- k = rawdata.find('<', i + 1)
- if k < 0:
- k = i + 1
- else:
- k += 1
- if self.convert_charrefs and not self.cdata_elem:
- self.handle_data(unescape(rawdata[i:k]))
+ if starttagopen.match(rawdata, i): # < + letter
+ pass
+ elif startswith("</", i):
+ if i + 2 == n:
+ self.handle_data("</")
+ elif endtagopen.match(rawdata, i): # </ + letter
+ pass
+ else:
+ # bogus comment
+ self.handle_comment(rawdata[i+2:])
+ elif startswith("<!--", i):
+ j = n
+ for suffix in ("--!", "--", "-"):
+ if rawdata.endswith(suffix, i+4):
+ j -= len(suffix)
+ break
+ self.handle_comment(rawdata[i+4:j])
+ elif startswith("<![CDATA[", i):
+ self.unknown_decl(rawdata[i+3:])
+ elif rawdata[i:i+9].lower() == '<!doctype':
+ self.handle_decl(rawdata[i+2:])
+ elif startswith("<!", i):
+ # bogus comment
+ self.handle_comment(rawdata[i+2:])
+ elif startswith("<?", i):
+ self.handle_pi(rawdata[i+2:])
else:
- self.handle_data(rawdata[i:k])
+ raise AssertionError("we should not get here!")
+ k = n
i = self.updatepos(i, k)
elif startswith("&#", i):
match = charref.match(rawdata, i)
@@ -242,7 +279,7 @@ class HTMLParser(_markupbase.ParserBase):
else:
assert 0, "interesting.search() lied"
# end while
- if end and i < n and not self.cdata_elem:
+ if end and i < n:
if self.convert_charrefs and not self.cdata_elem:
self.handle_data(unescape(rawdata[i:n]))
else:
@@ -260,7 +297,7 @@ class HTMLParser(_markupbase.ParserBase):
if rawdata[i:i+4] == '<!--':
# this case is actually already handled in goahead()
return self.parse_comment(i)
- elif rawdata[i:i+3] == '<![':
+ elif rawdata[i:i+9] == '<![CDATA[':
return self.parse_marked_section(i)
elif rawdata[i:i+9].lower() == '<!doctype':
# find the closing >
@@ -277,7 +314,7 @@ class HTMLParser(_markupbase.ParserBase):
def parse_bogus_comment(self, i, report=1):
rawdata = self.rawdata
assert rawdata[i:i+2] in ('<!', '</'), ('unexpected call to '
- 'parse_comment()')
+ 'parse_bogus_comment()')
pos = rawdata.find('>', i+2)
if pos == -1:
return -1
@@ -323,7 +360,7 @@ class HTMLParser(_markupbase.ParserBase):
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
- attrvalue = unescape(attrvalue)
+ attrvalue = _unescape_attrvalue(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
diff --git a/Lib/http/client.py b/Lib/http/client.py
index 33a858d34ae..e7a1c7bc3b2 100644
--- a/Lib/http/client.py
+++ b/Lib/http/client.py
@@ -181,11 +181,10 @@ def _strip_ipv6_iface(enc_name: bytes) -> bytes:
return enc_name
class HTTPMessage(email.message.Message):
- # XXX The only usage of this method is in
- # http.server.CGIHTTPRequestHandler. Maybe move the code there so
- # that it doesn't need to be part of the public API. The API has
- # never been defined so this could cause backwards compatibility
- # issues.
+
+ # The getallmatchingheaders() method was only used by the CGI handler
+ # that was removed in Python 3.15. However, since the public API was not
+ # properly defined, it will be kept for backwards compatibility reasons.
def getallmatchingheaders(self, name):
"""Find all header lines matching a given header name.
diff --git a/Lib/http/server.py b/Lib/http/server.py
index a2aad4c9be3..a2ffbe2e44d 100644
--- a/Lib/http/server.py
+++ b/Lib/http/server.py
@@ -1,29 +1,10 @@
"""HTTP server classes.
Note: BaseHTTPRequestHandler doesn't implement any HTTP request; see
-SimpleHTTPRequestHandler for simple implementations of GET, HEAD and POST,
-and (deprecated) CGIHTTPRequestHandler for CGI scripts.
+SimpleHTTPRequestHandler for simple implementations of GET, HEAD and POST.
It does, however, optionally implement HTTP/1.1 persistent connections.
-Notes on CGIHTTPRequestHandler
-------------------------------
-
-This class is deprecated. It implements GET and POST requests to cgi-bin scripts.
-
-If the os.fork() function is not present (Windows), subprocess.Popen() is used,
-with slightly altered but never documented semantics. Use from a threaded
-process is likely to trigger a warning at os.fork() time.
-
-In all cases, the implementation is intentionally naive -- all
-requests are executed synchronously.
-
-SECURITY WARNING: DON'T USE THIS CODE UNLESS YOU ARE INSIDE A FIREWALL
--- it may execute arbitrary Python code or external programs.
-
-Note that status code 200 is sent prior to execution of a CGI script, so
-scripts cannot send other status codes such as 302 (redirect).
-
XXX To do:
- log requests even later (to capture byte count)
@@ -86,10 +67,8 @@ __all__ = [
"HTTPServer", "ThreadingHTTPServer",
"HTTPSServer", "ThreadingHTTPSServer",
"BaseHTTPRequestHandler", "SimpleHTTPRequestHandler",
- "CGIHTTPRequestHandler",
]
-import copy
import datetime
import email.utils
import html
@@ -99,7 +78,6 @@ import itertools
import mimetypes
import os
import posixpath
-import select
import shutil
import socket
import socketserver
@@ -137,7 +115,7 @@ DEFAULT_ERROR_CONTENT_TYPE = "text/html;charset=utf-8"
class HTTPServer(socketserver.TCPServer):
allow_reuse_address = True # Seems to make sense in testing environment
- allow_reuse_port = True
+ allow_reuse_port = False
def server_bind(self):
"""Override server_bind to store the server name."""
@@ -750,7 +728,7 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
f = None
if os.path.isdir(path):
parts = urllib.parse.urlsplit(self.path)
- if not parts.path.endswith('/'):
+ if not parts.path.endswith(('/', '%2f', '%2F')):
# redirect browser - doing basically what apache does
self.send_response(HTTPStatus.MOVED_PERMANENTLY)
new_parts = (parts[0], parts[1], parts[2] + '/',
@@ -840,11 +818,14 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
return None
list.sort(key=lambda a: a.lower())
r = []
+ displaypath = self.path
+ displaypath = displaypath.split('#', 1)[0]
+ displaypath = displaypath.split('?', 1)[0]
try:
- displaypath = urllib.parse.unquote(self.path,
+ displaypath = urllib.parse.unquote(displaypath,
errors='surrogatepass')
except UnicodeDecodeError:
- displaypath = urllib.parse.unquote(self.path)
+ displaypath = urllib.parse.unquote(displaypath)
displaypath = html.escape(displaypath, quote=False)
enc = sys.getfilesystemencoding()
title = f'Directory listing for {displaypath}'
@@ -890,14 +871,14 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
"""
# abandon query parameters
- path = path.split('?',1)[0]
- path = path.split('#',1)[0]
+ path = path.split('#', 1)[0]
+ path = path.split('?', 1)[0]
# Don't forget explicit trailing slash when normalizing. Issue17324
- trailing_slash = path.rstrip().endswith('/')
try:
path = urllib.parse.unquote(path, errors='surrogatepass')
except UnicodeDecodeError:
path = urllib.parse.unquote(path)
+ trailing_slash = path.endswith('/')
path = posixpath.normpath(path)
words = path.split('/')
words = filter(None, words)
@@ -953,56 +934,6 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
return 'application/octet-stream'
-# Utilities for CGIHTTPRequestHandler
-
-def _url_collapse_path(path):
- """
- Given a URL path, remove extra '/'s and '.' path elements and collapse
- any '..' references and returns a collapsed path.
-
- Implements something akin to RFC-2396 5.2 step 6 to parse relative paths.
- The utility of this function is limited to is_cgi method and helps
- preventing some security attacks.
-
- Returns: The reconstituted URL, which will always start with a '/'.
-
- Raises: IndexError if too many '..' occur within the path.
-
- """
- # Query component should not be involved.
- path, _, query = path.partition('?')
- path = urllib.parse.unquote(path)
-
- # Similar to os.path.split(os.path.normpath(path)) but specific to URL
- # path semantics rather than local operating system semantics.
- path_parts = path.split('/')
- head_parts = []
- for part in path_parts[:-1]:
- if part == '..':
- head_parts.pop() # IndexError if more '..' than prior parts
- elif part and part != '.':
- head_parts.append( part )
- if path_parts:
- tail_part = path_parts.pop()
- if tail_part:
- if tail_part == '..':
- head_parts.pop()
- tail_part = ''
- elif tail_part == '.':
- tail_part = ''
- else:
- tail_part = ''
-
- if query:
- tail_part = '?'.join((tail_part, query))
-
- splitpath = ('/' + '/'.join(head_parts), tail_part)
- collapsed_path = "/".join(splitpath)
-
- return collapsed_path
-
-
-
nobody = None
def nobody_uid():
@@ -1026,274 +957,6 @@ def executable(path):
return os.access(path, os.X_OK)
-class CGIHTTPRequestHandler(SimpleHTTPRequestHandler):
-
- """Complete HTTP server with GET, HEAD and POST commands.
-
- GET and HEAD also support running CGI scripts.
-
- The POST command is *only* implemented for CGI scripts.
-
- """
-
- def __init__(self, *args, **kwargs):
- import warnings
- warnings._deprecated("http.server.CGIHTTPRequestHandler",
- remove=(3, 15))
- super().__init__(*args, **kwargs)
-
- # Determine platform specifics
- have_fork = hasattr(os, 'fork')
-
- # Make rfile unbuffered -- we need to read one line and then pass
- # the rest to a subprocess, so we can't use buffered input.
- rbufsize = 0
-
- def do_POST(self):
- """Serve a POST request.
-
- This is only implemented for CGI scripts.
-
- """
-
- if self.is_cgi():
- self.run_cgi()
- else:
- self.send_error(
- HTTPStatus.NOT_IMPLEMENTED,
- "Can only POST to CGI scripts")
-
- def send_head(self):
- """Version of send_head that support CGI scripts"""
- if self.is_cgi():
- return self.run_cgi()
- else:
- return SimpleHTTPRequestHandler.send_head(self)
-
- def is_cgi(self):
- """Test whether self.path corresponds to a CGI script.
-
- Returns True and updates the cgi_info attribute to the tuple
- (dir, rest) if self.path requires running a CGI script.
- Returns False otherwise.
-
- If any exception is raised, the caller should assume that
- self.path was rejected as invalid and act accordingly.
-
- The default implementation tests whether the normalized url
- path begins with one of the strings in self.cgi_directories
- (and the next character is a '/' or the end of the string).
-
- """
- collapsed_path = _url_collapse_path(self.path)
- dir_sep = collapsed_path.find('/', 1)
- while dir_sep > 0 and not collapsed_path[:dir_sep] in self.cgi_directories:
- dir_sep = collapsed_path.find('/', dir_sep+1)
- if dir_sep > 0:
- head, tail = collapsed_path[:dir_sep], collapsed_path[dir_sep+1:]
- self.cgi_info = head, tail
- return True
- return False
-
-
- cgi_directories = ['/cgi-bin', '/htbin']
-
- def is_executable(self, path):
- """Test whether argument path is an executable file."""
- return executable(path)
-
- def is_python(self, path):
- """Test whether argument path is a Python script."""
- head, tail = os.path.splitext(path)
- return tail.lower() in (".py", ".pyw")
-
- def run_cgi(self):
- """Execute a CGI script."""
- dir, rest = self.cgi_info
- path = dir + '/' + rest
- i = path.find('/', len(dir)+1)
- while i >= 0:
- nextdir = path[:i]
- nextrest = path[i+1:]
-
- scriptdir = self.translate_path(nextdir)
- if os.path.isdir(scriptdir):
- dir, rest = nextdir, nextrest
- i = path.find('/', len(dir)+1)
- else:
- break
-
- # find an explicit query string, if present.
- rest, _, query = rest.partition('?')
-
- # dissect the part after the directory name into a script name &
- # a possible additional path, to be stored in PATH_INFO.
- i = rest.find('/')
- if i >= 0:
- script, rest = rest[:i], rest[i:]
- else:
- script, rest = rest, ''
-
- scriptname = dir + '/' + script
- scriptfile = self.translate_path(scriptname)
- if not os.path.exists(scriptfile):
- self.send_error(
- HTTPStatus.NOT_FOUND,
- "No such CGI script (%r)" % scriptname)
- return
- if not os.path.isfile(scriptfile):
- self.send_error(
- HTTPStatus.FORBIDDEN,
- "CGI script is not a plain file (%r)" % scriptname)
- return
- ispy = self.is_python(scriptname)
- if self.have_fork or not ispy:
- if not self.is_executable(scriptfile):
- self.send_error(
- HTTPStatus.FORBIDDEN,
- "CGI script is not executable (%r)" % scriptname)
- return
-
- # Reference: https://www6.uniovi.es/~antonio/ncsa_httpd/cgi/env.html
- # XXX Much of the following could be prepared ahead of time!
- env = copy.deepcopy(os.environ)
- env['SERVER_SOFTWARE'] = self.version_string()
- env['SERVER_NAME'] = self.server.server_name
- env['GATEWAY_INTERFACE'] = 'CGI/1.1'
- env['SERVER_PROTOCOL'] = self.protocol_version
- env['SERVER_PORT'] = str(self.server.server_port)
- env['REQUEST_METHOD'] = self.command
- uqrest = urllib.parse.unquote(rest)
- env['PATH_INFO'] = uqrest
- env['PATH_TRANSLATED'] = self.translate_path(uqrest)
- env['SCRIPT_NAME'] = scriptname
- env['QUERY_STRING'] = query
- env['REMOTE_ADDR'] = self.client_address[0]
- authorization = self.headers.get("authorization")
- if authorization:
- authorization = authorization.split()
- if len(authorization) == 2:
- import base64, binascii
- env['AUTH_TYPE'] = authorization[0]
- if authorization[0].lower() == "basic":
- try:
- authorization = authorization[1].encode('ascii')
- authorization = base64.decodebytes(authorization).\
- decode('ascii')
- except (binascii.Error, UnicodeError):
- pass
- else:
- authorization = authorization.split(':')
- if len(authorization) == 2:
- env['REMOTE_USER'] = authorization[0]
- # XXX REMOTE_IDENT
- if self.headers.get('content-type') is None:
- env['CONTENT_TYPE'] = self.headers.get_content_type()
- else:
- env['CONTENT_TYPE'] = self.headers['content-type']
- length = self.headers.get('content-length')
- if length:
- env['CONTENT_LENGTH'] = length
- referer = self.headers.get('referer')
- if referer:
- env['HTTP_REFERER'] = referer
- accept = self.headers.get_all('accept', ())
- env['HTTP_ACCEPT'] = ','.join(accept)
- ua = self.headers.get('user-agent')
- if ua:
- env['HTTP_USER_AGENT'] = ua
- co = filter(None, self.headers.get_all('cookie', []))
- cookie_str = ', '.join(co)
- if cookie_str:
- env['HTTP_COOKIE'] = cookie_str
- # XXX Other HTTP_* headers
- # Since we're setting the env in the parent, provide empty
- # values to override previously set values
- for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
- 'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
- env.setdefault(k, "")
-
- self.send_response(HTTPStatus.OK, "Script output follows")
- self.flush_headers()
-
- decoded_query = query.replace('+', ' ')
-
- if self.have_fork:
- # Unix -- fork as we should
- args = [script]
- if '=' not in decoded_query:
- args.append(decoded_query)
- nobody = nobody_uid()
- self.wfile.flush() # Always flush before forking
- pid = os.fork()
- if pid != 0:
- # Parent
- pid, sts = os.waitpid(pid, 0)
- # throw away additional data [see bug #427345]
- while select.select([self.rfile], [], [], 0)[0]:
- if not self.rfile.read(1):
- break
- exitcode = os.waitstatus_to_exitcode(sts)
- if exitcode:
- self.log_error(f"CGI script exit code {exitcode}")
- return
- # Child
- try:
- try:
- os.setuid(nobody)
- except OSError:
- pass
- os.dup2(self.rfile.fileno(), 0)
- os.dup2(self.wfile.fileno(), 1)
- os.execve(scriptfile, args, env)
- except:
- self.server.handle_error(self.request, self.client_address)
- os._exit(127)
-
- else:
- # Non-Unix -- use subprocess
- import subprocess
- cmdline = [scriptfile]
- if self.is_python(scriptfile):
- interp = sys.executable
- if interp.lower().endswith("w.exe"):
- # On Windows, use python.exe, not pythonw.exe
- interp = interp[:-5] + interp[-4:]
- cmdline = [interp, '-u'] + cmdline
- if '=' not in query:
- cmdline.append(query)
- self.log_message("command: %s", subprocess.list2cmdline(cmdline))
- try:
- nbytes = int(length)
- except (TypeError, ValueError):
- nbytes = 0
- p = subprocess.Popen(cmdline,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env = env
- )
- if self.command.lower() == "post" and nbytes > 0:
- data = self.rfile.read(nbytes)
- else:
- data = None
- # throw away additional data [see bug #427345]
- while select.select([self.rfile._sock], [], [], 0)[0]:
- if not self.rfile._sock.recv(1):
- break
- stdout, stderr = p.communicate(data)
- self.wfile.write(stdout)
- if stderr:
- self.log_error('%s', stderr)
- p.stderr.close()
- p.stdout.close()
- status = p.returncode
- if status:
- self.log_error("CGI script exit status %#x", status)
- else:
- self.log_message("CGI script exited OK")
-
-
def _get_best_family(*address):
infos = socket.getaddrinfo(
*address,
@@ -1317,8 +980,8 @@ def test(HandlerClass=BaseHTTPRequestHandler,
HandlerClass.protocol_version = protocol
if tls_cert:
- server = ThreadingHTTPSServer(addr, HandlerClass, certfile=tls_cert,
- keyfile=tls_key, password=tls_password)
+ server = ServerClass(addr, HandlerClass, certfile=tls_cert,
+ keyfile=tls_key, password=tls_password)
else:
server = ServerClass(addr, HandlerClass)
@@ -1336,13 +999,12 @@ def test(HandlerClass=BaseHTTPRequestHandler,
print("\nKeyboard interrupt received, exiting.")
sys.exit(0)
-if __name__ == '__main__':
+
+def _main(args=None):
import argparse
import contextlib
- parser = argparse.ArgumentParser()
- parser.add_argument('--cgi', action='store_true',
- help='run as CGI server')
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('-b', '--bind', metavar='ADDRESS',
help='bind to this address '
'(default: all interfaces)')
@@ -1362,7 +1024,7 @@ if __name__ == '__main__':
parser.add_argument('port', default=8000, type=int, nargs='?',
help='bind to this port '
'(default: %(default)s)')
- args = parser.parse_args()
+ args = parser.parse_args(args)
if not args.tls_cert and args.tls_key:
parser.error("--tls-key requires --tls-cert to be set")
@@ -1378,13 +1040,8 @@ if __name__ == '__main__':
except OSError as e:
parser.error(f"Failed to read TLS password file: {e}")
- if args.cgi:
- handler_class = CGIHTTPRequestHandler
- else:
- handler_class = SimpleHTTPRequestHandler
-
# ensure dual-stack is not disabled; ref #38907
- class DualStackServer(ThreadingHTTPServer):
+ class DualStackServerMixin:
def server_bind(self):
# suppress exception when protocol is IPv4
@@ -1397,9 +1054,16 @@ if __name__ == '__main__':
self.RequestHandlerClass(request, client_address, self,
directory=args.directory)
+ class HTTPDualStackServer(DualStackServerMixin, ThreadingHTTPServer):
+ pass
+ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
+ pass
+
+ ServerClass = HTTPSDualStackServer if args.tls_cert else HTTPDualStackServer
+
test(
- HandlerClass=handler_class,
- ServerClass=DualStackServer,
+ HandlerClass=SimpleHTTPRequestHandler,
+ ServerClass=ServerClass,
port=args.port,
bind=args.bind,
protocol=args.protocol,
@@ -1407,3 +1071,7 @@ if __name__ == '__main__':
tls_key=args.tls_key,
tls_password=tls_key_password,
)
+
+
+if __name__ == '__main__':
+ _main()
diff --git a/Lib/idlelib/NEWS2x.txt b/Lib/idlelib/NEWS2x.txt
index 6751ca5f111..3721193007e 100644
--- a/Lib/idlelib/NEWS2x.txt
+++ b/Lib/idlelib/NEWS2x.txt
@@ -1,6 +1,6 @@
What's New in IDLE 2.7? (Merged into 3.1 before 2.7 release.)
=======================
-*Release date: XX-XXX-2010*
+*Release date: 03-Jul-2010*
- idle.py modified and simplified to better support developing experimental
versions of IDLE which are not installed in the standard location.
diff --git a/Lib/idlelib/News3.txt b/Lib/idlelib/News3.txt
index 74d84b38931..30784578cc6 100644
--- a/Lib/idlelib/News3.txt
+++ b/Lib/idlelib/News3.txt
@@ -4,6 +4,13 @@ Released on 2025-10-07
=========================
+gh-112936: IDLE - Include Shell menu in single-process mode,
+though with Restart Shell and View Last Restart disabled.
+Patch by Zhikang Yan.
+
+gh-112938: IDLE - Fix uninteruptable hang when Shell gets
+rapid continuous output.
+
gh-127060: Set TERM environment variable to 'dumb' to not add ANSI escape
sequences for text color in tracebacks. IDLE does not understand them.
Patch by Victor Stinner.
diff --git a/Lib/idlelib/configdialog.py b/Lib/idlelib/configdialog.py
index 4d2adb48570..e618ef07a90 100644
--- a/Lib/idlelib/configdialog.py
+++ b/Lib/idlelib/configdialog.py
@@ -435,7 +435,7 @@ class FontPage(Frame):
self.font_name.set(font.lower())
def set_samples(self, event=None):
- """Update update both screen samples with the font settings.
+ """Update both screen samples with the font settings.
Called on font initialization and change events.
Accesses font_name, font_size, and font_bold Variables.
diff --git a/Lib/idlelib/debugger.py b/Lib/idlelib/debugger.py
index d90dbcd11f9..1fae1d4b0ad 100644
--- a/Lib/idlelib/debugger.py
+++ b/Lib/idlelib/debugger.py
@@ -1,6 +1,6 @@
"""Debug user code with a GUI interface to a subclass of bdb.Bdb.
-The Idb idb and Debugger gui instances each need a reference to each
+The Idb instance 'idb' and Debugger instance 'gui' need references to each
other or to an rpc proxy for each other.
If IDLE is started with '-n', so that user code and idb both run in the
diff --git a/Lib/idlelib/editor.py b/Lib/idlelib/editor.py
index c76db20c587..17b498f63ba 100644
--- a/Lib/idlelib/editor.py
+++ b/Lib/idlelib/editor.py
@@ -1649,7 +1649,7 @@ class IndentSearcher:
self.finished = 1
def run(self):
- """Return 2 lines containing block opener and and indent.
+ """Return 2 lines containing block opener and indent.
Either the indent line or both may be None.
"""
diff --git a/Lib/idlelib/idle_test/htest.py b/Lib/idlelib/idle_test/htest.py
index a7293774eec..b63ff9ec287 100644
--- a/Lib/idlelib/idle_test/htest.py
+++ b/Lib/idlelib/idle_test/htest.py
@@ -337,7 +337,7 @@ _tree_widget_spec = {
'file': 'tree',
'kwds': {},
'msg': "The canvas is scrollable.\n"
- "Click on folders up to to the lowest level."
+ "Click on folders up to the lowest level."
}
_undo_delegator_spec = {
diff --git a/Lib/idlelib/pyshell.py b/Lib/idlelib/pyshell.py
index 60b63d58cdd..74a0e03994f 100755
--- a/Lib/idlelib/pyshell.py
+++ b/Lib/idlelib/pyshell.py
@@ -1350,7 +1350,7 @@ class PyShell(OutputWindow):
self.text.see("insert")
self.text.undo_block_stop()
- _last_newline_re = re.compile(r"[ \t]*(\n[ \t]*)?\Z")
+ _last_newline_re = re.compile(r"[ \t]*(\n[ \t]*)?\z")
def runit(self):
index_before = self.text.index("end-2c")
line = self.text.get("iomark", "end-1c")
diff --git a/Lib/inspect.py b/Lib/inspect.py
index 9592559ba6d..183e67fabf9 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -2074,13 +2074,11 @@ def _signature_is_functionlike(obj):
code = getattr(obj, '__code__', None)
defaults = getattr(obj, '__defaults__', _void) # Important to use _void ...
kwdefaults = getattr(obj, '__kwdefaults__', _void) # ... and not None here
- annotations = getattr(obj, '__annotations__', None)
return (isinstance(code, types.CodeType) and
isinstance(name, str) and
(defaults is None or isinstance(defaults, tuple)) and
- (kwdefaults is None or isinstance(kwdefaults, dict)) and
- (isinstance(annotations, (dict)) or annotations is None) )
+ (kwdefaults is None or isinstance(kwdefaults, dict)))
def _signature_strip_non_python_syntax(signature):
@@ -3343,7 +3341,7 @@ def _main():
import argparse
import importlib
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(
'object',
help="The object to be analysed. "
diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
index 703fa289dda..8b60b9d5c9c 100644
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -729,7 +729,7 @@ class _BaseNetwork(_IPAddressBase):
return NotImplemented
def __hash__(self):
- return hash(int(self.network_address) ^ int(self.netmask))
+ return hash((int(self.network_address), int(self.netmask)))
def __contains__(self, other):
# always false if one is v4 and the other is v6.
@@ -1660,8 +1660,18 @@ class _BaseV6:
"""
if not ip_str:
raise AddressValueError('Address cannot be empty')
-
- parts = ip_str.split(':')
+ if len(ip_str) > 45:
+ shorten = ip_str
+ if len(shorten) > 100:
+ shorten = f'{ip_str[:45]}({len(ip_str)-90} chars elided){ip_str[-45:]}'
+ raise AddressValueError(f"At most 45 characters expected in "
+ f"{shorten!r}")
+
+ # We want to allow more parts than the max to be 'split'
+ # to preserve the correct error message when there are
+ # too many parts combined with '::'
+ _max_parts = cls._HEXTET_COUNT + 1
+ parts = ip_str.split(':', maxsplit=_max_parts)
# An IPv6 address needs at least 2 colons (3 parts).
_min_parts = 3
@@ -1681,7 +1691,6 @@ class _BaseV6:
# An IPv6 address can't have more than 8 colons (9 parts).
# The extra colon comes from using the "::" notation for a single
# leading or trailing zero part.
- _max_parts = cls._HEXTET_COUNT + 1
if len(parts) > _max_parts:
msg = "At most %d colons permitted in %r" % (_max_parts-1, ip_str)
raise AddressValueError(msg)
diff --git a/Lib/json/encoder.py b/Lib/json/encoder.py
index 016638549aa..bc446e0f377 100644
--- a/Lib/json/encoder.py
+++ b/Lib/json/encoder.py
@@ -348,7 +348,6 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
_current_indent_level += 1
newline_indent = '\n' + _indent * _current_indent_level
item_separator = _item_separator + newline_indent
- yield newline_indent
else:
newline_indent = None
item_separator = _item_separator
@@ -381,6 +380,8 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
f'not {key.__class__.__name__}')
if first:
first = False
+ if newline_indent is not None:
+ yield newline_indent
else:
yield item_separator
yield _encoder(key)
@@ -413,7 +414,7 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
except BaseException as exc:
exc.add_note(f'when serializing {type(dct).__name__} item {key!r}')
raise
- if newline_indent is not None:
+ if not first and newline_indent is not None:
_current_indent_level -= 1
yield '\n' + _indent * _current_indent_level
yield '}'
diff --git a/Lib/json/tool.py b/Lib/json/tool.py
index 585583da860..1967817add8 100644
--- a/Lib/json/tool.py
+++ b/Lib/json/tool.py
@@ -7,7 +7,7 @@ import argparse
import json
import re
import sys
-from _colorize import ANSIColors, can_colorize
+from _colorize import get_theme, can_colorize
# The string we are colorizing is valid JSON,
@@ -17,34 +17,34 @@ from _colorize import ANSIColors, can_colorize
_color_pattern = re.compile(r'''
(?P<key>"(\\.|[^"\\])*")(?=:) |
(?P<string>"(\\.|[^"\\])*") |
+ (?P<number>NaN|-?Infinity|[0-9\-+.Ee]+) |
(?P<boolean>true|false) |
(?P<null>null)
''', re.VERBOSE)
-
-_colors = {
- 'key': ANSIColors.INTENSE_BLUE,
- 'string': ANSIColors.BOLD_GREEN,
- 'boolean': ANSIColors.BOLD_CYAN,
- 'null': ANSIColors.BOLD_CYAN,
+_group_to_theme_color = {
+ "key": "definition",
+ "string": "string",
+ "number": "number",
+ "boolean": "keyword",
+ "null": "keyword",
}
-def _replace_match_callback(match):
- for key, color in _colors.items():
- if m := match.group(key):
- return f"{color}{m}{ANSIColors.RESET}"
- return match.group()
-
+def _colorize_json(json_str, theme):
+ def _replace_match_callback(match):
+ for group, color in _group_to_theme_color.items():
+ if m := match.group(group):
+ return f"{theme[color]}{m}{theme.reset}"
+ return match.group()
-def _colorize_json(json_str):
return re.sub(_color_pattern, _replace_match_callback, json_str)
def main():
description = ('A simple command line interface for json module '
'to validate and pretty-print JSON objects.')
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument('infile', nargs='?',
help='a JSON file to be validated or pretty-printed',
default='-')
@@ -100,13 +100,16 @@ def main():
else:
outfile = open(options.outfile, 'w', encoding='utf-8')
with outfile:
- for obj in objs:
- if can_colorize(file=outfile):
+ if can_colorize(file=outfile):
+ t = get_theme(tty_file=outfile).syntax
+ for obj in objs:
json_str = json.dumps(obj, **dump_args)
- outfile.write(_colorize_json(json_str))
- else:
+ outfile.write(_colorize_json(json_str, t))
+ outfile.write('\n')
+ else:
+ for obj in objs:
json.dump(obj, outfile, **dump_args)
- outfile.write('\n')
+ outfile.write('\n')
except ValueError as e:
raise SystemExit(e)
diff --git a/Lib/linecache.py b/Lib/linecache.py
index 87d7d6fda65..ef73d1aa997 100644
--- a/Lib/linecache.py
+++ b/Lib/linecache.py
@@ -33,10 +33,9 @@ def getlines(filename, module_globals=None):
"""Get the lines for a Python source file from the cache.
Update the cache if it doesn't contain an entry for this file already."""
- if filename in cache:
- entry = cache[filename]
- if len(entry) != 1:
- return cache[filename][2]
+ entry = cache.get(filename, None)
+ if entry is not None and len(entry) != 1:
+ return entry[2]
try:
return updatecache(filename, module_globals)
@@ -56,10 +55,9 @@ def _make_key(code):
def _getlines_from_code(code):
code_id = _make_key(code)
- if code_id in _interactive_cache:
- entry = _interactive_cache[code_id]
- if len(entry) != 1:
- return _interactive_cache[code_id][2]
+ entry = _interactive_cache.get(code_id, None)
+ if entry is not None and len(entry) != 1:
+ return entry[2]
return []
@@ -84,12 +82,8 @@ def checkcache(filename=None):
filenames = [filename]
for filename in filenames:
- try:
- entry = cache[filename]
- except KeyError:
- continue
-
- if len(entry) == 1:
+ entry = cache.get(filename, None)
+ if entry is None or len(entry) == 1:
# lazy cache entry, leave it lazy.
continue
size, mtime, lines, fullname = entry
@@ -125,9 +119,7 @@ def updatecache(filename, module_globals=None):
# These import can fail if the interpreter is shutting down
return []
- if filename in cache:
- if len(cache[filename]) != 1:
- cache.pop(filename, None)
+ entry = cache.pop(filename, None)
if _source_unavailable(filename):
return []
@@ -146,9 +138,12 @@ def updatecache(filename, module_globals=None):
# Realise a lazy loader based lookup if there is one
# otherwise try to lookup right now.
- if lazycache(filename, module_globals):
+ lazy_entry = entry if entry is not None and len(entry) == 1 else None
+ if lazy_entry is None:
+ lazy_entry = _make_lazycache_entry(filename, module_globals)
+ if lazy_entry is not None:
try:
- data = cache[filename][0]()
+ data = lazy_entry[0]()
except (ImportError, OSError):
pass
else:
@@ -156,13 +151,14 @@ def updatecache(filename, module_globals=None):
# No luck, the PEP302 loader cannot find the source
# for this module.
return []
- cache[filename] = (
+ entry = (
len(data),
None,
[line + '\n' for line in data.splitlines()],
fullname
)
- return cache[filename][2]
+ cache[filename] = entry
+ return entry[2]
# Try looking through the module search path, which is only useful
# when handling a relative filename.
@@ -211,13 +207,20 @@ def lazycache(filename, module_globals):
get_source method must be found, the filename must be a cacheable
filename, and the filename must not be already cached.
"""
- if filename in cache:
- if len(cache[filename]) == 1:
- return True
- else:
- return False
+ entry = cache.get(filename, None)
+ if entry is not None:
+ return len(entry) == 1
+
+ lazy_entry = _make_lazycache_entry(filename, module_globals)
+ if lazy_entry is not None:
+ cache[filename] = lazy_entry
+ return True
+ return False
+
+
+def _make_lazycache_entry(filename, module_globals):
if not filename or (filename.startswith('<') and filename.endswith('>')):
- return False
+ return None
# Try for a __loader__, if available
if module_globals and '__name__' in module_globals:
spec = module_globals.get('__spec__')
@@ -230,9 +233,10 @@ def lazycache(filename, module_globals):
if name and get_source:
def get_lines(name=name, *args, **kwargs):
return get_source(name, *args, **kwargs)
- cache[filename] = (get_lines,)
- return True
- return False
+ return (get_lines,)
+ return None
+
+
def _register_code(code, string, name):
entry = (len(string),
@@ -245,4 +249,5 @@ def _register_code(code, string, name):
for const in code.co_consts:
if isinstance(const, type(code)):
stack.append(const)
- _interactive_cache[_make_key(code)] = entry
+ key = _make_key(code)
+ _interactive_cache[key] = entry
diff --git a/Lib/locale.py b/Lib/locale.py
index 2feb10e59c9..dfedc6386cb 100644
--- a/Lib/locale.py
+++ b/Lib/locale.py
@@ -883,6 +883,10 @@ del k, v
# updated 'sr@latn' -> 'sr_CS.UTF-8@latin' to 'sr_RS.UTF-8@latin'
# removed 'univ'
# removed 'universal'
+#
+# SS 2025-06-10:
+# Remove 'c.utf8' -> 'en_US.UTF-8' because 'en_US.UTF-8' does not exist
+# on all platforms.
locale_alias = {
'a3': 'az_AZ.KOI8-C',
@@ -962,7 +966,6 @@ locale_alias = {
'c.ascii': 'C',
'c.en': 'C',
'c.iso88591': 'en_US.ISO8859-1',
- 'c.utf8': 'en_US.UTF-8',
'c_c': 'C',
'c_c.c': 'C',
'ca': 'ca_ES.ISO8859-1',
diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py
index aa9b79d8cab..c5860d53b1b 100644
--- a/Lib/logging/__init__.py
+++ b/Lib/logging/__init__.py
@@ -591,6 +591,7 @@ class Formatter(object):
%(threadName)s Thread name (if available)
%(taskName)s Task name (if available)
%(process)d Process ID (if available)
+ %(processName)s Process name (if available)
%(message)s The result of record.getMessage(), computed just as
the record is emitted
"""
@@ -2045,6 +2046,15 @@ def basicConfig(**kwargs):
created FileHandler, causing it to be used when the file is
opened in text mode. If not specified, the default value is
`backslashreplace`.
+ formatter If specified, set this formatter instance for all involved
+ handlers.
+ If not specified, the default is to create and use an instance of
+ `logging.Formatter` based on arguments 'format', 'datefmt' and
+ 'style'.
+ When 'formatter' is specified together with any of the three
+ arguments 'format', 'datefmt' and 'style', a `ValueError`
+ is raised to signal that these arguments would lose meaning
+ otherwise.
Note that you could specify a stream created using open(filename, mode)
rather than passing the filename and mode in. However, it should be
@@ -2067,6 +2077,9 @@ def basicConfig(**kwargs):
.. versionchanged:: 3.9
Added the ``encoding`` and ``errors`` parameters.
+
+ .. versionchanged:: 3.15
+ Added the ``formatter`` parameter.
"""
# Add thread safety in case someone mistakenly calls
# basicConfig() from multiple threads
@@ -2102,13 +2115,19 @@ def basicConfig(**kwargs):
stream = kwargs.pop("stream", None)
h = StreamHandler(stream)
handlers = [h]
- dfs = kwargs.pop("datefmt", None)
- style = kwargs.pop("style", '%')
- if style not in _STYLES:
- raise ValueError('Style must be one of: %s' % ','.join(
- _STYLES.keys()))
- fs = kwargs.pop("format", _STYLES[style][1])
- fmt = Formatter(fs, dfs, style)
+ fmt = kwargs.pop("formatter", None)
+ if fmt is None:
+ dfs = kwargs.pop("datefmt", None)
+ style = kwargs.pop("style", '%')
+ if style not in _STYLES:
+ raise ValueError('Style must be one of: %s' % ','.join(
+ _STYLES.keys()))
+ fs = kwargs.pop("format", _STYLES[style][1])
+ fmt = Formatter(fs, dfs, style)
+ else:
+ for forbidden_key in ("datefmt", "format", "style"):
+ if forbidden_key in kwargs:
+ raise ValueError(f"{forbidden_key!r} should not be specified together with 'formatter'")
for h in handlers:
if h.formatter is None:
h.setFormatter(fmt)
diff --git a/Lib/logging/config.py b/Lib/logging/config.py
index c994349fd6e..3d9aa00fa52 100644
--- a/Lib/logging/config.py
+++ b/Lib/logging/config.py
@@ -1018,7 +1018,7 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None):
"""
allow_reuse_address = True
- allow_reuse_port = True
+ allow_reuse_port = False
def __init__(self, host='localhost', port=DEFAULT_LOGGING_CONFIG_PORT,
handler=None, ready=None, verify=None):
diff --git a/Lib/mimetypes.py b/Lib/mimetypes.py
index b5a1b8da263..33e86d51a0f 100644
--- a/Lib/mimetypes.py
+++ b/Lib/mimetypes.py
@@ -698,7 +698,9 @@ _default_mime_types()
def _parse_args(args):
from argparse import ArgumentParser
- parser = ArgumentParser(description='map filename extensions to MIME types')
+ parser = ArgumentParser(
+ description='map filename extensions to MIME types', color=True
+ )
parser.add_argument(
'-e', '--extension',
action='store_true',
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py
index 5f288a8d393..fc00d286126 100644
--- a/Lib/multiprocessing/connection.py
+++ b/Lib/multiprocessing/connection.py
@@ -76,7 +76,7 @@ def arbitrary_address(family):
if family == 'AF_INET':
return ('localhost', 0)
elif family == 'AF_UNIX':
- return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir())
+ return tempfile.mktemp(prefix='sock-', dir=util.get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), next(_mmap_counter)), dir="")
diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py
index d0a3ad00e53..051d567d457 100644
--- a/Lib/multiprocessing/context.py
+++ b/Lib/multiprocessing/context.py
@@ -145,7 +145,7 @@ class BaseContext(object):
'''Check whether this is a fake forked process in a frozen executable.
If so then run code specified by commandline and exit.
'''
- if sys.platform == 'win32' and getattr(sys, 'frozen', False):
+ if self.get_start_method() == 'spawn' and getattr(sys, 'frozen', False):
from .spawn import freeze_support
freeze_support()
diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py
index 681af2610e9..c91891ff162 100644
--- a/Lib/multiprocessing/forkserver.py
+++ b/Lib/multiprocessing/forkserver.py
@@ -222,6 +222,10 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
except ImportError:
pass
+ # gh-135335: flush stdout/stderr in case any of the preloaded modules
+ # wrote to them, otherwise children might inherit buffered data
+ util._flush_std_streams()
+
util._close_stdin()
sig_r, sig_w = os.pipe()
diff --git a/Lib/multiprocessing/sharedctypes.py b/Lib/multiprocessing/sharedctypes.py
index 6071707027b..eee1172e6e9 100644
--- a/Lib/multiprocessing/sharedctypes.py
+++ b/Lib/multiprocessing/sharedctypes.py
@@ -37,7 +37,12 @@ typecode_to_type = {
#
def _new_value(type_):
- size = ctypes.sizeof(type_)
+ try:
+ size = ctypes.sizeof(type_)
+ except TypeError as e:
+ raise TypeError("bad typecode (must be a ctypes type or one of "
+ "c, b, B, u, h, H, i, I, l, L, q, Q, f or d)") from e
+
wrapper = heap.BufferWrapper(size)
return rebuild_ctype(type_, wrapper, None)
diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py
index b7192042b9c..a1a537dd48d 100644
--- a/Lib/multiprocessing/util.py
+++ b/Lib/multiprocessing/util.py
@@ -19,7 +19,7 @@ from subprocess import _args_from_interpreter_flags # noqa: F401
from . import process
__all__ = [
- 'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger',
+ 'sub_debug', 'debug', 'info', 'sub_warning', 'warn', 'get_logger',
'log_to_stderr', 'get_temp_dir', 'register_after_fork',
'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal',
'close_all_fds_except', 'SUBDEBUG', 'SUBWARNING',
@@ -34,6 +34,7 @@ SUBDEBUG = 5
DEBUG = 10
INFO = 20
SUBWARNING = 25
+WARNING = 30
LOGGER_NAME = 'multiprocessing'
DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s'
@@ -53,6 +54,10 @@ def info(msg, *args):
if _logger:
_logger.log(INFO, msg, *args, stacklevel=2)
+def warn(msg, *args):
+ if _logger:
+ _logger.log(WARNING, msg, *args, stacklevel=2)
+
def sub_warning(msg, *args):
if _logger:
_logger.log(SUBWARNING, msg, *args, stacklevel=2)
@@ -121,6 +126,21 @@ abstract_sockets_supported = _platform_supports_abstract_sockets()
# Function returning a temp directory which will be removed on exit
#
+# Maximum length of a socket file path is usually between 92 and 108 [1],
+# but Linux is known to use a size of 108 [2]. BSD-based systems usually
+# use a size of 104 or 108 and Windows does not create AF_UNIX sockets.
+#
+# [1]: https://pubs.opengroup.org/onlinepubs/9799919799/basedefs/sys_un.h.html
+# [2]: https://man7.org/linux/man-pages/man7/unix.7.html.
+
+if sys.platform == 'linux':
+ _SUN_PATH_MAX = 108
+elif sys.platform.startswith(('openbsd', 'freebsd')):
+ _SUN_PATH_MAX = 104
+else:
+ # On Windows platforms, we do not create AF_UNIX sockets.
+ _SUN_PATH_MAX = None if os.name == 'nt' else 92
+
def _remove_temp_dir(rmtree, tempdir):
rmtree(tempdir)
@@ -130,12 +150,67 @@ def _remove_temp_dir(rmtree, tempdir):
if current_process is not None:
current_process._config['tempdir'] = None
+def _get_base_temp_dir(tempfile):
+ """Get a temporary directory where socket files will be created.
+
+ To prevent additional imports, pass a pre-imported 'tempfile' module.
+ """
+ if os.name == 'nt':
+ return None
+ # Most of the time, the default temporary directory is /tmp. Thus,
+ # listener sockets files "$TMPDIR/pymp-XXXXXXXX/sock-XXXXXXXX" do
+ # not have a path length exceeding SUN_PATH_MAX.
+ #
+ # If users specify their own temporary directory, we may be unable
+ # to create those files. Therefore, we fall back to the system-wide
+ # temporary directory /tmp, assumed to exist on POSIX systems.
+ #
+ # See https://github.com/python/cpython/issues/132124.
+ base_tempdir = tempfile.gettempdir()
+ # Files created in a temporary directory are suffixed by a string
+ # generated by tempfile._RandomNameSequence, which, by design,
+ # is 8 characters long.
+ #
+ # Thus, the length of socket filename will be:
+ #
+ # len(base_tempdir + '/pymp-XXXXXXXX' + '/sock-XXXXXXXX')
+ sun_path_len = len(base_tempdir) + 14 + 14
+ if sun_path_len <= _SUN_PATH_MAX:
+ return base_tempdir
+ # Fallback to the default system-wide temporary directory.
+ # This ignores user-defined environment variables.
+ #
+ # On POSIX systems, /tmp MUST be writable by any application [1].
+ # We however emit a warning if this is not the case to prevent
+ # obscure errors later in the execution.
+ #
+ # On some legacy systems, /var/tmp and /usr/tmp can be present
+ # and will be used instead.
+ #
+ # [1]: https://refspecs.linuxfoundation.org/FHS_3.0/fhs/ch03s18.html
+ dirlist = ['/tmp', '/var/tmp', '/usr/tmp']
+ try:
+ base_system_tempdir = tempfile._get_default_tempdir(dirlist)
+ except FileNotFoundError:
+ warn("Process-wide temporary directory %s will not be usable for "
+ "creating socket files and no usable system-wide temporary "
+ "directory was found in %s", base_tempdir, dirlist)
+ # At this point, the system-wide temporary directory is not usable
+ # but we may assume that the user-defined one is, even if we will
+ # not be able to write socket files out there.
+ return base_tempdir
+ warn("Ignoring user-defined temporary directory: %s", base_tempdir)
+ # at most max(map(len, dirlist)) + 14 + 14 = 36 characters
+ assert len(base_system_tempdir) + 14 + 14 <= _SUN_PATH_MAX
+ return base_system_tempdir
+
def get_temp_dir():
# get name of a temp directory which will be automatically cleaned up
tempdir = process.current_process()._config.get('tempdir')
if tempdir is None:
import shutil, tempfile
- tempdir = tempfile.mkdtemp(prefix='pymp-')
+ base_tempdir = _get_base_temp_dir(tempfile)
+ tempdir = tempfile.mkdtemp(prefix='pymp-', dir=base_tempdir)
info('created temp directory %s', tempdir)
# keep a strong reference to shutil.rmtree(), since the finalizer
# can be called late during Python shutdown
diff --git a/Lib/netrc.py b/Lib/netrc.py
index b285fd8e357..2f502c1d533 100644
--- a/Lib/netrc.py
+++ b/Lib/netrc.py
@@ -7,6 +7,19 @@ import os, stat
__all__ = ["netrc", "NetrcParseError"]
+def _can_security_check():
+ # On WASI, getuid() is indicated as a stub but it may also be missing.
+ return os.name == 'posix' and hasattr(os, 'getuid')
+
+
+def _getpwuid(uid):
+ try:
+ import pwd
+ return pwd.getpwuid(uid)[0]
+ except (ImportError, LookupError):
+ return f'uid {uid}'
+
+
class NetrcParseError(Exception):
"""Exception raised on syntax errors in the .netrc file."""
def __init__(self, msg, filename=None, lineno=None):
@@ -142,21 +155,15 @@ class netrc:
self._security_check(fp, default_netrc, self.hosts[entryname][0])
def _security_check(self, fp, default_netrc, login):
- if os.name == 'posix' and default_netrc and login != "anonymous":
+ if _can_security_check() and default_netrc and login != "anonymous":
prop = os.fstat(fp.fileno())
- if prop.st_uid != os.getuid():
- import pwd
- try:
- fowner = pwd.getpwuid(prop.st_uid)[0]
- except KeyError:
- fowner = 'uid %s' % prop.st_uid
- try:
- user = pwd.getpwuid(os.getuid())[0]
- except KeyError:
- user = 'uid %s' % os.getuid()
+ current_user_id = os.getuid()
+ if prop.st_uid != current_user_id:
+ fowner = _getpwuid(prop.st_uid)
+ user = _getpwuid(current_user_id)
raise NetrcParseError(
- (f"~/.netrc file owner ({fowner}, {user}) does not match"
- " current user"))
+ f"~/.netrc file owner ({fowner}) does not match"
+ f" current user ({user})")
if (prop.st_mode & (stat.S_IRWXG | stat.S_IRWXO)):
raise NetrcParseError(
"~/.netrc access too permissive: access"
diff --git a/Lib/ntpath.py b/Lib/ntpath.py
index 5481bb8888e..9cdc16480f9 100644
--- a/Lib/ntpath.py
+++ b/Lib/ntpath.py
@@ -29,7 +29,7 @@ __all__ = ["normcase","isabs","join","splitdrive","splitroot","split","splitext"
"abspath","curdir","pardir","sep","pathsep","defpath","altsep",
"extsep","devnull","realpath","supports_unicode_filenames","relpath",
"samefile", "sameopenfile", "samestat", "commonpath", "isjunction",
- "isdevdrive"]
+ "isdevdrive", "ALLOW_MISSING"]
def _get_bothseps(path):
if isinstance(path, bytes):
@@ -601,9 +601,10 @@ try:
from nt import _findfirstfile, _getfinalpathname, readlink as _nt_readlink
except ImportError:
# realpath is a no-op on systems without _getfinalpathname support.
- realpath = abspath
+ def realpath(path, *, strict=False):
+ return abspath(path)
else:
- def _readlink_deep(path):
+ def _readlink_deep(path, ignored_error=OSError):
# These error codes indicate that we should stop reading links and
# return the path we currently have.
# 1: ERROR_INVALID_FUNCTION
@@ -636,7 +637,7 @@ else:
path = old_path
break
path = normpath(join(dirname(old_path), path))
- except OSError as ex:
+ except ignored_error as ex:
if ex.winerror in allowed_winerror:
break
raise
@@ -645,7 +646,7 @@ else:
break
return path
- def _getfinalpathname_nonstrict(path):
+ def _getfinalpathname_nonstrict(path, ignored_error=OSError):
# These error codes indicate that we should stop resolving the path
# and return the value we currently have.
# 1: ERROR_INVALID_FUNCTION
@@ -661,9 +662,10 @@ else:
# 87: ERROR_INVALID_PARAMETER
# 123: ERROR_INVALID_NAME
# 161: ERROR_BAD_PATHNAME
+ # 1005: ERROR_UNRECOGNIZED_VOLUME
# 1920: ERROR_CANT_ACCESS_FILE
# 1921: ERROR_CANT_RESOLVE_FILENAME (implies unfollowable symlink)
- allowed_winerror = 1, 2, 3, 5, 21, 32, 50, 53, 65, 67, 87, 123, 161, 1920, 1921
+ allowed_winerror = 1, 2, 3, 5, 21, 32, 50, 53, 65, 67, 87, 123, 161, 1005, 1920, 1921
# Non-strict algorithm is to find as much of the target directory
# as we can and join the rest.
@@ -672,17 +674,18 @@ else:
try:
path = _getfinalpathname(path)
return join(path, tail) if tail else path
- except OSError as ex:
+ except ignored_error as ex:
if ex.winerror not in allowed_winerror:
raise
try:
# The OS could not resolve this path fully, so we attempt
# to follow the link ourselves. If we succeed, join the tail
# and return.
- new_path = _readlink_deep(path)
+ new_path = _readlink_deep(path,
+ ignored_error=ignored_error)
if new_path != path:
return join(new_path, tail) if tail else new_path
- except OSError:
+ except ignored_error:
# If we fail to readlink(), let's keep traversing
pass
# If we get these errors, try to get the real name of the file without accessing it.
@@ -690,7 +693,7 @@ else:
try:
name = _findfirstfile(path)
path, _ = split(path)
- except OSError:
+ except ignored_error:
path, name = split(path)
else:
path, name = split(path)
@@ -720,6 +723,15 @@ else:
if normcase(path) == devnull:
return '\\\\.\\NUL'
had_prefix = path.startswith(prefix)
+
+ if strict is ALLOW_MISSING:
+ ignored_error = FileNotFoundError
+ strict = True
+ elif strict:
+ ignored_error = ()
+ else:
+ ignored_error = OSError
+
if not had_prefix and not isabs(path):
path = join(cwd, path)
try:
@@ -727,17 +739,16 @@ else:
initial_winerror = 0
except ValueError as ex:
# gh-106242: Raised for embedded null characters
- # In strict mode, we convert into an OSError.
+ # In strict modes, we convert into an OSError.
# Non-strict mode returns the path as-is, since we've already
# made it absolute.
if strict:
raise OSError(str(ex)) from None
path = normpath(path)
- except OSError as ex:
- if strict:
- raise
+ except ignored_error as ex:
initial_winerror = ex.winerror
- path = _getfinalpathname_nonstrict(path)
+ path = _getfinalpathname_nonstrict(path,
+ ignored_error=ignored_error)
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
diff --git a/Lib/os.py b/Lib/os.py
index 266e40b56f6..12926c832f5 100644
--- a/Lib/os.py
+++ b/Lib/os.py
@@ -10,7 +10,7 @@ This exports:
- os.extsep is the extension separator (always '.')
- os.altsep is the alternate pathname separator (None or '/')
- os.pathsep is the component separator used in $PATH etc
- - os.linesep is the line separator in text files ('\r' or '\n' or '\r\n')
+ - os.linesep is the line separator in text files ('\n' or '\r\n')
- os.defpath is the default search path for executables
- os.devnull is the file path of the null device ('/dev/null', etc.)
@@ -118,6 +118,7 @@ if _exists("_have_functions"):
_add("HAVE_FCHMODAT", "chmod")
_add("HAVE_FCHOWNAT", "chown")
_add("HAVE_FSTATAT", "stat")
+ _add("HAVE_LSTAT", "lstat")
_add("HAVE_FUTIMESAT", "utime")
_add("HAVE_LINKAT", "link")
_add("HAVE_MKDIRAT", "mkdir")
diff --git a/Lib/pathlib/__init__.py b/Lib/pathlib/__init__.py
index 12cf9f579cb..2dc1f7f7126 100644
--- a/Lib/pathlib/__init__.py
+++ b/Lib/pathlib/__init__.py
@@ -28,8 +28,9 @@ except ImportError:
from pathlib._os import (
PathInfo, DirEntryInfo,
+ magic_open, vfspath,
ensure_different_files, ensure_distinct_paths,
- copyfile2, copyfileobj, magic_open, copy_info,
+ copyfile2, copyfileobj, copy_info,
)
@@ -1164,12 +1165,12 @@ class Path(PurePath):
# os.symlink() incorrectly creates a file-symlink on Windows. Avoid
# this by passing *target_is_dir* to os.symlink() on Windows.
def _copy_from_symlink(self, source, preserve_metadata=False):
- os.symlink(str(source.readlink()), self, source.info.is_dir())
+ os.symlink(vfspath(source.readlink()), self, source.info.is_dir())
if preserve_metadata:
copy_info(source.info, self, follow_symlinks=False)
else:
def _copy_from_symlink(self, source, preserve_metadata=False):
- os.symlink(str(source.readlink()), self)
+ os.symlink(vfspath(source.readlink()), self)
if preserve_metadata:
copy_info(source.info, self, follow_symlinks=False)
diff --git a/Lib/pathlib/_os.py b/Lib/pathlib/_os.py
index 039836941dd..62a4adb555e 100644
--- a/Lib/pathlib/_os.py
+++ b/Lib/pathlib/_os.py
@@ -210,6 +210,26 @@ def magic_open(path, mode='r', buffering=-1, encoding=None, errors=None,
raise TypeError(f"{cls.__name__} can't be opened with mode {mode!r}")
+def vfspath(path):
+ """
+ Return the string representation of a virtual path object.
+ """
+ try:
+ return os.fsdecode(path)
+ except TypeError:
+ pass
+
+ path_type = type(path)
+ try:
+ return path_type.__vfspath__(path)
+ except AttributeError:
+ if hasattr(path_type, '__vfspath__'):
+ raise
+
+ raise TypeError("expected str, bytes, os.PathLike or JoinablePath "
+ "object, not " + path_type.__name__)
+
+
def ensure_distinct_paths(source, target):
"""
Raise OSError(EINVAL) if the other path is within this path.
@@ -225,8 +245,8 @@ def ensure_distinct_paths(source, target):
err = OSError(EINVAL, "Source path is a parent of target path")
else:
return
- err.filename = str(source)
- err.filename2 = str(target)
+ err.filename = vfspath(source)
+ err.filename2 = vfspath(target)
raise err
@@ -247,8 +267,8 @@ def ensure_different_files(source, target):
except (OSError, ValueError):
return
err = OSError(EINVAL, "Source and target are the same file")
- err.filename = str(source)
- err.filename2 = str(target)
+ err.filename = vfspath(source)
+ err.filename2 = vfspath(target)
raise err
diff --git a/Lib/pathlib/types.py b/Lib/pathlib/types.py
index d8f5c34a1a7..42b80221608 100644
--- a/Lib/pathlib/types.py
+++ b/Lib/pathlib/types.py
@@ -11,9 +11,10 @@ Protocols for supporting classes in pathlib.
from abc import ABC, abstractmethod
-from glob import _PathGlobber
+from glob import _GlobberBase
from io import text_encoding
-from pathlib._os import magic_open, ensure_distinct_paths, ensure_different_files, copyfileobj
+from pathlib._os import (magic_open, vfspath, ensure_distinct_paths,
+ ensure_different_files, copyfileobj)
from pathlib import PurePath, Path
from typing import Optional, Protocol, runtime_checkable
@@ -60,6 +61,25 @@ class PathInfo(Protocol):
def is_symlink(self) -> bool: ...
+class _PathGlobber(_GlobberBase):
+ """Provides shell-style pattern matching and globbing for ReadablePath.
+ """
+
+ @staticmethod
+ def lexists(path):
+ return path.info.exists(follow_symlinks=False)
+
+ @staticmethod
+ def scandir(path):
+ return ((child.info, child.name, child) for child in path.iterdir())
+
+ @staticmethod
+ def concat_path(path, text):
+ return path.with_segments(vfspath(path) + text)
+
+ stringify_path = staticmethod(vfspath)
+
+
class _JoinablePath(ABC):
"""Abstract base class for pure path objects.
@@ -86,20 +106,19 @@ class _JoinablePath(ABC):
raise NotImplementedError
@abstractmethod
- def __str__(self):
- """Return the string representation of the path, suitable for
- passing to system calls."""
+ def __vfspath__(self):
+ """Return the string representation of the path."""
raise NotImplementedError
@property
def anchor(self):
"""The concatenation of the drive and root, or ''."""
- return _explode_path(str(self), self.parser.split)[0]
+ return _explode_path(vfspath(self), self.parser.split)[0]
@property
def name(self):
"""The final path component, if any."""
- return self.parser.split(str(self))[1]
+ return self.parser.split(vfspath(self))[1]
@property
def suffix(self):
@@ -135,7 +154,7 @@ class _JoinablePath(ABC):
split = self.parser.split
if split(name)[0]:
raise ValueError(f"Invalid name {name!r}")
- path = str(self)
+ path = vfspath(self)
path = path.removesuffix(split(path)[1]) + name
return self.with_segments(path)
@@ -168,7 +187,7 @@ class _JoinablePath(ABC):
def parts(self):
"""An object providing sequence-like access to the
components in the filesystem path."""
- anchor, parts = _explode_path(str(self), self.parser.split)
+ anchor, parts = _explode_path(vfspath(self), self.parser.split)
if anchor:
parts.append(anchor)
return tuple(reversed(parts))
@@ -179,24 +198,24 @@ class _JoinablePath(ABC):
paths) or a totally different path (if one of the arguments is
anchored).
"""
- return self.with_segments(str(self), *pathsegments)
+ return self.with_segments(vfspath(self), *pathsegments)
def __truediv__(self, key):
try:
- return self.with_segments(str(self), key)
+ return self.with_segments(vfspath(self), key)
except TypeError:
return NotImplemented
def __rtruediv__(self, key):
try:
- return self.with_segments(key, str(self))
+ return self.with_segments(key, vfspath(self))
except TypeError:
return NotImplemented
@property
def parent(self):
"""The logical parent of the path."""
- path = str(self)
+ path = vfspath(self)
parent = self.parser.split(path)[0]
if path != parent:
return self.with_segments(parent)
@@ -206,7 +225,7 @@ class _JoinablePath(ABC):
def parents(self):
"""A sequence of this path's logical parents."""
split = self.parser.split
- path = str(self)
+ path = vfspath(self)
parent = split(path)[0]
parents = []
while path != parent:
@@ -223,7 +242,7 @@ class _JoinablePath(ABC):
case_sensitive = self.parser.normcase('Aa') == 'Aa'
globber = _PathGlobber(self.parser.sep, case_sensitive, recursive=True)
match = globber.compile(pattern, altsep=self.parser.altsep)
- return match(str(self)) is not None
+ return match(vfspath(self)) is not None
class _ReadablePath(_JoinablePath):
@@ -412,7 +431,7 @@ class _WritablePath(_JoinablePath):
while stack:
src, dst = stack.pop()
if not follow_symlinks and src.info.is_symlink():
- dst.symlink_to(str(src.readlink()), src.info.is_dir())
+ dst.symlink_to(vfspath(src.readlink()), src.info.is_dir())
elif src.info.is_dir():
children = src.iterdir()
dst.mkdir()
diff --git a/Lib/pdb.py b/Lib/pdb.py
index 343cf4404d7..fc83728fb6d 100644
--- a/Lib/pdb.py
+++ b/Lib/pdb.py
@@ -75,8 +75,10 @@ import dis
import code
import glob
import json
+import stat
import token
import types
+import atexit
import codeop
import pprint
import signal
@@ -92,10 +94,12 @@ import tokenize
import itertools
import traceback
import linecache
+import selectors
+import threading
import _colorize
+import _pyrepl.utils
-from contextlib import closing
-from contextlib import contextmanager
+from contextlib import ExitStack, closing, contextmanager
from rlcompleter import Completer
from types import CodeType
from warnings import deprecated
@@ -339,7 +343,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
_last_pdb_instance = None
def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
- nosigint=False, readrc=True, mode=None, backend=None):
+ nosigint=False, readrc=True, mode=None, backend=None, colorize=False):
bdb.Bdb.__init__(self, skip=skip, backend=backend if backend else get_default_backend())
cmd.Cmd.__init__(self, completekey, stdin, stdout)
sys.audit("pdb.Pdb")
@@ -352,6 +356,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
self._wait_for_mainpyfile = False
self.tb_lineno = {}
self.mode = mode
+ self.colorize = colorize and _colorize.can_colorize(file=stdout or sys.stdout)
# Try to load readline if it exists
try:
import readline
@@ -743,12 +748,34 @@ class Pdb(bdb.Bdb, cmd.Cmd):
self.message(repr(obj))
@contextmanager
- def _enable_multiline_completion(self):
+ def _enable_multiline_input(self):
+ try:
+ import readline
+ except ImportError:
+ yield
+ return
+
+ def input_auto_indent():
+ last_index = readline.get_current_history_length()
+ last_line = readline.get_history_item(last_index)
+ if last_line:
+ if last_line.isspace():
+ # If the last line is empty, we don't need to indent
+ return
+
+ last_line = last_line.rstrip('\r\n')
+ indent = len(last_line) - len(last_line.lstrip())
+ if last_line.endswith(":"):
+ indent += 4
+ readline.insert_text(' ' * indent)
+
completenames = self.completenames
try:
self.completenames = self.complete_multiline_names
+ readline.set_startup_hook(input_auto_indent)
yield
finally:
+ readline.set_startup_hook()
self.completenames = completenames
return
@@ -857,7 +884,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
try:
if (code := codeop.compile_command(line + '\n', '<stdin>', 'single')) is None:
# Multi-line mode
- with self._enable_multiline_completion():
+ with self._enable_multiline_input():
buffer = line
continue_prompt = "... "
while (code := codeop.compile_command(buffer, '<stdin>', 'single')) is None:
@@ -879,7 +906,11 @@ class Pdb(bdb.Bdb, cmd.Cmd):
return None, None, False
else:
line = line.rstrip('\r\n')
- buffer += '\n' + line
+ if line.isspace():
+ # empty line, just continue
+ buffer += '\n'
+ else:
+ buffer += '\n' + line
self.lastcmd = buffer
except SyntaxError as e:
# Maybe it's an await expression/statement
@@ -1036,6 +1067,13 @@ class Pdb(bdb.Bdb, cmd.Cmd):
return True
return False
+ def _colorize_code(self, code):
+ if self.colorize:
+ colors = list(_pyrepl.utils.gen_colors(code))
+ chars, _ = _pyrepl.utils.disp_str(code, colors=colors, force_color=True)
+ code = "".join(chars)
+ return code
+
# interface abstraction functions
def message(self, msg, end='\n'):
@@ -2166,6 +2204,8 @@ class Pdb(bdb.Bdb, cmd.Cmd):
s += '->'
elif lineno == exc_lineno:
s += '>>'
+ if self.colorize:
+ line = self._colorize_code(line)
self.message(s + '\t' + line.rstrip())
def do_whatis(self, arg):
@@ -2365,8 +2405,14 @@ class Pdb(bdb.Bdb, cmd.Cmd):
prefix = '> '
else:
prefix = ' '
- self.message(prefix +
- self.format_stack_entry(frame_lineno, prompt_prefix))
+ stack_entry = self.format_stack_entry(frame_lineno, prompt_prefix)
+ if self.colorize:
+ lines = stack_entry.split(prompt_prefix, 1)
+ if len(lines) > 1:
+ # We have some code to display
+ lines[1] = self._colorize_code(lines[1])
+ stack_entry = prompt_prefix.join(lines)
+ self.message(prefix + stack_entry)
# Provide help
@@ -2604,7 +2650,7 @@ def set_trace(*, header=None, commands=None):
if Pdb._last_pdb_instance is not None:
pdb = Pdb._last_pdb_instance
else:
- pdb = Pdb(mode='inline', backend='monitoring')
+ pdb = Pdb(mode='inline', backend='monitoring', colorize=True)
if header is not None:
pdb.message(header)
pdb.set_trace(sys._getframe().f_back, commands=commands)
@@ -2619,7 +2665,7 @@ async def set_trace_async(*, header=None, commands=None):
if Pdb._last_pdb_instance is not None:
pdb = Pdb._last_pdb_instance
else:
- pdb = Pdb(mode='inline', backend='monitoring')
+ pdb = Pdb(mode='inline', backend='monitoring', colorize=True)
if header is not None:
pdb.message(header)
await pdb.set_trace_async(sys._getframe().f_back, commands=commands)
@@ -2627,13 +2673,26 @@ async def set_trace_async(*, header=None, commands=None):
# Remote PDB
class _PdbServer(Pdb):
- def __init__(self, sockfile, owns_sockfile=True, **kwargs):
+ def __init__(
+ self,
+ sockfile,
+ signal_server=None,
+ owns_sockfile=True,
+ colorize=False,
+ **kwargs,
+ ):
self._owns_sockfile = owns_sockfile
self._interact_state = None
self._sockfile = sockfile
self._command_name_cache = []
self._write_failed = False
- super().__init__(**kwargs)
+ if signal_server:
+ # Only started by the top level _PdbServer, not recursive ones.
+ self._start_signal_listener(signal_server)
+ # Override the `colorize` attribute set by the parent constructor,
+ # because it checks the server's stdout, rather than the client's.
+ super().__init__(colorize=False, **kwargs)
+ self.colorize = colorize
@staticmethod
def protocol_version():
@@ -2688,15 +2747,49 @@ class _PdbServer(Pdb):
f"PDB message doesn't follow the schema! {msg}"
)
+ @classmethod
+ def _start_signal_listener(cls, address):
+ def listener(sock):
+ with closing(sock):
+ # Check if the interpreter is finalizing every quarter of a second.
+ # Clean up and exit if so.
+ sock.settimeout(0.25)
+ sock.shutdown(socket.SHUT_WR)
+ while not shut_down.is_set():
+ try:
+ data = sock.recv(1024)
+ except socket.timeout:
+ continue
+ if data == b"":
+ return # EOF
+ signal.raise_signal(signal.SIGINT)
+
+ def stop_thread():
+ shut_down.set()
+ thread.join()
+
+ # Use a daemon thread so that we don't detach until after all non-daemon
+ # threads are done. Use an atexit handler to stop gracefully at that point,
+ # so that our thread is stopped before the interpreter is torn down.
+ shut_down = threading.Event()
+ thread = threading.Thread(
+ target=listener,
+ args=[socket.create_connection(address, timeout=5)],
+ daemon=True,
+ )
+ atexit.register(stop_thread)
+ thread.start()
+
def _send(self, **kwargs):
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
self._sockfile.write(json_payload.encode() + b"\n")
self._sockfile.flush()
- except OSError:
- # This means that the client has abruptly disconnected, but we'll
- # handle that the next time we try to read from the client instead
+ except (OSError, ValueError):
+ # We get an OSError if the network connection has dropped, and a
+ # ValueError if detach() if the sockfile has been closed. We'll
+ # handle this the next time we try to read from the client instead
# of trying to handle it from everywhere _send() may be called.
# Track this with a flag rather than assuming readline() will ever
# return an empty string because the socket may be half-closed.
@@ -2887,7 +2980,11 @@ class _PdbServer(Pdb):
@typing.override
def _create_recursive_debugger(self):
- return _PdbServer(self._sockfile, owns_sockfile=False)
+ return _PdbServer(
+ self._sockfile,
+ owns_sockfile=False,
+ colorize=self.colorize,
+ )
@typing.override
def _prompt_for_confirmation(self, prompt, default):
@@ -2924,15 +3021,21 @@ class _PdbServer(Pdb):
class _PdbClient:
- def __init__(self, pid, sockfile, interrupt_script):
+ def __init__(self, pid, server_socket, interrupt_sock):
self.pid = pid
- self.sockfile = sockfile
- self.interrupt_script = interrupt_script
+ self.read_buf = b""
+ self.signal_read = None
+ self.signal_write = None
+ self.sigint_received = False
+ self.raise_on_sigint = False
+ self.server_socket = server_socket
+ self.interrupt_sock = interrupt_sock
self.pdb_instance = Pdb()
self.pdb_commands = set()
self.completion_matches = []
self.state = "dumb"
self.write_failed = False
+ self.multiline_block = False
def _ensure_valid_message(self, msg):
# Ensure the message conforms to our protocol.
@@ -2968,8 +3071,7 @@ class _PdbClient:
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
- self.sockfile.write(json_payload.encode() + b"\n")
- self.sockfile.flush()
+ self.server_socket.sendall(json_payload.encode() + b"\n")
except OSError:
# This means that the client has abruptly disconnected, but we'll
# handle that the next time we try to read from the client instead
@@ -2978,9 +3080,44 @@ class _PdbClient:
# return an empty string because the socket may be half-closed.
self.write_failed = True
- def read_command(self, prompt):
- reply = input(prompt)
+ def _readline(self):
+ if self.sigint_received:
+ # There's a pending unhandled SIGINT. Handle it now.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+ # Wait for either a SIGINT or a line or EOF from the PDB server.
+ selector = selectors.DefaultSelector()
+ selector.register(self.signal_read, selectors.EVENT_READ)
+ selector.register(self.server_socket, selectors.EVENT_READ)
+
+ while b"\n" not in self.read_buf:
+ for key, _ in selector.select():
+ if key.fileobj == self.signal_read:
+ self.signal_read.recv(1024)
+ if self.sigint_received:
+ # If not, we're reading wakeup events for sigints that
+ # we've previously handled, and can ignore them.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+ elif key.fileobj == self.server_socket:
+ data = self.server_socket.recv(16 * 1024)
+ self.read_buf += data
+ if not data and b"\n" not in self.read_buf:
+ # EOF without a full final line. Drop the partial line.
+ self.read_buf = b""
+ return b""
+
+ ret, sep, self.read_buf = self.read_buf.partition(b"\n")
+ return ret + sep
+
+ def read_input(self, prompt, multiline_block):
+ self.multiline_block = multiline_block
+ with self._sigint_raises_keyboard_interrupt():
+ return input(prompt)
+
+ def read_command(self, prompt):
+ reply = self.read_input(prompt, multiline_block=False)
if self.state == "dumb":
# No logic applied whatsoever, just pass the raw reply back.
return reply
@@ -3003,9 +3140,9 @@ class _PdbClient:
return prefix + reply
# Otherwise, valid first line of a multi-line statement
- continue_prompt = "...".ljust(len(prompt))
+ more_prompt = "...".ljust(len(prompt))
while codeop.compile_command(reply, "<stdin>", "single") is None:
- reply += "\n" + input(continue_prompt)
+ reply += "\n" + self.read_input(more_prompt, multiline_block=True)
return prefix + reply
@@ -3030,11 +3167,70 @@ class _PdbClient:
finally:
readline.set_completer(old_completer)
+ @contextmanager
+ def _sigint_handler(self):
+ # Signal handling strategy:
+ # - When we call input() we want a SIGINT to raise KeyboardInterrupt
+ # - Otherwise we want to write to the wakeup FD and set a flag.
+ # We'll break out of select() when the wakeup FD is written to,
+ # and we'll check the flag whenever we're about to accept input.
+ def handler(signum, frame):
+ self.sigint_received = True
+ if self.raise_on_sigint:
+ # One-shot; don't raise again until the flag is set again.
+ self.raise_on_sigint = False
+ self.sigint_received = False
+ raise KeyboardInterrupt
+
+ sentinel = object()
+ old_handler = sentinel
+ old_wakeup_fd = sentinel
+
+ self.signal_read, self.signal_write = socket.socketpair()
+ with (closing(self.signal_read), closing(self.signal_write)):
+ self.signal_read.setblocking(False)
+ self.signal_write.setblocking(False)
+
+ try:
+ old_handler = signal.signal(signal.SIGINT, handler)
+
+ try:
+ old_wakeup_fd = signal.set_wakeup_fd(
+ self.signal_write.fileno(),
+ warn_on_full_buffer=False,
+ )
+ yield
+ finally:
+ # Restore the old wakeup fd if we installed a new one
+ if old_wakeup_fd is not sentinel:
+ signal.set_wakeup_fd(old_wakeup_fd)
+ finally:
+ self.signal_read = self.signal_write = None
+ if old_handler is not sentinel:
+ # Restore the old handler if we installed a new one
+ signal.signal(signal.SIGINT, old_handler)
+
+ @contextmanager
+ def _sigint_raises_keyboard_interrupt(self):
+ if self.sigint_received:
+ # There's a pending unhandled SIGINT. Handle it now.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+
+ try:
+ self.raise_on_sigint = True
+ yield
+ finally:
+ self.raise_on_sigint = False
+
def cmdloop(self):
- with self.readline_completion(self.complete):
+ with (
+ self._sigint_handler(),
+ self.readline_completion(self.complete),
+ ):
while not self.write_failed:
try:
- if not (payload_bytes := self.sockfile.readline()):
+ if not (payload_bytes := self._readline()):
break
except KeyboardInterrupt:
self.send_interrupt()
@@ -3052,11 +3248,17 @@ class _PdbClient:
self.process_payload(payload)
def send_interrupt(self):
- print(
- "\n*** Program will stop at the next bytecode instruction."
- " (Use 'cont' to resume)."
- )
- sys.remote_exec(self.pid, self.interrupt_script)
+ if self.interrupt_sock is not None:
+ # Write to a socket that the PDB server listens on. This triggers
+ # the remote to raise a SIGINT for itself. We do this because
+ # Windows doesn't allow triggering SIGINT remotely.
+ # See https://stackoverflow.com/a/35792192 for many more details.
+ self.interrupt_sock.sendall(signal.SIGINT.to_bytes())
+ else:
+ # On Unix we can just send a SIGINT to the remote process.
+ # This is preferable to using the signal thread approach that we
+ # use on Windows because it can interrupt IO in the main thread.
+ os.kill(self.pid, signal.SIGINT)
def process_payload(self, payload):
match payload:
@@ -3105,9 +3307,13 @@ class _PdbClient:
origline = readline.get_line_buffer()
line = origline.lstrip()
- stripped = len(origline) - len(line)
- begidx = readline.get_begidx() - stripped
- endidx = readline.get_endidx() - stripped
+ if self.multiline_block:
+ # We're completing a line contained in a multi-line block.
+ # Force the remote to treat it as a Python expression.
+ line = "! " + line
+ offset = len(origline) - len(line)
+ begidx = readline.get_begidx() - offset
+ endidx = readline.get_endidx() - offset
msg = {
"complete": {
@@ -3122,7 +3328,7 @@ class _PdbClient:
if self.write_failed:
return None
- payload = self.sockfile.readline()
+ payload = self._readline()
if not payload:
return None
@@ -3139,11 +3345,31 @@ class _PdbClient:
return None
-def _connect(host, port, frame, commands, version):
+def _connect(
+ *,
+ host,
+ port,
+ frame,
+ commands,
+ version,
+ signal_raising_thread,
+ colorize,
+):
with closing(socket.create_connection((host, port))) as conn:
sockfile = conn.makefile("rwb")
- remote_pdb = _PdbServer(sockfile)
+ # The client requests this thread on Windows but not on Unix.
+ # Most tests don't request this thread, to keep them simpler.
+ if signal_raising_thread:
+ signal_server = (host, port)
+ else:
+ signal_server = None
+
+ remote_pdb = _PdbServer(
+ sockfile,
+ signal_server=signal_server,
+ colorize=colorize,
+ )
weakref.finalize(remote_pdb, sockfile.close)
if Pdb._last_pdb_instance is not None:
@@ -3158,49 +3384,57 @@ def _connect(host, port, frame, commands, version):
f"\nLocal pdb module's protocol version: {attach_ver}"
)
else:
- remote_pdb.rcLines.extend(commands.splitlines())
- remote_pdb.set_trace(frame=frame)
+ remote_pdb.set_trace(frame=frame, commands=commands.splitlines())
def attach(pid, commands=()):
"""Attach to a running process with the given PID."""
- with closing(socket.create_server(("localhost", 0))) as server:
+ with ExitStack() as stack:
+ server = stack.enter_context(
+ closing(socket.create_server(("localhost", 0)))
+ )
port = server.getsockname()[1]
- with tempfile.NamedTemporaryFile("w", delete_on_close=False) as connect_script:
- connect_script.write(
- textwrap.dedent(
- f"""
- import pdb, sys
- pdb._connect(
- host="localhost",
- port={port},
- frame=sys._getframe(1),
- commands={json.dumps("\n".join(commands))},
- version={_PdbServer.protocol_version()},
- )
- """
+ connect_script = stack.enter_context(
+ tempfile.NamedTemporaryFile("w", delete_on_close=False)
+ )
+
+ use_signal_thread = sys.platform == "win32"
+ colorize = _colorize.can_colorize()
+
+ connect_script.write(
+ textwrap.dedent(
+ f"""
+ import pdb, sys
+ pdb._connect(
+ host="localhost",
+ port={port},
+ frame=sys._getframe(1),
+ commands={json.dumps("\n".join(commands))},
+ version={_PdbServer.protocol_version()},
+ signal_raising_thread={use_signal_thread!r},
+ colorize={colorize!r},
)
+ """
)
- connect_script.close()
- sys.remote_exec(pid, connect_script.name)
-
- # TODO Add a timeout? Or don't bother since the user can ^C?
- client_sock, _ = server.accept()
-
- with closing(client_sock):
- sockfile = client_sock.makefile("rwb")
-
- with closing(sockfile):
- with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
- interrupt_script.write(
- 'import pdb, sys\n'
- 'if inst := pdb.Pdb._last_pdb_instance:\n'
- ' inst.set_trace(sys._getframe(1))\n'
- )
- interrupt_script.close()
+ )
+ connect_script.close()
+ orig_mode = os.stat(connect_script.name).st_mode
+ os.chmod(connect_script.name, orig_mode | stat.S_IROTH | stat.S_IRGRP)
+ sys.remote_exec(pid, connect_script.name)
+
+ # TODO Add a timeout? Or don't bother since the user can ^C?
+ client_sock, _ = server.accept()
+ stack.enter_context(closing(client_sock))
+
+ if use_signal_thread:
+ interrupt_sock, _ = server.accept()
+ stack.enter_context(closing(interrupt_sock))
+ interrupt_sock.setblocking(False)
+ else:
+ interrupt_sock = None
- _PdbClient(pid, sockfile, interrupt_script.name).cmdloop()
+ _PdbClient(pid, client_sock, interrupt_sock).cmdloop()
# Post-Mortem interface
@@ -3258,7 +3492,8 @@ def help():
_usage = """\
Debug the Python program given by pyfile. Alternatively,
an executable module or package to debug can be specified using
-the -m switch.
+the -m switch. You can also attach to a running Python process
+using the -p option with its PID.
Initial commands are read from .pdbrc files in your home directory
and in the current directory, if they exist. Commands supplied with
@@ -3272,10 +3507,13 @@ To let the script run up to a given line X in the debugged file, use
def main():
import argparse
- parser = argparse.ArgumentParser(usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]",
- description=_usage,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- allow_abbrev=False)
+ parser = argparse.ArgumentParser(
+ usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]",
+ description=_usage,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ allow_abbrev=False,
+ color=True,
+ )
# We need to maunally get the script from args, because the first positional
# arguments could be either the script we need to debug, or the argument
@@ -3338,7 +3576,7 @@ def main():
# modified by the script being debugged. It's a bad idea when it was
# changed by the user from the command line. There is a "restart" command
# which allows explicit specification of command line arguments.
- pdb = Pdb(mode='cli', backend='monitoring')
+ pdb = Pdb(mode='cli', backend='monitoring', colorize=True)
pdb.rcLines.extend(opts.commands)
while True:
try:
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 4fa3632d1a7..beaefae0479 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -1911,7 +1911,9 @@ def _main(args=None):
import argparse
import pprint
parser = argparse.ArgumentParser(
- description='display contents of the pickle files')
+ description='display contents of the pickle files',
+ color=True,
+ )
parser.add_argument(
'pickle_file',
nargs='+', help='the pickle file')
diff --git a/Lib/pickletools.py b/Lib/pickletools.py
index 53f25ea4e46..bcddfb722bd 100644
--- a/Lib/pickletools.py
+++ b/Lib/pickletools.py
@@ -2842,7 +2842,9 @@ __test__ = {'disassembler_test': _dis_test,
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
- description='disassemble one or more pickle files')
+ description='disassemble one or more pickle files',
+ color=True,
+ )
parser.add_argument(
'pickle_file',
nargs='+', help='the pickle file')
diff --git a/Lib/platform.py b/Lib/platform.py
index a62192589af..da15bb4717b 100644
--- a/Lib/platform.py
+++ b/Lib/platform.py
@@ -29,7 +29,7 @@
#
# History:
#
-# <see CVS and SVN checkin messages for history>
+# <see checkin messages for history>
#
# 1.0.9 - added invalidate_caches() function to invalidate cached values
# 1.0.8 - changed Windows support to read version from kernel32.dll
@@ -110,7 +110,7 @@ __copyright__ = """
"""
-__version__ = '1.0.9'
+__version__ = '1.1.0'
import collections
import os
@@ -528,53 +528,6 @@ def ios_ver(system="", release="", model="", is_simulator=False):
return IOSVersionInfo(system, release, model, is_simulator)
-def _java_getprop(name, default):
- """This private helper is deprecated in 3.13 and will be removed in 3.15"""
- from java.lang import System
- try:
- value = System.getProperty(name)
- if value is None:
- return default
- return value
- except AttributeError:
- return default
-
-def java_ver(release='', vendor='', vminfo=('', '', ''), osinfo=('', '', '')):
-
- """ Version interface for Jython.
-
- Returns a tuple (release, vendor, vminfo, osinfo) with vminfo being
- a tuple (vm_name, vm_release, vm_vendor) and osinfo being a
- tuple (os_name, os_version, os_arch).
-
- Values which cannot be determined are set to the defaults
- given as parameters (which all default to '').
-
- """
- import warnings
- warnings._deprecated('java_ver', remove=(3, 15))
- # Import the needed APIs
- try:
- import java.lang # noqa: F401
- except ImportError:
- return release, vendor, vminfo, osinfo
-
- vendor = _java_getprop('java.vendor', vendor)
- release = _java_getprop('java.version', release)
- vm_name, vm_release, vm_vendor = vminfo
- vm_name = _java_getprop('java.vm.name', vm_name)
- vm_vendor = _java_getprop('java.vm.vendor', vm_vendor)
- vm_release = _java_getprop('java.vm.version', vm_release)
- vminfo = vm_name, vm_release, vm_vendor
- os_name, os_version, os_arch = osinfo
- os_arch = _java_getprop('java.os.arch', os_arch)
- os_name = _java_getprop('java.os.name', os_name)
- os_version = _java_getprop('java.os.version', os_version)
- osinfo = os_name, os_version, os_arch
-
- return release, vendor, vminfo, osinfo
-
-
AndroidVer = collections.namedtuple(
"AndroidVer", "release api_level manufacturer model device is_emulator")
@@ -659,6 +612,9 @@ def system_alias(system, release, version):
### Various internal helpers
+# Table for cleaning up characters in filenames.
+_SIMPLE_SUBSTITUTIONS = str.maketrans(r' /\:;"()', r'_-------')
+
def _platform(*args):
""" Helper to format the platform string in a filename
@@ -668,28 +624,13 @@ def _platform(*args):
platform = '-'.join(x.strip() for x in filter(len, args))
# Cleanup some possible filename obstacles...
- platform = platform.replace(' ', '_')
- platform = platform.replace('/', '-')
- platform = platform.replace('\\', '-')
- platform = platform.replace(':', '-')
- platform = platform.replace(';', '-')
- platform = platform.replace('"', '-')
- platform = platform.replace('(', '-')
- platform = platform.replace(')', '-')
+ platform = platform.translate(_SIMPLE_SUBSTITUTIONS)
# No need to report 'unknown' information...
platform = platform.replace('unknown', '')
# Fold '--'s and remove trailing '-'
- while True:
- cleaned = platform.replace('--', '-')
- if cleaned == platform:
- break
- platform = cleaned
- while platform and platform[-1] == '-':
- platform = platform[:-1]
-
- return platform
+ return re.sub(r'-{2,}', '-', platform).rstrip('-')
def _node(default=''):
@@ -1034,13 +975,6 @@ def uname():
version = '16bit'
system = 'Windows'
- elif system[:4] == 'java':
- release, vendor, vminfo, osinfo = java_ver()
- system = 'Java'
- version = ', '.join(vminfo)
- if not version:
- version = vendor
-
# System specific extensions
if system == 'OpenVMS':
# OpenVMS seems to have release and version mixed up
@@ -1198,7 +1132,7 @@ def _sys_version(sys_version=None):
# CPython
cpython_sys_version_parser = re.compile(
r'([\w.+]+)\s*' # "version<space>"
- r'(?:experimental free-threading build\s+)?' # "free-threading-build<space>"
+ r'(?:free-threading build\s+)?' # "free-threading-build<space>"
r'\(#?([^,]+)' # "(#buildno"
r'(?:,\s*([\w ]*)' # ", builddate"
r'(?:,\s*([\w :]*))?)?\)\s*' # ", buildtime)<space>"
@@ -1370,15 +1304,6 @@ def platform(aliased=False, terse=False):
platform = _platform(system, release, machine, processor,
'with',
libcname+libcversion)
- elif system == 'Java':
- # Java platforms
- r, v, vminfo, (os_name, os_version, os_arch) = java_ver()
- if terse or not os_name:
- platform = _platform(system, release, version)
- else:
- platform = _platform(system, release, version,
- 'on',
- os_name, os_version, os_arch)
else:
# Generic handler
@@ -1464,9 +1389,41 @@ def invalidate_caches():
### Command line interface
-if __name__ == '__main__':
- # Default is to print the aliased verbose platform string
- terse = ('terse' in sys.argv or '--terse' in sys.argv)
- aliased = (not 'nonaliased' in sys.argv and not '--nonaliased' in sys.argv)
+def _parse_args(args: list[str] | None):
+ import argparse
+
+ parser = argparse.ArgumentParser(color=True)
+ parser.add_argument("args", nargs="*", choices=["nonaliased", "terse"])
+ parser.add_argument(
+ "--terse",
+ action="store_true",
+ help=(
+ "return only the absolute minimum information needed "
+ "to identify the platform"
+ ),
+ )
+ parser.add_argument(
+ "--nonaliased",
+ dest="aliased",
+ action="store_false",
+ help=(
+ "disable system/OS name aliasing. If aliasing is enabled, "
+ "some platforms report system names different from "
+ "their common names, e.g. SunOS is reported as Solaris"
+ ),
+ )
+
+ return parser.parse_args(args)
+
+
+def _main(args: list[str] | None = None):
+ args = _parse_args(args)
+
+ terse = args.terse or ("terse" in args.args)
+ aliased = args.aliased and ('nonaliased' not in args.args)
+
print(platform(aliased, terse))
- sys.exit(0)
+
+
+if __name__ == "__main__":
+ _main()
diff --git a/Lib/posixpath.py b/Lib/posixpath.py
index db72ded8826..d38f3bd5872 100644
--- a/Lib/posixpath.py
+++ b/Lib/posixpath.py
@@ -36,7 +36,7 @@ __all__ = ["normcase","isabs","join","splitdrive","splitroot","split","splitext"
"samefile","sameopenfile","samestat",
"curdir","pardir","sep","pathsep","defpath","altsep","extsep",
"devnull","realpath","supports_unicode_filenames","relpath",
- "commonpath", "isjunction","isdevdrive"]
+ "commonpath", "isjunction","isdevdrive","ALLOW_MISSING"]
def _get_sep(path):
@@ -402,10 +402,18 @@ symbolic links encountered in the path."""
curdir = '.'
pardir = '..'
getcwd = os.getcwd
- return _realpath(filename, strict, sep, curdir, pardir, getcwd)
+ if strict is ALLOW_MISSING:
+ ignored_error = FileNotFoundError
+ strict = True
+ elif strict:
+ ignored_error = ()
+ else:
+ ignored_error = OSError
+
+ lstat = os.lstat
+ readlink = os.readlink
+ maxlinks = None
-def _realpath(filename, strict=False, sep=sep, curdir=curdir, pardir=pardir,
- getcwd=os.getcwd, lstat=os.lstat, readlink=os.readlink, maxlinks=None):
# The stack of unresolved path parts. When popped, a special value of None
# indicates that a symlink target has been resolved, and that the original
# symlink path can be retrieved by popping again. The [::-1] slice is a
@@ -477,27 +485,28 @@ def _realpath(filename, strict=False, sep=sep, curdir=curdir, pardir=pardir,
path = newpath
continue
target = readlink(newpath)
- except OSError:
- if strict:
- raise
- path = newpath
+ except ignored_error:
+ pass
+ else:
+ # Resolve the symbolic link
+ if target.startswith(sep):
+ # Symlink target is absolute; reset resolved path.
+ path = sep
+ if maxlinks is None:
+ # Mark this symlink as seen but not fully resolved.
+ seen[newpath] = None
+ # Push the symlink path onto the stack, and signal its specialness
+ # by also pushing None. When these entries are popped, we'll
+ # record the fully-resolved symlink target in the 'seen' mapping.
+ rest.append(newpath)
+ rest.append(None)
+ # Push the unresolved symlink target parts onto the stack.
+ target_parts = target.split(sep)[::-1]
+ rest.extend(target_parts)
+ part_count += len(target_parts)
continue
- # Resolve the symbolic link
- if target.startswith(sep):
- # Symlink target is absolute; reset resolved path.
- path = sep
- if maxlinks is None:
- # Mark this symlink as seen but not fully resolved.
- seen[newpath] = None
- # Push the symlink path onto the stack, and signal its specialness
- # by also pushing None. When these entries are popped, we'll
- # record the fully-resolved symlink target in the 'seen' mapping.
- rest.append(newpath)
- rest.append(None)
- # Push the unresolved symlink target parts onto the stack.
- target_parts = target.split(sep)[::-1]
- rest.extend(target_parts)
- part_count += len(target_parts)
+ # An error occurred and was ignored.
+ path = newpath
return path
diff --git a/Lib/pprint.py b/Lib/pprint.py
index dc0953cec67..92a2c543ac2 100644
--- a/Lib/pprint.py
+++ b/Lib/pprint.py
@@ -248,6 +248,49 @@ class PrettyPrinter:
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
+ def _pprint_dict_view(self, object, stream, indent, allowance, context, level):
+ """Pretty print dict views (keys, values, items)."""
+ if isinstance(object, self._dict_items_view):
+ key = _safe_tuple
+ else:
+ key = _safe_key
+ write = stream.write
+ write(object.__class__.__name__ + '([')
+ if self._indent_per_level > 1:
+ write((self._indent_per_level - 1) * ' ')
+ length = len(object)
+ if length:
+ if self._sort_dicts:
+ entries = sorted(object, key=key)
+ else:
+ entries = object
+ self._format_items(entries, stream, indent, allowance + 1,
+ context, level)
+ write('])')
+
+ def _pprint_mapping_abc_view(self, object, stream, indent, allowance, context, level):
+ """Pretty print mapping views from collections.abc."""
+ write = stream.write
+ write(object.__class__.__name__ + '(')
+ # Dispatch formatting to the view's _mapping
+ self._format(object._mapping, stream, indent, allowance, context, level)
+ write(')')
+
+ _dict_keys_view = type({}.keys())
+ _dispatch[_dict_keys_view.__repr__] = _pprint_dict_view
+
+ _dict_values_view = type({}.values())
+ _dispatch[_dict_values_view.__repr__] = _pprint_dict_view
+
+ _dict_items_view = type({}.items())
+ _dispatch[_dict_items_view.__repr__] = _pprint_dict_view
+
+ _dispatch[_collections.abc.MappingView.__repr__] = _pprint_mapping_abc_view
+
+ _view_reprs = {cls.__repr__ for cls in
+ (_dict_keys_view, _dict_values_view, _dict_items_view,
+ _collections.abc.MappingView)}
+
def _pprint_list(self, object, stream, indent, allowance, context, level):
stream.write('[')
self._format_items(object, stream, indent, allowance + 1,
@@ -644,6 +687,42 @@ class PrettyPrinter:
del context[objid]
return format % ", ".join(components), readable, recursive
+ if issubclass(typ, _collections.abc.MappingView) and r in self._view_reprs:
+ objid = id(object)
+ if maxlevels and level >= maxlevels:
+ return "{...}", False, objid in context
+ if objid in context:
+ return _recursion(object), False, True
+ key = _safe_key
+ if issubclass(typ, (self._dict_items_view, _collections.abc.ItemsView)):
+ key = _safe_tuple
+ if hasattr(object, "_mapping"):
+ # Dispatch formatting to the view's _mapping
+ mapping_repr, readable, recursive = self.format(
+ object._mapping, context, maxlevels, level)
+ return (typ.__name__ + '(%s)' % mapping_repr), readable, recursive
+ elif hasattr(typ, "_mapping"):
+ # We have a view that somehow has lost its type's _mapping, raise
+ # an error by calling repr() instead of failing cryptically later
+ return repr(object), True, False
+ if self._sort_dicts:
+ object = sorted(object, key=key)
+ context[objid] = 1
+ readable = True
+ recursive = False
+ components = []
+ append = components.append
+ level += 1
+ for val in object:
+ vrepr, vreadable, vrecur = self.format(
+ val, context, maxlevels, level)
+ append(vrepr)
+ readable = readable and vreadable
+ if vrecur:
+ recursive = True
+ del context[objid]
+ return typ.__name__ + '([%s])' % ", ".join(components), readable, recursive
+
rep = repr(object)
return rep, (rep and not rep.startswith('<')), False
diff --git a/Lib/py_compile.py b/Lib/py_compile.py
index 388614e51b1..43d8ec90ffb 100644
--- a/Lib/py_compile.py
+++ b/Lib/py_compile.py
@@ -177,7 +177,7 @@ def main():
import argparse
description = 'A simple command-line interface for py_compile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument(
'-q', '--quiet',
action='store_true',
diff --git a/Lib/pydoc.py b/Lib/pydoc.py
index def76d076a2..d508fb70ea4 100644
--- a/Lib/pydoc.py
+++ b/Lib/pydoc.py
@@ -1812,7 +1812,6 @@ def writedocs(dir, pkgpath='', done=None):
def _introdoc():
- import textwrap
ver = '%d.%d' % sys.version_info[:2]
if os.environ.get('PYTHON_BASIC_REPL'):
pyrepl_keys = ''
@@ -2110,7 +2109,7 @@ has the same effect as typing a particular string at the help> prompt.
self.output.write(_introdoc())
def list(self, items, columns=4, width=80):
- items = list(sorted(items))
+ items = sorted(items)
colw = width // columns
rows = (len(items) + columns - 1) // columns
for row in range(rows):
@@ -2142,7 +2141,7 @@ to. Enter any symbol to get more help.
Here is a list of available topics. Enter any topic name to get more help.
''')
- self.list(self.topics.keys())
+ self.list(self.topics.keys(), columns=3)
def showtopic(self, topic, more_xrefs=''):
try:
@@ -2170,7 +2169,6 @@ module "pydoc_data.topics" could not be found.
if more_xrefs:
xrefs = (xrefs or '') + ' ' + more_xrefs
if xrefs:
- import textwrap
text = 'Related help topics: ' + ', '.join(xrefs.split()) + '\n'
wrapped_text = textwrap.wrap(text, 72)
doc += '\n%s\n' % '\n'.join(wrapped_text)
diff --git a/Lib/pydoc_data/topics.py b/Lib/pydoc_data/topics.py
index 3f3d52dcc6b..5f7e14a79d3 100644
--- a/Lib/pydoc_data/topics.py
+++ b/Lib/pydoc_data/topics.py
@@ -1,4 +1,4 @@
-# Autogenerated by Sphinx on Tue Apr 8 14:20:44 2025
+# Autogenerated by Sphinx on Tue May 6 18:33:44 2025
# as part of the release process.
topics = {
@@ -1784,9 +1784,8 @@ Additional information on exceptions can be found in section
Exceptions, and information on using the "raise" statement to generate
exceptions may be found in section The raise statement.
-Changed in version 3.14.0a6 (unreleased): Support for optionally
-dropping grouping parentheses when using multiple exception types. See
-**PEP 758**.
+Changed in version 3.14: Support for optionally dropping grouping
+parentheses when using multiple exception types. See **PEP 758**.
"except" clause
@@ -1976,9 +1975,9 @@ Changed in version 3.8: Prior to Python 3.8, a "continue" statement
was illegal in the "finally" clause due to a problem with the
implementation.
-Changed in version 3.14.0a6 (unreleased): The compiler emits a
-"SyntaxWarning" when a "return", "break" or "continue" appears in a
-"finally" block (see **PEP 765**).
+Changed in version 3.14: The compiler emits a "SyntaxWarning" when a
+"return", "break" or "continue" appears in a "finally" block (see
+**PEP 765**).
The "with" statement
@@ -3933,6 +3932,19 @@ pdb.set_trace(*, header=None, commands=None)
Added in version 3.14: The *commands* argument.
+awaitable pdb.set_trace_async(*, header=None, commands=None)
+
+ async version of "set_trace()". This function should be used inside
+ an async function with "await".
+
+ async def f():
+ await pdb.set_trace_async()
+
+ "await" statements are supported if the debugger is invoked by this
+ function.
+
+ Added in version 3.14.
+
pdb.post_mortem(t=None)
Enter post-mortem debugging of the given exception or traceback
@@ -3970,7 +3982,7 @@ The "run*" functions and "set_trace()" are aliases for instantiating
the "Pdb" class and calling the method of the same name. If you want
to access further features, you have to do this yourself:
-class pdb.Pdb(completekey='tab', stdin=None, stdout=None, skip=None, nosigint=False, readrc=True, mode=None, backend=None)
+class pdb.Pdb(completekey='tab', stdin=None, stdout=None, skip=None, nosigint=False, readrc=True, mode=None, backend=None, colorize=False)
"Pdb" is the debugger class.
@@ -4001,6 +4013,10 @@ class pdb.Pdb(completekey='tab', stdin=None, stdout=None, skip=None, nosigint=Fa
See "set_default_backend()". Otherwise the supported backends are
"'settrace'" and "'monitoring'".
+ The *colorize* argument, if set to "True", will enable colorized
+ output in the debugger, if color is supported. This will highlight
+ source code displayed in pdb.
+
Example call to enable tracing with *skip*:
import pdb; pdb.Pdb(skip=['django.*']).set_trace()
@@ -4018,6 +4034,8 @@ class pdb.Pdb(completekey='tab', stdin=None, stdout=None, skip=None, nosigint=Fa
Added in version 3.14: Added the *backend* argument.
+ Added in version 3.14: Added the *colorize* argument.
+
Changed in version 3.14: Inline breakpoints like "breakpoint()" or
"pdb.set_trace()" will always stop the program at calling frame,
ignoring the *skip* pattern (if any).
@@ -4496,7 +4514,7 @@ exceptions [excnumber]
When using "pdb.pm()" or "Pdb.post_mortem(...)" with a chained
exception instead of a traceback, it allows the user to move
between the chained exceptions using "exceptions" command to list
- exceptions, and "exception <number>" to switch to that exception.
+ exceptions, and "exceptions <number>" to switch to that exception.
Example:
@@ -6752,8 +6770,8 @@ object.__or__(self, other)
called. The "__divmod__()" method should be the equivalent to
using "__floordiv__()" and "__mod__()"; it should not be related to
"__truediv__()". Note that "__pow__()" should be defined to accept
- an optional third argument if the ternary version of the built-in
- "pow()" function is to be supported.
+ an optional third argument if the three-argument version of the
+ built-in "pow()" function is to be supported.
If one of those methods does not support the operation with the
supplied arguments, it should return "NotImplemented".
@@ -6785,8 +6803,13 @@ object.__ror__(self, other)
called if "type(x).__sub__(x, y)" returns "NotImplemented" or
"type(y)" is a subclass of "type(x)". [5]
- Note that ternary "pow()" will not try calling "__rpow__()" (the
- coercion rules would become too complicated).
+ Note that "__rpow__()" should be defined to accept an optional
+ third argument if the three-argument version of the built-in
+ "pow()" function is to be supported.
+
+ Changed in version 3.14.0a7 (unreleased): Three-argument "pow()"
+ now try calling "__rpow__()" if necessary. Previously it was only
+ called in two-argument "pow()" and the binary power operator.
Note:
@@ -8785,8 +8808,8 @@ object.__or__(self, other)
called. The "__divmod__()" method should be the equivalent to
using "__floordiv__()" and "__mod__()"; it should not be related to
"__truediv__()". Note that "__pow__()" should be defined to accept
- an optional third argument if the ternary version of the built-in
- "pow()" function is to be supported.
+ an optional third argument if the three-argument version of the
+ built-in "pow()" function is to be supported.
If one of those methods does not support the operation with the
supplied arguments, it should return "NotImplemented".
@@ -8818,8 +8841,13 @@ object.__ror__(self, other)
called if "type(x).__sub__(x, y)" returns "NotImplemented" or
"type(y)" is a subclass of "type(x)". [5]
- Note that ternary "pow()" will not try calling "__rpow__()" (the
- coercion rules would become too complicated).
+ Note that "__rpow__()" should be defined to accept an optional
+ third argument if the three-argument version of the built-in
+ "pow()" function is to be supported.
+
+ Changed in version 3.14.0a7 (unreleased): Three-argument "pow()"
+ now try calling "__rpow__()" if necessary. Previously it was only
+ called in two-argument "pow()" and the binary power operator.
Note:
@@ -10100,9 +10128,8 @@ Additional information on exceptions can be found in section
Exceptions, and information on using the "raise" statement to generate
exceptions may be found in section The raise statement.
-Changed in version 3.14.0a6 (unreleased): Support for optionally
-dropping grouping parentheses when using multiple exception types. See
-**PEP 758**.
+Changed in version 3.14: Support for optionally dropping grouping
+parentheses when using multiple exception types. See **PEP 758**.
"except" clause
@@ -10292,9 +10319,9 @@ Changed in version 3.8: Prior to Python 3.8, a "continue" statement
was illegal in the "finally" clause due to a problem with the
implementation.
-Changed in version 3.14.0a6 (unreleased): The compiler emits a
-"SyntaxWarning" when a "return", "break" or "continue" appears in a
-"finally" block (see **PEP 765**).
+Changed in version 3.14: The compiler emits a "SyntaxWarning" when a
+"return", "break" or "continue" appears in a "finally" block (see
+**PEP 765**).
''',
'types': r'''The standard type hierarchy
***************************
@@ -11141,25 +11168,15 @@ Special attributes
| | collected during class body execution. See also: |
| | "__annotations__ attributes". For best practices |
| | on working with "__annotations__", please see |
-| | "annotationlib". Caution: Accessing the |
-| | "__annotations__" attribute of a class object |
-| | directly may yield incorrect results in the |
-| | presence of metaclasses. In addition, the |
-| | attribute may not exist for some classes. Use |
-| | "annotationlib.get_annotations()" to retrieve |
-| | class annotations safely. Changed in version |
-| | 3.14: Annotations are now lazily evaluated. See |
-| | **PEP 649**. |
+| | "annotationlib". Where possible, use |
+| | "annotationlib.get_annotations()" instead of |
+| | accessing this attribute directly. Changed in |
+| | version 3.14: Annotations are now lazily |
+| | evaluated. See **PEP 649**. |
+----------------------------------------------------+----------------------------------------------------+
| type.__annotate__() | The *annotate function* for this class, or "None" |
| | if the class has no annotations. See also: |
-| | "__annotate__ attributes". Caution: Accessing |
-| | the "__annotate__" attribute of a class object |
-| | directly may yield incorrect results in the |
-| | presence of metaclasses. Use |
-| | "annotationlib.get_annotate_function()" to |
-| | retrieve the annotate function safely. Added in |
-| | version 3.14. |
+| | "__annotate__ attributes". Added in version 3.14. |
+----------------------------------------------------+----------------------------------------------------+
| type.__type_params__ | A "tuple" containing the type parameters of a |
| | generic class. Added in version 3.12. |
@@ -11355,16 +11372,16 @@ the "**keywords" syntax to accept arbitrary keyword arguments; bit
Flags for details on the semantics of each flags that might be
present.
-Future feature declarations ("from __future__ import division") also
-use bits in "co_flags" to indicate whether a code object was compiled
-with a particular feature enabled: bit "0x2000" is set if the function
-was compiled with future division enabled; bits "0x10" and "0x1000"
-were used in earlier versions of Python.
+Future feature declarations (for example, "from __future__ import
+division") also use bits in "co_flags" to indicate whether a code
+object was compiled with a particular feature enabled. See
+"compiler_flag".
Other bits in "co_flags" are reserved for internal use.
-If a code object represents a function and has a docstring, the first
-item in "co_consts" is the docstring of the function.
+If a code object represents a function and has a docstring, the
+"CO_HAS_DOCSTRING" bit is set in "co_flags" and the first item in
+"co_consts" is the docstring of the function.
Methods on code objects
diff --git a/Lib/random.py b/Lib/random.py
index 5e5d0c4c694..86d562f0b8a 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -1011,7 +1011,7 @@ if hasattr(_os, "fork"):
def _parse_args(arg_list: list[str] | None):
import argparse
parser = argparse.ArgumentParser(
- formatter_class=argparse.RawTextHelpFormatter)
+ formatter_class=argparse.RawTextHelpFormatter, color=True)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-c", "--choice", nargs="+",
diff --git a/Lib/re/__init__.py b/Lib/re/__init__.py
index 7e8abbf6ffe..af2808a77da 100644
--- a/Lib/re/__init__.py
+++ b/Lib/re/__init__.py
@@ -61,7 +61,7 @@ below. If the ordinary character is not on the list, then the
resulting RE will match the second character.
\number Matches the contents of the group of the same number.
\A Matches only at the start of the string.
- \Z Matches only at the end of the string.
+ \z Matches only at the end of the string.
\b Matches the empty string, but only at the start or end of a word.
\B Matches the empty string, but not at the start or end of a word.
\d Matches any decimal digit; equivalent to the set [0-9] in
diff --git a/Lib/re/_parser.py b/Lib/re/_parser.py
index 0990255b22c..35ab7ede2a7 100644
--- a/Lib/re/_parser.py
+++ b/Lib/re/_parser.py
@@ -49,7 +49,8 @@ CATEGORIES = {
r"\S": (IN, [(CATEGORY, CATEGORY_NOT_SPACE)]),
r"\w": (IN, [(CATEGORY, CATEGORY_WORD)]),
r"\W": (IN, [(CATEGORY, CATEGORY_NOT_WORD)]),
- r"\Z": (AT, AT_END_STRING), # end of string
+ r"\z": (AT, AT_END_STRING), # end of string
+ r"\Z": (AT, AT_END_STRING), # end of string (obsolete)
}
FLAGS = {
diff --git a/Lib/reprlib.py b/Lib/reprlib.py
index 19dbe3a07eb..ab18247682b 100644
--- a/Lib/reprlib.py
+++ b/Lib/reprlib.py
@@ -28,7 +28,7 @@ def recursive_repr(fillvalue='...'):
wrapper.__doc__ = getattr(user_function, '__doc__')
wrapper.__name__ = getattr(user_function, '__name__')
wrapper.__qualname__ = getattr(user_function, '__qualname__')
- wrapper.__annotations__ = getattr(user_function, '__annotations__', {})
+ wrapper.__annotate__ = getattr(user_function, '__annotate__', None)
wrapper.__type_params__ = getattr(user_function, '__type_params__', ())
wrapper.__wrapped__ = user_function
return wrapper
@@ -181,7 +181,22 @@ class Repr:
return s
def repr_int(self, x, level):
- s = builtins.repr(x) # XXX Hope this isn't too slow...
+ try:
+ s = builtins.repr(x)
+ except ValueError as exc:
+ assert 'sys.set_int_max_str_digits()' in str(exc)
+ # Those imports must be deferred due to Python's build system
+ # where the reprlib module is imported before the math module.
+ import math, sys
+ # Integers with more than sys.get_int_max_str_digits() digits
+ # are rendered differently as their repr() raises a ValueError.
+ # See https://github.com/python/cpython/issues/135487.
+ k = 1 + int(math.log10(abs(x)))
+ # Note: math.log10(abs(x)) may be overestimated or underestimated,
+ # but for simplicity, we do not compute the exact number of digits.
+ max_digits = sys.get_int_max_str_digits()
+ return (f'<{x.__class__.__name__} instance with roughly {k} '
+ f'digits (limit at {max_digits}) at 0x{id(x):x}>')
if len(s) > self.maxlong:
i = max(0, (self.maxlong-3)//2)
j = max(0, self.maxlong-3-i)
diff --git a/Lib/shelve.py b/Lib/shelve.py
index 50584716e9e..b53dc8b7a8e 100644
--- a/Lib/shelve.py
+++ b/Lib/shelve.py
@@ -171,6 +171,11 @@ class Shelf(collections.abc.MutableMapping):
if hasattr(self.dict, 'sync'):
self.dict.sync()
+ def reorganize(self):
+ self.sync()
+ if hasattr(self.dict, 'reorganize'):
+ self.dict.reorganize()
+
class BsdDbShelf(Shelf):
"""Shelf implementation using the "BSD" db interface.
diff --git a/Lib/shutil.py b/Lib/shutil.py
index 510ae8c6f22..ca0a2ea2f7f 100644
--- a/Lib/shutil.py
+++ b/Lib/shutil.py
@@ -32,6 +32,13 @@ try:
except ImportError:
_LZMA_SUPPORTED = False
+try:
+ from compression import zstd
+ del zstd
+ _ZSTD_SUPPORTED = True
+except ImportError:
+ _ZSTD_SUPPORTED = False
+
_WINDOWS = os.name == 'nt'
posix = nt = None
if os.name == 'posix':
@@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
tar_compression = 'bz2'
elif _LZMA_SUPPORTED and compress == 'xz':
tar_compression = 'xz'
+ elif _ZSTD_SUPPORTED and compress == 'zst':
+ tar_compression = 'zst'
else:
raise ValueError("bad value for 'compress', or compression format not "
"supported : {0}".format(compress))
@@ -1134,6 +1143,10 @@ if _LZMA_SUPPORTED:
_ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')],
"xz'ed tar-file")
+if _ZSTD_SUPPORTED:
+ _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')],
+ "zstd'ed tar-file")
+
def get_archive_formats():
"""Returns a list of supported formats for archiving and unarchiving.
@@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0,
'base_name' is the name of the file to create, minus any format-specific
extension; 'format' is the archive format: one of "zip", "tar", "gztar",
- "bztar", or "xztar". Or any other registered format.
+ "bztar", "zstdtar", or "xztar". Or any other registered format.
'root_dir' is a directory that will be the root directory of the
archive; ie. we typically chdir into 'root_dir' before creating the
@@ -1359,6 +1372,10 @@ if _LZMA_SUPPORTED:
_UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [],
"xz'ed tar-file")
+if _ZSTD_SUPPORTED:
+ _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [],
+ "zstd'ed tar-file")
+
def _find_unpack_format(filename):
for name, info in _UNPACK_FORMATS.items():
for extension in info[0]:
diff --git a/Lib/site.py b/Lib/site.py
index 5c38b1b17d5..f9327197159 100644
--- a/Lib/site.py
+++ b/Lib/site.py
@@ -75,6 +75,7 @@ import builtins
import _sitebuiltins
import _io as io
import stat
+import errno
# Prefixes for site-packages; add additional prefixes like /usr/local here
PREFIXES = [sys.prefix, sys.exec_prefix]
@@ -578,10 +579,15 @@ def register_readline():
def write_history():
try:
readline_module.write_history_file(history)
- except (FileNotFoundError, PermissionError):
+ except FileNotFoundError, PermissionError:
# home directory does not exist or is not writable
# https://bugs.python.org/issue19891
pass
+ except OSError:
+ if errno.EROFS:
+ pass # gh-128066: read-only file system
+ else:
+ raise
atexit.register(write_history)
diff --git a/Lib/socketserver.py b/Lib/socketserver.py
index 35b2723de3b..93b0a23be27 100644
--- a/Lib/socketserver.py
+++ b/Lib/socketserver.py
@@ -441,7 +441,7 @@ class TCPServer(BaseServer):
socket_type = socket.SOCK_STREAM
- request_queue_size = 5
+ request_queue_size = getattr(socket, "SOMAXCONN", 5)
allow_reuse_address = False
diff --git a/Lib/sqlite3/__main__.py b/Lib/sqlite3/__main__.py
index 79a6209468d..35344ecceff 100644
--- a/Lib/sqlite3/__main__.py
+++ b/Lib/sqlite3/__main__.py
@@ -10,9 +10,12 @@ import sys
from argparse import ArgumentParser
from code import InteractiveConsole
from textwrap import dedent
+from _colorize import get_theme, theme_no_color
+from ._completer import completer
-def execute(c, sql, suppress_errors=True):
+
+def execute(c, sql, suppress_errors=True, theme=theme_no_color):
"""Helper that wraps execution of SQL code.
This is used both by the REPL and by direct execution from the CLI.
@@ -25,11 +28,15 @@ def execute(c, sql, suppress_errors=True):
for row in c.execute(sql):
print(row)
except sqlite3.Error as e:
+ t = theme.traceback
tp = type(e).__name__
try:
- print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr)
+ tp += f" ({e.sqlite_errorname})"
except AttributeError:
- print(f"{tp}: {e}", file=sys.stderr)
+ pass
+ print(
+ f"{t.type}{tp}{t.reset}: {t.message}{e}{t.reset}", file=sys.stderr
+ )
if not suppress_errors:
sys.exit(1)
@@ -37,10 +44,11 @@ def execute(c, sql, suppress_errors=True):
class SqliteInteractiveConsole(InteractiveConsole):
"""A simple SQLite REPL."""
- def __init__(self, connection):
+ def __init__(self, connection, use_color=False):
super().__init__()
self._con = connection
self._cur = connection.cursor()
+ self._use_color = use_color
def runsource(self, source, filename="<input>", symbol="single"):
"""Override runsource, the core of the InteractiveConsole REPL.
@@ -48,23 +56,39 @@ class SqliteInteractiveConsole(InteractiveConsole):
Return True if more input is needed; buffering is done automatically.
Return False if input is a complete statement ready for execution.
"""
- match source:
- case ".version":
- print(f"{sqlite3.sqlite_version}")
- case ".help":
- print("Enter SQL code and press enter.")
- case ".quit":
- sys.exit(0)
- case _:
- if not sqlite3.complete_statement(source):
- return True
- execute(self._cur, source)
+ theme = get_theme(force_no_color=not self._use_color)
+
+ if not source or source.isspace():
+ return False
+ if source[0] == ".":
+ match source[1:].strip():
+ case "version":
+ print(sqlite3.sqlite_version)
+ case "help":
+ t = theme.syntax
+ print(f"Enter SQL code or one of the below commands, and press enter.\n\n"
+ f"{t.builtin}.version{t.reset} Print underlying SQLite library version\n"
+ f"{t.builtin}.help{t.reset} Print this help message\n"
+ f"{t.builtin}.quit{t.reset} Exit the CLI, equivalent to CTRL-D\n")
+ case "quit":
+ sys.exit(0)
+ case "":
+ pass
+ case _ as unknown:
+ t = theme.traceback
+ self.write(f'{t.type}Error{t.reset}: {t.message}unknown '
+ f'command: "{unknown}"{t.reset}\n')
+ else:
+ if not sqlite3.complete_statement(source):
+ return True
+ execute(self._cur, source, theme=theme)
return False
def main(*args):
parser = ArgumentParser(
description="Python sqlite3 CLI",
+ color=True,
)
parser.add_argument(
"filename", type=str, default=":memory:", nargs="?",
@@ -104,22 +128,23 @@ def main(*args):
Each command will be run using execute() on the cursor.
Type ".help" for more information; type ".quit" or {eofkey} to quit.
""").strip()
- sys.ps1 = "sqlite> "
- sys.ps2 = " ... "
+
+ theme = get_theme()
+ s = theme.syntax
+
+ sys.ps1 = f"{s.prompt}sqlite> {s.reset}"
+ sys.ps2 = f"{s.prompt} ... {s.reset}"
con = sqlite3.connect(args.filename, isolation_level=None)
try:
if args.sql:
# SQL statement provided on the command-line; execute it directly.
- execute(con, args.sql, suppress_errors=False)
+ execute(con, args.sql, suppress_errors=False, theme=theme)
else:
# No SQL provided; start the REPL.
- console = SqliteInteractiveConsole(con)
- try:
- import readline # noqa: F401
- except ImportError:
- pass
- console.interact(banner, exitmsg="")
+ with completer():
+ console = SqliteInteractiveConsole(con, use_color=True)
+ console.interact(banner, exitmsg="")
finally:
con.close()
diff --git a/Lib/sqlite3/_completer.py b/Lib/sqlite3/_completer.py
new file mode 100644
index 00000000000..f21ef69cad6
--- /dev/null
+++ b/Lib/sqlite3/_completer.py
@@ -0,0 +1,42 @@
+from contextlib import contextmanager
+
+try:
+ from _sqlite3 import SQLITE_KEYWORDS
+except ImportError:
+ SQLITE_KEYWORDS = ()
+
+_completion_matches = []
+
+
+def _complete(text, state):
+ global _completion_matches
+
+ if state == 0:
+ text_upper = text.upper()
+ _completion_matches = [c for c in SQLITE_KEYWORDS if c.startswith(text_upper)]
+ try:
+ return _completion_matches[state] + " "
+ except IndexError:
+ return None
+
+
+@contextmanager
+def completer():
+ try:
+ import readline
+ except ImportError:
+ yield
+ return
+
+ old_completer = readline.get_completer()
+ try:
+ readline.set_completer(_complete)
+ if readline.backend == "editline":
+ # libedit uses "^I" instead of "tab"
+ command_string = "bind ^I rl_complete"
+ else:
+ command_string = "tab: complete"
+ readline.parse_and_bind(command_string)
+ yield
+ finally:
+ readline.set_completer(old_completer)
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 05df4ad7f0f..7e3c4cbd6bb 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -116,7 +116,7 @@ except ImportError:
from _ssl import (
HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
- HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK, HAS_PHA
+ HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK, HAS_PSK_TLS13, HAS_PHA
)
from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
diff --git a/Lib/string/__init__.py b/Lib/string/__init__.py
index eab5067c9b1..b788d7136f1 100644
--- a/Lib/string/__init__.py
+++ b/Lib/string/__init__.py
@@ -191,16 +191,14 @@ class Template:
########################################################################
-# the Formatter class
-# see PEP 3101 for details and purpose of this class
-
-# The hard parts are reused from the C implementation. They're exposed as "_"
-# prefixed methods of str.
-
+# The Formatter class (PEP 3101).
+#
# The overall parser is implemented in _string.formatter_parser.
-# The field name parser is implemented in _string.formatter_field_name_split
+# The field name parser is implemented in _string.formatter_field_name_split.
class Formatter:
+ """See PEP 3101 for details and purpose of this class."""
+
def format(self, format_string, /, *args, **kwargs):
return self.vformat(format_string, args, kwargs)
@@ -264,22 +262,18 @@ class Formatter:
return ''.join(result), auto_arg_index
-
def get_value(self, key, args, kwargs):
if isinstance(key, int):
return args[key]
else:
return kwargs[key]
-
def check_unused_args(self, used_args, args, kwargs):
pass
-
def format_field(self, value, format_spec):
return format(value, format_spec)
-
def convert_field(self, value, conversion):
# do any conversion on the resulting object
if conversion is None:
@@ -292,28 +286,26 @@ class Formatter:
return ascii(value)
raise ValueError("Unknown conversion specifier {0!s}".format(conversion))
-
- # returns an iterable that contains tuples of the form:
- # (literal_text, field_name, format_spec, conversion)
- # literal_text can be zero length
- # field_name can be None, in which case there's no
- # object to format and output
- # if field_name is not None, it is looked up, formatted
- # with format_spec and conversion and then used
def parse(self, format_string):
+ """
+ Return an iterable that contains tuples of the form
+ (literal_text, field_name, format_spec, conversion).
+
+ *field_name* can be None, in which case there's no object
+ to format and output; otherwise, it is looked up and
+ formatted with *format_spec* and *conversion*.
+ """
return _string.formatter_parser(format_string)
-
- # given a field_name, find the object it references.
- # field_name: the field being looked up, e.g. "0.name"
- # or "lookup[3]"
- # used_args: a set of which args have been used
- # args, kwargs: as passed in to vformat
def get_field(self, field_name, args, kwargs):
- first, rest = _string.formatter_field_name_split(field_name)
+ """Find the object referenced by a given field name.
+ The field name *field_name* can be for instance "0.name"
+ or "lookup[3]". The *args* and *kwargs* arguments are
+ passed to get_value().
+ """
+ first, rest = _string.formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)
-
# loop through the rest of the field_name, doing
# getattr or getitem as needed
for is_attr, i in rest:
@@ -321,5 +313,4 @@ class Formatter:
obj = getattr(obj, i)
else:
obj = obj[i]
-
return obj, first
diff --git a/Lib/subprocess.py b/Lib/subprocess.py
index da5f5729e09..54c2eb515b6 100644
--- a/Lib/subprocess.py
+++ b/Lib/subprocess.py
@@ -1235,8 +1235,11 @@ class Popen:
finally:
self._communication_started = True
-
- sts = self.wait(timeout=self._remaining_time(endtime))
+ try:
+ sts = self.wait(timeout=self._remaining_time(endtime))
+ except TimeoutExpired as exc:
+ exc.timeout = timeout
+ raise
return (stdout, stderr)
@@ -2145,8 +2148,11 @@ class Popen:
selector.unregister(key.fileobj)
key.fileobj.close()
self._fileobj2output[key.fileobj].append(data)
-
- self.wait(timeout=self._remaining_time(endtime))
+ try:
+ self.wait(timeout=self._remaining_time(endtime))
+ except TimeoutExpired as exc:
+ exc.timeout = orig_timeout
+ raise
# All data exchanged. Translate lists into strings.
if stdout is not None:
diff --git a/Lib/sysconfig/__init__.py b/Lib/sysconfig/__init__.py
index dad715eb087..49e0986517c 100644
--- a/Lib/sysconfig/__init__.py
+++ b/Lib/sysconfig/__init__.py
@@ -219,18 +219,7 @@ if os.name == 'nt':
if "_PYTHON_PROJECT_BASE" in os.environ:
_PROJECT_BASE = _safe_realpath(os.environ["_PYTHON_PROJECT_BASE"])
-def is_python_build(check_home=None):
- if check_home is not None:
- import warnings
- warnings.warn(
- (
- 'The check_home argument of sysconfig.is_python_build is '
- 'deprecated and its value is ignored. '
- 'It will be removed in Python 3.15.'
- ),
- DeprecationWarning,
- stacklevel=2,
- )
+def is_python_build():
for fn in ("Setup", "Setup.local"):
if os.path.isfile(os.path.join(_PROJECT_BASE, "Modules", fn)):
return True
@@ -468,7 +457,7 @@ def get_config_h_filename():
"""Return the path of pyconfig.h."""
if _PYTHON_BUILD:
if os.name == "nt":
- inc_dir = os.path.dirname(sys._base_executable)
+ inc_dir = os.path.join(_PROJECT_BASE, 'PC')
else:
inc_dir = _PROJECT_BASE
else:
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
index 82c5f6704cb..068aa13ed70 100644
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -67,7 +67,7 @@ __all__ = ["TarFile", "TarInfo", "is_tarfile", "TarError", "ReadError",
"DEFAULT_FORMAT", "open","fully_trusted_filter", "data_filter",
"tar_filter", "FilterError", "AbsoluteLinkError",
"OutsideDestinationError", "SpecialFileError", "AbsolutePathError",
- "LinkOutsideDestinationError"]
+ "LinkOutsideDestinationError", "LinkFallbackError"]
#---------------------------------------------------------
@@ -399,7 +399,17 @@ class _Stream:
self.exception = lzma.LZMAError
else:
self.cmp = lzma.LZMACompressor(preset=preset)
-
+ elif comptype == "zst":
+ try:
+ from compression import zstd
+ except ImportError:
+ raise CompressionError("compression.zstd module is not available") from None
+ if mode == "r":
+ self.dbuf = b""
+ self.cmp = zstd.ZstdDecompressor()
+ self.exception = zstd.ZstdError
+ else:
+ self.cmp = zstd.ZstdCompressor()
elif comptype != "tar":
raise CompressionError("unknown compression type %r" % comptype)
@@ -591,6 +601,8 @@ class _StreamProxy(object):
return "bz2"
elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")):
return "xz"
+ elif self.buf.startswith(b"\x28\xb5\x2f\xfd"):
+ return "zst"
else:
return "tar"
@@ -754,10 +766,22 @@ class LinkOutsideDestinationError(FilterError):
super().__init__(f'{tarinfo.name!r} would link to {path!r}, '
+ 'which is outside the destination')
+class LinkFallbackError(FilterError):
+ def __init__(self, tarinfo, path):
+ self.tarinfo = tarinfo
+ self._path = path
+ super().__init__(f'link {tarinfo.name!r} would be extracted as a '
+ + f'copy of {path!r}, which was rejected')
+
+# Errors caused by filters -- both "fatal" and "non-fatal" -- that
+# we consider to be issues with the argument, rather than a bug in the
+# filter function
+_FILTER_ERRORS = (FilterError, OSError, ExtractError)
+
def _get_filtered_attrs(member, dest_path, for_data=True):
new_attrs = {}
name = member.name
- dest_path = os.path.realpath(dest_path)
+ dest_path = os.path.realpath(dest_path, strict=os.path.ALLOW_MISSING)
# Strip leading / (tar's directory separator) from filenames.
# Include os.sep (target OS directory separator) as well.
if name.startswith(('/', os.sep)):
@@ -767,7 +791,8 @@ def _get_filtered_attrs(member, dest_path, for_data=True):
# For example, 'C:/foo' on Windows.
raise AbsolutePathError(member)
# Ensure we stay in the destination
- target_path = os.path.realpath(os.path.join(dest_path, name))
+ target_path = os.path.realpath(os.path.join(dest_path, name),
+ strict=os.path.ALLOW_MISSING)
if os.path.commonpath([target_path, dest_path]) != dest_path:
raise OutsideDestinationError(member, target_path)
# Limit permissions (no high bits, and go-w)
@@ -805,6 +830,9 @@ def _get_filtered_attrs(member, dest_path, for_data=True):
if member.islnk() or member.issym():
if os.path.isabs(member.linkname):
raise AbsoluteLinkError(member)
+ normalized = os.path.normpath(member.linkname)
+ if normalized != member.linkname:
+ new_attrs['linkname'] = normalized
if member.issym():
target_path = os.path.join(dest_path,
os.path.dirname(name),
@@ -812,7 +840,8 @@ def _get_filtered_attrs(member, dest_path, for_data=True):
else:
target_path = os.path.join(dest_path,
member.linkname)
- target_path = os.path.realpath(target_path)
+ target_path = os.path.realpath(target_path,
+ strict=os.path.ALLOW_MISSING)
if os.path.commonpath([target_path, dest_path]) != dest_path:
raise LinkOutsideDestinationError(member, target_path)
return new_attrs
@@ -1817,11 +1846,13 @@ class TarFile(object):
'r:gz' open for reading with gzip compression
'r:bz2' open for reading with bzip2 compression
'r:xz' open for reading with lzma compression
+ 'r:zst' open for reading with zstd compression
'a' or 'a:' open for appending, creating the file if necessary
'w' or 'w:' open for writing without compression
'w:gz' open for writing with gzip compression
'w:bz2' open for writing with bzip2 compression
'w:xz' open for writing with lzma compression
+ 'w:zst' open for writing with zstd compression
'x' or 'x:' create a tarfile exclusively without compression, raise
an exception if the file is already created
@@ -1831,16 +1862,20 @@ class TarFile(object):
if the file is already created
'x:xz' create an lzma compressed tarfile, raise an exception
if the file is already created
+ 'x:zst' create a zstd compressed tarfile, raise an exception
+ if the file is already created
'r|*' open a stream of tar blocks with transparent compression
'r|' open an uncompressed stream of tar blocks for reading
'r|gz' open a gzip compressed stream of tar blocks
'r|bz2' open a bzip2 compressed stream of tar blocks
'r|xz' open an lzma compressed stream of tar blocks
+ 'r|zst' open a zstd compressed stream of tar blocks
'w|' open an uncompressed stream for writing
'w|gz' open a gzip compressed stream for writing
'w|bz2' open a bzip2 compressed stream for writing
'w|xz' open an lzma compressed stream for writing
+ 'w|zst' open a zstd compressed stream for writing
"""
if not name and not fileobj:
@@ -2006,12 +2041,48 @@ class TarFile(object):
t._extfileobj = False
return t
+ @classmethod
+ def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None,
+ zstd_dict=None, **kwargs):
+ """Open zstd compressed tar archive name for reading or writing.
+ Appending is not allowed.
+ """
+ if mode not in ("r", "w", "x"):
+ raise ValueError("mode must be 'r', 'w' or 'x'")
+
+ try:
+ from compression.zstd import ZstdFile, ZstdError
+ except ImportError:
+ raise CompressionError("compression.zstd module is not available") from None
+
+ fileobj = ZstdFile(
+ fileobj or name,
+ mode,
+ level=level,
+ options=options,
+ zstd_dict=zstd_dict
+ )
+
+ try:
+ t = cls.taropen(name, mode, fileobj, **kwargs)
+ except (ZstdError, EOFError) as e:
+ fileobj.close()
+ if mode == 'r':
+ raise ReadError("not a zstd file") from e
+ raise
+ except Exception:
+ fileobj.close()
+ raise
+ t._extfileobj = False
+ return t
+
# All *open() methods are registered here.
OPEN_METH = {
"tar": "taropen", # uncompressed tar
"gz": "gzopen", # gzip compressed tar
"bz2": "bz2open", # bzip2 compressed tar
- "xz": "xzopen" # lzma compressed tar
+ "xz": "xzopen", # lzma compressed tar
+ "zst": "zstopen", # zstd compressed tar
}
#--------------------------------------------------------------------------
@@ -2332,30 +2403,58 @@ class TarFile(object):
members = self
for member in members:
- tarinfo = self._get_extract_tarinfo(member, filter_function, path)
+ tarinfo, unfiltered = self._get_extract_tarinfo(
+ member, filter_function, path)
if tarinfo is None:
continue
if tarinfo.isdir():
# For directories, delay setting attributes until later,
# since permissions can interfere with extraction and
# extracting contents can reset mtime.
- directories.append(tarinfo)
+ directories.append(unfiltered)
self._extract_one(tarinfo, path, set_attrs=not tarinfo.isdir(),
- numeric_owner=numeric_owner)
+ numeric_owner=numeric_owner,
+ filter_function=filter_function)
# Reverse sort directories.
directories.sort(key=lambda a: a.name, reverse=True)
+
# Set correct owner, mtime and filemode on directories.
- for tarinfo in directories:
- dirpath = os.path.join(path, tarinfo.name)
+ for unfiltered in directories:
try:
+ # Need to re-apply any filter, to take the *current* filesystem
+ # state into account.
+ try:
+ tarinfo = filter_function(unfiltered, path)
+ except _FILTER_ERRORS as exc:
+ self._log_no_directory_fixup(unfiltered, repr(exc))
+ continue
+ if tarinfo is None:
+ self._log_no_directory_fixup(unfiltered,
+ 'excluded by filter')
+ continue
+ dirpath = os.path.join(path, tarinfo.name)
+ try:
+ lstat = os.lstat(dirpath)
+ except FileNotFoundError:
+ self._log_no_directory_fixup(tarinfo, 'missing')
+ continue
+ if not stat.S_ISDIR(lstat.st_mode):
+ # This is no longer a directory; presumably a later
+ # member overwrote the entry.
+ self._log_no_directory_fixup(tarinfo, 'not a directory')
+ continue
self.chown(tarinfo, dirpath, numeric_owner=numeric_owner)
self.utime(tarinfo, dirpath)
self.chmod(tarinfo, dirpath)
except ExtractError as e:
self._handle_nonfatal_error(e)
+ def _log_no_directory_fixup(self, member, reason):
+ self._dbg(2, "tarfile: Not fixing up directory %r (%s)" %
+ (member.name, reason))
+
def extract(self, member, path="", set_attrs=True, *, numeric_owner=False,
filter=None):
"""Extract a member from the archive to the current working directory,
@@ -2371,42 +2470,57 @@ class TarFile(object):
String names of common filters are accepted.
"""
filter_function = self._get_filter_function(filter)
- tarinfo = self._get_extract_tarinfo(member, filter_function, path)
+ tarinfo, unfiltered = self._get_extract_tarinfo(
+ member, filter_function, path)
if tarinfo is not None:
self._extract_one(tarinfo, path, set_attrs, numeric_owner)
def _get_extract_tarinfo(self, member, filter_function, path):
- """Get filtered TarInfo (or None) from member, which might be a str"""
+ """Get (filtered, unfiltered) TarInfos from *member*
+
+ *member* might be a string.
+
+ Return (None, None) if not found.
+ """
+
if isinstance(member, str):
- tarinfo = self.getmember(member)
+ unfiltered = self.getmember(member)
else:
- tarinfo = member
+ unfiltered = member
- unfiltered = tarinfo
+ filtered = None
try:
- tarinfo = filter_function(tarinfo, path)
- except (OSError, FilterError) as e:
+ filtered = filter_function(unfiltered, path)
+ except (OSError, UnicodeEncodeError, FilterError) as e:
self._handle_fatal_error(e)
except ExtractError as e:
self._handle_nonfatal_error(e)
- if tarinfo is None:
+ if filtered is None:
self._dbg(2, "tarfile: Excluded %r" % unfiltered.name)
- return None
+ return None, None
+
# Prepare the link target for makelink().
- if tarinfo.islnk():
- tarinfo = copy.copy(tarinfo)
- tarinfo._link_target = os.path.join(path, tarinfo.linkname)
- return tarinfo
+ if filtered.islnk():
+ filtered = copy.copy(filtered)
+ filtered._link_target = os.path.join(path, filtered.linkname)
+ return filtered, unfiltered
- def _extract_one(self, tarinfo, path, set_attrs, numeric_owner):
- """Extract from filtered tarinfo to disk"""
+ def _extract_one(self, tarinfo, path, set_attrs, numeric_owner,
+ filter_function=None):
+ """Extract from filtered tarinfo to disk.
+
+ filter_function is only used when extracting a *different*
+ member (e.g. as fallback to creating a symlink)
+ """
self._check("r")
try:
self._extract_member(tarinfo, os.path.join(path, tarinfo.name),
set_attrs=set_attrs,
- numeric_owner=numeric_owner)
- except OSError as e:
+ numeric_owner=numeric_owner,
+ filter_function=filter_function,
+ extraction_root=path)
+ except (OSError, UnicodeEncodeError) as e:
self._handle_fatal_error(e)
except ExtractError as e:
self._handle_nonfatal_error(e)
@@ -2463,9 +2577,13 @@ class TarFile(object):
return None
def _extract_member(self, tarinfo, targetpath, set_attrs=True,
- numeric_owner=False):
- """Extract the TarInfo object tarinfo to a physical
+ numeric_owner=False, *, filter_function=None,
+ extraction_root=None):
+ """Extract the filtered TarInfo object tarinfo to a physical
file called targetpath.
+
+ filter_function is only used when extracting a *different*
+ member (e.g. as fallback to creating a symlink)
"""
# Fetch the TarInfo object for the given name
# and build the destination pathname, replacing
@@ -2494,7 +2612,10 @@ class TarFile(object):
elif tarinfo.ischr() or tarinfo.isblk():
self.makedev(tarinfo, targetpath)
elif tarinfo.islnk() or tarinfo.issym():
- self.makelink(tarinfo, targetpath)
+ self.makelink_with_filter(
+ tarinfo, targetpath,
+ filter_function=filter_function,
+ extraction_root=extraction_root)
elif tarinfo.type not in SUPPORTED_TYPES:
self.makeunknown(tarinfo, targetpath)
else:
@@ -2577,10 +2698,18 @@ class TarFile(object):
os.makedev(tarinfo.devmajor, tarinfo.devminor))
def makelink(self, tarinfo, targetpath):
+ return self.makelink_with_filter(tarinfo, targetpath, None, None)
+
+ def makelink_with_filter(self, tarinfo, targetpath,
+ filter_function, extraction_root):
"""Make a (symbolic) link called targetpath. If it cannot be created
(platform limitation), we try to make a copy of the referenced file
instead of a link.
+
+ filter_function is only used when extracting a *different*
+ member (e.g. as fallback to creating a link).
"""
+ keyerror_to_extracterror = False
try:
# For systems that support symbolic and hard links.
if tarinfo.issym():
@@ -2588,18 +2717,38 @@ class TarFile(object):
# Avoid FileExistsError on following os.symlink.
os.unlink(targetpath)
os.symlink(tarinfo.linkname, targetpath)
+ return
else:
if os.path.exists(tarinfo._link_target):
os.link(tarinfo._link_target, targetpath)
- else:
- self._extract_member(self._find_link_target(tarinfo),
- targetpath)
+ return
except symlink_exception:
+ keyerror_to_extracterror = True
+
+ try:
+ unfiltered = self._find_link_target(tarinfo)
+ except KeyError:
+ if keyerror_to_extracterror:
+ raise ExtractError(
+ "unable to resolve link inside archive") from None
+ else:
+ raise
+
+ if filter_function is None:
+ filtered = unfiltered
+ else:
+ if extraction_root is None:
+ raise ExtractError(
+ "makelink_with_filter: if filter_function is not None, "
+ + "extraction_root must also not be None")
try:
- self._extract_member(self._find_link_target(tarinfo),
- targetpath)
- except KeyError:
- raise ExtractError("unable to resolve link inside archive") from None
+ filtered = filter_function(unfiltered, extraction_root)
+ except _FILTER_ERRORS as cause:
+ raise LinkFallbackError(tarinfo, unfiltered.name) from cause
+ if filtered is not None:
+ self._extract_member(filtered, targetpath,
+ filter_function=filter_function,
+ extraction_root=extraction_root)
def chown(self, tarinfo, targetpath, numeric_owner):
"""Set owner of targetpath according to tarinfo. If numeric_owner
@@ -2883,7 +3032,7 @@ def main():
import argparse
description = 'A simple command-line interface for tarfile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
help='Verbose output')
parser.add_argument('--filter', metavar='<filtername>',
@@ -2963,6 +3112,9 @@ def main():
'.tbz': 'bz2',
'.tbz2': 'bz2',
'.tb2': 'bz2',
+ # zstd
+ '.zst': 'zst',
+ '.tzst': 'zst',
}
tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w'
tar_files = args.create
diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index cadb0bed3cc..5e3ccab5f48 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -180,7 +180,7 @@ def _candidate_tempdir_list():
return dirlist
-def _get_default_tempdir():
+def _get_default_tempdir(dirlist=None):
"""Calculate the default directory to use for temporary files.
This routine should be called exactly once.
@@ -190,7 +190,8 @@ def _get_default_tempdir():
service, the name of the test file must be randomized."""
namer = _RandomNameSequence()
- dirlist = _candidate_tempdir_list()
+ if dirlist is None:
+ dirlist = _candidate_tempdir_list()
for dir in dirlist:
if dir != _os.curdir:
diff --git a/Lib/test/.ruff.toml b/Lib/test/.ruff.toml
index a1eac32a83a..f1a967203ce 100644
--- a/Lib/test/.ruff.toml
+++ b/Lib/test/.ruff.toml
@@ -9,8 +9,9 @@ extend-exclude = [
"encoded_modules/module_iso_8859_1.py",
"encoded_modules/module_koi8_r.py",
# SyntaxError because of t-strings
- "test_tstring.py",
+ "test_annotationlib.py",
"test_string/test_templatelib.py",
+ "test_tstring.py",
# New grammar constructions may not yet be recognized by Ruff,
# and tests re-use the same names as only the grammar is being checked.
"test_grammar.py",
@@ -18,5 +19,12 @@ extend-exclude = [
[lint]
select = [
+ "F401", # Unused import
"F811", # Redefinition of unused variable (useful for finding test methods with the same name)
]
+
+[lint.per-file-ignores]
+"*/**/__main__.py" = ["F401"] # Unused import
+"test_import/*.py" = ["F401"] # Unused import
+"test_importlib/*.py" = ["F401"] # Unused import
+"typinganndata/partialexecution/*.py" = ["F401"] # Unused import
diff --git a/Lib/test/_code_definitions.py b/Lib/test/_code_definitions.py
index 06cf6a10231..70c44da2ec6 100644
--- a/Lib/test/_code_definitions.py
+++ b/Lib/test/_code_definitions.py
@@ -1,4 +1,32 @@
+def simple_script():
+ assert True
+
+
+def complex_script():
+ obj = 'a string'
+ pickle = __import__('pickle')
+ def spam_minimal():
+ pass
+ spam_minimal()
+ data = pickle.dumps(obj)
+ res = pickle.loads(data)
+ assert res == obj, (res, obj)
+
+
+def script_with_globals():
+ obj1, obj2 = spam(42)
+ assert obj1 == 42
+ assert obj2 is None
+
+
+def script_with_explicit_empty_return():
+ return None
+
+
+def script_with_return():
+ return True
+
def spam_minimal():
# no arg defaults or kwarg defaults
@@ -12,6 +40,70 @@ def spam_minimal():
return
+def spam_with_builtins():
+ x = 42
+ values = (42,)
+ checks = tuple(callable(v) for v in values)
+ res = callable(values), tuple(values), list(values), checks
+ print(res)
+
+
+def spam_with_globals_and_builtins():
+ func1 = spam
+ func2 = spam_minimal
+ funcs = (func1, func2)
+ checks = tuple(callable(f) for f in funcs)
+ res = callable(funcs), tuple(funcs), list(funcs), checks
+ print(res)
+
+
+def spam_with_global_and_attr_same_name():
+ try:
+ spam_minimal.spam_minimal
+ except AttributeError:
+ pass
+
+
+def spam_full_args(a, b, /, c, d, *args, e, f, **kwargs):
+ return (a, b, c, d, e, f, args, kwargs)
+
+
+def spam_full_args_with_defaults(a=-1, b=-2, /, c=-3, d=-4, *args,
+ e=-5, f=-6, **kwargs):
+ return (a, b, c, d, e, f, args, kwargs)
+
+
+def spam_args_attrs_and_builtins(a, b, /, c, d, *args, e, f, **kwargs):
+ if args.__len__() > 2:
+ return None
+ return a, b, c, d, e, f, args, kwargs
+
+
+def spam_returns_arg(x):
+ return x
+
+
+def spam_raises():
+ raise Exception('spam!')
+
+
+def spam_with_inner_not_closure():
+ def eggs():
+ pass
+ eggs()
+
+
+def spam_with_inner_closure():
+ x = 42
+ def eggs():
+ print(x)
+ eggs()
+
+
+def spam_annotated(a: int, b: str, c: object) -> tuple:
+ return a, b, c
+
+
def spam_full(a, b, /, c, d:int=1, *args, e, f:object=None, **kwargs) -> tuple:
# arg defaults, kwarg defaults
# annotations
@@ -97,7 +189,23 @@ ham_C_closure, *_ = eggs_closure_C(2)
TOP_FUNCTIONS = [
# shallow
+ simple_script,
+ complex_script,
+ script_with_globals,
+ script_with_explicit_empty_return,
+ script_with_return,
spam_minimal,
+ spam_with_builtins,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+ spam_full_args,
+ spam_full_args_with_defaults,
+ spam_args_attrs_and_builtins,
+ spam_returns_arg,
+ spam_raises,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+ spam_annotated,
spam_full,
spam,
# outer func
@@ -127,6 +235,58 @@ FUNCTIONS = [
*NESTED_FUNCTIONS,
]
+STATELESS_FUNCTIONS = [
+ simple_script,
+ complex_script,
+ script_with_explicit_empty_return,
+ script_with_return,
+ spam,
+ spam_minimal,
+ spam_with_builtins,
+ spam_full_args,
+ spam_args_attrs_and_builtins,
+ spam_returns_arg,
+ spam_raises,
+ spam_annotated,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+ spam_N,
+ spam_C,
+ spam_NN,
+ spam_NC,
+ spam_CN,
+ spam_CC,
+ eggs_nested,
+ eggs_nested_N,
+ ham_nested,
+ ham_C_nested
+]
+STATELESS_CODE = [
+ *STATELESS_FUNCTIONS,
+ script_with_globals,
+ spam_full_args_with_defaults,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+ spam_full,
+]
+
+PURE_SCRIPT_FUNCTIONS = [
+ simple_script,
+ complex_script,
+ script_with_explicit_empty_return,
+ spam_minimal,
+ spam_with_builtins,
+ spam_raises,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+]
+SCRIPT_FUNCTIONS = [
+ *PURE_SCRIPT_FUNCTIONS,
+ script_with_globals,
+ spam_with_globals_and_builtins,
+ spam_with_global_and_attr_same_name,
+]
+
# generators
diff --git a/Lib/test/_test_embed_structseq.py b/Lib/test/_test_embed_structseq.py
index 154662efce9..4cac84d7a46 100644
--- a/Lib/test/_test_embed_structseq.py
+++ b/Lib/test/_test_embed_structseq.py
@@ -11,7 +11,7 @@ class TestStructSeq(unittest.TestCase):
# ob_refcnt
self.assertGreaterEqual(sys.getrefcount(obj_type), 1)
# tp_base
- self.assertTrue(issubclass(obj_type, tuple))
+ self.assertIsSubclass(obj_type, tuple)
# tp_bases
self.assertEqual(obj_type.__bases__, (tuple,))
# tp_dict
diff --git a/Lib/test/_test_gc_fast_cycles.py b/Lib/test/_test_gc_fast_cycles.py
new file mode 100644
index 00000000000..4e2c7d72a02
--- /dev/null
+++ b/Lib/test/_test_gc_fast_cycles.py
@@ -0,0 +1,48 @@
+# Run by test_gc.
+from test import support
+import _testinternalcapi
+import gc
+import unittest
+
+class IncrementalGCTests(unittest.TestCase):
+
+ # Use small increments to emulate longer running process in a shorter time
+ @support.gc_threshold(200, 10)
+ def test_incremental_gc_handles_fast_cycle_creation(self):
+
+ class LinkedList:
+
+ #Use slots to reduce number of implicit objects
+ __slots__ = "next", "prev", "surprise"
+
+ def __init__(self, next=None, prev=None):
+ self.next = next
+ if next is not None:
+ next.prev = self
+ self.prev = prev
+ if prev is not None:
+ prev.next = self
+
+ def make_ll(depth):
+ head = LinkedList()
+ for i in range(depth):
+ head = LinkedList(head, head.prev)
+ return head
+
+ head = make_ll(1000)
+
+ assert(gc.isenabled())
+ olds = []
+ initial_heap_size = _testinternalcapi.get_tracked_heap_size()
+ for i in range(20_000):
+ newhead = make_ll(20)
+ newhead.surprise = head
+ olds.append(newhead)
+ if len(olds) == 20:
+ new_objects = _testinternalcapi.get_tracked_heap_size() - initial_heap_size
+ self.assertLess(new_objects, 27_000, f"Heap growing. Reached limit after {i} iterations")
+ del olds[:]
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index 4dc9a31d22f..a1259ff1d63 100644
--- a/Lib/test/_test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -513,9 +513,14 @@ class _TestProcess(BaseTestCase):
time.sleep(100)
@classmethod
- def _sleep_no_int_handler(cls):
+ def _sleep_some_event(cls, event):
+ event.set()
+ time.sleep(100)
+
+ @classmethod
+ def _sleep_no_int_handler(cls, event):
signal.signal(signal.SIGINT, signal.SIG_DFL)
- cls._sleep_some()
+ cls._sleep_some_event(event)
@classmethod
def _test_sleep(cls, delay):
@@ -525,7 +530,10 @@ class _TestProcess(BaseTestCase):
if self.TYPE == 'threads':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
- p = self.Process(target=target or self._sleep_some)
+ event = self.Event()
+ if not target:
+ target = self._sleep_some_event
+ p = self.Process(target=target, args=(event,))
p.daemon = True
p.start()
@@ -543,8 +551,11 @@ class _TestProcess(BaseTestCase):
self.assertTimingAlmostEqual(join.elapsed, 0.0)
self.assertEqual(p.is_alive(), True)
- # XXX maybe terminating too soon causes the problems on Gentoo...
- time.sleep(1)
+ timeout = support.SHORT_TIMEOUT
+ if not event.wait(timeout):
+ p.terminate()
+ p.join()
+ self.fail(f"event not signaled in {timeout} seconds")
meth(p)
@@ -2463,6 +2474,12 @@ class _TestValue(BaseTestCase):
self.assertNotHasAttr(arr5, 'get_lock')
self.assertNotHasAttr(arr5, 'get_obj')
+ @unittest.skipIf(c_int is None, "requires _ctypes")
+ def test_invalid_typecode(self):
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.Value('x', None)
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.RawValue('x', None)
class _TestArray(BaseTestCase):
@@ -2543,6 +2560,12 @@ class _TestArray(BaseTestCase):
self.assertNotHasAttr(arr5, 'get_lock')
self.assertNotHasAttr(arr5, 'get_obj')
+ @unittest.skipIf(c_int is None, "requires _ctypes")
+ def test_invalid_typecode(self):
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.Array('x', [])
+ with self.assertRaisesRegex(TypeError, 'bad typecode'):
+ self.RawArray('x', [])
#
#
#
@@ -6778,6 +6801,35 @@ class _TestSpawnedSysPath(BaseTestCase):
self.assertEqual(child_sys_path[1:], sys.path[1:])
self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}")
+ def test_std_streams_flushed_after_preload(self):
+ # gh-135335: Check fork server flushes standard streams after
+ # preloading modules
+ if multiprocessing.get_start_method() != "forkserver":
+ self.skipTest("forkserver specific test")
+
+ # Create a test module in the temporary directory on the child's path
+ # TODO: This can all be simplified once gh-126631 is fixed and we can
+ # use __main__ instead of a module.
+ dirname = os.path.join(self._temp_dir, 'preloaded_module')
+ init_name = os.path.join(dirname, '__init__.py')
+ os.mkdir(dirname)
+ with open(init_name, "w") as f:
+ cmd = '''if 1:
+ import sys
+ print('stderr', end='', file=sys.stderr)
+ print('stdout', end='', file=sys.stdout)
+ '''
+ f.write(cmd)
+
+ name = os.path.join(os.path.dirname(__file__), 'mp_preload_flush.py')
+ env = {'PYTHONPATH': self._temp_dir}
+ _, out, err = test.support.script_helper.assert_python_ok(name, **env)
+
+ # Check stderr first, as it is more likely to be useful to see in the
+ # event of a failure.
+ self.assertEqual(err.decode().rstrip(), 'stderr')
+ self.assertEqual(out.decode().rstrip(), 'stdout')
+
class MiscTestCase(unittest.TestCase):
def test__all__(self):
@@ -6821,6 +6873,28 @@ class MiscTestCase(unittest.TestCase):
self.assertEqual("332833500", out.decode('utf-8').strip())
self.assertFalse(err, msg=err.decode('utf-8'))
+ def test_forked_thread_not_started(self):
+ # gh-134381: Ensure that a thread that has not been started yet in
+ # the parent process can be started within a forked child process.
+
+ if multiprocessing.get_start_method() != "fork":
+ self.skipTest("fork specific test")
+
+ q = multiprocessing.Queue()
+ t = threading.Thread(target=lambda: q.put("done"), daemon=True)
+
+ def child():
+ t.start()
+ t.join()
+
+ p = multiprocessing.Process(target=child)
+ p.start()
+ p.join(support.SHORT_TIMEOUT)
+
+ self.assertEqual(p.exitcode, 0)
+ self.assertEqual(q.get_nowait(), "done")
+ close_queue(q)
+
#
# Mixins
diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py
index 08b638e4b8d..6884ac0dbe6 100644
--- a/Lib/test/audit-tests.py
+++ b/Lib/test/audit-tests.py
@@ -643,6 +643,34 @@ def test_assert_unicode():
else:
raise RuntimeError("Expected sys.audit(9) to fail.")
+def test_sys_remote_exec():
+ import tempfile
+
+ pid = os.getpid()
+ event_pid = -1
+ event_script_path = ""
+ remote_event_script_path = ""
+ def hook(event, args):
+ if event not in ["sys.remote_exec", "cpython.remote_debugger_script"]:
+ return
+ print(event, args)
+ match event:
+ case "sys.remote_exec":
+ nonlocal event_pid, event_script_path
+ event_pid = args[0]
+ event_script_path = args[1]
+ case "cpython.remote_debugger_script":
+ nonlocal remote_event_script_path
+ remote_event_script_path = args[0]
+
+ sys.addaudithook(hook)
+ with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp_file:
+ tmp_file.write("a = 1+1\n")
+ tmp_file.flush()
+ sys.remote_exec(pid, tmp_file.name)
+ assertEqual(event_pid, pid)
+ assertEqual(event_script_path, tmp_file.name)
+ assertEqual(remote_event_script_path, tmp_file.name)
if __name__ == "__main__":
from test.support import suppress_msvcrt_asserts
diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index 55844ec35a9..93b3382b9c6 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -183,7 +183,7 @@ class TestTZInfo(unittest.TestCase):
def __init__(self, offset, name):
self.__offset = offset
self.__name = name
- self.assertTrue(issubclass(NotEnough, tzinfo))
+ self.assertIsSubclass(NotEnough, tzinfo)
ne = NotEnough(3, "NotByALongShot")
self.assertIsInstance(ne, tzinfo)
@@ -232,7 +232,7 @@ class TestTZInfo(unittest.TestCase):
self.assertIs(type(derived), otype)
self.assertEqual(derived.utcoffset(None), offset)
self.assertEqual(derived.tzname(None), oname)
- self.assertFalse(hasattr(derived, 'spam'))
+ self.assertNotHasAttr(derived, 'spam')
def test_issue23600(self):
DSTDIFF = DSTOFFSET = timedelta(hours=1)
@@ -773,6 +773,9 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
microseconds=999999)),
"999999999 days, 23:59:59.999999")
+ # test the Doc/library/datetime.rst recipe
+ eq(f'-({-td(hours=-1)!s})', "-(1:00:00)")
+
def test_repr(self):
name = 'datetime.' + self.theclass.__name__
self.assertEqual(repr(self.theclass(1)),
@@ -810,7 +813,7 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
# Verify td -> string -> td identity.
s = repr(td)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
td2 = eval(s)
self.assertEqual(td, td2)
@@ -1228,7 +1231,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
self.theclass.today()):
# Verify dt -> string -> date identity.
s = repr(dt)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
dt2 = eval(s)
self.assertEqual(dt, dt2)
@@ -2215,7 +2218,7 @@ class TestDateTime(TestDate):
self.theclass.now()):
# Verify dt -> string -> datetime identity.
s = repr(dt)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
dt2 = eval(s)
self.assertEqual(dt, dt2)
@@ -2969,6 +2972,17 @@ class TestDateTime(TestDate):
with self._assertNotWarns(DeprecationWarning):
self.theclass.strptime('02-29,2024', '%m-%d,%Y')
+ def test_strptime_z_empty(self):
+ for directive in ('z',):
+ string = '2025-04-25 11:42:47'
+ format = f'%Y-%m-%d %H:%M:%S%{directive}'
+ target = self.theclass(2025, 4, 25, 11, 42, 47)
+ with self.subTest(string=string,
+ format=format,
+ target=target):
+ result = self.theclass.strptime(string, format)
+ self.assertEqual(result, target)
+
def test_more_timetuple(self):
# This tests fields beyond those tested by the TestDate.test_timetuple.
t = self.theclass(2004, 12, 31, 6, 22, 33)
@@ -3568,6 +3582,10 @@ class TestDateTime(TestDate):
'2009-04-19T12:30:45.400 +02:30', # Space between ms and timezone (gh-130959)
'2009-04-19T12:30:45.400 ', # Trailing space (gh-130959)
'2009-04-19T12:30:45. 400', # Space before fraction (gh-130959)
+ '2009-04-19T12:30:45+00:90:00', # Time zone field out from range
+ '2009-04-19T12:30:45+00:00:90', # Time zone field out from range
+ '2009-04-19T12:30:45-00:90:00', # Time zone field out from range
+ '2009-04-19T12:30:45-00:00:90', # Time zone field out from range
]
for bad_str in bad_strs:
@@ -3669,7 +3687,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
# Verify t -> string -> time identity.
s = repr(t)
- self.assertTrue(s.startswith('datetime.'))
+ self.assertStartsWith(s, 'datetime.')
s = s[9:]
t2 = eval(s)
self.assertEqual(t, t2)
@@ -4792,6 +4810,11 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
'12:30:45.400 +02:30', # Space between ms and timezone (gh-130959)
'12:30:45.400 ', # Trailing space (gh-130959)
'12:30:45. 400', # Space before fraction (gh-130959)
+ '24:00:00.000001', # Has non-zero microseconds on 24:00
+ '24:00:01.000000', # Has non-zero seconds on 24:00
+ '24:01:00.000000', # Has non-zero minutes on 24:00
+ '12:30:45+00:90:00', # Time zone field out from range
+ '12:30:45+00:00:90', # Time zone field out from range
]
for bad_str in bad_strs:
diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py
index 713cbedb299..a2d01b157ac 100644
--- a/Lib/test/libregrtest/main.py
+++ b/Lib/test/libregrtest/main.py
@@ -190,6 +190,12 @@ class Regrtest:
strip_py_suffix(tests)
+ exclude_tests = set()
+ if self.exclude:
+ for arg in self.cmdline_args:
+ exclude_tests.add(arg)
+ self.cmdline_args = []
+
if self.pgo:
# add default PGO tests if no tests are specified
setup_pgo_tests(self.cmdline_args, self.pgo_extended)
@@ -200,17 +206,15 @@ class Regrtest:
if self.tsan_parallel:
setup_tsan_parallel_tests(self.cmdline_args)
- exclude_tests = set()
- if self.exclude:
- for arg in self.cmdline_args:
- exclude_tests.add(arg)
- self.cmdline_args = []
-
alltests = findtests(testdir=self.test_dir,
exclude=exclude_tests)
if not self.fromfile:
selected = tests or self.cmdline_args
+ if exclude_tests:
+ # Support "--pgo/--tsan -x test_xxx" command
+ selected = [name for name in selected
+ if name not in exclude_tests]
if selected:
selected = split_test_packages(selected)
else:
@@ -543,8 +547,6 @@ class Regrtest:
self.first_runtests = runtests
self.logger.set_tests(runtests)
- setup_process()
-
if (runtests.hunt_refleak is not None) and (not self.num_workers):
# gh-109739: WindowsLoadTracker thread interferes with refleak check
use_load_tracker = False
@@ -721,10 +723,7 @@ class Regrtest:
self._execute_python(cmd, environ)
def _init(self):
- # Set sys.stdout encoder error handler to backslashreplace,
- # similar to sys.stderr error handler, to avoid UnicodeEncodeError
- # when printing a traceback or any other non-encodable character.
- sys.stdout.reconfigure(errors="backslashreplace")
+ setup_process()
if self.junit_filename and not os.path.isabs(self.junit_filename):
self.junit_filename = os.path.abspath(self.junit_filename)
diff --git a/Lib/test/libregrtest/setup.py b/Lib/test/libregrtest/setup.py
index c0346aa934d..9bfc414cd61 100644
--- a/Lib/test/libregrtest/setup.py
+++ b/Lib/test/libregrtest/setup.py
@@ -1,5 +1,6 @@
import faulthandler
import gc
+import io
import os
import random
import signal
@@ -40,7 +41,7 @@ def setup_process() -> None:
faulthandler.enable(all_threads=True, file=stderr_fd)
# Display the Python traceback on SIGALRM or SIGUSR1 signal
- signals = []
+ signals: list[signal.Signals] = []
if hasattr(signal, 'SIGALRM'):
signals.append(signal.SIGALRM)
if hasattr(signal, 'SIGUSR1'):
@@ -52,6 +53,14 @@ def setup_process() -> None:
support.record_original_stdout(sys.stdout)
+ # Set sys.stdout encoder error handler to backslashreplace,
+ # similar to sys.stderr error handler, to avoid UnicodeEncodeError
+ # when printing a traceback or any other non-encodable character.
+ #
+ # Use an assertion to fix mypy error.
+ assert isinstance(sys.stdout, io.TextIOWrapper)
+ sys.stdout.reconfigure(errors="backslashreplace")
+
# Some times __path__ and __file__ are not absolute (e.g. while running from
# Lib/) and, if we change the CWD to run the tests in a temporary dir, some
# imports might fail. This affects only the modules imported before os.chdir().
diff --git a/Lib/test/libregrtest/single.py b/Lib/test/libregrtest/single.py
index 57d7b649d2e..958a915626a 100644
--- a/Lib/test/libregrtest/single.py
+++ b/Lib/test/libregrtest/single.py
@@ -283,7 +283,7 @@ def _runtest(result: TestResult, runtests: RunTests) -> None:
try:
setup_tests(runtests)
- if output_on_failure:
+ if output_on_failure or runtests.pgo:
support.verbose = True
stream = io.StringIO()
diff --git a/Lib/test/libregrtest/tsan.py b/Lib/test/libregrtest/tsan.py
index d984a735bdf..3545c5f999f 100644
--- a/Lib/test/libregrtest/tsan.py
+++ b/Lib/test/libregrtest/tsan.py
@@ -8,7 +8,7 @@ TSAN_TESTS = [
'test_capi.test_pyatomic',
'test_code',
'test_ctypes',
- # 'test_concurrent_futures', # gh-130605: too many data races
+ 'test_concurrent_futures',
'test_enum',
'test_functools',
'test_httpservers',
diff --git a/Lib/test/libregrtest/utils.py b/Lib/test/libregrtest/utils.py
index c4a1506c9a7..72b8ea89e62 100644
--- a/Lib/test/libregrtest/utils.py
+++ b/Lib/test/libregrtest/utils.py
@@ -31,7 +31,7 @@ WORKER_WORK_DIR_PREFIX = WORK_DIR_PREFIX + 'worker_'
EXIT_TIMEOUT = 120.0
-ALL_RESOURCES = ('audio', 'curses', 'largefile', 'network',
+ALL_RESOURCES = ('audio', 'console', 'curses', 'largefile', 'network',
'decimal', 'cpu', 'subprocess', 'urlfetch', 'gui', 'walltime')
# Other resources excluded from --use=all:
@@ -335,43 +335,11 @@ def get_build_info():
build.append('with_assert')
# --enable-experimental-jit
- tier2 = re.search('-D_Py_TIER2=([0-9]+)', cflags)
- if tier2:
- tier2 = int(tier2.group(1))
-
- if not sys.flags.ignore_environment:
- PYTHON_JIT = os.environ.get('PYTHON_JIT', None)
- if PYTHON_JIT:
- PYTHON_JIT = (PYTHON_JIT != '0')
- else:
- PYTHON_JIT = None
-
- if tier2 == 1: # =yes
- if PYTHON_JIT == False:
- jit = 'JIT=off'
- else:
- jit = 'JIT'
- elif tier2 == 3: # =yes-off
- if PYTHON_JIT:
- jit = 'JIT'
+ if sys._jit.is_available():
+ if sys._jit.is_enabled():
+ build.append("JIT")
else:
- jit = 'JIT=off'
- elif tier2 == 4: # =interpreter
- if PYTHON_JIT == False:
- jit = 'JIT-interpreter=off'
- else:
- jit = 'JIT-interpreter'
- elif tier2 == 6: # =interpreter-off (Secret option!)
- if PYTHON_JIT:
- jit = 'JIT-interpreter'
- else:
- jit = 'JIT-interpreter=off'
- elif '-D_Py_JIT' in cflags:
- jit = 'JIT'
- else:
- jit = None
- if jit:
- build.append(jit)
+ build.append("JIT (disabled)")
# --enable-framework=name
framework = sysconfig.get_config_var('PYTHONFRAMEWORK')
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index 009e04e9c0b..691029a1a54 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -124,6 +124,11 @@ class BaseLockTests(BaseTestCase):
lock = self.locktype()
del lock
+ def test_constructor_noargs(self):
+ self.assertRaises(TypeError, self.locktype, 1)
+ self.assertRaises(TypeError, self.locktype, x=1)
+ self.assertRaises(TypeError, self.locktype, 1, x=2)
+
def test_repr(self):
lock = self.locktype()
self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
@@ -332,6 +337,26 @@ class RLockTests(BaseLockTests):
"""
Tests for recursive locks.
"""
+ def test_repr_count(self):
+ # see gh-134322: check that count values are correct:
+ # when a rlock is just created,
+ # in a second thread when rlock is acquired in the main thread.
+ lock = self.locktype()
+ self.assertIn("count=0", repr(lock))
+ self.assertIn("<unlocked", repr(lock))
+ lock.acquire()
+ lock.acquire()
+ self.assertIn("count=2", repr(lock))
+ self.assertIn("<locked", repr(lock))
+
+ result = []
+ def call_repr():
+ result.append(repr(lock))
+ with Bunch(call_repr, 1):
+ pass
+ self.assertIn("count=2", result[0])
+ self.assertIn("<locked", result[0])
+
def test_reacquire(self):
lock = self.locktype()
lock.acquire()
@@ -365,6 +390,24 @@ class RLockTests(BaseLockTests):
lock.release()
self.assertFalse(lock.locked())
+ def test_locked_with_2threads(self):
+ # see gh-134323: check that a rlock which
+ # is acquired in a different thread,
+ # is still locked in the main thread.
+ result = []
+ rlock = self.locktype()
+ self.assertFalse(rlock.locked())
+ def acquire():
+ result.append(rlock.locked())
+ rlock.acquire()
+ result.append(rlock.locked())
+
+ with Bunch(acquire, 1):
+ pass
+ self.assertTrue(rlock.locked())
+ self.assertFalse(result[0])
+ self.assertTrue(result[1])
+
def test_release_save_unacquired(self):
# Cannot _release_save an unacquired lock
lock = self.locktype()
diff --git a/Lib/test/mapping_tests.py b/Lib/test/mapping_tests.py
index 9d38da5a86e..20306e1526d 100644
--- a/Lib/test/mapping_tests.py
+++ b/Lib/test/mapping_tests.py
@@ -70,8 +70,8 @@ class BasicTestMappingProtocol(unittest.TestCase):
if not d: self.fail("Full mapping must compare to True")
# keys(), items(), iterkeys() ...
def check_iterandlist(iter, lst, ref):
- self.assertTrue(hasattr(iter, '__next__'))
- self.assertTrue(hasattr(iter, '__iter__'))
+ self.assertHasAttr(iter, '__next__')
+ self.assertHasAttr(iter, '__iter__')
x = list(iter)
self.assertTrue(set(x)==set(lst)==set(ref))
check_iterandlist(iter(d.keys()), list(d.keys()),
diff --git a/Lib/test/mp_preload_flush.py b/Lib/test/mp_preload_flush.py
new file mode 100644
index 00000000000..3501554d366
--- /dev/null
+++ b/Lib/test/mp_preload_flush.py
@@ -0,0 +1,15 @@
+import multiprocessing
+import sys
+
+modname = 'preloaded_module'
+if __name__ == '__main__':
+ if modname in sys.modules:
+ raise AssertionError(f'{modname!r} is not in sys.modules')
+ multiprocessing.set_start_method('forkserver')
+ multiprocessing.set_forkserver_preload([modname])
+ for _ in range(2):
+ p = multiprocessing.Process()
+ p.start()
+ p.join()
+elif modname not in sys.modules:
+ raise AssertionError(f'{modname!r} is not in sys.modules')
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index bdc7ef62943..9a3a26a8400 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1100,6 +1100,11 @@ class AbstractUnpickleTests:
self.check_unpickling_error((pickle.UnpicklingError, OverflowError),
dumped)
+ def test_large_binstring(self):
+ errmsg = 'BINSTRING pickle has negative byte count'
+ with self.assertRaisesRegex(pickle.UnpicklingError, errmsg):
+ self.loads(b'T\0\0\0\x80')
+
def test_get(self):
pickled = b'((lp100000\ng100000\nt.'
unpickled = self.loads(pickled)
@@ -2272,7 +2277,11 @@ class AbstractPicklingErrorTests:
def test_nested_lookup_error(self):
# Nested name does not exist
- obj = REX('AbstractPickleTests.spam')
+ global TestGlobal
+ class TestGlobal:
+ class A:
+ pass
+ obj = REX('TestGlobal.A.B.C')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
@@ -2280,9 +2289,9 @@ class AbstractPicklingErrorTests:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
f"Can't pickle {obj!r}: "
- f"it's not found as {__name__}.AbstractPickleTests.spam")
+ f"it's not found as {__name__}.TestGlobal.A.B.C")
self.assertEqual(str(cm.exception.__context__),
- "type object 'AbstractPickleTests' has no attribute 'spam'")
+ "type object 'A' has no attribute 'B'")
obj.__module__ = None
for proto in protocols:
@@ -2290,21 +2299,25 @@ class AbstractPicklingErrorTests:
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests.spam")
+ f"Can't pickle {obj!r}: "
+ f"it's not found as __main__.TestGlobal.A.B.C")
self.assertEqual(str(cm.exception.__context__),
- "module '__main__' has no attribute 'AbstractPickleTests'")
+ "module '__main__' has no attribute 'TestGlobal'")
def test_wrong_object_lookup_error(self):
# Name is bound to different object
- obj = REX('AbstractPickleTests')
+ global TestGlobal
+ class TestGlobal:
+ pass
+ obj = REX('TestGlobal')
obj.__module__ = __name__
- AbstractPickleTests.ham = []
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not the same object as {__name__}.AbstractPickleTests")
+ f"Can't pickle {obj!r}: "
+ f"it's not the same object as {__name__}.TestGlobal")
self.assertIsNone(cm.exception.__context__)
obj.__module__ = None
@@ -2313,9 +2326,10 @@ class AbstractPicklingErrorTests:
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(obj, proto)
self.assertEqual(str(cm.exception),
- f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests")
+ f"Can't pickle {obj!r}: "
+ f"it's not found as __main__.TestGlobal")
self.assertEqual(str(cm.exception.__context__),
- "module '__main__' has no attribute 'AbstractPickleTests'")
+ "module '__main__' has no attribute 'TestGlobal'")
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
@@ -3059,7 +3073,7 @@ class AbstractPickleTests:
pickled = self.dumps(None, proto)
if proto >= 2:
proto_header = pickle.PROTO + bytes([proto])
- self.assertTrue(pickled.startswith(proto_header))
+ self.assertStartsWith(pickled, proto_header)
else:
self.assertEqual(count_opcode(pickle.PROTO, pickled), 0)
@@ -4998,7 +5012,7 @@ class AbstractDispatchTableTests:
p = self.pickler_class(f, 0)
with self.assertRaises(AttributeError):
p.dispatch_table
- self.assertFalse(hasattr(p, 'dispatch_table'))
+ self.assertNotHasAttr(p, 'dispatch_table')
def test_class_dispatch_table(self):
# A dispatch_table attribute can be specified class-wide
diff --git a/Lib/test/pythoninfo.py b/Lib/test/pythoninfo.py
index 682815c3fdd..80a262c18a5 100644
--- a/Lib/test/pythoninfo.py
+++ b/Lib/test/pythoninfo.py
@@ -658,6 +658,16 @@ def collect_zlib(info_add):
copy_attributes(info_add, zlib, 'zlib.%s', attributes)
+def collect_zstd(info_add):
+ try:
+ import _zstd
+ except ImportError:
+ return
+
+ attributes = ('zstd_version',)
+ copy_attributes(info_add, _zstd, 'zstd.%s', attributes)
+
+
def collect_expat(info_add):
try:
from xml.parsers import expat
@@ -910,10 +920,17 @@ def collect_windows(info_add):
try:
import _winapi
- dll_path = _winapi.GetModuleFileName(sys.dllhandle)
- info_add('windows.dll_path', dll_path)
- except (ImportError, AttributeError):
+ except ImportError:
pass
+ else:
+ try:
+ dll_path = _winapi.GetModuleFileName(sys.dllhandle)
+ info_add('windows.dll_path', dll_path)
+ except AttributeError:
+ pass
+
+ call_func(info_add, 'windows.ansi_code_page', _winapi, 'GetACP')
+ call_func(info_add, 'windows.oem_code_page', _winapi, 'GetOEMCP')
# windows.version_caption: "wmic os get Caption,Version /value" command
import subprocess
@@ -1051,6 +1068,7 @@ def collect_info(info):
collect_tkinter,
collect_windows,
collect_zlib,
+ collect_zstd,
collect_libregrtest_utils,
# Collecting from tests should be last as they have side effects.
diff --git a/Lib/test/re_tests.py b/Lib/test/re_tests.py
index 85b026736ca..e50f5d52bbd 100755
--- a/Lib/test/re_tests.py
+++ b/Lib/test/re_tests.py
@@ -531,7 +531,7 @@ xyzabc
(r'a[ ]*?\ (\d+).*', 'a 10', SUCCEED, 'found', 'a 10'),
(r'a[ ]*?\ (\d+).*', 'a 10', SUCCEED, 'found', 'a 10'),
# bug 127259: \Z shouldn't depend on multiline mode
- (r'(?ms).*?x\s*\Z(.*)','xx\nx\n', SUCCEED, 'g1', ''),
+ (r'(?ms).*?x\s*\z(.*)','xx\nx\n', SUCCEED, 'g1', ''),
# bug 128899: uppercase literals under the ignorecase flag
(r'(?i)M+', 'MMM', SUCCEED, 'found', 'MMM'),
(r'(?i)m+', 'MMM', SUCCEED, 'found', 'MMM'),
diff --git a/Lib/test/subprocessdata/fd_status.py b/Lib/test/subprocessdata/fd_status.py
index d12bd95abee..90e785981ae 100644
--- a/Lib/test/subprocessdata/fd_status.py
+++ b/Lib/test/subprocessdata/fd_status.py
@@ -2,7 +2,7 @@
file descriptors on stdout.
Usage:
-fd_stats.py: check all file descriptors
+fd_status.py: check all file descriptors (up to 255)
fd_status.py fd1 fd2 ...: check only specified file descriptors
"""
@@ -18,7 +18,7 @@ if __name__ == "__main__":
_MAXFD = os.sysconf("SC_OPEN_MAX")
except:
_MAXFD = 256
- test_fds = range(0, _MAXFD)
+ test_fds = range(0, min(_MAXFD, 256))
else:
test_fds = map(int, sys.argv[1:])
for fd in test_fds:
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 24984ad81ff..fd39d3f7c95 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -33,7 +33,7 @@ __all__ = [
"is_resource_enabled", "requires", "requires_freebsd_version",
"requires_gil_enabled", "requires_linux_version", "requires_mac_ver",
"check_syntax_error",
- "requires_gzip", "requires_bz2", "requires_lzma",
+ "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd",
"bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
"requires_IEEE_754", "requires_zlib",
"has_fork_support", "requires_fork",
@@ -46,6 +46,7 @@ __all__ = [
# sys
"MS_WINDOWS", "is_jython", "is_android", "is_emscripten", "is_wasi",
"is_apple_mobile", "check_impl_detail", "unix_shell", "setswitchinterval",
+ "support_remote_exec_only",
# os
"get_pagesize",
# network
@@ -527,6 +528,13 @@ def requires_lzma(reason='requires lzma'):
lzma = None
return unittest.skipUnless(lzma, reason)
+def requires_zstd(reason='requires zstd'):
+ try:
+ from compression import zstd
+ except ImportError:
+ zstd = None
+ return unittest.skipUnless(zstd, reason)
+
def has_no_debug_ranges():
try:
import _testcapi
@@ -689,9 +697,11 @@ def sortdict(dict):
return "{%s}" % withcommas
-def run_code(code: str) -> dict[str, object]:
+def run_code(code: str, extra_names: dict[str, object] | None = None) -> dict[str, object]:
"""Run a piece of code after dedenting it, and return its global namespace."""
ns = {}
+ if extra_names:
+ ns.update(extra_names)
exec(textwrap.dedent(code), ns)
return ns
@@ -936,6 +946,31 @@ def check_sizeof(test, o, size):
% (type(o), result, size)
test.assertEqual(result, size, msg)
+def subTests(arg_names, arg_values, /, *, _do_cleanups=False):
+ """Run multiple subtests with different parameters.
+ """
+ single_param = False
+ if isinstance(arg_names, str):
+ arg_names = arg_names.replace(',',' ').split()
+ if len(arg_names) == 1:
+ single_param = True
+ arg_values = tuple(arg_values)
+ def decorator(func):
+ if isinstance(func, type):
+ raise TypeError('subTests() can only decorate methods, not classes')
+ @functools.wraps(func)
+ def wrapper(self, /, *args, **kwargs):
+ for values in arg_values:
+ if single_param:
+ values = (values,)
+ subtest_kwargs = dict(zip(arg_names, values))
+ with self.subTest(**subtest_kwargs):
+ func(self, *args, **kwargs, **subtest_kwargs)
+ if _do_cleanups:
+ self.doCleanups()
+ return wrapper
+ return decorator
+
#=======================================================================
# Decorator/context manager for running a code in a different locale,
# correctly resetting it afterwards.
@@ -1075,7 +1110,7 @@ def set_memlimit(limit: str) -> None:
global real_max_memuse
memlimit = _parse_memlimit(limit)
if memlimit < _2G - 1:
- raise ValueError('Memory limit {limit!r} too low to be useful')
+ raise ValueError(f'Memory limit {limit!r} too low to be useful')
real_max_memuse = memlimit
memlimit = min(memlimit, MAX_Py_ssize_t)
@@ -1092,7 +1127,6 @@ class _MemoryWatchdog:
self.started = False
def start(self):
- import warnings
try:
f = open(self.procfile, 'r')
except OSError as e:
@@ -2299,6 +2333,7 @@ def check_disallow_instantiation(testcase, tp, *args, **kwds):
qualname = f"{name}"
msg = f"cannot create '{re.escape(qualname)}' instances"
testcase.assertRaisesRegex(TypeError, msg, tp, *args, **kwds)
+ testcase.assertRaisesRegex(TypeError, msg, tp.__new__, tp, *args, **kwds)
def get_recursion_depth():
"""Get the recursion depth of the caller function.
@@ -2350,7 +2385,7 @@ def infinite_recursion(max_depth=None):
# very deep recursion.
max_depth = 20_000
elif max_depth < 3:
- raise ValueError("max_depth must be at least 3, got {max_depth}")
+ raise ValueError(f"max_depth must be at least 3, got {max_depth}")
depth = get_recursion_depth()
depth = max(depth - 1, 1) # Ignore infinite_recursion() frame.
limit = depth + max_depth
@@ -2648,13 +2683,9 @@ skip_on_s390x = unittest.skipIf(is_s390x, 'skipped on s390x')
Py_TRACE_REFS = hasattr(sys, 'getobjects')
-try:
- from _testinternalcapi import jit_enabled
-except ImportError:
- requires_jit_enabled = requires_jit_disabled = unittest.skip("requires _testinternalcapi")
-else:
- requires_jit_enabled = unittest.skipUnless(jit_enabled(), "requires JIT enabled")
- requires_jit_disabled = unittest.skipIf(jit_enabled(), "requires JIT disabled")
+_JIT_ENABLED = sys._jit.is_enabled()
+requires_jit_enabled = unittest.skipUnless(_JIT_ENABLED, "requires JIT enabled")
+requires_jit_disabled = unittest.skipIf(_JIT_ENABLED, "requires JIT disabled")
_BASE_COPY_SRC_DIR_IGNORED_NAMES = frozenset({
@@ -2723,7 +2754,7 @@ def iter_builtin_types():
# Fall back to making a best-effort guess.
if hasattr(object, '__flags__'):
# Look for any type object with the Py_TPFLAGS_STATIC_BUILTIN flag set.
- import datetime
+ import datetime # noqa: F401
seen = set()
for cls, subs in walk_class_hierarchy(object):
if cls in seen:
@@ -2855,36 +2886,59 @@ def iter_slot_wrappers(cls):
@contextlib.contextmanager
-def no_color():
+def force_color(color: bool):
import _colorize
from .os_helper import EnvironmentVarGuard
with (
- swap_attr(_colorize, "can_colorize", lambda file=None: False),
+ swap_attr(_colorize, "can_colorize", lambda file=None: color),
EnvironmentVarGuard() as env,
):
env.unset("FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS")
- env.set("NO_COLOR", "1")
+ env.set("FORCE_COLOR" if color else "NO_COLOR", "1")
yield
+def force_colorized(func):
+ """Force the terminal to be colorized."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ with force_color(True):
+ return func(*args, **kwargs)
+ return wrapper
+
+
def force_not_colorized(func):
- """Force the terminal not to be colorized."""
+ """Force the terminal NOT to be colorized."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
- with no_color():
+ with force_color(False):
return func(*args, **kwargs)
return wrapper
+def force_colorized_test_class(cls):
+ """Force the terminal to be colorized for the entire test class."""
+ original_setUpClass = cls.setUpClass
+
+ @classmethod
+ @functools.wraps(cls.setUpClass)
+ def new_setUpClass(cls):
+ cls.enterClassContext(force_color(True))
+ original_setUpClass()
+
+ cls.setUpClass = new_setUpClass
+ return cls
+
+
def force_not_colorized_test_class(cls):
- """Force the terminal not to be colorized for the entire test class."""
+ """Force the terminal NOT to be colorized for the entire test class."""
original_setUpClass = cls.setUpClass
@classmethod
@functools.wraps(cls.setUpClass)
def new_setUpClass(cls):
- cls.enterClassContext(no_color())
+ cls.enterClassContext(force_color(False))
original_setUpClass()
cls.setUpClass = new_setUpClass
@@ -2901,12 +2955,6 @@ def make_clean_env() -> dict[str, str]:
return clean_env
-def initialized_with_pyrepl():
- """Detect whether PyREPL was used during Python initialization."""
- # If the main module has a __file__ attribute it's a Python module, which means PyREPL.
- return hasattr(sys.modules["__main__"], "__file__")
-
-
WINDOWS_STATUS = {
0xC0000005: "STATUS_ACCESS_VIOLATION",
0xC00000FD: "STATUS_STACK_OVERFLOW",
@@ -3023,6 +3071,27 @@ def is_libssl_fips_mode():
return False # more of a maybe, unless we add this to the _ssl module.
return get_fips_mode() != 0
+def _supports_remote_attaching():
+ PROCESS_VM_READV_SUPPORTED = False
+
+ try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ except ImportError:
+ pass
+
+ return PROCESS_VM_READV_SUPPORTED
+
+def _support_remote_exec_only_impl():
+ if not sys.is_remote_debug_enabled():
+ return unittest.skip("Remote debugging is not enabled")
+ if sys.platform not in ("darwin", "linux", "win32"):
+ return unittest.skip("Test only runs on Linux, Windows and macOS")
+ if sys.platform == "linux" and not _supports_remote_attaching():
+ return unittest.skip("Test only runs on Linux with process_vm_readv support")
+ return _id
+
+def support_remote_exec_only(test):
+ return _support_remote_exec_only_impl()(test)
class EqualToForwardRef:
"""Helper to ease use of annotationlib.ForwardRef in tests.
diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/channels.py
index d2bd93d77f7..b2de24d9d3e 100644
--- a/Lib/test/support/interpreters/channels.py
+++ b/Lib/test/support/channels.py
@@ -2,14 +2,14 @@
import time
import _interpchannels as _channels
-from . import _crossinterp
+from concurrent.interpreters import _crossinterp
# aliases:
from _interpchannels import (
- ChannelError, ChannelNotFoundError, ChannelClosedError,
- ChannelEmptyError, ChannelNotEmptyError,
+ ChannelError, ChannelNotFoundError, ChannelClosedError, # noqa: F401
+ ChannelEmptyError, ChannelNotEmptyError, # noqa: F401
)
-from ._crossinterp import (
+from concurrent.interpreters._crossinterp import (
UNBOUND_ERROR, UNBOUND_REMOVE,
)
@@ -55,15 +55,23 @@ def create(*, unbounditems=UNBOUND):
"""
unbound = _serialize_unbound(unbounditems)
unboundop, = unbound
- cid = _channels.create(unboundop)
- recv, send = RecvChannel(cid), SendChannel(cid, _unbound=unbound)
+ cid = _channels.create(unboundop, -1)
+ recv, send = RecvChannel(cid), SendChannel(cid)
+ send._set_unbound(unboundop, unbounditems)
return recv, send
def list_all():
"""Return a list of (recv, send) for all open channels."""
- return [(RecvChannel(cid), SendChannel(cid, _unbound=unbound))
- for cid, unbound in _channels.list_all()]
+ channels = []
+ for cid, unboundop, _ in _channels.list_all():
+ chan = _, send = RecvChannel(cid), SendChannel(cid)
+ if not hasattr(send, '_unboundop'):
+ send._set_unbound(unboundop)
+ else:
+ assert send._unbound[0] == unboundop
+ channels.append(chan)
+ return channels
class _ChannelEnd:
@@ -175,16 +183,33 @@ class SendChannel(_ChannelEnd):
_end = 'send'
- def __new__(cls, cid, *, _unbound=None):
- if _unbound is None:
- try:
- op = _channels.get_channel_defaults(cid)
- _unbound = (op,)
- except ChannelNotFoundError:
- _unbound = _serialize_unbound(UNBOUND)
- self = super().__new__(cls, cid)
- self._unbound = _unbound
- return self
+# def __new__(cls, cid, *, _unbound=None):
+# if _unbound is None:
+# try:
+# op = _channels.get_channel_defaults(cid)
+# _unbound = (op,)
+# except ChannelNotFoundError:
+# _unbound = _serialize_unbound(UNBOUND)
+# self = super().__new__(cls, cid)
+# self._unbound = _unbound
+# return self
+
+ def _set_unbound(self, op, items=None):
+ assert not hasattr(self, '_unbound')
+ if items is None:
+ items = _resolve_unbound(op)
+ unbound = (op, items)
+ self._unbound = unbound
+ return unbound
+
+ @property
+ def unbounditems(self):
+ try:
+ _, items = self._unbound
+ except AttributeError:
+ op, _ = _channels.get_queue_defaults(self._id)
+ _, items = self._set_unbound(op)
+ return items
@property
def is_closed(self):
@@ -192,61 +217,61 @@ class SendChannel(_ChannelEnd):
return info.closed or info.closing
def send(self, obj, timeout=None, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object (i.e. its data) to the channel's receiving end.
This blocks until the object is received.
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
_channels.send(self._id, obj, unboundop, timeout=timeout, blocking=True)
def send_nowait(self, obj, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object to the channel's receiving end.
If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
# XXX Note that at the moment channel_send() only ever returns
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj, unboundop, blocking=False)
def send_buffer(self, obj, timeout=None, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object's buffer to the channel's receiving end.
This blocks until the object is received.
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
_channels.send_buffer(self._id, obj, unboundop,
timeout=timeout, blocking=True)
def send_buffer_nowait(self, obj, *,
- unbound=None,
+ unbounditems=None,
):
"""Send the object's buffer to the channel's receiving end.
If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
- if unbound is None:
- unboundop, = self._unbound
+ if unbounditems is None:
+ unboundop = -1
else:
- unboundop, = _serialize_unbound(unbound)
+ unboundop, = _serialize_unbound(unbounditems)
return _channels.send_buffer(self._id, obj, unboundop, blocking=False)
def close(self):
diff --git a/Lib/test/support/hashlib_helper.py b/Lib/test/support/hashlib_helper.py
index 5043f08dd93..7032257b068 100644
--- a/Lib/test/support/hashlib_helper.py
+++ b/Lib/test/support/hashlib_helper.py
@@ -23,6 +23,22 @@ def requires_builtin_hmac():
return unittest.skipIf(_hmac is None, "requires _hmac")
+def _missing_hash(digestname, implementation=None, *, exc=None):
+ parts = ["missing", implementation, f"hash algorithm: {digestname!r}"]
+ msg = " ".join(filter(None, parts))
+ raise unittest.SkipTest(msg) from exc
+
+
+def _openssl_availabillity(digestname, *, usedforsecurity):
+ try:
+ _hashlib.new(digestname, usedforsecurity=usedforsecurity)
+ except AttributeError:
+ assert _hashlib is None
+ _missing_hash(digestname, "OpenSSL")
+ except ValueError as exc:
+ _missing_hash(digestname, "OpenSSL", exc=exc)
+
+
def _decorate_func_or_class(func_or_class, decorator_func):
if not isinstance(func_or_class, type):
return decorator_func(func_or_class)
@@ -71,8 +87,7 @@ def requires_hashdigest(digestname, openssl=None, usedforsecurity=True):
try:
test_availability()
except ValueError as exc:
- msg = f"missing hash algorithm: {digestname!r}"
- raise unittest.SkipTest(msg) from exc
+ _missing_hash(digestname, exc=exc)
return func(*args, **kwargs)
return wrapper
@@ -87,14 +102,44 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True):
The hashing algorithm may be missing or blocked by a strict crypto policy.
"""
def decorator_func(func):
- @requires_hashlib()
+ @requires_hashlib() # avoid checking at each call
@functools.wraps(func)
def wrapper(*args, **kwargs):
+ _openssl_availabillity(digestname, usedforsecurity=usedforsecurity)
+ return func(*args, **kwargs)
+ return wrapper
+
+ def decorator(func_or_class):
+ return _decorate_func_or_class(func_or_class, decorator_func)
+ return decorator
+
+
+def find_openssl_hashdigest_constructor(digestname, *, usedforsecurity=True):
+ """Find the OpenSSL hash function constructor by its name."""
+ assert isinstance(digestname, str), digestname
+ _openssl_availabillity(digestname, usedforsecurity=usedforsecurity)
+ # This returns a function of the form _hashlib.openssl_<name> and
+ # not a lambda function as it is rejected by _hashlib.hmac_new().
+ return getattr(_hashlib, f"openssl_{digestname}")
+
+
+def requires_builtin_hashdigest(
+ module_name, digestname, *, usedforsecurity=True
+):
+ """Decorator raising SkipTest if a HACL* hashing algorithm is missing.
+
+ - The *module_name* is the C extension module name based on HACL*.
+ - The *digestname* is one of its member, e.g., 'md5'.
+ """
+ def decorator_func(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ module = import_module(module_name)
try:
- _hashlib.new(digestname, usedforsecurity=usedforsecurity)
- except ValueError:
- msg = f"missing OpenSSL hash algorithm: {digestname!r}"
- raise unittest.SkipTest(msg)
+ getattr(module, digestname)
+ except AttributeError:
+ fullname = f'{module_name}.{digestname}'
+ _missing_hash(fullname, implementation="HACL")
return func(*args, **kwargs)
return wrapper
@@ -103,6 +148,168 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True):
return decorator
+def find_builtin_hashdigest_constructor(
+ module_name, digestname, *, usedforsecurity=True
+):
+ """Find the HACL* hash function constructor.
+
+ - The *module_name* is the C extension module name based on HACL*.
+ - The *digestname* is one of its member, e.g., 'md5'.
+ """
+ module = import_module(module_name)
+ try:
+ constructor = getattr(module, digestname)
+ constructor(b'', usedforsecurity=usedforsecurity)
+ except (AttributeError, TypeError, ValueError):
+ _missing_hash(f'{module_name}.{digestname}', implementation="HACL")
+ return constructor
+
+
+class HashFunctionsTrait:
+ """Mixin trait class containing hash functions.
+
+ This class is assumed to have all unitest.TestCase methods but should
+ not directly inherit from it to prevent the test suite being run on it.
+
+ Subclasses should implement the hash functions by returning an object
+ that can be recognized as a valid digestmod parameter for both hashlib
+ and HMAC. In particular, it cannot be a lambda function as it will not
+ be recognized by hashlib (it will still be accepted by the pure Python
+ implementation of HMAC).
+ """
+
+ ALGORITHMS = [
+ 'md5', 'sha1',
+ 'sha224', 'sha256', 'sha384', 'sha512',
+ 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512',
+ ]
+
+ # Default 'usedforsecurity' to use when looking up a hash function.
+ usedforsecurity = True
+
+ def _find_constructor(self, name):
+ # By default, a missing algorithm skips the test that uses it.
+ self.assertIn(name, self.ALGORITHMS)
+ self.skipTest(f"missing hash function: {name}")
+
+ @property
+ def md5(self):
+ return self._find_constructor("md5")
+
+ @property
+ def sha1(self):
+ return self._find_constructor("sha1")
+
+ @property
+ def sha224(self):
+ return self._find_constructor("sha224")
+
+ @property
+ def sha256(self):
+ return self._find_constructor("sha256")
+
+ @property
+ def sha384(self):
+ return self._find_constructor("sha384")
+
+ @property
+ def sha512(self):
+ return self._find_constructor("sha512")
+
+ @property
+ def sha3_224(self):
+ return self._find_constructor("sha3_224")
+
+ @property
+ def sha3_256(self):
+ return self._find_constructor("sha3_256")
+
+ @property
+ def sha3_384(self):
+ return self._find_constructor("sha3_384")
+
+ @property
+ def sha3_512(self):
+ return self._find_constructor("sha3_512")
+
+
+class NamedHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing named hash functions.
+
+ Hash functions are available if and only if they are available in hashlib.
+ """
+
+ def _find_constructor(self, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return name
+
+
+class OpenSSLHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing OpenSSL hash functions.
+
+ Hash functions are available if and only if they are available in _hashlib.
+ """
+
+ def _find_constructor(self, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return find_openssl_hashdigest_constructor(
+ name, usedforsecurity=self.usedforsecurity
+ )
+
+
+class BuiltinHashFunctionsTrait(HashFunctionsTrait):
+ """Trait containing HACL* hash functions.
+
+ Hash functions are available if and only if they are available in C.
+ In particular, HACL* HMAC-MD5 may be available even though HACL* md5
+ is not since the former is unconditionally built.
+ """
+
+ def _find_constructor_in(self, module, name):
+ self.assertIn(name, self.ALGORITHMS)
+ return find_builtin_hashdigest_constructor(module, name)
+
+ @property
+ def md5(self):
+ return self._find_constructor_in("_md5", "md5")
+
+ @property
+ def sha1(self):
+ return self._find_constructor_in("_sha1", "sha1")
+
+ @property
+ def sha224(self):
+ return self._find_constructor_in("_sha2", "sha224")
+
+ @property
+ def sha256(self):
+ return self._find_constructor_in("_sha2", "sha256")
+
+ @property
+ def sha384(self):
+ return self._find_constructor_in("_sha2", "sha384")
+
+ @property
+ def sha512(self):
+ return self._find_constructor_in("_sha2", "sha512")
+
+ @property
+ def sha3_224(self):
+ return self._find_constructor_in("_sha3", "sha3_224")
+
+ @property
+ def sha3_256(self):
+ return self._find_constructor_in("_sha3","sha3_256")
+
+ @property
+ def sha3_384(self):
+ return self._find_constructor_in("_sha3","sha3_384")
+
+ @property
+ def sha3_512(self):
+ return self._find_constructor_in("_sha3","sha3_512")
+
+
def find_gil_minsize(modules_names, default=2048):
"""Get the largest GIL_MINSIZE value for the given cryptographic modules.
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index edb734d294f..0af63501f93 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -438,5 +438,5 @@ def ensure_module_imported(name, *, clearnone=True):
if sys.modules.get(name) is not None:
mod = sys.modules[name]
else:
- mod, _, _ = _force_import(name, False, True, clearnone)
+ mod, _, _ = _ensure_module(name, False, True, clearnone)
return mod
diff --git a/Lib/test/support/strace_helper.py b/Lib/test/support/strace_helper.py
index 798d6c68869..cf95f7bdc7d 100644
--- a/Lib/test/support/strace_helper.py
+++ b/Lib/test/support/strace_helper.py
@@ -38,7 +38,7 @@ class StraceResult:
This assumes the program under inspection doesn't print any non-utf8
strings which would mix into the strace output."""
- decoded_events = self.event_bytes.decode('utf-8')
+ decoded_events = self.event_bytes.decode('utf-8', 'surrogateescape')
matches = [
_syscall_regex.match(event)
for event in decoded_events.splitlines()
@@ -178,7 +178,10 @@ def get_syscalls(code, strace_flags, prelude="", cleanup="",
# Moderately expensive (spawns a subprocess), so share results when possible.
@cache
def _can_strace():
- res = strace_python("import sys; sys.exit(0)", [], check=False)
+ res = strace_python("import sys; sys.exit(0)",
+ # --trace option needs strace 5.5 (gh-133741)
+ ["--trace=%process"],
+ check=False)
if res.strace_returncode == 0 and res.python_returncode == 0:
assert res.events(), "Should have parsed multiple calls"
return True
diff --git a/Lib/test/support/warnings_helper.py b/Lib/test/support/warnings_helper.py
index a6e43dff200..5f6f14afd74 100644
--- a/Lib/test/support/warnings_helper.py
+++ b/Lib/test/support/warnings_helper.py
@@ -23,8 +23,7 @@ def check_syntax_warning(testcase, statement, errtext='',
testcase.assertEqual(len(warns), 1, warns)
warn, = warns
- testcase.assertTrue(issubclass(warn.category, SyntaxWarning),
- warn.category)
+ testcase.assertIsSubclass(warn.category, SyntaxWarning)
if errtext:
testcase.assertRegex(str(warn.message), errtext)
testcase.assertEqual(warn.filename, '<testcase>')
diff --git a/Lib/test/test__interpchannels.py b/Lib/test/test__interpchannels.py
index e4c1ad85451..858d31a73cf 100644
--- a/Lib/test/test__interpchannels.py
+++ b/Lib/test/test__interpchannels.py
@@ -9,7 +9,7 @@ import unittest
from test.support import import_helper, skip_if_sanitizer
_channels = import_helper.import_module('_interpchannels')
-from test.support.interpreters import _crossinterp
+from concurrent.interpreters import _crossinterp
from test.test__interpreters import (
_interpreters,
_run_output,
@@ -247,7 +247,7 @@ def _run_action(cid, action, end, state):
def clean_up_channels():
- for cid, _ in _channels.list_all():
+ for cid, _, _ in _channels.list_all():
try:
_channels.destroy(cid)
except _channels.ChannelNotFoundError:
@@ -373,11 +373,11 @@ class ChannelTests(TestBase):
self.assertIsInstance(cid, _channels.ChannelID)
def test_sequential_ids(self):
- before = [cid for cid, _ in _channels.list_all()]
+ before = [cid for cid, _, _ in _channels.list_all()]
id1 = _channels.create(REPLACE)
id2 = _channels.create(REPLACE)
id3 = _channels.create(REPLACE)
- after = [cid for cid, _ in _channels.list_all()]
+ after = [cid for cid, _, _ in _channels.list_all()]
self.assertEqual(id2, int(id1) + 1)
self.assertEqual(id3, int(id2) + 1)
diff --git a/Lib/test/test__interpreters.py b/Lib/test/test__interpreters.py
index 0c43f46300f..a32d5d81d2b 100644
--- a/Lib/test/test__interpreters.py
+++ b/Lib/test/test__interpreters.py
@@ -474,15 +474,32 @@ class CommonTests(TestBase):
def test_signatures(self):
# See https://github.com/python/cpython/issues/126654
- msg = "expected 'shared' to be a dict"
+ msg = r'_interpreters.exec\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.exec(self.id, 'a', 1)
with self.assertRaisesRegex(TypeError, msg):
_interpreters.exec(self.id, 'a', shared=1)
+ msg = r'_interpreters.run_string\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.run_string(self.id, 'a', shared=1)
+ msg = r'_interpreters.run_func\(\) argument 3 must be dict, not int'
with self.assertRaisesRegex(TypeError, msg):
_interpreters.run_func(self.id, lambda: None, shared=1)
+ # See https://github.com/python/cpython/issues/135855
+ msg = r'_interpreters.set___main___attrs\(\) argument 2 must be dict, not int'
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.set___main___attrs(self.id, 1)
+
+ def test_invalid_shared_none(self):
+ msg = r'must be dict, not None'
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.exec(self.id, 'a', shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.run_string(self.id, 'a', shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.run_func(self.id, lambda: None, shared=None)
+ with self.assertRaisesRegex(TypeError, msg):
+ _interpreters.set___main___attrs(self.id, None)
def test_invalid_shared_encoding(self):
# See https://github.com/python/cpython/issues/127196
@@ -952,7 +969,8 @@ class RunFailedTests(TestBase):
""")
with self.subTest('script'):
- self.assert_run_failed(SyntaxError, script)
+ with self.assertRaises(SyntaxError):
+ _interpreters.run_string(self.id, script)
with self.subTest('module'):
modname = 'spam_spam_spam'
@@ -1019,12 +1037,19 @@ class RunFuncTests(TestBase):
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
+ failed = None
def f():
- _interpreters.set___main___attrs(self.id, dict(w=w))
- _interpreters.run_func(self.id, script)
+ nonlocal failed
+ try:
+ _interpreters.set___main___attrs(self.id, dict(w=w))
+ _interpreters.run_func(self.id, script)
+ except Exception as exc:
+ failed = exc
t = threading.Thread(target=f)
t.start()
t.join()
+ if failed:
+ raise Exception from failed
with open(r, encoding="utf-8") as outfile:
out = outfile.read()
@@ -1053,18 +1078,16 @@ class RunFuncTests(TestBase):
spam = True
def script():
assert spam
-
with self.assertRaises(ValueError):
_interpreters.run_func(self.id, script)
- # XXX This hasn't been fixed yet.
- @unittest.expectedFailure
def test_return_value(self):
def script():
return 'spam'
with self.assertRaises(ValueError):
_interpreters.run_func(self.id, script)
+# @unittest.skip("we're not quite there yet")
def test_args(self):
with self.subTest('args'):
def script(a, b=0):
diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py
index 53aa26620a6..0813c4804c1 100644
--- a/Lib/test/test__osx_support.py
+++ b/Lib/test/test__osx_support.py
@@ -66,8 +66,8 @@ class Test_OSXSupport(unittest.TestCase):
'cc not found - check xcode-select')
def test__get_system_version(self):
- self.assertTrue(platform.mac_ver()[0].startswith(
- _osx_support._get_system_version()))
+ self.assertStartsWith(platform.mac_ver()[0],
+ _osx_support._get_system_version())
def test__remove_original_values(self):
config_vars = {
diff --git a/Lib/test/test_abstract_numbers.py b/Lib/test/test_abstract_numbers.py
index 72232b670cd..cf071d2c933 100644
--- a/Lib/test/test_abstract_numbers.py
+++ b/Lib/test/test_abstract_numbers.py
@@ -24,11 +24,11 @@ def concretize(cls):
class TestNumbers(unittest.TestCase):
def test_int(self):
- self.assertTrue(issubclass(int, Integral))
- self.assertTrue(issubclass(int, Rational))
- self.assertTrue(issubclass(int, Real))
- self.assertTrue(issubclass(int, Complex))
- self.assertTrue(issubclass(int, Number))
+ self.assertIsSubclass(int, Integral)
+ self.assertIsSubclass(int, Rational)
+ self.assertIsSubclass(int, Real)
+ self.assertIsSubclass(int, Complex)
+ self.assertIsSubclass(int, Number)
self.assertEqual(7, int(7).real)
self.assertEqual(0, int(7).imag)
@@ -38,11 +38,11 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(1, int(7).denominator)
def test_float(self):
- self.assertFalse(issubclass(float, Integral))
- self.assertFalse(issubclass(float, Rational))
- self.assertTrue(issubclass(float, Real))
- self.assertTrue(issubclass(float, Complex))
- self.assertTrue(issubclass(float, Number))
+ self.assertNotIsSubclass(float, Integral)
+ self.assertNotIsSubclass(float, Rational)
+ self.assertIsSubclass(float, Real)
+ self.assertIsSubclass(float, Complex)
+ self.assertIsSubclass(float, Number)
self.assertEqual(7.3, float(7.3).real)
self.assertEqual(0, float(7.3).imag)
@@ -50,11 +50,11 @@ class TestNumbers(unittest.TestCase):
self.assertEqual(-7.3, float(-7.3).conjugate())
def test_complex(self):
- self.assertFalse(issubclass(complex, Integral))
- self.assertFalse(issubclass(complex, Rational))
- self.assertFalse(issubclass(complex, Real))
- self.assertTrue(issubclass(complex, Complex))
- self.assertTrue(issubclass(complex, Number))
+ self.assertNotIsSubclass(complex, Integral)
+ self.assertNotIsSubclass(complex, Rational)
+ self.assertNotIsSubclass(complex, Real)
+ self.assertIsSubclass(complex, Complex)
+ self.assertIsSubclass(complex, Number)
c1, c2 = complex(3, 2), complex(4,1)
# XXX: This is not ideal, but see the comment in math_trunc().
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index be55f044b15..ae0e73f08c5 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -1,18 +1,19 @@
"""Tests for the annotations module."""
+import textwrap
import annotationlib
import builtins
import collections
import functools
import itertools
import pickle
+from string.templatelib import Template
import typing
import unittest
from annotationlib import (
Format,
ForwardRef,
get_annotations,
- get_annotate_function,
annotations_to_string,
type_repr,
)
@@ -121,6 +122,28 @@ class TestForwardRefFormat(unittest.TestCase):
self.assertIsInstance(gamma_anno, ForwardRef)
self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f))
+ def test_partially_nonexistent_union(self):
+ # Test unions with '|' syntax equal unions with typing.Union[] with some forwardrefs
+ class UnionForwardrefs:
+ pipe: str | undefined
+ union: Union[str, undefined]
+
+ annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF)
+
+ pipe = annos["pipe"]
+ self.assertIsInstance(pipe, ForwardRef)
+ self.assertEqual(
+ pipe.evaluate(globals={"undefined": int}),
+ str | int,
+ )
+ union = annos["union"]
+ self.assertIsInstance(union, Union)
+ arg1, arg2 = typing.get_args(union)
+ self.assertIs(arg1, str)
+ self.assertEqual(
+ arg2, support.EqualToForwardRef("undefined", is_class=True, owner=UnionForwardrefs)
+ )
+
class TestStringFormat(unittest.TestCase):
def test_closure(self):
@@ -251,6 +274,126 @@ class TestStringFormat(unittest.TestCase):
},
)
+ def test_template_str(self):
+ def f(
+ x: t"{a}",
+ y: list[t"{a}"],
+ z: t"{a:b} {c!r} {d!s:t}",
+ a: t"a{b}c{d}e{f}g",
+ b: t"{a:{1}}",
+ c: t"{a | b * c}",
+ ): pass
+
+ annos = get_annotations(f, format=Format.STRING)
+ self.assertEqual(annos, {
+ "x": "t'{a}'",
+ "y": "list[t'{a}']",
+ "z": "t'{a:b} {c!r} {d!s:t}'",
+ "a": "t'a{b}c{d}e{f}g'",
+ # interpolations in the format spec are eagerly evaluated so we can't recover the source
+ "b": "t'{a:1}'",
+ "c": "t'{a | b * c}'",
+ })
+
+ def g(
+ x: t"{a}",
+ ): ...
+
+ annos = get_annotations(g, format=Format.FORWARDREF)
+ templ = annos["x"]
+ # Template and Interpolation don't have __eq__ so we have to compare manually
+ self.assertIsInstance(templ, Template)
+ self.assertEqual(templ.strings, ("", ""))
+ self.assertEqual(len(templ.interpolations), 1)
+ interp = templ.interpolations[0]
+ self.assertEqual(interp.value, support.EqualToForwardRef("a", owner=g))
+ self.assertEqual(interp.expression, "a")
+ self.assertIsNone(interp.conversion)
+ self.assertEqual(interp.format_spec, "")
+
+ def test_getitem(self):
+ def f(x: undef1[str, undef2]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "undef1[str, undef2]"})
+
+ anno = get_annotations(f, format=Format.FORWARDREF)
+ fwdref = anno["x"]
+ self.assertIsInstance(fwdref, ForwardRef)
+ self.assertEqual(
+ fwdref.evaluate(globals={"undef1": dict, "undef2": float}), dict[str, float]
+ )
+
+ def test_slice(self):
+ def f(x: a[b:c]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[b:c]"})
+
+ def f(x: a[b:c, d:e]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[b:c, d:e]"})
+
+ obj = slice(1, 1, 1)
+ def f(x: obj):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "obj"})
+
+ def test_literals(self):
+ def f(
+ a: 1,
+ b: 1.0,
+ c: "hello",
+ d: b"hello",
+ e: True,
+ f: None,
+ g: ...,
+ h: 1j,
+ ):
+ pass
+
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(
+ anno,
+ {
+ "a": "1",
+ "b": "1.0",
+ "c": 'hello',
+ "d": "b'hello'",
+ "e": "True",
+ "f": "None",
+ "g": "...",
+ "h": "1j",
+ },
+ )
+
+ def test_displays(self):
+ # Simple case first
+ def f(x: a[[int, str], float]):
+ pass
+ anno = get_annotations(f, format=Format.STRING)
+ self.assertEqual(anno, {"x": "a[[int, str], float]"})
+
+ def g(
+ w: a[[int, str], float],
+ x: a[{int}, 3],
+ y: a[{int: str}, 4],
+ z: a[(int, str), 5],
+ ):
+ pass
+ anno = get_annotations(g, format=Format.STRING)
+ self.assertEqual(
+ anno,
+ {
+ "w": "a[[int, str], float]",
+ "x": "a[{int}, 3]",
+ "y": "a[{int: str}, 4]",
+ "z": "a[(int, str), 5]",
+ },
+ )
+
def test_nested_expressions(self):
def f(
nested: list[Annotated[set[int], "set of ints", 4j]],
@@ -296,6 +439,17 @@ class TestStringFormat(unittest.TestCase):
with self.assertRaisesRegex(TypeError, format_msg):
get_annotations(f, format=Format.STRING)
+ def test_shenanigans(self):
+ # In cases like this we can't reconstruct the source; test that we do something
+ # halfway reasonable.
+ def f(x: x | (1).__class__, y: (1).__class__):
+ pass
+
+ self.assertEqual(
+ get_annotations(f, format=Format.STRING),
+ {"x": "x | <class 'int'>", "y": "<class 'int'>"},
+ )
+
class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
@@ -661,6 +815,70 @@ class TestGetAnnotations(unittest.TestCase):
{"x": int},
)
+ def test_stringized_annotation_permutations(self):
+ def define_class(name, has_future, has_annos, base_text, extra_names=None):
+ lines = []
+ if has_future:
+ lines.append("from __future__ import annotations")
+ lines.append(f"class {name}({base_text}):")
+ if has_annos:
+ lines.append(f" {name}_attr: int")
+ else:
+ lines.append(" pass")
+ code = "\n".join(lines)
+ ns = support.run_code(code, extra_names=extra_names)
+ return ns[name]
+
+ def check_annotations(cls, has_future, has_annos):
+ if has_annos:
+ if has_future:
+ anno = "int"
+ else:
+ anno = int
+ self.assertEqual(get_annotations(cls), {f"{cls.__name__}_attr": anno})
+ else:
+ self.assertEqual(get_annotations(cls), {})
+
+ for meta_future, base_future, child_future, meta_has_annos, base_has_annos, child_has_annos in itertools.product(
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ (False, True),
+ ):
+ with self.subTest(
+ meta_future=meta_future,
+ base_future=base_future,
+ child_future=child_future,
+ meta_has_annos=meta_has_annos,
+ base_has_annos=base_has_annos,
+ child_has_annos=child_has_annos,
+ ):
+ meta = define_class(
+ "Meta",
+ has_future=meta_future,
+ has_annos=meta_has_annos,
+ base_text="type",
+ )
+ base = define_class(
+ "Base",
+ has_future=base_future,
+ has_annos=base_has_annos,
+ base_text="metaclass=Meta",
+ extra_names={"Meta": meta},
+ )
+ child = define_class(
+ "Child",
+ has_future=child_future,
+ has_annos=child_has_annos,
+ base_text="Base",
+ extra_names={"Base": base},
+ )
+ check_annotations(meta, meta_future, meta_has_annos)
+ check_annotations(base, base_future, base_has_annos)
+ check_annotations(child, child_future, child_has_annos)
+
def test_modify_annotations(self):
def f(x: int):
pass
@@ -901,6 +1119,73 @@ class TestGetAnnotations(unittest.TestCase):
set(results.generic_func.__type_params__),
)
+ def test_partial_evaluation(self):
+ def f(
+ x: builtins.undef,
+ y: list[int],
+ z: 1 + int,
+ a: builtins.int,
+ b: [builtins.undef, builtins.int],
+ ):
+ pass
+
+ self.assertEqual(
+ get_annotations(f, format=Format.FORWARDREF),
+ {
+ "x": support.EqualToForwardRef("builtins.undef", owner=f),
+ "y": list[int],
+ "z": support.EqualToForwardRef("1 + int", owner=f),
+ "a": int,
+ "b": [
+ support.EqualToForwardRef("builtins.undef", owner=f),
+ # We can't resolve this because we have to evaluate the whole annotation
+ support.EqualToForwardRef("builtins.int", owner=f),
+ ],
+ },
+ )
+
+ self.assertEqual(
+ get_annotations(f, format=Format.STRING),
+ {
+ "x": "builtins.undef",
+ "y": "list[int]",
+ "z": "1 + int",
+ "a": "builtins.int",
+ "b": "[builtins.undef, builtins.int]",
+ },
+ )
+
+ def test_partial_evaluation_error(self):
+ def f(x: range[1]):
+ pass
+ with self.assertRaisesRegex(
+ TypeError, "type 'range' is not subscriptable"
+ ):
+ f.__annotations__
+
+ self.assertEqual(
+ get_annotations(f, format=Format.FORWARDREF),
+ {
+ "x": support.EqualToForwardRef("range[1]", owner=f),
+ },
+ )
+
+ def test_partial_evaluation_cell(self):
+ obj = object()
+
+ class RaisesAttributeError:
+ attriberr: obj.missing
+
+ anno = get_annotations(RaisesAttributeError, format=Format.FORWARDREF)
+ self.assertEqual(
+ anno,
+ {
+ "attriberr": support.EqualToForwardRef(
+ "obj.missing", is_class=True, owner=RaisesAttributeError
+ )
+ },
+ )
+
class TestCallEvaluateFunction(unittest.TestCase):
def test_evaluation(self):
@@ -933,13 +1218,13 @@ class MetaclassTests(unittest.TestCase):
b: float
self.assertEqual(get_annotations(Meta), {"a": int})
- self.assertEqual(get_annotate_function(Meta)(Format.VALUE), {"a": int})
+ self.assertEqual(Meta.__annotate__(Format.VALUE), {"a": int})
self.assertEqual(get_annotations(X), {})
- self.assertIs(get_annotate_function(X), None)
+ self.assertIs(X.__annotate__, None)
self.assertEqual(get_annotations(Y), {"b": float})
- self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float})
+ self.assertEqual(Y.__annotate__(Format.VALUE), {"b": float})
def test_unannotated_meta(self):
class Meta(type):
@@ -952,13 +1237,13 @@ class MetaclassTests(unittest.TestCase):
pass
self.assertEqual(get_annotations(Meta), {})
- self.assertIs(get_annotate_function(Meta), None)
+ self.assertIs(Meta.__annotate__, None)
self.assertEqual(get_annotations(Y), {})
- self.assertIs(get_annotate_function(Y), None)
+ self.assertIs(Y.__annotate__, None)
self.assertEqual(get_annotations(X), {"a": str})
- self.assertEqual(get_annotate_function(X)(Format.VALUE), {"a": str})
+ self.assertEqual(X.__annotate__(Format.VALUE), {"a": str})
def test_ordering(self):
# Based on a sample by David Ellis
@@ -996,7 +1281,7 @@ class MetaclassTests(unittest.TestCase):
for c in classes:
with self.subTest(c=c):
self.assertEqual(get_annotations(c), c.expected_annotations)
- annotate_func = get_annotate_function(c)
+ annotate_func = getattr(c, "__annotate__", None)
if c.expected_annotations:
self.assertEqual(
annotate_func(Format.VALUE), c.expected_annotations
@@ -1005,25 +1290,39 @@ class MetaclassTests(unittest.TestCase):
self.assertIs(annotate_func, None)
-class TestGetAnnotateFunction(unittest.TestCase):
- def test_static_class(self):
- self.assertIsNone(get_annotate_function(object))
- self.assertIsNone(get_annotate_function(int))
-
- def test_unannotated_class(self):
- class C:
- pass
+class TestGetAnnotateFromClassNamespace(unittest.TestCase):
+ def test_with_metaclass(self):
+ class Meta(type):
+ def __new__(mcls, name, bases, ns):
+ annotate = annotationlib.get_annotate_from_class_namespace(ns)
+ expected = ns["expected_annotate"]
+ with self.subTest(name=name):
+ if expected:
+ self.assertIsNotNone(annotate)
+ else:
+ self.assertIsNone(annotate)
+ return super().__new__(mcls, name, bases, ns)
+
+ class HasAnnotations(metaclass=Meta):
+ expected_annotate = True
+ a: int
- self.assertIsNone(get_annotate_function(C))
+ class NoAnnotations(metaclass=Meta):
+ expected_annotate = False
- D = type("D", (), {})
- self.assertIsNone(get_annotate_function(D))
+ class CustomAnnotate(metaclass=Meta):
+ expected_annotate = True
+ def __annotate__(format):
+ return {}
- def test_annotated_class(self):
- class C:
- a: int
+ code = """
+ from __future__ import annotations
- self.assertEqual(get_annotate_function(C)(Format.VALUE), {"a": int})
+ class HasFutureAnnotations(metaclass=Meta):
+ expected_annotate = False
+ a: int
+ """
+ exec(textwrap.dedent(code), {"Meta": Meta})
class TestTypeRepr(unittest.TestCase):
@@ -1240,6 +1539,38 @@ class TestForwardRefClass(unittest.TestCase):
with self.assertRaises(TypeError):
pickle.dumps(fr, proto)
+ def test_evaluate_string_format(self):
+ fr = ForwardRef("set[Any]")
+ self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")
+
+ def test_evaluate_forwardref_format(self):
+ fr = ForwardRef("undef")
+ evaluated = fr.evaluate(format=Format.FORWARDREF)
+ self.assertIs(fr, evaluated)
+
+ fr = ForwardRef("set[undefined]")
+ evaluated = fr.evaluate(format=Format.FORWARDREF)
+ self.assertEqual(
+ evaluated,
+ set[support.EqualToForwardRef("undefined")],
+ )
+
+ fr = ForwardRef("a + b")
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF),
+ support.EqualToForwardRef("a + b"),
+ )
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF, locals={"a": 1, "b": 2}),
+ 3,
+ )
+
+ fr = ForwardRef('"a" + 1')
+ self.assertEqual(
+ fr.evaluate(format=Format.FORWARDREF),
+ support.EqualToForwardRef('"a" + 1'),
+ )
+
def test_evaluate_with_type_params(self):
class Gen[T]:
alias = int
@@ -1319,9 +1650,11 @@ class TestForwardRefClass(unittest.TestCase):
with support.swap_attr(builtins, "int", dict):
self.assertIs(ForwardRef("int").evaluate(), dict)
- with self.assertRaises(NameError):
+ with self.assertRaises(NameError, msg="name 'doesntexist' is not defined") as exc:
ForwardRef("doesntexist").evaluate()
+ self.assertEqual(exc.exception.name, "doesntexist")
+
def test_fwdref_invalid_syntax(self):
fr = ForwardRef("if")
with self.assertRaises(SyntaxError):
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index c5a1f31aa52..08ff41368d9 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -5469,11 +5469,60 @@ class TestHelpMetavarTypeFormatter(HelpTestCase):
version = ''
-class TestHelpUsageLongSubparserCommand(TestCase):
- """Test that subparser commands are formatted correctly in help"""
+class TestHelpCustomHelpFormatter(TestCase):
maxDiff = None
- def test_parent_help(self):
+ def test_custom_formatter_function(self):
+ def custom_formatter(prog):
+ return argparse.RawTextHelpFormatter(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog='PROG',
+ prefix_chars='-+',
+ formatter_class=custom_formatter
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent('''\
+ usage: PROG [-h] [+f FOO] spam
+
+ positional arguments:
+ spam spam help
+
+ options:
+ -h, --help show this help message and exit
+ +f, ++foo FOO foo help
+ '''))
+
+ def test_custom_formatter_class(self):
+ class CustomFormatter(argparse.RawTextHelpFormatter):
+ def __init__(self, prog):
+ super().__init__(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog='PROG',
+ prefix_chars='-+',
+ formatter_class=CustomFormatter
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent('''\
+ usage: PROG [-h] [+f FOO] spam
+
+ positional arguments:
+ spam spam help
+
+ options:
+ -h, --help show this help message and exit
+ +f, ++foo FOO foo help
+ '''))
+
+ def test_usage_long_subparser_command(self):
+ """Test that subparser commands are formatted correctly in help"""
def custom_formatter(prog):
return argparse.RawTextHelpFormatter(prog, max_help_position=50)
@@ -6756,7 +6805,7 @@ class TestImportStar(TestCase):
def test(self):
for name in argparse.__all__:
- self.assertTrue(hasattr(argparse, name))
+ self.assertHasAttr(argparse, name)
def test_all_exports_everything_but_modules(self):
items = [
@@ -6973,7 +7022,7 @@ class TestProgName(TestCase):
def check_usage(self, expected, *args, **kwargs):
res = script_helper.assert_python_ok('-Xutf8', *args, '-h', **kwargs)
- self.assertEqual(res.out.splitlines()[0].decode(),
+ self.assertEqual(os.fsdecode(res.out.splitlines()[0]),
f'usage: {expected} [-h]')
def test_script(self, compiled=False):
@@ -7053,12 +7102,13 @@ class TestTranslations(TestTranslationsBase):
class TestColorized(TestCase):
+ maxDiff = None
def setUp(self):
super().setUp()
# Ensure color even if ran with NO_COLOR=1
_colorize.can_colorize = lambda *args, **kwargs: True
- self.ansi = _colorize.ANSIColors()
+ self.theme = _colorize.get_theme(force_color=True).argparse
def test_argparse_color(self):
# Arrange: create a parser with a bit of everything
@@ -7120,13 +7170,17 @@ class TestColorized(TestCase):
sub2 = subparsers.add_parser("sub2", deprecated=True, help="sub2 help")
sub2.add_argument("--baz", choices=("X", "Y", "Z"), help="baz help")
- heading = self.ansi.BOLD_BLUE
- label, label_b = self.ansi.YELLOW, self.ansi.BOLD_YELLOW
- long, long_b = self.ansi.CYAN, self.ansi.BOLD_CYAN
- pos, pos_b = short, short_b = self.ansi.GREEN, self.ansi.BOLD_GREEN
- sub = self.ansi.BOLD_GREEN
- prog = self.ansi.BOLD_MAGENTA
- reset = self.ansi.RESET
+ prog = self.theme.prog
+ heading = self.theme.heading
+ long = self.theme.summary_long_option
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
# Act
help_text = parser.format_help()
@@ -7171,9 +7225,9 @@ class TestColorized(TestCase):
{heading}subcommands:{reset}
valid subcommands
- {sub}{{sub1,sub2}}{reset} additional help
- {sub}sub1{reset} sub1 help
- {sub}sub2{reset} sub2 help
+ {pos_b}{{sub1,sub2}}{reset} additional help
+ {pos_b}sub1{reset} sub1 help
+ {pos_b}sub2{reset} sub2 help
"""
),
)
@@ -7187,10 +7241,10 @@ class TestColorized(TestCase):
prog="PROG",
usage="[prefix] %(prog)s [suffix]",
)
- heading = self.ansi.BOLD_BLUE
- prog = self.ansi.BOLD_MAGENTA
- reset = self.ansi.RESET
- usage = self.ansi.MAGENTA
+ heading = self.theme.heading
+ prog = self.theme.prog
+ reset = self.theme.reset
+ usage = self.theme.prog_extra
# Act
help_text = parser.format_help()
@@ -7207,6 +7261,79 @@ class TestColorized(TestCase):
),
)
+ def test_custom_formatter_function(self):
+ def custom_formatter(prog):
+ return argparse.RawTextHelpFormatter(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog="PROG",
+ prefix_chars="-+",
+ formatter_class=custom_formatter,
+ color=True,
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ prog = self.theme.prog
+ heading = self.theme.heading
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent(f'''\
+ {heading}usage: {reset}{prog}PROG{reset} [{short}-h{reset}] [{short}+f {label}FOO{reset}] {pos}spam{reset}
+
+ {heading}positional arguments:{reset}
+ {pos_b}spam{reset} spam help
+
+ {heading}options:{reset}
+ {short_b}-h{reset}, {long_b}--help{reset} show this help message and exit
+ {short_b}+f{reset}, {long_b}++foo{reset} {label_b}FOO{reset} foo help
+ '''))
+
+ def test_custom_formatter_class(self):
+ class CustomFormatter(argparse.RawTextHelpFormatter):
+ def __init__(self, prog):
+ super().__init__(prog, indent_increment=5)
+
+ parser = argparse.ArgumentParser(
+ prog="PROG",
+ prefix_chars="-+",
+ formatter_class=CustomFormatter,
+ color=True,
+ )
+ parser.add_argument('+f', '++foo', help="foo help")
+ parser.add_argument('spam', help="spam help")
+
+ prog = self.theme.prog
+ heading = self.theme.heading
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
+
+ parser_help = parser.format_help()
+ self.assertEqual(parser_help, textwrap.dedent(f'''\
+ {heading}usage: {reset}{prog}PROG{reset} [{short}-h{reset}] [{short}+f {label}FOO{reset}] {pos}spam{reset}
+
+ {heading}positional arguments:{reset}
+ {pos_b}spam{reset} spam help
+
+ {heading}options:{reset}
+ {short_b}-h{reset}, {long_b}--help{reset} show this help message and exit
+ {short_b}+f{reset}, {long_b}++foo{reset} {label_b}FOO{reset} foo help
+ '''))
+
def tearDownModule():
# Remove global references to avoid looking like we have refleaks.
diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py
index 2c198a6b8b2..b9df6568123 100644
--- a/Lib/test/test_asdl_parser.py
+++ b/Lib/test/test_asdl_parser.py
@@ -62,17 +62,17 @@ class TestAsdlParser(unittest.TestCase):
alias = self.types['alias']
self.assertEqual(
str(alias),
- 'Product([Field(identifier, name), Field(identifier, asname, opt=True)], '
+ 'Product([Field(identifier, name), Field(identifier, asname, quantifiers=[OPTIONAL])], '
'[Field(int, lineno), Field(int, col_offset), '
- 'Field(int, end_lineno, opt=True), Field(int, end_col_offset, opt=True)])')
+ 'Field(int, end_lineno, quantifiers=[OPTIONAL]), Field(int, end_col_offset, quantifiers=[OPTIONAL])])')
def test_attributes(self):
stmt = self.types['stmt']
self.assertEqual(len(stmt.attributes), 4)
self.assertEqual(repr(stmt.attributes[0]), 'Field(int, lineno)')
self.assertEqual(repr(stmt.attributes[1]), 'Field(int, col_offset)')
- self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, opt=True)')
- self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, opt=True)')
+ self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, quantifiers=[OPTIONAL])')
+ self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, quantifiers=[OPTIONAL])')
def test_constructor_fields(self):
ehandler = self.types['excepthandler']
diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py
index 7d64b1c0e3a..cc46529c0ef 100644
--- a/Lib/test/test_ast/test_ast.py
+++ b/Lib/test/test_ast/test_ast.py
@@ -1,16 +1,20 @@
import _ast_unparse
import ast
import builtins
+import contextlib
import copy
import dis
import enum
+import itertools
import os
import re
import sys
+import tempfile
import textwrap
import types
import unittest
import weakref
+from io import StringIO
from pathlib import Path
from textwrap import dedent
try:
@@ -19,9 +23,10 @@ except ImportError:
_testinternalcapi = None
from test import support
-from test.support import os_helper, script_helper
+from test.support import os_helper
from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow
from test.support.ast_helper import ASTTestMixin
+from test.support.import_helper import ensure_lazy_imports
from test.test_ast.utils import to_tuple
from test.test_ast.snippets import (
eval_tests, eval_results, exec_tests, exec_results, single_tests, single_results
@@ -43,6 +48,12 @@ def ast_repr_update_snapshots() -> None:
AST_REPR_DATA_FILE.write_text("\n".join(data))
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("ast", {"contextlib", "enum", "inspect", "re", "collections", "argparse"})
+
+
class AST_Tests(unittest.TestCase):
maxDiff = None
@@ -264,12 +275,12 @@ class AST_Tests(unittest.TestCase):
self.assertEqual(alias.end_col_offset, 17)
def test_base_classes(self):
- self.assertTrue(issubclass(ast.For, ast.stmt))
- self.assertTrue(issubclass(ast.Name, ast.expr))
- self.assertTrue(issubclass(ast.stmt, ast.AST))
- self.assertTrue(issubclass(ast.expr, ast.AST))
- self.assertTrue(issubclass(ast.comprehension, ast.AST))
- self.assertTrue(issubclass(ast.Gt, ast.AST))
+ self.assertIsSubclass(ast.For, ast.stmt)
+ self.assertIsSubclass(ast.Name, ast.expr)
+ self.assertIsSubclass(ast.stmt, ast.AST)
+ self.assertIsSubclass(ast.expr, ast.AST)
+ self.assertIsSubclass(ast.comprehension, ast.AST)
+ self.assertIsSubclass(ast.Gt, ast.AST)
def test_field_attr_existence(self):
for name, item in ast.__dict__.items():
@@ -692,6 +703,63 @@ class AST_Tests(unittest.TestCase):
with self.assertRaises(SyntaxError):
ast.parse(code, feature_version=(3, 13))
+ def test_pep758_except_with_single_expr(self):
+ single_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} TypeError:
+ ...
+ """)
+
+ single_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} TypeError as exc:
+ ...
+ """)
+
+ single_tuple_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError,):
+ ...
+ """)
+
+ single_tuple_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError,) as exc:
+ ...
+ """)
+
+ single_parens_expr = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError):
+ ...
+ """)
+
+ single_parens_expr_with_as = textwrap.dedent("""
+ try:
+ ...
+ except{0} (TypeError) as exc:
+ ...
+ """)
+
+ for code in [
+ single_expr,
+ single_expr_with_as,
+ single_tuple_expr,
+ single_tuple_expr_with_as,
+ single_parens_expr,
+ single_parens_expr_with_as,
+ ]:
+ for star in [True, False]:
+ code = code.format('*' if star else '')
+ with self.subTest(code=code, star=star):
+ ast.parse(code, feature_version=(3, 14))
+ ast.parse(code, feature_version=(3, 13))
+
def test_pep758_except_star_without_parens(self):
code = textwrap.dedent("""
try:
@@ -753,6 +821,17 @@ class AST_Tests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
compile(expr, "<test>", "eval")
+ def test_constant_as_unicode_name(self):
+ constants = [
+ ("True", b"Tru\xe1\xb5\x89"),
+ ("False", b"Fal\xc5\xbfe"),
+ ("None", b"N\xc2\xbane"),
+ ]
+ for constant in constants:
+ with self.assertRaisesRegex(ValueError,
+ f"identifier field can't represent '{constant[0]}' constant"):
+ ast.parse(constant[1], mode="eval")
+
def test_precedence_enum(self):
class _Precedence(enum.IntEnum):
"""Precedence table that originated from python grammar."""
@@ -1022,7 +1101,7 @@ class CopyTests(unittest.TestCase):
def test_replace_interface(self):
for klass in self.iter_ast_classes():
with self.subTest(klass=klass):
- self.assertTrue(hasattr(klass, '__replace__'))
+ self.assertHasAttr(klass, '__replace__')
fields = set(klass._fields)
with self.subTest(klass=klass, fields=fields):
@@ -1236,13 +1315,22 @@ class CopyTests(unittest.TestCase):
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)
+ def test_replace_accept_missing_field_with_default(self):
+ node = ast.FunctionDef(name="foo", args=ast.arguments())
+ self.assertIs(node.returns, None)
+ self.assertEqual(node.decorator_list, [])
+ node2 = copy.replace(node, name="bar")
+ self.assertEqual(node2.name, "bar")
+ self.assertIs(node2.returns, None)
+ self.assertEqual(node2.decorator_list, [])
+
def test_replace_reject_known_custom_instance_fields_commits(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx
# explicit rejection of known instance fields
- self.assertTrue(hasattr(node, 'extra'))
+ self.assertHasAttr(node, 'extra')
msg = "Name.__replace__ got an unexpected keyword argument 'extra'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, extra=1)
@@ -1284,17 +1372,17 @@ class ASTHelpers_Test(unittest.TestCase):
def test_dump(self):
node = ast.parse('spam(eggs, "and cheese")')
self.assertEqual(ast.dump(node),
- "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), "
- "args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])"
+ "Module(body=[Expr(value=Call(func=Name(id='spam'), "
+ "args=[Name(id='eggs'), Constant(value='and cheese')]))])"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Module([Expr(Call(Name('spam', Load()), [Name('eggs', Load()), "
+ "Module([Expr(Call(Name('spam'), [Name('eggs'), "
"Constant('and cheese')]))])"
)
self.assertEqual(ast.dump(node, include_attributes=True),
- "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load(), "
+ "Module(body=[Expr(value=Call(func=Name(id='spam', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=4), "
- "args=[Name(id='eggs', ctx=Load(), lineno=1, col_offset=5, "
+ "args=[Name(id='eggs', lineno=1, col_offset=5, "
"end_lineno=1, end_col_offset=9), Constant(value='and cheese', "
"lineno=1, col_offset=11, end_lineno=1, end_col_offset=23)], "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=24), "
@@ -1308,18 +1396,18 @@ Module(
body=[
Expr(
value=Call(
- func=Name(id='spam', ctx=Load()),
+ func=Name(id='spam'),
args=[
- Name(id='eggs', ctx=Load()),
+ Name(id='eggs'),
Constant(value='and cheese')]))])""")
self.assertEqual(ast.dump(node, annotate_fields=False, indent='\t'), """\
Module(
\t[
\t\tExpr(
\t\t\tCall(
-\t\t\t\tName('spam', Load()),
+\t\t\t\tName('spam'),
\t\t\t\t[
-\t\t\t\t\tName('eggs', Load()),
+\t\t\t\t\tName('eggs'),
\t\t\t\t\tConstant('and cheese')]))])""")
self.assertEqual(ast.dump(node, include_attributes=True, indent=3), """\
Module(
@@ -1328,7 +1416,6 @@ Module(
value=Call(
func=Name(
id='spam',
- ctx=Load(),
lineno=1,
col_offset=0,
end_lineno=1,
@@ -1336,7 +1423,6 @@ Module(
args=[
Name(
id='eggs',
- ctx=Load(),
lineno=1,
col_offset=5,
end_lineno=1,
@@ -1366,23 +1452,23 @@ Module(
)
node = ast.Raise(exc=ast.Name(id='e', ctx=ast.Load()), lineno=3, col_offset=4)
self.assertEqual(ast.dump(node),
- "Raise(exc=Name(id='e', ctx=Load()))"
+ "Raise(exc=Name(id='e'))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Raise(Name('e', Load()))"
+ "Raise(Name('e'))"
)
self.assertEqual(ast.dump(node, include_attributes=True),
- "Raise(exc=Name(id='e', ctx=Load()), lineno=3, col_offset=4)"
+ "Raise(exc=Name(id='e'), lineno=3, col_offset=4)"
)
self.assertEqual(ast.dump(node, annotate_fields=False, include_attributes=True),
- "Raise(Name('e', Load()), lineno=3, col_offset=4)"
+ "Raise(Name('e'), lineno=3, col_offset=4)"
)
node = ast.Raise(cause=ast.Name(id='e', ctx=ast.Load()))
self.assertEqual(ast.dump(node),
- "Raise(cause=Name(id='e', ctx=Load()))"
+ "Raise(cause=Name(id='e'))"
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "Raise(cause=Name('e', Load()))"
+ "Raise(cause=Name('e'))"
)
# Arguments:
node = ast.arguments(args=[ast.arg("x")])
@@ -1414,10 +1500,10 @@ Module(
[ast.Name('dataclass', ctx=ast.Load())],
)
self.assertEqual(ast.dump(node),
- "ClassDef(name='T', keywords=[keyword(arg='a', value=Constant(value=None))], decorator_list=[Name(id='dataclass', ctx=Load())])",
+ "ClassDef(name='T', keywords=[keyword(arg='a', value=Constant(value=None))], decorator_list=[Name(id='dataclass')])",
)
self.assertEqual(ast.dump(node, annotate_fields=False),
- "ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass', Load())])",
+ "ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass')])",
)
def test_dump_show_empty(self):
@@ -1445,7 +1531,7 @@ Module(
check_node(
# Corner case: there are no real `Name` instances with `id=''`:
ast.Name(id='', ctx=ast.Load()),
- empty="Name(id='', ctx=Load())",
+ empty="Name(id='')",
full="Name(id='', ctx=Load())",
)
@@ -1456,39 +1542,63 @@ Module(
)
check_node(
+ ast.MatchSingleton(value=[]),
+ empty="MatchSingleton(value=[])",
+ full="MatchSingleton(value=[])",
+ )
+
+ check_node(
ast.Constant(value=None),
empty="Constant(value=None)",
full="Constant(value=None)",
)
check_node(
+ ast.Constant(value=[]),
+ empty="Constant(value=[])",
+ full="Constant(value=[])",
+ )
+
+ check_node(
ast.Constant(value=''),
empty="Constant(value='')",
full="Constant(value='')",
)
+ check_node(
+ ast.Interpolation(value=ast.Constant(42), str=None, conversion=-1),
+ empty="Interpolation(value=Constant(value=42), str=None, conversion=-1)",
+ full="Interpolation(value=Constant(value=42), str=None, conversion=-1)",
+ )
+
+ check_node(
+ ast.Interpolation(value=ast.Constant(42), str=[], conversion=-1),
+ empty="Interpolation(value=Constant(value=42), str=[], conversion=-1)",
+ full="Interpolation(value=Constant(value=42), str=[], conversion=-1)",
+ )
+
check_text(
"def a(b: int = 0, *, c): ...",
- empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))])])",
+ empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int'))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))])])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load()))], kwonlyargs=[arg(arg='c')], kw_defaults=[None], defaults=[Constant(value=0)]), body=[Expr(value=Constant(value=Ellipsis))], decorator_list=[], type_params=[])], type_ignores=[])",
)
check_text(
"def a(b: int = 0, *, c): ...",
- empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)])",
+ empty="Module(body=[FunctionDef(name='a', args=arguments(args=[arg(arg='b', annotation=Name(id='int', lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)])",
full="Module(body=[FunctionDef(name='a', args=arguments(posonlyargs=[], args=[arg(arg='b', annotation=Name(id='int', ctx=Load(), lineno=1, col_offset=9, end_lineno=1, end_col_offset=12), lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], kwonlyargs=[arg(arg='c', lineno=1, col_offset=21, end_lineno=1, end_col_offset=22)], kw_defaults=[None], defaults=[Constant(value=0, lineno=1, col_offset=15, end_lineno=1, end_col_offset=16)]), body=[Expr(value=Constant(value=Ellipsis, lineno=1, col_offset=25, end_lineno=1, end_col_offset=28), lineno=1, col_offset=25, end_lineno=1, end_col_offset=28)], decorator_list=[], type_params=[], lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)], type_ignores=[])",
include_attributes=True,
)
check_text(
'spam(eggs, "and cheese")',
- empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))])",
+ empty="Module(body=[Expr(value=Call(func=Name(id='spam'), args=[Name(id='eggs'), Constant(value='and cheese')]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')], keywords=[]))], type_ignores=[])",
)
check_text(
'spam(eggs, text="and cheese")',
- empty="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))])",
+ empty="Module(body=[Expr(value=Call(func=Name(id='spam'), args=[Name(id='eggs')], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))])",
full="Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load())], keywords=[keyword(arg='text', value=Constant(value='and cheese'))]))], type_ignores=[])",
)
@@ -1522,12 +1632,12 @@ Module(
self.assertEqual(src, ast.fix_missing_locations(src))
self.maxDiff = None
self.assertEqual(ast.dump(src, include_attributes=True),
- "Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), "
+ "Module(body=[Expr(value=Call(func=Name(id='write', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=5), "
"args=[Constant(value='spam', lineno=1, col_offset=6, end_lineno=1, "
"end_col_offset=12)], lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=13), lineno=1, col_offset=0, end_lineno=1, "
- "end_col_offset=13), Expr(value=Call(func=Name(id='spam', ctx=Load(), "
+ "end_col_offset=13), Expr(value=Call(func=Name(id='spam', "
"lineno=1, col_offset=0, end_lineno=1, end_col_offset=0), "
"args=[Constant(value='eggs', lineno=1, col_offset=0, end_lineno=1, "
"end_col_offset=0)], lineno=1, col_offset=0, end_lineno=1, "
@@ -2983,7 +3093,7 @@ class ASTConstructorTests(unittest.TestCase):
with self.assertWarnsRegex(DeprecationWarning,
r"FunctionDef\.__init__ missing 1 required positional argument: 'name'"):
node = ast.FunctionDef(args=args)
- self.assertFalse(hasattr(node, "name"))
+ self.assertNotHasAttr(node, "name")
self.assertEqual(node.decorator_list, [])
node = ast.FunctionDef(name='foo', args=args)
self.assertEqual(node.name, 'foo')
@@ -3175,23 +3285,263 @@ class ModuleStateTests(unittest.TestCase):
self.assertEqual(res, 0)
-class ASTMainTests(unittest.TestCase):
- # Tests `ast.main()` function.
+class CommandLineTests(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+ self.addCleanup(os_helper.unlink, self.filename)
+
+ @staticmethod
+ def text_normalize(string):
+ return textwrap.dedent(string).strip()
+
+ def set_source(self, content):
+ Path(self.filename).write_text(self.text_normalize(content))
+
+ def invoke_ast(self, *flags):
+ stderr = StringIO()
+ stdout = StringIO()
+ with (
+ contextlib.redirect_stdout(stdout),
+ contextlib.redirect_stderr(stderr),
+ ):
+ ast.main(args=[*flags, self.filename])
+ self.assertEqual(stderr.getvalue(), '')
+ return stdout.getvalue().strip()
+
+ def check_output(self, source, expect, *flags):
+ self.set_source(source)
+ res = self.invoke_ast(*flags)
+ expect = self.text_normalize(expect)
+ self.assertEqual(res, expect)
- def test_cli_file_input(self):
- code = "print(1, 2, 3)"
- expected = ast.dump(ast.parse(code), indent=3)
+ @support.requires_resource('cpu')
+ def test_invocation(self):
+ # test various combinations of parameters
+ base_flags = (
+ ('-m=exec', '--mode=exec'),
+ ('--no-type-comments', '--no-type-comments'),
+ ('-a', '--include-attributes'),
+ ('-i=4', '--indent=4'),
+ ('--feature-version=3.13', '--feature-version=3.13'),
+ ('-O=-1', '--optimize=-1'),
+ ('--show-empty', '--show-empty'),
+ )
+ self.set_source('''
+ print(1, 2, 3)
+ def f(x: int) -> int:
+ x -= 1
+ return x
+ ''')
- with os_helper.temp_dir() as tmp_dir:
- filename = os.path.join(tmp_dir, "test_module.py")
- with open(filename, 'w', encoding='utf-8') as f:
- f.write(code)
- res, _ = script_helper.run_python_until_end("-m", "ast", filename)
+ for r in range(1, len(base_flags) + 1):
+ for choices in itertools.combinations(base_flags, r=r):
+ for args in itertools.product(*choices):
+ with self.subTest(flags=args):
+ self.invoke_ast(*args)
+
+ @support.force_not_colorized
+ def test_help_message(self):
+ for flag in ('-h', '--help', '--unknown'):
+ with self.subTest(flag=flag):
+ output = StringIO()
+ with self.assertRaises(SystemExit):
+ with contextlib.redirect_stderr(output):
+ ast.main(args=flag)
+ self.assertStartsWith(output.getvalue(), 'usage: ')
+
+ def test_exec_mode_flag(self):
+ # test 'python -m ast -m/--mode exec'
+ source = 'x: bool = 1 # type: ignore[assignment]'
+ expect = '''
+ Module(
+ body=[
+ AnnAssign(
+ target=Name(id='x', ctx=Store()),
+ annotation=Name(id='bool'),
+ value=Constant(value=1),
+ simple=1)],
+ type_ignores=[
+ TypeIgnore(lineno=1, tag='[assignment]')])
+ '''
+ for flag in ('-m=exec', '--mode=exec'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_single_mode_flag(self):
+ # test 'python -m ast -m/--mode single'
+ source = 'pass'
+ expect = '''
+ Interactive(
+ body=[
+ Pass()])
+ '''
+ for flag in ('-m=single', '--mode=single'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_eval_mode_flag(self):
+ # test 'python -m ast -m/--mode eval'
+ source = 'print(1, 2, 3)'
+ expect = '''
+ Expression(
+ body=Call(
+ func=Name(id='print'),
+ args=[
+ Constant(value=1),
+ Constant(value=2),
+ Constant(value=3)]))
+ '''
+ for flag in ('-m=eval', '--mode=eval'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_func_type_mode_flag(self):
+ # test 'python -m ast -m/--mode func_type'
+ source = '(int, str) -> list[int]'
+ expect = '''
+ FunctionType(
+ argtypes=[
+ Name(id='int'),
+ Name(id='str')],
+ returns=Subscript(
+ value=Name(id='list'),
+ slice=Name(id='int')))
+ '''
+ for flag in ('-m=func_type', '--mode=func_type'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_no_type_comments_flag(self):
+ # test 'python -m ast --no-type-comments'
+ source = 'x: bool = 1 # type: ignore[assignment]'
+ expect = '''
+ Module(
+ body=[
+ AnnAssign(
+ target=Name(id='x', ctx=Store()),
+ annotation=Name(id='bool'),
+ value=Constant(value=1),
+ simple=1)])
+ '''
+ self.check_output(source, expect, '--no-type-comments')
+
+ def test_include_attributes_flag(self):
+ # test 'python -m ast -a/--include-attributes'
+ source = 'pass'
+ expect = '''
+ Module(
+ body=[
+ Pass(
+ lineno=1,
+ col_offset=0,
+ end_lineno=1,
+ end_col_offset=4)])
+ '''
+ for flag in ('-a', '--include-attributes'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_indent_flag(self):
+ # test 'python -m ast -i/--indent 0'
+ source = 'pass'
+ expect = '''
+ Module(
+ body=[
+ Pass()])
+ '''
+ for flag in ('-i=0', '--indent=0'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_feature_version_flag(self):
+ # test 'python -m ast --feature-version 3.9/3.10'
+ source = '''
+ match x:
+ case 1:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='x'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=1)),
+ body=[
+ Pass()])])])
+ '''
+ self.check_output(source, expect, '--feature-version=3.10')
+ with self.assertRaises(SyntaxError):
+ self.invoke_ast('--feature-version=3.9')
- self.assertEqual(res.err, b"")
- self.assertEqual(expected.splitlines(),
- res.out.decode("utf8").splitlines())
- self.assertEqual(res.rc, 0)
+ def test_no_optimize_flag(self):
+ # test 'python -m ast -O/--optimize -1/0'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=BinOp(
+ left=Constant(value=1),
+ op=Add(),
+ right=Constant(value=2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=-1', '--optimize=-1', '-O=0', '--optimize=0'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_optimize_flag(self):
+ # test 'python -m ast -O/--optimize 1/2'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a'),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=(1+2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=1', '--optimize=1', '-O=2', '--optimize=2'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_show_empty_flag(self):
+ # test 'python -m ast --show-empty'
+ source = 'print(1, 2, 3)'
+ expect = '''
+ Module(
+ body=[
+ Expr(
+ value=Call(
+ func=Name(id='print', ctx=Load()),
+ args=[
+ Constant(value=1),
+ Constant(value=2),
+ Constant(value=3)],
+ keywords=[]))],
+ type_ignores=[])
+ '''
+ self.check_output(source, expect, '--show-empty')
class ASTOptimiziationTests(unittest.TestCase):
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 2c44647bf3e..636cb33dd98 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -2021,6 +2021,15 @@ class TestUnawaitedWarnings(unittest.TestCase):
g.athrow(RuntimeError)
gc_collect()
+ def test_athrow_throws_immediately(self):
+ async def gen():
+ yield 1
+
+ g = gen()
+ msg = "athrow expected at least 1 argument, got 0"
+ with self.assertRaisesRegex(TypeError, msg):
+ g.athrow()
+
def test_aclose(self):
async def gen():
yield 1
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
index a2fb1022ae4..9f3b6f9acef 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -263,6 +263,24 @@ class EagerTaskFactoryLoopTests:
self.run_coro(run())
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=False, name="example"
+ )
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ self.run_coro(main())
+
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
@@ -505,5 +523,24 @@ class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
+
+class DefaultTaskFactoryEagerStart(test_utils.TestCase):
+ def test_eager_start_true_with_default_factory(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=True, name="example"
+ )
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
index 8b51522278a..39bef465bdb 100644
--- a/Lib/test/test_asyncio/test_futures.py
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -413,7 +413,7 @@ class BaseFutureTests:
def test_copy_state(self):
from asyncio.futures import _copy_future_state
- f = self._new_future(loop=self.loop)
+ f = concurrent.futures.Future()
f.set_result(10)
newf = self._new_future(loop=self.loop)
@@ -421,7 +421,7 @@ class BaseFutureTests:
self.assertTrue(newf.done())
self.assertEqual(newf.result(), 10)
- f_exception = self._new_future(loop=self.loop)
+ f_exception = concurrent.futures.Future()
f_exception.set_exception(RuntimeError())
newf_exception = self._new_future(loop=self.loop)
@@ -429,7 +429,7 @@ class BaseFutureTests:
self.assertTrue(newf_exception.done())
self.assertRaises(RuntimeError, newf_exception.result)
- f_cancelled = self._new_future(loop=self.loop)
+ f_cancelled = concurrent.futures.Future()
f_cancelled.cancel()
newf_cancelled = self._new_future(loop=self.loop)
@@ -441,7 +441,7 @@ class BaseFutureTests:
except BaseException as e:
f_exc = e
- f_conexc = self._new_future(loop=self.loop)
+ f_conexc = concurrent.futures.Future()
f_conexc.set_exception(f_exc)
newf_conexc = self._new_future(loop=self.loop)
@@ -454,6 +454,56 @@ class BaseFutureTests:
newf_tb = ''.join(traceback.format_tb(newf_exc.__traceback__))
self.assertEqual(newf_tb.count('raise concurrent.futures.InvalidStateError'), 1)
+ def test_copy_state_from_concurrent_futures(self):
+ """Test _copy_future_state from concurrent.futures.Future.
+
+ This tests the optimized path using _get_snapshot when available.
+ """
+ from asyncio.futures import _copy_future_state
+
+ # Test with a result
+ f_concurrent = concurrent.futures.Future()
+ f_concurrent.set_result(42)
+ f_asyncio = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent, f_asyncio)
+ self.assertTrue(f_asyncio.done())
+ self.assertEqual(f_asyncio.result(), 42)
+
+ # Test with an exception
+ f_concurrent_exc = concurrent.futures.Future()
+ f_concurrent_exc.set_exception(ValueError("test exception"))
+ f_asyncio_exc = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_exc, f_asyncio_exc)
+ self.assertTrue(f_asyncio_exc.done())
+ with self.assertRaises(ValueError) as cm:
+ f_asyncio_exc.result()
+ self.assertEqual(str(cm.exception), "test exception")
+
+ # Test with cancelled state
+ f_concurrent_cancelled = concurrent.futures.Future()
+ f_concurrent_cancelled.cancel()
+ f_asyncio_cancelled = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_cancelled, f_asyncio_cancelled)
+ self.assertTrue(f_asyncio_cancelled.cancelled())
+
+ # Test that destination already cancelled prevents copy
+ f_concurrent_result = concurrent.futures.Future()
+ f_concurrent_result.set_result(10)
+ f_asyncio_precancelled = self._new_future(loop=self.loop)
+ f_asyncio_precancelled.cancel()
+ _copy_future_state(f_concurrent_result, f_asyncio_precancelled)
+ self.assertTrue(f_asyncio_precancelled.cancelled())
+
+ # Test exception type conversion
+ f_concurrent_invalid = concurrent.futures.Future()
+ f_concurrent_invalid.set_exception(concurrent.futures.InvalidStateError("invalid"))
+ f_asyncio_invalid = self._new_future(loop=self.loop)
+ _copy_future_state(f_concurrent_invalid, f_asyncio_invalid)
+ self.assertTrue(f_asyncio_invalid.done())
+ with self.assertRaises(asyncio.exceptions.InvalidStateError) as cm:
+ f_asyncio_invalid.result()
+ self.assertEqual(str(cm.exception), "invalid")
+
def test_iter(self):
fut = self._new_future(loop=self.loop)
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
index 3bb3e5c4ca0..047f03cbb14 100644
--- a/Lib/test/test_asyncio/test_locks.py
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -14,7 +14,7 @@ STR_RGX_REPR = (
r'(, value:\d)?'
r'(, waiters:\d+)?'
r'(, waiters:\d+\/\d+)?' # barrier
- r')\]>\Z'
+ r')\]>\z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index de81936b745..aab6a779170 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -347,6 +347,18 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
selectors.EVENT_WRITE)])
self.loop._remove_writer.assert_called_with(1)
+ def test_accept_connection_zero_one(self):
+ for backlog in [0, 1]:
+ sock = mock.Mock()
+ sock.accept.return_value = (mock.Mock(), mock.Mock())
+ with self.subTest(backlog):
+ mock_obj = mock.patch.object
+ with mock_obj(self.loop, '_accept_connection2') as accept2_mock:
+ self.loop._accept_connection(
+ mock.Mock(), sock, backlog=backlog)
+ self.loop.run_until_complete(asyncio.sleep(0))
+ self.assertEqual(sock.accept.call_count, backlog + 1)
+
def test_accept_connection_multiple(self):
sock = mock.Mock()
sock.accept.return_value = (mock.Mock(), mock.Mock())
@@ -362,7 +374,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
self.loop._accept_connection(
mock.Mock(), sock, backlog=backlog)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
def test_accept_connection_skip_connectionabortederror(self):
sock = mock.Mock()
@@ -388,7 +400,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
# as in test_accept_connection_multiple avoid task pending
# warnings by using asyncio.sleep(0)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
class SelectorTransportTests(test_utils.TestCase):
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
index 986ecc2c5a9..3a7185cd897 100644
--- a/Lib/test/test_asyncio/test_ssl.py
+++ b/Lib/test/test_asyncio/test_ssl.py
@@ -195,9 +195,10 @@ class TestSSL(test_utils.TestCase):
except (BrokenPipeError, ConnectionError):
pass
- def test_create_server_ssl_1(self):
+ @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False)
+ def test_create_server_ssl_1(self, size):
CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
+ TOTAL_CNT = size # total number of clients that test will create
TIMEOUT = support.LONG_TIMEOUT # timeout for this test
A_DATA = b'A' * 1024 * BUF_MULTIPLIER
@@ -1038,9 +1039,10 @@ class TestSSL(test_utils.TestCase):
self.loop.run_until_complete(run_main())
- def test_create_server_ssl_over_ssl(self):
+ @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False)
+ def test_create_server_ssl_over_ssl(self, size):
CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
+ TOTAL_CNT = size # total number of clients that test will create
TIMEOUT = support.LONG_TIMEOUT # timeout for this test
A_DATA = b'A' * 1024 * BUF_MULTIPLIER
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 8d7f1733454..f6f976f213a 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -89,8 +89,8 @@ class BaseTaskTests:
Future = None
all_tasks = None
- def new_task(self, loop, coro, name='TestTask', context=None):
- return self.__class__.Task(coro, loop=loop, name=name, context=context)
+ def new_task(self, loop, coro, name='TestTask', context=None, eager_start=None):
+ return self.__class__.Task(coro, loop=loop, name=name, context=context, eager_start=eager_start)
def new_future(self, loop):
return self.__class__.Future(loop=loop)
@@ -2116,6 +2116,46 @@ class BaseTaskTests:
self.assertTrue(outer.cancelled())
self.assertEqual(0, 0 if outer._callbacks is None else len(outer._callbacks))
+ def test_shield_cancel_outer_result(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_result(1)
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_not_called()
+
+ def test_shield_cancel_outer_exception(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_exception(Exception('foo'))
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_called_once()
+
+ def test_shield_duplicate_log_once(self):
+ mock_handler = mock.Mock()
+ self.loop.set_exception_handler(mock_handler)
+ inner = self.new_future(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ outer = asyncio.shield(inner)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ inner.set_exception(Exception('foo'))
+ test_utils.run_briefly(self.loop)
+ mock_handler.assert_called_once()
+
def test_shield_shortcut(self):
fut = self.new_future(self.loop)
fut.set_result(42)
@@ -2686,6 +2726,35 @@ class BaseTaskTests:
self.assertEqual([None, 1, 2], ret)
+ def test_eager_start_true(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=True, name="example")
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=False, name="example")
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
diff --git a/Lib/test/test_asyncio/test_tools.py b/Lib/test/test_asyncio/test_tools.py
new file mode 100644
index 00000000000..34e94830204
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tools.py
@@ -0,0 +1,1706 @@
+import unittest
+
+from asyncio import tools
+
+from collections import namedtuple
+
+FrameInfo = namedtuple('FrameInfo', ['funcname', 'filename', 'lineno'])
+CoroInfo = namedtuple('CoroInfo', ['call_stack', 'task_name'])
+TaskInfo = namedtuple('TaskInfo', ['task_id', 'task_name', 'coroutine_stack', 'awaited_by'])
+AwaitedInfo = namedtuple('AwaitedInfo', ['thread_id', 'awaited_by'])
+
+
+# mock output of get_all_awaited_by function.
+TEST_INPUTS_TREE = [
+ [
+ # test case containing a task called timer being awaited in two
+ # different subtasks part of a TaskGroup (root1 and root2) which call
+ # awaiter functions.
+ (
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="timer",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "/path/to/app.py", 130),
+ FrameInfo("awaiter2", "/path/to/app.py", 120),
+ FrameInfo("awaiter", "/path/to/app.py", 110)
+ ],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiterB3", "/path/to/app.py", 190),
+ FrameInfo("awaiterB2", "/path/to/app.py", 180),
+ FrameInfo("awaiterB", "/path/to/app.py", 170)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiterB3", "/path/to/app.py", 190),
+ FrameInfo("awaiterB2", "/path/to/app.py", 180),
+ FrameInfo("awaiterB", "/path/to/app.py", 170)
+ ],
+ task_name=6
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "/path/to/app.py", 130),
+ FrameInfo("awaiter2", "/path/to/app.py", 120),
+ FrameInfo("awaiter", "/path/to/app.py", 110)
+ ],
+ task_name=7
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="root1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=9,
+ task_name="root2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="child1_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="child2_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="child1_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="child2_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " └── __aexit__",
+ " └── _aexit",
+ " ├── (T) root1",
+ " │ └── bloch",
+ " │ └── blocho_caller",
+ " │ └── __aexit__",
+ " │ └── _aexit",
+ " │ ├── (T) child1_1",
+ " │ │ └── awaiter /path/to/app.py:110",
+ " │ │ └── awaiter2 /path/to/app.py:120",
+ " │ │ └── awaiter3 /path/to/app.py:130",
+ " │ │ └── (T) timer",
+ " │ └── (T) child2_1",
+ " │ └── awaiterB /path/to/app.py:170",
+ " │ └── awaiterB2 /path/to/app.py:180",
+ " │ └── awaiterB3 /path/to/app.py:190",
+ " │ └── (T) timer",
+ " └── (T) root2",
+ " └── bloch",
+ " └── blocho_caller",
+ " └── __aexit__",
+ " └── _aexit",
+ " ├── (T) child1_2",
+ " │ └── awaiter /path/to/app.py:110",
+ " │ └── awaiter2 /path/to/app.py:120",
+ " │ └── awaiter3 /path/to/app.py:130",
+ " │ └── (T) timer",
+ " └── (T) child2_2",
+ " └── awaiterB /path/to/app.py:170",
+ " └── awaiterB2 /path/to/app.py:180",
+ " └── awaiterB3 /path/to/app.py:190",
+ " └── (T) timer",
+ ]
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots
+ (
+ AwaitedInfo(
+ thread_id=9,
+ awaited_by=[
+ TaskInfo(
+ task_id=5,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-6",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-7",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="Task-8",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=10,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=11, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [
+ "└── (T) Task-5",
+ " └── main2",
+ " ├── (T) Task-6",
+ " ├── (T) Task-7",
+ " └── (T) Task-8",
+ ],
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " ├── (T) Task-2",
+ " ├── (T) Task-3",
+ " └── (T) Task-4",
+ ],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots, one of them without subtasks
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=3,
+ awaited_by=[
+ TaskInfo(
+ task_id=4,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=8, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ ["└── (T) Task-5"],
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " ├── (T) Task-2",
+ " ├── (T) Task-3",
+ " └── (T) Task-4",
+ ],
+ ]
+ ),
+ ],
+]
+
+TEST_INPUTS_CYCLES_TREE = [
+ [
+ # this test case contains a cycle: two tasks awaiting each other.
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="a",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter2", "", 0)],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="b",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ ([[4, 3, 4]]),
+ ],
+ [
+ # this test case contains two cycles
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_c", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_a", "", 0)
+ ],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0)
+ ],
+ task_name=6
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ ([[4, 3, 4], [4, 6, 5, 4]]),
+ ],
+]
+
+TEST_INPUTS_TABLE = [
+ [
+ # test case containing a task called timer being awaited in two
+ # different subtasks part of a TaskGroup (root1 and root2) which call
+ # awaiter functions.
+ (
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="timer",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "", 0),
+ FrameInfo("awaiter2", "", 0),
+ FrameInfo("awaiter", "", 0)
+ ],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter1_3", "", 0),
+ FrameInfo("awaiter1_2", "", 0),
+ FrameInfo("awaiter1", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter1_3", "", 0),
+ FrameInfo("awaiter1_2", "", 0),
+ FrameInfo("awaiter1", "", 0)
+ ],
+ task_name=6
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("awaiter3", "", 0),
+ FrameInfo("awaiter2", "", 0),
+ FrameInfo("awaiter", "", 0)
+ ],
+ task_name=7
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="root1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=9,
+ task_name="root2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("main", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="child1_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="child2_1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=8
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="child1_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="child2_2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("_aexit", "", 0),
+ FrameInfo("__aexit__", "", 0),
+ FrameInfo("blocho_caller", "", 0),
+ FrameInfo("bloch", "", 0)
+ ],
+ task_name=9
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter3 -> awaiter2 -> awaiter",
+ "child1_1",
+ "0x4",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter1_3 -> awaiter1_2 -> awaiter1",
+ "child2_2",
+ "0x5",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter1_3 -> awaiter1_2 -> awaiter1",
+ "child2_1",
+ "0x6",
+ ],
+ [
+ 1,
+ "0x3",
+ "timer",
+ "",
+ "awaiter3 -> awaiter2 -> awaiter",
+ "child1_2",
+ "0x7",
+ ],
+ [
+ 1,
+ "0x8",
+ "root1",
+ "",
+ "_aexit -> __aexit__ -> main",
+ "Task-1",
+ "0x2",
+ ],
+ [
+ 1,
+ "0x9",
+ "root2",
+ "",
+ "_aexit -> __aexit__ -> main",
+ "Task-1",
+ "0x2",
+ ],
+ [
+ 1,
+ "0x4",
+ "child1_1",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root1",
+ "0x8",
+ ],
+ [
+ 1,
+ "0x6",
+ "child2_1",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root1",
+ "0x8",
+ ],
+ [
+ 1,
+ "0x7",
+ "child1_2",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root2",
+ "0x9",
+ ],
+ [
+ 1,
+ "0x5",
+ "child2_2",
+ "",
+ "_aexit -> __aexit__ -> blocho_caller -> bloch",
+ "root2",
+ "0x9",
+ ],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots
+ (
+ AwaitedInfo(
+ thread_id=9,
+ awaited_by=[
+ TaskInfo(
+ task_id=5,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-6",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-7",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=8,
+ task_name="Task-8",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main2", "", 0)],
+ task_name=5
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=10,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=11, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ),
+ (
+ [
+ [9, "0x5", "Task-5", "", "", "", "0x0"],
+ [9, "0x6", "Task-6", "", "main2", "Task-5", "0x5"],
+ [9, "0x7", "Task-7", "", "main2", "Task-5", "0x5"],
+ [9, "0x8", "Task-8", "", "main2", "Task-5", "0x5"],
+ [10, "0x1", "Task-1", "", "", "", "0x0"],
+ [10, "0x2", "Task-2", "", "main", "Task-1", "0x1"],
+ [10, "0x3", "Task-3", "", "main", "Task-1", "0x1"],
+ [10, "0x4", "Task-4", "", "main", "Task-1", "0x1"],
+ ]
+ ),
+ ],
+ [
+ # test case containing two roots, one of them without subtasks
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-5",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ ),
+ AwaitedInfo(
+ thread_id=3,
+ awaited_by=[
+ TaskInfo(
+ task_id=4,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=7,
+ task_name="Task-4",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=8, awaited_by=[]),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-5", "", "", "", "0x0"],
+ [3, "0x4", "Task-1", "", "", "", "0x0"],
+ [3, "0x5", "Task-2", "", "main", "Task-1", "0x4"],
+ [3, "0x6", "Task-3", "", "main", "Task-1", "0x4"],
+ [3, "0x7", "Task-4", "", "main", "Task-1", "0x4"],
+ ]
+ ),
+ ],
+ # CASES WITH CYCLES
+ [
+ # this test case contains a cycle: two tasks awaiting each other.
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="a",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter2", "", 0)],
+ task_name=4
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="b",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("awaiter", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [1, "0x3", "a", "", "awaiter2", "b", "0x4"],
+ [1, "0x3", "a", "", "main", "Task-1", "0x2"],
+ [1, "0x4", "b", "", "awaiter", "a", "0x3"],
+ ]
+ ),
+ ],
+ [
+ # this test case contains two cycles
+ (
+ [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_c", "", 0)
+ ],
+ task_name=5
+ ),
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_a", "", 0)
+ ],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=5,
+ task_name="C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0)
+ ],
+ task_name=6
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=6,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("nested", "", 0),
+ FrameInfo("nested", "", 0),
+ FrameInfo("task_b", "", 0)
+ ],
+ task_name=4
+ )
+ ]
+ )
+ ]
+ ),
+ AwaitedInfo(thread_id=0, awaited_by=[])
+ ]
+ ),
+ (
+ [
+ [1, "0x2", "Task-1", "", "", "", "0x0"],
+ [
+ 1,
+ "0x3",
+ "A",
+ "",
+ "nested -> nested -> task_b",
+ "B",
+ "0x4",
+ ],
+ [
+ 1,
+ "0x4",
+ "B",
+ "",
+ "nested -> nested -> task_c",
+ "C",
+ "0x5",
+ ],
+ [
+ 1,
+ "0x4",
+ "B",
+ "",
+ "nested -> nested -> task_a",
+ "A",
+ "0x3",
+ ],
+ [
+ 1,
+ "0x5",
+ "C",
+ "",
+ "nested -> nested",
+ "Task-2",
+ "0x6",
+ ],
+ [
+ 1,
+ "0x6",
+ "Task-2",
+ "",
+ "nested -> nested -> task_b",
+ "B",
+ "0x4",
+ ],
+ ]
+ ),
+ ],
+]
+
+
+class TestAsyncioToolsTree(unittest.TestCase):
+ def test_asyncio_utils(self):
+ for input_, tree in TEST_INPUTS_TREE:
+ with self.subTest(input_):
+ result = tools.build_async_tree(input_)
+ self.assertEqual(result, tree)
+
+ def test_asyncio_utils_cycles(self):
+ for input_, cycles in TEST_INPUTS_CYCLES_TREE:
+ with self.subTest(input_):
+ try:
+ tools.build_async_tree(input_)
+ except tools.CycleFoundException as e:
+ self.assertEqual(e.cycles, cycles)
+
+
+class TestAsyncioToolsTable(unittest.TestCase):
+ def test_asyncio_utils(self):
+ for input_, table in TEST_INPUTS_TABLE:
+ with self.subTest(input_):
+ result = tools.build_task_table(input_)
+ self.assertEqual(result, table)
+
+
+class TestAsyncioToolsBasic(unittest.TestCase):
+ def test_empty_input_tree(self):
+ """Test build_async_tree with empty input."""
+ result = []
+ expected_output = []
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_empty_input_table(self):
+ """Test build_task_table with empty input."""
+ result = []
+ expected_output = []
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_only_independent_tasks_tree(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="taskA",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="taskB",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected = [["└── (T) taskA"], ["└── (T) taskB"]]
+ result = tools.build_async_tree(input_)
+ self.assertEqual(sorted(result), sorted(expected))
+
+ def test_only_independent_tasks_table(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="taskA",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="taskB",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ self.assertEqual(
+ tools.build_task_table(input_),
+ [[1, '0xa', 'taskA', '', '', '', '0x0'], [1, '0xb', 'taskB', '', '', '', '0x0']]
+ )
+
+ def test_single_task_tree(self):
+ """Test build_async_tree with a single task and no awaits."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [
+ "└── (T) Task-1",
+ ]
+ ]
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_single_task_table(self):
+ """Test build_task_table with a single task and no awaits."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected_output = [[1, '0x2', 'Task-1', '', '', '', '0x0']]
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_cycle_detection(self):
+ """Test build_async_tree raises CycleFoundException for cyclic input."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as context:
+ tools.build_async_tree(result)
+ self.assertEqual(context.exception.cycles, [[3, 2, 3]])
+
+ def test_complex_tree(self):
+ """Test build_async_tree with a more complex tree structure."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [
+ "└── (T) Task-1",
+ " └── main",
+ " └── (T) Task-2",
+ " └── main",
+ " └── (T) Task-3",
+ ]
+ ]
+ self.assertEqual(tools.build_async_tree(result), expected_output)
+
+ def test_complex_table(self):
+ """Test build_task_table with a more complex tree structure."""
+ result = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=2,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=4,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("main", "", 0)],
+ task_name=3
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ expected_output = [
+ [1, '0x2', 'Task-1', '', '', '', '0x0'],
+ [1, '0x3', 'Task-2', '', 'main', 'Task-1', '0x2'],
+ [1, '0x4', 'Task-3', '', 'main', 'Task-2', '0x3']
+ ]
+ self.assertEqual(tools.build_task_table(result), expected_output)
+
+ def test_deep_coroutine_chain(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=10,
+ task_name="leaf",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("c1", "", 0),
+ FrameInfo("c2", "", 0),
+ FrameInfo("c3", "", 0),
+ FrameInfo("c4", "", 0),
+ FrameInfo("c5", "", 0)
+ ],
+ task_name=11
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=11,
+ task_name="root",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ expected = [
+ [
+ "└── (T) root",
+ " └── c5",
+ " └── c4",
+ " └── c3",
+ " └── c2",
+ " └── c1",
+ " └── (T) leaf",
+ ]
+ ]
+ result = tools.build_async_tree(input_)
+ self.assertEqual(result, expected)
+
+ def test_multiple_cycles_same_node(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call1", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-B",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call2", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-C",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("call3", "", 0)],
+ task_name=1
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("call4", "", 0)],
+ task_name=2
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as ctx:
+ tools.build_async_tree(input_)
+ cycles = ctx.exception.cycles
+ self.assertTrue(any(set(c) == {1, 2, 3} for c in cycles))
+
+ def test_table_output_format(self):
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("foo", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-B",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ table = tools.build_task_table(input_)
+ for row in table:
+ self.assertEqual(len(row), 7)
+ self.assertIsInstance(row[0], int) # thread ID
+ self.assertTrue(
+ isinstance(row[1], str) and row[1].startswith("0x")
+ ) # hex task ID
+ self.assertIsInstance(row[2], str) # task name
+ self.assertIsInstance(row[3], str) # coroutine stack
+ self.assertIsInstance(row[4], str) # coroutine chain
+ self.assertIsInstance(row[5], str) # awaiter name
+ self.assertTrue(
+ isinstance(row[6], str) and row[6].startswith("0x")
+ ) # hex awaiter ID
+
+
+class TestAsyncioToolsEdgeCases(unittest.TestCase):
+
+ def test_task_awaits_self(self):
+ """A task directly awaits itself - should raise a cycle."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Self-Awaiter",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("loopback", "", 0)],
+ task_name=1
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ with self.assertRaises(tools.CycleFoundException) as ctx:
+ tools.build_async_tree(input_)
+ self.assertIn([1, 1], ctx.exception.cycles)
+
+ def test_task_with_missing_awaiter_id(self):
+ """Awaiter ID not in task list - should not crash, just show 'Unknown'."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-A",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("coro", "", 0)],
+ task_name=999
+ )
+ ]
+ )
+ ]
+ )
+ ]
+ table = tools.build_task_table(input_)
+ self.assertEqual(len(table), 1)
+ self.assertEqual(table[0][5], "Unknown")
+
+ def test_duplicate_coroutine_frames(self):
+ """Same coroutine frame repeated under a parent - should deduplicate."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="Task-1",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("frameA", "", 0)],
+ task_name=2
+ ),
+ CoroInfo(
+ call_stack=[FrameInfo("frameA", "", 0)],
+ task_name=3
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="Task-2",
+ coroutine_stack=[],
+ awaited_by=[]
+ ),
+ TaskInfo(
+ task_id=3,
+ task_name="Task-3",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ tree = tools.build_async_tree(input_)
+ # Both children should be under the same coroutine node
+ flat = "\n".join(tree[0])
+ self.assertIn("frameA", flat)
+ self.assertIn("Task-2", flat)
+ self.assertIn("Task-1", flat)
+
+ flat = "\n".join(tree[1])
+ self.assertIn("frameA", flat)
+ self.assertIn("Task-3", flat)
+ self.assertIn("Task-1", flat)
+
+ def test_task_with_no_name(self):
+ """Task with no name in id2name - should still render with fallback."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="root",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[FrameInfo("f1", "", 0)],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name=None,
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ # If name is None, fallback to string should not crash
+ tree = tools.build_async_tree(input_)
+ self.assertIn("(T) None", "\n".join(tree[0]))
+
+ def test_tree_rendering_with_custom_emojis(self):
+ """Pass custom emojis to the tree renderer."""
+ input_ = [
+ AwaitedInfo(
+ thread_id=1,
+ awaited_by=[
+ TaskInfo(
+ task_id=1,
+ task_name="MainTask",
+ coroutine_stack=[],
+ awaited_by=[
+ CoroInfo(
+ call_stack=[
+ FrameInfo("f1", "", 0),
+ FrameInfo("f2", "", 0)
+ ],
+ task_name=2
+ )
+ ]
+ ),
+ TaskInfo(
+ task_id=2,
+ task_name="SubTask",
+ coroutine_stack=[],
+ awaited_by=[]
+ )
+ ]
+ )
+ ]
+ tree = tools.build_async_tree(input_, task_emoji="🧵", cor_emoji="🔁")
+ flat = "\n".join(tree[0])
+ self.assertIn("🧵 MainTask", flat)
+ self.assertIn("🔁 f1", flat)
+ self.assertIn("🔁 f2", flat)
+ self.assertIn("🧵 SubTask", flat)
diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py
index 2b24b5d7927..077765fcda2 100644
--- a/Lib/test/test_audit.py
+++ b/Lib/test/test_audit.py
@@ -134,7 +134,7 @@ class AuditTest(unittest.TestCase):
self.assertEqual(events[0][0], "socket.gethostname")
self.assertEqual(events[1][0], "socket.__new__")
self.assertEqual(events[2][0], "socket.bind")
- self.assertTrue(events[2][2].endswith("('127.0.0.1', 8080)"))
+ self.assertEndsWith(events[2][2], "('127.0.0.1', 8080)")
def test_gc(self):
returncode, events, stderr = self.run_python("test_gc")
@@ -322,6 +322,14 @@ class AuditTest(unittest.TestCase):
if returncode:
self.fail(stderr)
+ @support.support_remote_exec_only
+ @support.cpython_only
+ def test_sys_remote_exec(self):
+ returncode, events, stderr = self.run_python("test_sys_remote_exec")
+ self.assertTrue(any(["sys.remote_exec" in event for event in events]))
+ self.assertTrue(any(["cpython.remote_debugger_script" in event for event in events]))
+ if returncode:
+ self.fail(stderr)
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index 409c8c109e8..ce2e3e3726f 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -3,8 +3,16 @@ import base64
import binascii
import os
from array import array
+from test.support import cpython_only
from test.support import os_helper
from test.support import script_helper
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "getopt"})
class LegacyBase64TestCase(unittest.TestCase):
@@ -804,7 +812,7 @@ class BaseXYTestCase(unittest.TestCase):
self.assertRaises(ValueError, f, 'with non-ascii \xcb')
def test_ErrorHeritage(self):
- self.assertTrue(issubclass(binascii.Error, ValueError))
+ self.assertIsSubclass(binascii.Error, ValueError)
def test_RFC4648_test_cases(self):
# test cases from RFC 4648 section 10
diff --git a/Lib/test/test_baseexception.py b/Lib/test/test_baseexception.py
index e599b02c17d..12d4088842b 100644
--- a/Lib/test/test_baseexception.py
+++ b/Lib/test/test_baseexception.py
@@ -10,13 +10,11 @@ class ExceptionClassTests(unittest.TestCase):
inheritance hierarchy)"""
def test_builtins_new_style(self):
- self.assertTrue(issubclass(Exception, object))
+ self.assertIsSubclass(Exception, object)
def verify_instance_interface(self, ins):
for attr in ("args", "__str__", "__repr__"):
- self.assertTrue(hasattr(ins, attr),
- "%s missing %s attribute" %
- (ins.__class__.__name__, attr))
+ self.assertHasAttr(ins, attr)
def test_inheritance(self):
# Make sure the inheritance hierarchy matches the documentation
@@ -65,7 +63,7 @@ class ExceptionClassTests(unittest.TestCase):
elif last_depth > depth:
while superclasses[-1][0] >= depth:
superclasses.pop()
- self.assertTrue(issubclass(exc, superclasses[-1][1]),
+ self.assertIsSubclass(exc, superclasses[-1][1],
"%s is not a subclass of %s" % (exc.__name__,
superclasses[-1][1].__name__))
try: # Some exceptions require arguments; just skip them
diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py
index 1f3b6746ce4..7ed7d7c47b6 100644
--- a/Lib/test/test_binascii.py
+++ b/Lib/test/test_binascii.py
@@ -38,13 +38,13 @@ class BinASCIITest(unittest.TestCase):
def test_exceptions(self):
# Check module exceptions
- self.assertTrue(issubclass(binascii.Error, Exception))
- self.assertTrue(issubclass(binascii.Incomplete, Exception))
+ self.assertIsSubclass(binascii.Error, Exception)
+ self.assertIsSubclass(binascii.Incomplete, Exception)
def test_functions(self):
# Check presence of all functions
for name in all_functions:
- self.assertTrue(hasattr(getattr(binascii, name), '__call__'))
+ self.assertHasAttr(getattr(binascii, name), '__call__')
self.assertRaises(TypeError, getattr(binascii, name))
def test_returned_value(self):
diff --git a/Lib/test/test_binop.py b/Lib/test/test_binop.py
index 299af09c498..b224c3d4e60 100644
--- a/Lib/test/test_binop.py
+++ b/Lib/test/test_binop.py
@@ -383,7 +383,7 @@ class OperationOrderTests(unittest.TestCase):
self.assertEqual(op_sequence(le, B, C), ['C.__ge__', 'B.__le__'])
self.assertEqual(op_sequence(le, C, B), ['C.__le__', 'B.__ge__'])
- self.assertTrue(issubclass(V, B))
+ self.assertIsSubclass(V, B)
self.assertEqual(op_sequence(eq, B, V), ['B.__eq__', 'V.__eq__'])
self.assertEqual(op_sequence(le, B, V), ['B.__le__', 'V.__ge__'])
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
index 61921e93e85..19582e75716 100644
--- a/Lib/test/test_buffer.py
+++ b/Lib/test/test_buffer.py
@@ -2879,11 +2879,11 @@ class TestBufferProtocol(unittest.TestCase):
def test_memoryview_repr(self):
m = memoryview(bytearray(9))
r = m.__repr__()
- self.assertTrue(r.startswith("<memory"))
+ self.assertStartsWith(r, "<memory")
m.release()
r = m.__repr__()
- self.assertTrue(r.startswith("<released"))
+ self.assertStartsWith(r, "<released")
def test_memoryview_sequence(self):
diff --git a/Lib/test/test_bufio.py b/Lib/test/test_bufio.py
index dc9a82dc635..cb9cb4d0bc7 100644
--- a/Lib/test/test_bufio.py
+++ b/Lib/test/test_bufio.py
@@ -28,7 +28,7 @@ class BufferSizeTest:
f.write(b"\n")
f.write(s)
f.close()
- f = open(os_helper.TESTFN, "rb")
+ f = self.open(os_helper.TESTFN, "rb")
line = f.readline()
self.assertEqual(line, s + b"\n")
line = f.readline()
diff --git a/Lib/test/test_build_details.py b/Lib/test/test_build_details.py
index 05ce163a337..ba4b8c5aa9b 100644
--- a/Lib/test/test_build_details.py
+++ b/Lib/test/test_build_details.py
@@ -117,12 +117,26 @@ class CPythonBuildDetailsTests(unittest.TestCase, FormatTestsBase):
# Override generic format tests with tests for our specific implemenation.
@needs_installed_python
- @unittest.skipIf(is_android or is_apple_mobile, 'Android and iOS run tests via a custom testbed method that changes sys.executable')
+ @unittest.skipIf(
+ is_android or is_apple_mobile,
+ 'Android and iOS run tests via a custom testbed method that changes sys.executable'
+ )
def test_base_interpreter(self):
value = self.key('base_interpreter')
self.assertEqual(os.path.realpath(value), os.path.realpath(sys.executable))
+ @needs_installed_python
+ @unittest.skipIf(
+ is_android or is_apple_mobile,
+ "Android and iOS run tests via a custom testbed method that doesn't ship headers"
+ )
+ def test_c_api(self):
+ value = self.key('c_api')
+ self.assertTrue(os.path.exists(os.path.join(value['headers'], 'Python.h')))
+ version = sysconfig.get_config_var('VERSION')
+ self.assertTrue(os.path.exists(os.path.join(value['pkgconfig_path'], f'python-{version}.pc')))
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 31597a320d4..14fe3355239 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -393,7 +393,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(ValueError, chr, -2**1000)
def test_cmp(self):
- self.assertTrue(not hasattr(builtins, "cmp"))
+ self.assertNotHasAttr(builtins, "cmp")
def test_compile(self):
compile('print(1)\n', '', 'exec')
@@ -1120,6 +1120,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.check_iter_pickle(f1, list(f2), proto)
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@support.requires_resource('cpu')
def test_filter_dealloc(self):
# Tests recursive deallocation of nested filter objects using the
@@ -2303,7 +2304,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# tests for object.__format__ really belong elsewhere, but
# there's no good place to put them
x = object().__format__('')
- self.assertTrue(x.startswith('<object object at'))
+ self.assertStartsWith(x, '<object object at')
# first argument to object.__format__ must be string
self.assertRaises(TypeError, object().__format__, 3)
@@ -2990,7 +2991,8 @@ class TestType(unittest.TestCase):
def load_tests(loader, tests, pattern):
from doctest import DocTestSuite
- tests.addTest(DocTestSuite(builtins))
+ if sys.float_repr_style == 'short':
+ tests.addTest(DocTestSuite(builtins))
return tests
if __name__ == "__main__":
diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
index 82d9916e38d..bb0f8aa99da 100644
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -1974,9 +1974,9 @@ class AssortedBytesTest(unittest.TestCase):
@test.support.requires_docstrings
def test_doc(self):
self.assertIsNotNone(bytearray.__doc__)
- self.assertTrue(bytearray.__doc__.startswith("bytearray("), bytearray.__doc__)
+ self.assertStartsWith(bytearray.__doc__, "bytearray(")
self.assertIsNotNone(bytes.__doc__)
- self.assertTrue(bytes.__doc__.startswith("bytes("), bytes.__doc__)
+ self.assertStartsWith(bytes.__doc__, "bytes(")
def test_from_bytearray(self):
sample = bytes(b"Hello world\n\x80\x81\xfe\xff")
@@ -2107,7 +2107,7 @@ class BytesAsStringTest(FixedStringTest, unittest.TestCase):
class SubclassTest:
def test_basic(self):
- self.assertTrue(issubclass(self.type2test, self.basetype))
+ self.assertIsSubclass(self.type2test, self.basetype)
self.assertIsInstance(self.type2test(), self.basetype)
a, b = b"abcd", b"efgh"
@@ -2155,7 +2155,7 @@ class SubclassTest:
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.z), type(b.z))
- self.assertFalse(hasattr(b, 'y'))
+ self.assertNotHasAttr(b, 'y')
def test_copy(self):
a = self.type2test(b"abcd")
@@ -2169,7 +2169,7 @@ class SubclassTest:
self.assertEqual(a.z, b.z)
self.assertEqual(type(a), type(b))
self.assertEqual(type(a.z), type(b.z))
- self.assertFalse(hasattr(b, 'y'))
+ self.assertNotHasAttr(b, 'y')
def test_fromhex(self):
b = self.type2test.fromhex('1a2B30')
diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py
index f32b24b39ba..3b7897b8a88 100644
--- a/Lib/test/test_bz2.py
+++ b/Lib/test/test_bz2.py
@@ -184,7 +184,7 @@ class BZ2FileTest(BaseTest):
with BZ2File(self.filename) as bz2f:
pdata = bz2f.peek()
self.assertNotEqual(len(pdata), 0)
- self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertStartsWith(self.TEXT, pdata)
self.assertEqual(bz2f.read(), self.TEXT)
def testReadInto(self):
@@ -768,7 +768,7 @@ class BZ2FileTest(BaseTest):
with BZ2File(bio) as bz2f:
pdata = bz2f.peek()
self.assertNotEqual(len(pdata), 0)
- self.assertTrue(self.TEXT.startswith(pdata))
+ self.assertStartsWith(self.TEXT, pdata)
self.assertEqual(bz2f.read(), self.TEXT)
def testWriteBytesIO(self):
diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py
index 073df310bb4..bc39c86b8cf 100644
--- a/Lib/test/test_calendar.py
+++ b/Lib/test/test_calendar.py
@@ -417,7 +417,7 @@ class OutputTestCase(unittest.TestCase):
self.check_htmlcalendar_encoding('utf-8', 'utf-8')
def test_output_htmlcalendar_encoding_default(self):
- self.check_htmlcalendar_encoding(None, sys.getdefaultencoding())
+ self.check_htmlcalendar_encoding(None, 'utf-8')
def test_yeardatescalendar(self):
def shrink(cal):
@@ -987,6 +987,7 @@ class CommandLineTestCase(unittest.TestCase):
self.assertCLIFails(*args)
self.assertCmdFails(*args)
+ @support.force_not_colorized
def test_help(self):
stdout = self.run_cmd_ok('-h')
self.assertIn(b'usage:', stdout)
@@ -1097,7 +1098,7 @@ class CommandLineTestCase(unittest.TestCase):
output = run('--type', 'text', '2004')
self.assertEqual(output, conv(result_2004_text))
output = run('--type', 'html', '2004')
- self.assertEqual(output[:6], b'<?xml ')
+ self.assertStartsWith(output, b'<?xml ')
self.assertIn(b'<title>Calendar for 2004</title>', output)
def test_html_output_current_year(self):
diff --git a/Lib/test/test_call.py b/Lib/test/test_call.py
index 185ae84dc4d..1c73aaafb71 100644
--- a/Lib/test/test_call.py
+++ b/Lib/test/test_call.py
@@ -695,8 +695,8 @@ class TestPEP590(unittest.TestCase):
UnaffectedType2 = _testcapi.make_vectorcall_class(SuperType)
# Aside: Quickly check that the C helper actually made derived types
- self.assertTrue(issubclass(UnaffectedType1, DerivedType))
- self.assertTrue(issubclass(UnaffectedType2, SuperType))
+ self.assertIsSubclass(UnaffectedType1, DerivedType)
+ self.assertIsSubclass(UnaffectedType2, SuperType)
# Initial state: tp_call
self.assertEqual(instance(), "tp_call")
diff --git a/Lib/test/test_capi/test_bytearray.py b/Lib/test/test_capi/test_bytearray.py
index dfa98de9f00..52565ea34c6 100644
--- a/Lib/test/test_capi/test_bytearray.py
+++ b/Lib/test/test_capi/test_bytearray.py
@@ -66,6 +66,7 @@ class CAPITest(unittest.TestCase):
# Test PyByteArray_FromObject()
fromobject = _testlimitedcapi.bytearray_fromobject
+ self.assertEqual(fromobject(b''), bytearray(b''))
self.assertEqual(fromobject(b'abc'), bytearray(b'abc'))
self.assertEqual(fromobject(bytearray(b'abc')), bytearray(b'abc'))
self.assertEqual(fromobject(ByteArraySubclass(b'abc')), bytearray(b'abc'))
@@ -115,6 +116,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(concat(b'abc', bytearray(b'def')), bytearray(b'abcdef'))
self.assertEqual(concat(bytearray(b'abc'), b''), bytearray(b'abc'))
self.assertEqual(concat(b'', bytearray(b'def')), bytearray(b'def'))
+ self.assertEqual(concat(bytearray(b''), bytearray(b'')), bytearray(b''))
self.assertEqual(concat(memoryview(b'xabcy')[1:4], b'def'),
bytearray(b'abcdef'))
self.assertEqual(concat(b'abc', memoryview(b'xdefy')[1:4]),
@@ -150,6 +152,10 @@ class CAPITest(unittest.TestCase):
self.assertEqual(resize(ba, 0), 0)
self.assertEqual(ba, bytearray())
+ ba = bytearray(b'')
+ self.assertEqual(resize(ba, 0), 0)
+ self.assertEqual(ba, bytearray())
+
ba = ByteArraySubclass(b'abcdef')
self.assertEqual(resize(ba, 3), 0)
self.assertEqual(ba, bytearray(b'abc'))
diff --git a/Lib/test/test_capi/test_bytes.py b/Lib/test/test_capi/test_bytes.py
index 5b61c733815..bc820bd68d9 100644
--- a/Lib/test/test_capi/test_bytes.py
+++ b/Lib/test/test_capi/test_bytes.py
@@ -22,6 +22,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_Check()
check = _testlimitedcapi.bytes_check
self.assertTrue(check(b'abc'))
+ self.assertTrue(check(b''))
self.assertFalse(check('abc'))
self.assertFalse(check(bytearray(b'abc')))
self.assertTrue(check(BytesSubclass(b'abc')))
@@ -36,6 +37,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_CheckExact()
check = _testlimitedcapi.bytes_checkexact
self.assertTrue(check(b'abc'))
+ self.assertTrue(check(b''))
self.assertFalse(check('abc'))
self.assertFalse(check(bytearray(b'abc')))
self.assertFalse(check(BytesSubclass(b'abc')))
@@ -79,6 +81,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_FromObject()
fromobject = _testlimitedcapi.bytes_fromobject
+ self.assertEqual(fromobject(b''), b'')
self.assertEqual(fromobject(b'abc'), b'abc')
self.assertEqual(fromobject(bytearray(b'abc')), b'abc')
self.assertEqual(fromobject(BytesSubclass(b'abc')), b'abc')
@@ -108,6 +111,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(asstring(b'abc', 4), b'abc\0')
self.assertEqual(asstring(b'abc\0def', 8), b'abc\0def\0')
+ self.assertEqual(asstring(b'', 1), b'\0')
self.assertRaises(TypeError, asstring, 'abc', 0)
self.assertRaises(TypeError, asstring, object(), 0)
@@ -120,6 +124,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(asstringandsize(b'abc', 4), (b'abc\0', 3))
self.assertEqual(asstringandsize(b'abc\0def', 8), (b'abc\0def\0', 7))
+ self.assertEqual(asstringandsize(b'', 1), (b'\0', 0))
self.assertEqual(asstringandsize_null(b'abc', 4), b'abc\0')
self.assertRaises(ValueError, asstringandsize_null, b'abc\0def', 8)
self.assertRaises(TypeError, asstringandsize, 'abc', 0)
@@ -134,6 +139,7 @@ class CAPITest(unittest.TestCase):
# Test PyBytes_Repr()
bytes_repr = _testlimitedcapi.bytes_repr
+ self.assertEqual(bytes_repr(b'', 0), r"""b''""")
self.assertEqual(bytes_repr(b'''abc''', 0), r"""b'abc'""")
self.assertEqual(bytes_repr(b'''abc''', 1), r"""b'abc'""")
self.assertEqual(bytes_repr(b'''a'b"c"d''', 0), r"""b'a\'b"c"d'""")
@@ -163,6 +169,7 @@ class CAPITest(unittest.TestCase):
self.assertEqual(concat(b'', bytearray(b'def')), b'def')
self.assertEqual(concat(memoryview(b'xabcy')[1:4], b'def'), b'abcdef')
self.assertEqual(concat(b'abc', memoryview(b'xdefy')[1:4]), b'abcdef')
+ self.assertEqual(concat(b'', b''), b'')
self.assertEqual(concat(b'abc', b'def', True), b'abcdef')
self.assertEqual(concat(b'abc', bytearray(b'def'), True), b'abcdef')
@@ -192,6 +199,7 @@ class CAPITest(unittest.TestCase):
"""Test PyBytes_DecodeEscape()"""
decodeescape = _testlimitedcapi.bytes_decodeescape
+ self.assertEqual(decodeescape(b''), b'')
self.assertEqual(decodeescape(b'abc'), b'abc')
self.assertEqual(decodeescape(br'\t\n\r\x0b\x0c\x00\\\'\"'),
b'''\t\n\r\v\f\0\\'"''')
diff --git a/Lib/test/test_capi/test_config.py b/Lib/test/test_capi/test_config.py
index bf351c4defa..04a27de8d84 100644
--- a/Lib/test/test_capi/test_config.py
+++ b/Lib/test/test_capi/test_config.py
@@ -3,7 +3,6 @@ Tests PyConfig_Get() and PyConfig_Set() C API (PEP 741).
"""
import os
import sys
-import sysconfig
import types
import unittest
from test import support
@@ -57,7 +56,7 @@ class CAPITests(unittest.TestCase):
("home", str | None, None),
("thread_inherit_context", int, None),
("context_aware_warnings", int, None),
- ("import_time", bool, None),
+ ("import_time", int, None),
("inspect", bool, None),
("install_signal_handlers", bool, None),
("int_max_str_digits", int, None),
diff --git a/Lib/test/test_capi/test_float.py b/Lib/test/test_capi/test_float.py
index c832207fef0..f7efe0d0254 100644
--- a/Lib/test/test_capi/test_float.py
+++ b/Lib/test/test_capi/test_float.py
@@ -199,13 +199,13 @@ class CAPIFloatTest(unittest.TestCase):
signaling = 0
quiet = int(not signaling)
if size == 8:
- payload = random.randint(signaling, 1 << 50)
+ payload = random.randint(signaling, 0x7ffffffffffff)
i = (sign << 63) + (0x7ff << 52) + (quiet << 51) + payload
elif size == 4:
- payload = random.randint(signaling, 1 << 21)
+ payload = random.randint(signaling, 0x3fffff)
i = (sign << 31) + (0xff << 23) + (quiet << 22) + payload
elif size == 2:
- payload = random.randint(signaling, 1 << 8)
+ payload = random.randint(signaling, 0x1ff)
i = (sign << 15) + (0x1f << 10) + (quiet << 9) + payload
data = bytes.fromhex(f'{i:x}')
for endian in (BIG_ENDIAN, LITTLE_ENDIAN):
diff --git a/Lib/test/test_capi/test_import.py b/Lib/test/test_capi/test_import.py
index 25136624ca4..57e0316fda8 100644
--- a/Lib/test/test_capi/test_import.py
+++ b/Lib/test/test_capi/test_import.py
@@ -134,7 +134,7 @@ class ImportTests(unittest.TestCase):
# CRASHES importmodule(NULL)
def test_importmodulenoblock(self):
- # Test deprecated PyImport_ImportModuleNoBlock()
+ # Test deprecated (stable ABI only) PyImport_ImportModuleNoBlock()
importmodulenoblock = _testlimitedcapi.PyImport_ImportModuleNoBlock
with check_warnings(('', DeprecationWarning)):
self.check_import_func(importmodulenoblock)
diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py
index 98dc3b42ef0..ef950f5df04 100644
--- a/Lib/test/test_capi/test_misc.py
+++ b/Lib/test/test_capi/test_misc.py
@@ -306,7 +306,7 @@ class CAPITest(unittest.TestCase):
CURRENT_THREAD_REGEX +
r' File .*, line 6 in <module>\n'
r'\n'
- r'Extension modules: _testcapi, _testinternalcapi \(total: 2\)\n')
+ r'Extension modules: _testcapi \(total: 1\)\n')
else:
# Python built with NDEBUG macro defined:
# test _Py_CheckFunctionResult() instead.
@@ -412,10 +412,14 @@ class CAPITest(unittest.TestCase):
L = MyList((L,))
@support.requires_resource('cpu')
+ @support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_trashcan_python_class1(self):
self.do_test_trashcan_python_class(list)
@support.requires_resource('cpu')
+ @support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_trashcan_python_class2(self):
from _testcapi import MyList
self.do_test_trashcan_python_class(MyList)
diff --git a/Lib/test/test_capi/test_object.py b/Lib/test/test_capi/test_object.py
index 54a01ac7c4a..d4056727d07 100644
--- a/Lib/test/test_capi/test_object.py
+++ b/Lib/test/test_capi/test_object.py
@@ -174,6 +174,16 @@ class EnableDeferredRefcountingTest(unittest.TestCase):
self.assertTrue(_testinternalcapi.has_deferred_refcount(silly_list))
+class IsUniquelyReferencedTest(unittest.TestCase):
+ """Test PyUnstable_Object_IsUniquelyReferenced"""
+ def test_is_uniquely_referenced(self):
+ self.assertTrue(_testcapi.is_uniquely_referenced(object()))
+ self.assertTrue(_testcapi.is_uniquely_referenced([]))
+ # Immortals
+ self.assertFalse(_testcapi.is_uniquely_referenced(()))
+ self.assertFalse(_testcapi.is_uniquely_referenced(42))
+ # CRASHES is_uniquely_referenced(NULL)
+
class CAPITest(unittest.TestCase):
def check_negative_refcount(self, code):
# bpo-35059: Check that Py_DECREF() reports the correct filename
@@ -211,6 +221,7 @@ class CAPITest(unittest.TestCase):
"""
self.check_negative_refcount(code)
+ @support.requires_resource('cpu')
def test_decref_delayed(self):
# gh-130519: Test that _PyObject_XDecRefDelayed() and QSBR code path
# handles destructors that are possibly re-entrant or trigger a GC.
diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index 7e0c60d5522..e4c9a463855 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -407,12 +407,12 @@ class TestUops(unittest.TestCase):
x = 0
for i in range(m):
for j in MyIter(n):
- x += 1000*i + j
+ x += j
return x
- x = testfunc(TIER2_THRESHOLD, TIER2_THRESHOLD)
+ x = testfunc(TIER2_THRESHOLD, 2)
- self.assertEqual(x, sum(range(TIER2_THRESHOLD)) * TIER2_THRESHOLD * 1001)
+ self.assertEqual(x, sum(range(TIER2_THRESHOLD)) * 2)
ex = get_first_executor(testfunc)
self.assertIsNotNone(ex)
@@ -678,7 +678,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_ADD_FLOAT", uops)
+ self.assertIn("_BINARY_OP_ADD_FLOAT__NO_DECREF_INPUTS", uops)
def test_float_subtract_constant_propagation(self):
def testfunc(n):
@@ -700,7 +700,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_SUBTRACT_FLOAT", uops)
+ self.assertIn("_BINARY_OP_SUBTRACT_FLOAT__NO_DECREF_INPUTS", uops)
def test_float_multiply_constant_propagation(self):
def testfunc(n):
@@ -722,7 +722,7 @@ class TestUopsOptimization(unittest.TestCase):
self.assertLessEqual(len(guard_nos_float_count), 1)
# TODO gh-115506: this assertion may change after propagating constants.
# We'll also need to verify that propagation actually occurs.
- self.assertIn("_BINARY_OP_MULTIPLY_FLOAT", uops)
+ self.assertIn("_BINARY_OP_MULTIPLY_FLOAT__NO_DECREF_INPUTS", uops)
def test_add_unicode_propagation(self):
def testfunc(n):
@@ -1183,6 +1183,17 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertIn("_RETURN_GENERATOR", get_opnames(ex))
+ def test_for_iter(self):
+ def testfunc(n):
+ t = 0
+ for i in set(range(n)):
+ t += i
+ return t
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD * (TIER2_THRESHOLD - 1) // 2)
+ self.assertIsNotNone(ex)
+ self.assertIn("_FOR_ITER_TIER_TWO", get_opnames(ex))
+
@unittest.skip("Tracing into generators currently isn't supported.")
def test_for_iter_gen(self):
def gen(n):
@@ -1280,8 +1291,8 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertEqual(res, TIER2_THRESHOLD * 6 + 1)
call = opnames.index("_CALL_BUILTIN_FAST")
- load_attr_top = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", 0, call)
- load_attr_bottom = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", call)
+ load_attr_top = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", 0, call)
+ load_attr_bottom = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", call)
self.assertEqual(opnames[:load_attr_top].count("_GUARD_TYPE_VERSION"), 1)
self.assertEqual(opnames[call:load_attr_bottom].count("_CHECK_VALIDITY"), 2)
@@ -1303,8 +1314,8 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIsNotNone(ex)
self.assertEqual(res, TIER2_THRESHOLD * 2)
call = opnames.index("_CALL_BUILTIN_FAST_WITH_KEYWORDS")
- load_attr_top = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", 0, call)
- load_attr_bottom = opnames.index("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", call)
+ load_attr_top = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", 0, call)
+ load_attr_bottom = opnames.index("_POP_TOP_LOAD_CONST_INLINE_BORROW", call)
self.assertEqual(opnames[:load_attr_top].count("_GUARD_TYPE_VERSION"), 1)
self.assertEqual(opnames[call:load_attr_bottom].count("_CHECK_VALIDITY"), 2)
@@ -1370,6 +1381,21 @@ class TestUopsOptimization(unittest.TestCase):
# Removed guard
self.assertNotIn("_CHECK_FUNCTION_EXACT_ARGS", uops)
+ def test_method_guards_removed_or_reduced(self):
+ def testfunc(n):
+ result = 0
+ for i in range(n):
+ result += test_bound_method(i)
+ return result
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, sum(range(TIER2_THRESHOLD)))
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_PUSH_FRAME", uops)
+ # Strength reduced version
+ self.assertIn("_CHECK_FUNCTION_VERSION_INLINE", uops)
+ self.assertNotIn("_CHECK_METHOD_VERSION", uops)
+
def test_jit_error_pops(self):
"""
Tests that the correct number of pops are inserted into the
@@ -1655,13 +1681,11 @@ class TestUopsOptimization(unittest.TestCase):
self.assertIn("_CONTAINS_OP_DICT", uops)
self.assertNotIn("_TO_BOOL_BOOL", uops)
-
def test_remove_guard_for_known_type_str(self):
def f(n):
for i in range(n):
false = i == TIER2_THRESHOLD
empty = "X"[:false]
- empty += "" # Make JIT realize this is a string.
if empty:
return 1
return 0
@@ -1767,11 +1791,12 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_GUARD_TOS_UNICODE", uops)
self.assertIn("_BINARY_OP_ADD_UNICODE", uops)
- def test_call_type_1(self):
+ def test_call_type_1_guards_removed(self):
def testfunc(n):
x = 0
for _ in range(n):
- x += type(42) is int
+ foo = eval('42')
+ x += type(foo) is int
return x
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
@@ -1782,6 +1807,25 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_GUARD_NOS_NULL", uops)
self.assertNotIn("_GUARD_CALLABLE_TYPE_1", uops)
+ def test_call_type_1_known_type(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ x += type(42) is int
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the result of type(...) is known, _CALL_TYPE_1 is replaced with
+ # _POP_CALL_ONE_LOAD_CONST_INLINE_BORROW which is optimized away in
+ # remove_unneeded_uops.
+ self.assertNotIn("_CALL_TYPE_1", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+
def test_call_type_1_result_is_const(self):
def testfunc(n):
x = 0
@@ -1795,7 +1839,6 @@ class TestUopsOptimization(unittest.TestCase):
self.assertEqual(res, TIER2_THRESHOLD)
self.assertIsNotNone(ex)
uops = get_opnames(ex)
- self.assertIn("_CALL_TYPE_1", uops)
self.assertNotIn("_GUARD_IS_NOT_NONE_POP", uops)
def test_call_str_1(self):
@@ -1919,9 +1962,98 @@ class TestUopsOptimization(unittest.TestCase):
_, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_NULL", uops)
+ self.assertNotIn("_GUARD_CALLABLE_LEN", uops)
+ self.assertIn("_CALL_LEN", uops)
self.assertNotIn("_GUARD_NOS_INT", uops)
self.assertNotIn("_GUARD_TOS_INT", uops)
+
+ def test_call_len_known_length_small_int(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ t = (1, 2, 3, 4, 5)
+ if len(t) == 5:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the length is < _PY_NSMALLPOSINTS, the len() call is replaced
+ # with just an inline load.
+ self.assertNotIn("_CALL_LEN", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_len_known_length(self):
+ def testfunc(n):
+ class C:
+ t = tuple(range(300))
+
+ x = 0
+ for _ in range(n):
+ if len(C.t) == 300: # comparison + guard removed
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ # When the length is >= _PY_NSMALLPOSINTS, we cannot replace
+ # the len() call with an inline load, but knowing the exact
+ # length allows us to optimize more code, such as conditionals
+ # in this case
self.assertIn("_CALL_LEN", uops)
+ self.assertNotIn("_COMPARE_OP_INT", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_get_len_with_const_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match (1, 2, 3, 4):
+ case [_, _, _, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GET_LEN", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_get_len_with_non_const_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match object(), object():
+ case [_, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GET_LEN", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_get_len_with_non_tuple(self):
+ def testfunc(n):
+ x = 0.0
+ for _ in range(n):
+ match [1, 2, 3, 4]:
+ case [_, _, _, _]:
+ x += 1.0
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(int(res), TIER2_THRESHOLD)
+ uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertIn("_GET_LEN", uops)
def test_binary_op_subscr_tuple_int(self):
def testfunc(n):
@@ -1940,9 +2072,394 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_COMPARE_OP_INT", uops)
self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ def test_call_isinstance_guards_removed(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_GUARD_THIRD_NULL", uops)
+ self.assertNotIn("_GUARD_CALLABLE_ISINSTANCE", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_list_append(self):
+ def testfunc(n):
+ a = []
+ for i in range(n):
+ a.append(i)
+ return sum(a)
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, sum(range(TIER2_THRESHOLD)))
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_LIST_APPEND", uops)
+ # We should remove these in the future
+ self.assertIn("_GUARD_NOS_LIST", uops)
+ self.assertIn("_GUARD_CALLABLE_LIST_APPEND", uops)
+
+ def test_call_isinstance_is_true(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_is_false(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(42, str)
+ if not y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_FALSE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_subclass(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ y = isinstance(True, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertNotIn("_GUARD_IS_TRUE_POP", uops)
+ self.assertNotIn("_POP_TOP_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_ONE_LOAD_CONST_INLINE_BORROW", uops)
+ self.assertNotIn("_POP_CALL_TWO_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_call_isinstance_unknown_object(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # The optimizer doesn't know the return type here:
+ bar = eval("42")
+ # This will only narrow to bool:
+ y = isinstance(bar, int)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_call_isinstance_tuple_of_classes(self):
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # A tuple of classes is currently not optimized,
+ # so this is only narrowed to bool:
+ y = isinstance(42, (int, str))
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_call_isinstance_metaclass(self):
+ class EvenNumberMeta(type):
+ def __instancecheck__(self, number):
+ return number % 2 == 0
+
+ class EvenNumber(metaclass=EvenNumberMeta):
+ pass
+
+ def testfunc(n):
+ x = 0
+ for _ in range(n):
+ # Only narrowed to bool
+ y = isinstance(42, EvenNumber)
+ if y:
+ x += 1
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_CALL_ISINSTANCE", uops)
+ self.assertNotIn("_TO_BOOL_BOOL", uops)
+ self.assertIn("_GUARD_IS_TRUE_POP", uops)
+
+ def test_set_type_version_sets_type(self):
+ class C:
+ A = 1
+
+ def testfunc(n):
+ x = 0
+ c = C()
+ for _ in range(n):
+ x += c.A # Guarded.
+ x += type(c).A # Unguarded!
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, 2 * TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_GUARD_TYPE_VERSION", uops)
+ self.assertNotIn("_CHECK_ATTR_CLASS", uops)
+
+ def test_load_small_int(self):
+ def testfunc(n):
+ x = 0
+ for i in range(n):
+ x += 1
+ return x
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_LOAD_SMALL_INT", uops)
+ self.assertIn("_LOAD_CONST_INLINE_BORROW", uops)
+
+ def test_cached_attributes(self):
+ class C:
+ A = 1
+ def m(self):
+ return 1
+ class D:
+ __slots__ = ()
+ A = 1
+ def m(self):
+ return 1
+ class E(Exception):
+ def m(self):
+ return 1
+ def f(n):
+ x = 0
+ c = C()
+ d = D()
+ e = E()
+ for _ in range(n):
+ x += C.A # _LOAD_ATTR_CLASS
+ x += c.A # _LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES
+ x += d.A # _LOAD_ATTR_NONDESCRIPTOR_NO_DICT
+ x += c.m() # _LOAD_ATTR_METHOD_WITH_VALUES
+ x += d.m() # _LOAD_ATTR_METHOD_NO_DICT
+ x += e.m() # _LOAD_ATTR_METHOD_LAZY_DICT
+ return x
+
+ res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertEqual(res, 6 * TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertNotIn("_LOAD_ATTR_CLASS", uops)
+ self.assertNotIn("_LOAD_ATTR_NONDESCRIPTOR_WITH_VALUES", uops)
+ self.assertNotIn("_LOAD_ATTR_NONDESCRIPTOR_NO_DICT", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_WITH_VALUES", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_NO_DICT", uops)
+ self.assertNotIn("_LOAD_ATTR_METHOD_LAZY_DICT", uops)
+
+ def test_float_op_refcount_elimination(self):
+ def testfunc(args):
+ a, b, n = args
+ c = 0.0
+ for _ in range(n):
+ c += a + b
+ return c
+
+ res, ex = self._run_with_optimizer(testfunc, (0.1, 0.1, TIER2_THRESHOLD))
+ self.assertAlmostEqual(res, TIER2_THRESHOLD * (0.1 + 0.1))
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_BINARY_OP_ADD_FLOAT__NO_DECREF_INPUTS", uops)
+
+ def test_remove_guard_for_slice_list(self):
+ def f(n):
+ for i in range(n):
+ false = i == TIER2_THRESHOLD
+ sliced = [1, 2, 3][:false]
+ if sliced:
+ return 1
+ return 0
+
+ res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertEqual(res, 0)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_TO_BOOL_LIST", uops)
+ self.assertNotIn("_GUARD_TOS_LIST", uops)
+
+ def test_remove_guard_for_slice_tuple(self):
+ def f(n):
+ for i in range(n):
+ false = i == TIER2_THRESHOLD
+ a, b = (1, 2, 3)[: false + 2]
+
+ _, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_UNPACK_SEQUENCE_TWO_TUPLE", uops)
+ self.assertNotIn("_GUARD_TOS_TUPLE", uops)
+
+ def test_unary_invert_long_type(self):
+ def testfunc(n):
+ for _ in range(n):
+ a = 9397
+ x = ~a + ~a
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertNotIn("_GUARD_TOS_INT", uops)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+
+ def test_attr_promotion_failure(self):
+ # We're not testing for any specific uops here, just
+ # testing it doesn't crash.
+ script_helper.assert_python_ok('-c', textwrap.dedent("""
+ import _testinternalcapi
+ import _opcode
+ import email
+
+ def get_first_executor(func):
+ code = func.__code__
+ co_code = code.co_code
+ for i in range(0, len(co_code), 2):
+ try:
+ return _opcode.get_executor(code, i)
+ except ValueError:
+ pass
+ return None
+
+ def testfunc(n):
+ for _ in range(n):
+ email.jit_testing = None
+ prompt = email.jit_testing
+ del email.jit_testing
+
+
+ testfunc(_testinternalcapi.TIER2_THRESHOLD)
+ ex = get_first_executor(testfunc)
+ assert ex is not None
+ """))
+
+ def test_pop_top_specialize_none(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(None)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_NOP", uops)
+
+ def test_pop_top_specialize_int(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(100000)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_INT", uops)
+
+ def test_pop_top_specialize_float(self):
+ def testfunc(n):
+ for _ in range(n):
+ global_identity(1e6)
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertIn("_POP_TOP_FLOAT", uops)
+
+
+ def test_unary_negative_long_float_type(self):
+ def testfunc(n):
+ for _ in range(n):
+ a = 9397
+ f = 9397.0
+ x = -a + -a
+ y = -f + -f
+
+ testfunc(TIER2_THRESHOLD)
+
+ ex = get_first_executor(testfunc)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+
+ self.assertNotIn("_GUARD_TOS_INT", uops)
+ self.assertNotIn("_GUARD_NOS_INT", uops)
+ self.assertNotIn("_GUARD_TOS_FLOAT", uops)
+ self.assertNotIn("_GUARD_NOS_FLOAT", uops)
def global_identity(x):
return x
+class TestObject:
+ def test(self, *args, **kwargs):
+ return args[0]
+
+test_object = TestObject()
+test_bound_method = TestObject.test.__get__(test_object)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_capi/test_sys.py b/Lib/test/test_capi/test_sys.py
index d3a9b378e77..3793ce2461e 100644
--- a/Lib/test/test_capi/test_sys.py
+++ b/Lib/test/test_capi/test_sys.py
@@ -19,6 +19,68 @@ class CAPITest(unittest.TestCase):
maxDiff = None
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getattr(self):
+ # Test PySys_GetAttr()
+ sys_getattr = _testlimitedcapi.sys_getattr
+
+ self.assertIs(sys_getattr('stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(sys_getattr('\U0001f40d'), 42)
+
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.nonexistent'):
+ sys_getattr('nonexistent')
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.\U0001f40d'):
+ sys_getattr('\U0001f40d')
+ self.assertRaises(TypeError, sys_getattr, 1)
+ self.assertRaises(TypeError, sys_getattr, [])
+ # CRASHES sys_getattr(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getattrstring(self):
+ # Test PySys_GetAttrString()
+ getattrstring = _testlimitedcapi.sys_getattrstring
+
+ self.assertIs(getattrstring(b'stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getattrstring('\U0001f40d'.encode()), 42)
+
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.nonexistent'):
+ getattrstring(b'nonexistent')
+ with self.assertRaisesRegex(RuntimeError, r'lost sys\.\U0001f40d'):
+ getattrstring('\U0001f40d'.encode())
+ self.assertRaises(UnicodeDecodeError, getattrstring, b'\xff')
+ # CRASHES getattrstring(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getoptionalattr(self):
+ # Test PySys_GetOptionalAttr()
+ getoptionalattr = _testlimitedcapi.sys_getoptionalattr
+
+ self.assertIs(getoptionalattr('stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getoptionalattr('\U0001f40d'), 42)
+
+ self.assertIs(getoptionalattr('nonexistent'), AttributeError)
+ self.assertIs(getoptionalattr('\U0001f40d'), AttributeError)
+ self.assertRaises(TypeError, getoptionalattr, 1)
+ self.assertRaises(TypeError, getoptionalattr, [])
+ # CRASHES getoptionalattr(NULL)
+
+ @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
+ def test_sys_getoptionalattrstring(self):
+ # Test PySys_GetOptionalAttrString()
+ getoptionalattrstring = _testlimitedcapi.sys_getoptionalattrstring
+
+ self.assertIs(getoptionalattrstring(b'stdout'), sys.stdout)
+ with support.swap_attr(sys, '\U0001f40d', 42):
+ self.assertEqual(getoptionalattrstring('\U0001f40d'.encode()), 42)
+
+ self.assertIs(getoptionalattrstring(b'nonexistent'), AttributeError)
+ self.assertIs(getoptionalattrstring('\U0001f40d'.encode()), AttributeError)
+ self.assertRaises(UnicodeDecodeError, getoptionalattrstring, b'\xff')
+ # CRASHES getoptionalattrstring(NULL)
+
@support.cpython_only
@unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
def test_sys_getobject(self):
@@ -29,7 +91,7 @@ class CAPITest(unittest.TestCase):
with support.swap_attr(sys, '\U0001f40d', 42):
self.assertEqual(getobject('\U0001f40d'.encode()), 42)
- self.assertIs(getobject(b'nonexisting'), AttributeError)
+ self.assertIs(getobject(b'nonexistent'), AttributeError)
with support.catch_unraisable_exception() as cm:
self.assertIs(getobject(b'\xff'), AttributeError)
self.assertEqual(cm.unraisable.exc_type, UnicodeDecodeError)
diff --git a/Lib/test/test_capi/test_type.py b/Lib/test/test_capi/test_type.py
index 3c9974c7387..15fb4a93e2a 100644
--- a/Lib/test/test_capi/test_type.py
+++ b/Lib/test/test_capi/test_type.py
@@ -264,3 +264,13 @@ class TypeTests(unittest.TestCase):
ManualHeapType = _testcapi.ManualHeapType
for i in range(100):
self.assertIsInstance(ManualHeapType(), ManualHeapType)
+
+ def test_extension_managed_dict_type(self):
+ ManagedDictType = _testcapi.ManagedDictType
+ obj = ManagedDictType()
+ obj.foo = 42
+ self.assertEqual(obj.foo, 42)
+ self.assertEqual(obj.__dict__, {'foo': 42})
+ obj.__dict__ = {'bar': 3}
+ self.assertEqual(obj.__dict__, {'bar': 3})
+ self.assertEqual(obj.bar, 3)
diff --git a/Lib/test/test_capi/test_unicode.py b/Lib/test/test_capi/test_unicode.py
index 3408c10f426..6a9c60f3a6d 100644
--- a/Lib/test/test_capi/test_unicode.py
+++ b/Lib/test/test_capi/test_unicode.py
@@ -1739,6 +1739,20 @@ class CAPITest(unittest.TestCase):
# Check that the second call returns the same result
self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1))
+ @support.cpython_only
+ @unittest.skipIf(_testcapi is None, 'need _testcapi module')
+ def test_GET_CACHED_HASH(self):
+ from _testcapi import unicode_GET_CACHED_HASH
+ content_bytes = b'some new string'
+ # avoid parser interning & constant folding
+ obj = str(content_bytes, 'ascii')
+ # impl detail: fresh strings do not have cached hash
+ self.assertEqual(unicode_GET_CACHED_HASH(obj), -1)
+ # impl detail: adding string to a dict caches its hash
+ {obj: obj}
+ # impl detail: ASCII string hashes are equal to bytes ones
+ self.assertEqual(unicode_GET_CACHED_HASH(obj), hash(content_bytes))
+
class PyUnicodeWriterTest(unittest.TestCase):
def create_writer(self, size):
@@ -1776,6 +1790,13 @@ class PyUnicodeWriterTest(unittest.TestCase):
self.assertEqual(writer.finish(),
"ascii-latin1=\xE9-euro=\u20AC.")
+ def test_ascii(self):
+ writer = self.create_writer(0)
+ writer.write_ascii(b"Hello ", -1)
+ writer.write_ascii(b"", 0)
+ writer.write_ascii(b"Python! <truncated>", 6)
+ self.assertEqual(writer.finish(), "Hello Python")
+
def test_invalid_utf8(self):
writer = self.create_writer(0)
with self.assertRaises(UnicodeDecodeError):
diff --git a/Lib/test/test_class.py b/Lib/test/test_class.py
index 4c12d43556f..8c7a62a74ba 100644
--- a/Lib/test/test_class.py
+++ b/Lib/test/test_class.py
@@ -652,6 +652,7 @@ class ClassTests(unittest.TestCase):
a = A(hash(A.f)^(-1))
hash(a.f)
+ @cpython_only
def testSetattrWrapperNameIntern(self):
# Issue #25794: __setattr__ should intern the attribute name
class A:
diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py
index 0c99620e27c..580d54e0eb0 100644
--- a/Lib/test/test_clinic.py
+++ b/Lib/test/test_clinic.py
@@ -238,11 +238,11 @@ class ClinicWholeFileTest(TestCase):
# The generated output will differ for every run, but we can check that
# it starts with the clinic block, we check that it contains all the
# expected fields, and we check that it contains the checksum line.
- self.assertTrue(out.startswith(dedent("""
+ self.assertStartsWith(out, dedent("""
/*[clinic input]
output print 'I told you once.'
[clinic start generated code]*/
- """)))
+ """))
fields = {
"cpp_endif",
"cpp_if",
@@ -259,9 +259,7 @@ class ClinicWholeFileTest(TestCase):
with self.subTest(field=field):
self.assertIn(field, out)
last_line = out.rstrip().split("\n")[-1]
- self.assertTrue(
- last_line.startswith("/*[clinic end generated code: output=")
- )
+ self.assertStartsWith(last_line, "/*[clinic end generated code: output=")
def test_directive_wrong_arg_number(self):
raw = dedent("""
@@ -2705,8 +2703,7 @@ class ClinicExternalTest(TestCase):
# Note, we cannot check the entire fail msg, because the path to
# the tmp file will change for every run.
_, err = self.expect_failure(fn)
- self.assertTrue(err.endswith(fail_msg),
- f"{err!r} does not end with {fail_msg!r}")
+ self.assertEndsWith(err, fail_msg)
# Then, force regeneration; success expected.
out = self.expect_success("-f", fn)
self.assertEqual(out, "")
@@ -2717,8 +2714,7 @@ class ClinicExternalTest(TestCase):
)
with open(fn, encoding='utf-8') as f:
generated = f.read()
- self.assertTrue(generated.endswith(checksum),
- (generated, checksum))
+ self.assertEndsWith(generated, checksum)
def test_cli_make(self):
c_code = dedent("""
@@ -2835,6 +2831,10 @@ class ClinicExternalTest(TestCase):
"size_t",
"slice_index",
"str",
+ "uint16",
+ "uint32",
+ "uint64",
+ "uint8",
"unicode",
"unsigned_char",
"unsigned_int",
@@ -2863,8 +2863,8 @@ class ClinicExternalTest(TestCase):
# param may change (it's a set, thus unordered). So, let's compare the
# start and end of the expected output, and then assert that the
# converters appear lined up in alphabetical order.
- self.assertTrue(out.startswith(prelude), out)
- self.assertTrue(out.endswith(finale), out)
+ self.assertStartsWith(out, prelude)
+ self.assertEndsWith(out, finale)
out = out.removeprefix(prelude)
out = out.removesuffix(finale)
@@ -2872,10 +2872,7 @@ class ClinicExternalTest(TestCase):
for converter, line in zip(expected_converters, lines):
line = line.lstrip()
with self.subTest(converter=converter):
- self.assertTrue(
- line.startswith(converter),
- f"expected converter {converter!r}, got {line!r}"
- )
+ self.assertStartsWith(line, converter)
def test_cli_fail_converters_and_filename(self):
_, err = self.expect_failure("--converters", "test.c")
@@ -2985,7 +2982,7 @@ class ClinicFunctionalTest(unittest.TestCase):
regex = (
fr"Passing( more than)?( [0-9]+)? positional argument(s)? to "
fr"{re.escape(name)}\(\) is deprecated. Parameters? {pnames} will "
- fr"become( a)? keyword-only parameters? in Python 3\.14"
+ fr"become( a)? keyword-only parameters? in Python 3\.37"
)
self.check_depr(regex, fn, *args, **kwds)
@@ -2998,7 +2995,7 @@ class ClinicFunctionalTest(unittest.TestCase):
regex = (
fr"Passing keyword argument{pl} {pnames} to "
fr"{re.escape(name)}\(\) is deprecated. Parameter{pl} {pnames} "
- fr"will become positional-only in Python 3\.14."
+ fr"will become positional-only in Python 3\.37."
)
self.check_depr(regex, fn, *args, **kwds)
@@ -3782,9 +3779,9 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", b="b", c="c", d="d", e="e", f="f", g="g", h="h")
errmsg = (
"Passing more than 1 positional argument to depr_star_multi() is deprecated. "
- "Parameter 'b' will become a keyword-only parameter in Python 3.16. "
- "Parameters 'c' and 'd' will become keyword-only parameters in Python 3.15. "
- "Parameters 'e', 'f' and 'g' will become keyword-only parameters in Python 3.14.")
+ "Parameter 'b' will become a keyword-only parameter in Python 3.39. "
+ "Parameters 'c' and 'd' will become keyword-only parameters in Python 3.38. "
+ "Parameters 'e', 'f' and 'g' will become keyword-only parameters in Python 3.37.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", c="c", d="d", e="e", f="f", g="g", h="h")
check("a", "b", "c", d="d", e="e", f="f", g="g", h="h")
@@ -3883,9 +3880,9 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", "b", "c", "d", "e", "f", "g", h="h")
errmsg = (
"Passing keyword arguments 'b', 'c', 'd', 'e', 'f' and 'g' to depr_kwd_multi() is deprecated. "
- "Parameter 'b' will become positional-only in Python 3.14. "
- "Parameters 'c' and 'd' will become positional-only in Python 3.15. "
- "Parameters 'e', 'f' and 'g' will become positional-only in Python 3.16.")
+ "Parameter 'b' will become positional-only in Python 3.37. "
+ "Parameters 'c' and 'd' will become positional-only in Python 3.38. "
+ "Parameters 'e', 'f' and 'g' will become positional-only in Python 3.39.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", "c", "d", "e", "f", g="g", h="h")
check("a", "b", "c", "d", "e", f="f", g="g", h="h")
@@ -3900,8 +3897,8 @@ class ClinicFunctionalTest(unittest.TestCase):
self.assertRaises(TypeError, fn, "a", "b", "c", "d", "e", "f", "g")
errmsg = (
"Passing more than 4 positional arguments to depr_multi() is deprecated. "
- "Parameter 'e' will become a keyword-only parameter in Python 3.15. "
- "Parameter 'f' will become a keyword-only parameter in Python 3.14.")
+ "Parameter 'e' will become a keyword-only parameter in Python 3.38. "
+ "Parameter 'f' will become a keyword-only parameter in Python 3.37.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", "c", "d", "e", "f", g="g")
check("a", "b", "c", "d", "e", f="f", g="g")
@@ -3909,8 +3906,8 @@ class ClinicFunctionalTest(unittest.TestCase):
fn("a", "b", "c", d="d", e="e", f="f", g="g")
errmsg = (
"Passing keyword arguments 'b' and 'c' to depr_multi() is deprecated. "
- "Parameter 'b' will become positional-only in Python 3.14. "
- "Parameter 'c' will become positional-only in Python 3.15.")
+ "Parameter 'b' will become positional-only in Python 3.37. "
+ "Parameter 'c' will become positional-only in Python 3.38.")
check = partial(self.check_depr, re.escape(errmsg), fn)
check("a", "b", c="c", d="d", e="e", f="f", g="g")
check("a", b="b", c="c", d="d", e="e", f="f", g="g")
diff --git a/Lib/test/test_cmd.py b/Lib/test/test_cmd.py
index 46ec82b7049..dbfec42fc21 100644
--- a/Lib/test/test_cmd.py
+++ b/Lib/test/test_cmd.py
@@ -11,9 +11,15 @@ import unittest
import io
import textwrap
from test import support
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.pty_helper import run_pty
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("cmd", {"inspect", "string"})
+
+
class samplecmdclass(cmd.Cmd):
"""
Instance the sampleclass:
@@ -289,6 +295,30 @@ class CmdTestReadline(unittest.TestCase):
self.assertIn(b'ab_completion_test', output)
self.assertIn(b'tab completion success', output)
+ def test_bang_completion_without_do_shell(self):
+ script = textwrap.dedent("""
+ import cmd
+ class simplecmd(cmd.Cmd):
+ def completedefault(self, text, line, begidx, endidx):
+ return ["hello"]
+
+ def default(self, line):
+ if line.replace(" ", "") == "!hello":
+ print('tab completion success')
+ else:
+ print('tab completion failure')
+ return True
+
+ simplecmd().cmdloop()
+ """)
+
+ # '! h' or '!h' and complete 'ello' to 'hello'
+ for input in [b"! h\t\n", b"!h\t\n"]:
+ with self.subTest(input=input):
+ output = run_pty(script, input)
+ self.assertIn(b'hello', output)
+ self.assertIn(b'tab completion success', output)
+
def load_tests(loader, tests, pattern):
tests.addTest(doctest.DocTestSuite())
return tests
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 36f87e259e7..c17d749d4a1 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -39,7 +39,8 @@ class CmdLineTest(unittest.TestCase):
def verify_valid_flag(self, cmd_line):
rc, out, err = assert_python_ok(cmd_line)
- self.assertTrue(out == b'' or out.endswith(b'\n'))
+ if out != b'':
+ self.assertEndsWith(out, b'\n')
self.assertNotIn(b'Traceback', out)
self.assertNotIn(b'Traceback', err)
return out
@@ -89,8 +90,8 @@ class CmdLineTest(unittest.TestCase):
version = ('Python %d.%d' % sys.version_info[:2]).encode("ascii")
for switch in '-V', '--version', '-VV':
rc, out, err = assert_python_ok(switch)
- self.assertFalse(err.startswith(version))
- self.assertTrue(out.startswith(version))
+ self.assertNotStartsWith(err, version)
+ self.assertStartsWith(out, version)
def test_verbose(self):
# -v causes imports to write to stderr. If the write to
@@ -380,7 +381,7 @@ class CmdLineTest(unittest.TestCase):
p.stdin.flush()
data, rc = _kill_python_and_exit_code(p)
self.assertEqual(rc, 0)
- self.assertTrue(data.startswith(b'x'), data)
+ self.assertStartsWith(data, b'x')
def test_large_PYTHONPATH(self):
path1 = "ABCDE" * 100
@@ -972,10 +973,25 @@ class CmdLineTest(unittest.TestCase):
@unittest.skipUnless(support.MS_WINDOWS, 'Test only applicable on Windows')
def test_python_legacy_windows_stdio(self):
- code = "import sys; print(sys.stdin.encoding, sys.stdout.encoding)"
- expected = 'cp'
- rc, out, err = assert_python_ok('-c', code, PYTHONLEGACYWINDOWSSTDIO='1')
- self.assertIn(expected.encode(), out)
+ # Test that _WindowsConsoleIO is used when PYTHONLEGACYWINDOWSSTDIO
+ # is not set.
+ # We cannot use PIPE becase it prevents creating new console.
+ # So we use exit code.
+ code = "import sys; sys.exit(type(sys.stdout.buffer.raw).__name__ != '_WindowsConsoleIO')"
+ env = os.environ.copy()
+ env["PYTHONLEGACYWINDOWSSTDIO"] = ""
+ p = subprocess.run([sys.executable, "-c", code],
+ creationflags=subprocess.CREATE_NEW_CONSOLE,
+ env=env)
+ self.assertEqual(p.returncode, 0)
+
+ # Then test that FIleIO is used when PYTHONLEGACYWINDOWSSTDIO is set.
+ code = "import sys; sys.exit(type(sys.stdout.buffer.raw).__name__ != 'FileIO')"
+ env["PYTHONLEGACYWINDOWSSTDIO"] = "1"
+ p = subprocess.run([sys.executable, "-c", code],
+ creationflags=subprocess.CREATE_NEW_CONSOLE,
+ env=env)
+ self.assertEqual(p.returncode, 0)
@unittest.skipIf("-fsanitize" in sysconfig.get_config_vars().get('PY_CFLAGS', ()),
"PYTHONMALLOCSTATS doesn't work with ASAN")
@@ -1024,7 +1040,7 @@ class CmdLineTest(unittest.TestCase):
stderr=subprocess.PIPE,
text=True)
err_msg = "Unknown option: --unknown-option\nusage: "
- self.assertTrue(proc.stderr.startswith(err_msg), proc.stderr)
+ self.assertStartsWith(proc.stderr, err_msg)
self.assertNotEqual(proc.returncode, 0)
def test_int_max_str_digits(self):
@@ -1158,6 +1174,24 @@ class CmdLineTest(unittest.TestCase):
res = assert_python_ok('-c', code, PYTHON_CPU_COUNT='default')
self.assertEqual(self.res2int(res), (os.cpu_count(), os.process_cpu_count()))
+ def test_import_time(self):
+ # os is not imported at startup
+ code = 'import os; import os'
+
+ for case in 'importtime', 'importtime=1', 'importtime=true':
+ res = assert_python_ok('-X', case, '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertNotRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ res = assert_python_ok('-X', 'importtime=2', '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ assert_python_failure('-X', 'importtime=-1', '-c', code)
+ assert_python_failure('-X', 'importtime=3', '-c', code)
+
def res2int(self, res):
out = res.out.strip().decode("utf-8")
return tuple(int(i) for i in out.split())
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
index 53dc9b1a7ef..784c45aa96f 100644
--- a/Lib/test/test_cmd_line_script.py
+++ b/Lib/test/test_cmd_line_script.py
@@ -553,9 +553,9 @@ class CmdLineTest(unittest.TestCase):
exitcode, stdout, stderr = assert_python_failure(script_name)
text = stderr.decode('ascii').split('\n')
self.assertEqual(len(text), 5)
- self.assertTrue(text[0].startswith('Traceback'))
- self.assertTrue(text[1].startswith(' File '))
- self.assertTrue(text[3].startswith('NameError'))
+ self.assertStartsWith(text[0], 'Traceback')
+ self.assertStartsWith(text[1], ' File ')
+ self.assertStartsWith(text[3], 'NameError')
def test_non_ascii(self):
# Apple platforms deny the creation of a file with an invalid UTF-8 name.
@@ -708,9 +708,8 @@ class CmdLineTest(unittest.TestCase):
exitcode, stdout, stderr = assert_python_failure(script_name)
text = io.TextIOWrapper(io.BytesIO(stderr), 'ascii').read()
# It used to crash in https://github.com/python/cpython/issues/111132
- self.assertTrue(text.endswith(
- 'SyntaxError: nonlocal declaration not allowed at module level\n',
- ), text)
+ self.assertEndsWith(text,
+ 'SyntaxError: nonlocal declaration not allowed at module level\n')
def test_consistent_sys_path_for_direct_execution(self):
# This test case ensures that the following all give the same
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index 7cf09ee7847..655f5a9be7f 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -220,6 +220,7 @@ try:
import _testinternalcapi
except ModuleNotFoundError:
_testinternalcapi = None
+import test._code_definitions as defs
COPY_FREE_VARS = opmap['COPY_FREE_VARS']
@@ -671,9 +672,82 @@ class CodeTest(unittest.TestCase):
VARARGS = CO_FAST_LOCAL | CO_FAST_ARG_VAR | CO_FAST_ARG_POS
VARKWARGS = CO_FAST_LOCAL | CO_FAST_ARG_VAR | CO_FAST_ARG_KW
- import test._code_definitions as defs
funcs = {
+ defs.simple_script: {},
+ defs.complex_script: {
+ 'obj': CO_FAST_LOCAL,
+ 'pickle': CO_FAST_LOCAL,
+ 'spam_minimal': CO_FAST_LOCAL,
+ 'data': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.script_with_globals: {
+ 'obj1': CO_FAST_LOCAL,
+ 'obj2': CO_FAST_LOCAL,
+ },
+ defs.script_with_explicit_empty_return: {},
+ defs.script_with_return: {},
defs.spam_minimal: {},
+ defs.spam_with_builtins: {
+ 'x': CO_FAST_LOCAL,
+ 'values': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_with_globals_and_builtins: {
+ 'func1': CO_FAST_LOCAL,
+ 'func2': CO_FAST_LOCAL,
+ 'funcs': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_with_global_and_attr_same_name: {},
+ defs.spam_full_args: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_full_args_with_defaults: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_args_attrs_and_builtins: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_returns_arg: {
+ 'x': POSORKW,
+ },
+ defs.spam_raises: {},
+ defs.spam_with_inner_not_closure: {
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_with_inner_closure: {
+ 'x': CO_FAST_CELL,
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_annotated: {
+ 'a': POSORKW,
+ 'b': POSORKW,
+ 'c': POSORKW,
+ },
defs.spam_full: {
'a': POSONLY,
'b': POSONLY,
@@ -777,6 +851,319 @@ class CodeTest(unittest.TestCase):
kinds = _testinternalcapi.get_co_localskinds(func.__code__)
self.assertEqual(kinds, expected)
+ @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi")
+ def test_var_counts(self):
+ self.maxDiff = None
+ def new_var_counts(*,
+ posonly=0,
+ posorkw=0,
+ kwonly=0,
+ varargs=0,
+ varkwargs=0,
+ purelocals=0,
+ argcells=0,
+ othercells=0,
+ freevars=0,
+ globalvars=0,
+ attrs=0,
+ unknown=0,
+ ):
+ nargvars = posonly + posorkw + kwonly + varargs + varkwargs
+ nlocals = nargvars + purelocals + othercells
+ if isinstance(globalvars, int):
+ globalvars = {
+ 'total': globalvars,
+ 'numglobal': 0,
+ 'numbuiltin': 0,
+ 'numunknown': globalvars,
+ }
+ else:
+ g_numunknown = 0
+ if isinstance(globalvars, dict):
+ numglobal = globalvars['numglobal']
+ numbuiltin = globalvars['numbuiltin']
+ size = 2
+ if 'numunknown' in globalvars:
+ g_numunknown = globalvars['numunknown']
+ size += 1
+ assert len(globalvars) == size, globalvars
+ else:
+ assert not isinstance(globalvars, str), repr(globalvars)
+ try:
+ numglobal, numbuiltin = globalvars
+ except ValueError:
+ numglobal, numbuiltin, g_numunknown = globalvars
+ globalvars = {
+ 'total': numglobal + numbuiltin + g_numunknown,
+ 'numglobal': numglobal,
+ 'numbuiltin': numbuiltin,
+ 'numunknown': g_numunknown,
+ }
+ unbound = globalvars['total'] + attrs + unknown
+ return {
+ 'total': nlocals + freevars + unbound,
+ 'locals': {
+ 'total': nlocals,
+ 'args': {
+ 'total': nargvars,
+ 'numposonly': posonly,
+ 'numposorkw': posorkw,
+ 'numkwonly': kwonly,
+ 'varargs': varargs,
+ 'varkwargs': varkwargs,
+ },
+ 'numpure': purelocals,
+ 'cells': {
+ 'total': argcells + othercells,
+ 'numargs': argcells,
+ 'numothers': othercells,
+ },
+ 'hidden': {
+ 'total': 0,
+ 'numpure': 0,
+ 'numcells': 0,
+ },
+ },
+ 'numfree': freevars,
+ 'unbound': {
+ 'total': unbound,
+ 'globals': globalvars,
+ 'numattrs': attrs,
+ 'numunknown': unknown,
+ },
+ }
+
+ funcs = {
+ defs.simple_script: new_var_counts(),
+ defs.complex_script: new_var_counts(
+ purelocals=5,
+ globalvars=1,
+ attrs=2,
+ ),
+ defs.script_with_globals: new_var_counts(
+ purelocals=2,
+ globalvars=1,
+ ),
+ defs.script_with_explicit_empty_return: new_var_counts(),
+ defs.script_with_return: new_var_counts(),
+ defs.spam_minimal: new_var_counts(),
+ defs.spam_minimal: new_var_counts(),
+ defs.spam_with_builtins: new_var_counts(
+ purelocals=4,
+ globalvars=4,
+ ),
+ defs.spam_with_globals_and_builtins: new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ ),
+ defs.spam_with_global_and_attr_same_name: new_var_counts(
+ globalvars=2,
+ attrs=1,
+ ),
+ defs.spam_full_args: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ ),
+ defs.spam_full_args_with_defaults: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ ),
+ defs.spam_args_attrs_and_builtins: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ attrs=1,
+ ),
+ defs.spam_returns_arg: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_raises: new_var_counts(
+ globalvars=1,
+ ),
+ defs.spam_with_inner_not_closure: new_var_counts(
+ purelocals=1,
+ ),
+ defs.spam_with_inner_closure: new_var_counts(
+ othercells=1,
+ purelocals=1,
+ ),
+ defs.spam_annotated: new_var_counts(
+ posorkw=3,
+ ),
+ defs.spam_full: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ purelocals=4,
+ globalvars=3,
+ attrs=1,
+ ),
+ defs.spam: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_NN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_NC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.eggs_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.eggs_closure: new_var_counts(
+ posorkw=1,
+ freevars=2,
+ ),
+ defs.eggs_nested_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.eggs_nested_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ freevars=2,
+ ),
+ defs.ham_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_closure: new_var_counts(
+ posorkw=1,
+ freevars=3,
+ ),
+ defs.ham_C_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_C_closure: new_var_counts(
+ posorkw=1,
+ freevars=4,
+ ),
+ }
+ assert len(funcs) == len(defs.FUNCTIONS), (len(funcs), len(defs.FUNCTIONS))
+ for func in defs.FUNCTIONS:
+ with self.subTest(func):
+ expected = funcs[func]
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ func = defs.spam_with_globals_and_builtins
+ with self.subTest(f'{func} code'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} with own globals and builtins'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(2, 4),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without globals'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(0, 4, 2),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without both'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={},
+ builtinsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without builtins'):
+ expected = new_var_counts(
+ purelocals=5,
+ globalvars=(2, 0, 4),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, builtinsns={})
+ self.assertEqual(counts, expected)
+
+ @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi")
+ def test_stateless(self):
+ self.maxDiff = None
+
+ STATELESS_FUNCTIONS = [
+ *defs.STATELESS_FUNCTIONS,
+ # stateless with defaults
+ defs.spam_full_args_with_defaults,
+ ]
+
+ for func in defs.STATELESS_CODE:
+ with self.subTest((func, '(code)')):
+ _testinternalcapi.verify_stateless_code(func.__code__)
+ for func in STATELESS_FUNCTIONS:
+ with self.subTest((func, '(func)')):
+ _testinternalcapi.verify_stateless_code(func)
+
+ for func in defs.FUNCTIONS:
+ if func not in defs.STATELESS_CODE:
+ with self.subTest((func, '(code)')):
+ with self.assertRaises(Exception):
+ _testinternalcapi.verify_stateless_code(func.__code__)
+
+ if func not in STATELESS_FUNCTIONS:
+ with self.subTest((func, '(func)')):
+ with self.assertRaises(Exception):
+ _testinternalcapi.verify_stateless_code(func)
+
def isinterned(s):
return s is sys.intern(('_' + s + '_')[1:-1])
diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py
index 57fb130070b..3642b47c2c1 100644
--- a/Lib/test/test_code_module.py
+++ b/Lib/test/test_code_module.py
@@ -133,7 +133,7 @@ class TestInteractiveConsole(unittest.TestCase, MockSys):
output = ''.join(''.join(call[1]) for call in self.stderr.method_calls)
output = output[output.index('(InteractiveConsole)'):]
output = output[output.index('\n') + 1:]
- self.assertTrue(output.startswith('UnicodeEncodeError: '), output)
+ self.assertStartsWith(output, 'UnicodeEncodeError: ')
self.assertIs(self.sysmod.last_type, UnicodeEncodeError)
self.assertIs(type(self.sysmod.last_value), UnicodeEncodeError)
self.assertIsNone(self.sysmod.last_traceback)
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
index 86e5e5c1474..65d54d1004d 100644
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -1125,7 +1125,7 @@ class CodecCallbackTest(unittest.TestCase):
text = 'abc<def>ghi'*n
text.translate(charmap)
- def test_mutatingdecodehandler(self):
+ def test_mutating_decode_handler(self):
baddata = [
("ascii", b"\xff"),
("utf-7", b"++"),
@@ -1160,6 +1160,42 @@ class CodecCallbackTest(unittest.TestCase):
for (encoding, data) in baddata:
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
+ def test_mutating_decode_handler_unicode_escape(self):
+ decode = codecs.unicode_escape_decode
+ def mutating(exc):
+ if isinstance(exc, UnicodeDecodeError):
+ r = data.get(exc.object[:exc.end])
+ if r is not None:
+ exc.object = r[0] + exc.object[exc.end:]
+ return ('\u0404', r[1])
+ raise AssertionError("don't know how to handle %r" % exc)
+
+ codecs.register_error('test.mutating2', mutating)
+ data = {
+ br'\x0': (b'\\', 0),
+ br'\x3': (b'xxx\\', 3),
+ br'\x5': (b'x\\', 1),
+ }
+ def check(input, expected, msg):
+ with self.assertWarns(DeprecationWarning) as cm:
+ self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
+ self.assertIn(msg, str(cm.warning))
+
+ check(br'\x0n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
+ check(br'\x0n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
+ check(br'\x0z', '\u0404\\z', r'"\z" is an invalid escape sequence')
+
+ check(br'\x3n\zr', '\u0404\n\\zr', r'"\z" is an invalid escape sequence')
+ check(br'\x3zr', '\u0404\\zr', r'"\z" is an invalid escape sequence')
+ check(br'\x3z5', '\u0404\\z5', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'"\z" is an invalid escape sequence')
+
+ check(br'\x5n\z', '\u0404\n\\z', r'"\z" is an invalid escape sequence')
+ check(br'\x5n\501', '\u0404\n\u0141', r'"\501" is an invalid octal escape sequence')
+ check(br'\x5z', '\u0404\\z', r'"\z" is an invalid escape sequence')
+ check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'"\z" is an invalid escape sequence')
+
# issue32583
def test_crashing_decode_handler(self):
# better generating one more character to fill the extra space slot
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index 94fcf98e757..d8666f7290e 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -1,8 +1,10 @@
import codecs
import contextlib
import copy
+import importlib
import io
import pickle
+import os
import sys
import unittest
import encodings
@@ -1196,23 +1198,39 @@ class EscapeDecodeTest(unittest.TestCase):
check(br"[\1010]", b"[A0]")
check(br"[\x41]", b"[A]")
check(br"[\x410]", b"[A0]")
+
+ def test_warnings(self):
+ decode = codecs.escape_decode
+ check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtvx':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, b"\\" + b)
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), b"\\" + b.upper())
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\8" is an invalid escape sequence'):
check(br"\8", b"\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", b"\\9")
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", b"\\\xfa")
for i in range(0o400, 0o1000):
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, bytes([i & 0o377]))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\z" is an invalid escape sequence'):
+ self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\501" is an invalid octal escape sequence'):
+ self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
+
def test_errors(self):
decode = codecs.escape_decode
self.assertRaises(ValueError, decode, br"\x")
@@ -2661,24 +2679,40 @@ class UnicodeEscapeTest(ReadTest, unittest.TestCase):
check(br"[\x410]", "[A0]")
check(br"\u20ac", "\u20ac")
check(br"\U0001d120", "\U0001d120")
+
+ def test_decode_warnings(self):
+ decode = codecs.unicode_escape_decode
+ check = coding_checker(self, decode)
for i in range(97, 123):
b = bytes([i])
if b not in b'abfnrtuvx':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % i):
check(b"\\" + b, "\\" + chr(i))
if b.upper() not in b'UN':
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%c" is an invalid escape sequence' % (i-32)):
check(b"\\" + b.upper(), "\\" + chr(i-32))
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\8" is an invalid escape sequence'):
check(br"\8", "\\8")
with self.assertWarns(DeprecationWarning):
check(br"\9", "\\9")
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\\xfa" is an invalid escape sequence') as cm:
check(b"\\\xfa", "\\\xfa")
for i in range(0o400, 0o1000):
- with self.assertWarns(DeprecationWarning):
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\%o" is an invalid octal escape sequence' % i):
check(rb'\%o' % i, chr(i))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\z" is an invalid escape sequence'):
+ self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
+ with self.assertWarnsRegex(DeprecationWarning,
+ r'"\\501" is an invalid octal escape sequence'):
+ self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
+
def test_decode_errors(self):
decode = codecs.unicode_escape_decode
for c, d in (b'x', 2), (b'u', 4), (b'U', 4):
@@ -3075,6 +3109,13 @@ class TransformCodecTest(unittest.TestCase):
info = codecs.lookup(alias)
self.assertEqual(info.name, expected_name)
+ def test_alias_modules_exist(self):
+ encodings_dir = os.path.dirname(encodings.__file__)
+ for value in encodings.aliases.aliases.values():
+ codec_mod = f"encodings.{value}"
+ self.assertIsNotNone(importlib.util.find_spec(codec_mod),
+ f"Codec module not found: {codec_mod}")
+
def test_quopri_stateless(self):
# Should encode with quotetabs=True
encoded = codecs.encode(b"space tab\teol \n", "quopri-codec")
@@ -3762,7 +3803,7 @@ class LocaleCodecTest(unittest.TestCase):
with self.assertRaises(RuntimeError) as cm:
self.decode(encoded, errors)
errmsg = str(cm.exception)
- self.assertTrue(errmsg.startswith("decode error: "), errmsg)
+ self.assertStartsWith(errmsg, "decode error: ")
else:
decoded = self.decode(encoded, errors)
self.assertEqual(decoded, expected)
diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py
index 0eefc22d11b..ed10bd3dcb6 100644
--- a/Lib/test/test_codeop.py
+++ b/Lib/test/test_codeop.py
@@ -322,7 +322,7 @@ class CodeopTests(unittest.TestCase):
dedent("""\
def foo(x,x):
pass
- """), "duplicate argument 'x' in function definition")
+ """), "duplicate parameter 'x' in function definition")
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 1e93530398b..d9d61e5c205 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -542,6 +542,8 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(Dot(1)._replace(d=999), (999,))
self.assertEqual(Dot(1)._fields, ('d',))
+ @support.requires_resource('cpu')
+ def test_large_size(self):
n = support.exceeds_recursion_limit()
names = list(set(''.join([choice(string.ascii_letters)
for j in range(10)]) for i in range(n)))
diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py
index a580a240d9f..8384c183dd9 100644
--- a/Lib/test/test_compileall.py
+++ b/Lib/test/test_compileall.py
@@ -316,7 +316,7 @@ class CompileallTestsBase:
self.assertTrue(mods)
for mod in mods:
- self.assertTrue(mod.startswith(self.directory), mod)
+ self.assertStartsWith(mod, self.directory)
modcode = importlib.util.cache_from_source(mod)
modpath = mod[len(self.directory+os.sep):]
_, _, err = script_helper.assert_python_failure(modcode)
diff --git a/Lib/test/test_compiler_assemble.py b/Lib/test/test_compiler_assemble.py
index c4962e35999..99a11e99d56 100644
--- a/Lib/test/test_compiler_assemble.py
+++ b/Lib/test/test_compiler_assemble.py
@@ -146,4 +146,4 @@ class IsolatedAssembleTests(AssemblerTestCase):
L1 to L2 -> L2 [0]
L2 to L3 -> L3 [1] lasti
""")
- self.assertTrue(output.getvalue().endswith(exc_table))
+ self.assertEndsWith(output.getvalue(), exc_table)
diff --git a/Lib/test/test_concurrent_futures/test_future.py b/Lib/test/test_concurrent_futures/test_future.py
index 4066ea1ee4b..06b11a3bacf 100644
--- a/Lib/test/test_concurrent_futures/test_future.py
+++ b/Lib/test/test_concurrent_futures/test_future.py
@@ -6,6 +6,7 @@ from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
from test import support
+from test.support import threading_helper
from .util import (
PENDING_FUTURE, RUNNING_FUTURE, CANCELLED_FUTURE,
@@ -282,6 +283,62 @@ class FutureTests(BaseTestCase):
self.assertEqual(f.exception(), e)
+ def test_get_snapshot(self):
+ """Test the _get_snapshot method for atomic state retrieval."""
+ # Test with a pending future
+ f = Future()
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertFalse(done)
+ self.assertFalse(cancelled)
+ self.assertIsNone(result)
+ self.assertIsNone(exception)
+
+ # Test with a finished future (successful result)
+ f = Future()
+ f.set_result(42)
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertFalse(cancelled)
+ self.assertEqual(result, 42)
+ self.assertIsNone(exception)
+
+ # Test with a finished future (exception)
+ f = Future()
+ exc = ValueError("test error")
+ f.set_exception(exc)
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertFalse(cancelled)
+ self.assertIsNone(result)
+ self.assertIs(exception, exc)
+
+ # Test with a cancelled future
+ f = Future()
+ f.cancel()
+ done, cancelled, result, exception = f._get_snapshot()
+ self.assertTrue(done)
+ self.assertTrue(cancelled)
+ self.assertIsNone(result)
+ self.assertIsNone(exception)
+
+ # Test concurrent access (basic thread safety check)
+ f = Future()
+ f.set_result(100)
+ results = []
+
+ def get_snapshot():
+ for _ in range(1000):
+ snapshot = f._get_snapshot()
+ results.append(snapshot)
+
+ threads = [threading.Thread(target=get_snapshot) for _ in range(4)]
+ with threading_helper.start_threads(threads):
+ pass
+ # All snapshots should be identical for a finished future
+ expected = (True, False, 100, None)
+ for result in results:
+ self.assertEqual(result, expected)
+
def setUpModule():
setup_module()
diff --git a/Lib/test/test_concurrent_futures/test_init.py b/Lib/test/test_concurrent_futures/test_init.py
index df640929309..6b8484c0d5f 100644
--- a/Lib/test/test_concurrent_futures/test_init.py
+++ b/Lib/test/test_concurrent_futures/test_init.py
@@ -20,6 +20,10 @@ INITIALIZER_STATUS = 'uninitialized'
def init(x):
global INITIALIZER_STATUS
INITIALIZER_STATUS = x
+ # InterpreterPoolInitializerTest.test_initializer fails
+ # if we don't have a LOAD_GLOBAL. (It could be any global.)
+ # We will address this separately.
+ INITIALIZER_STATUS
def get_init_status():
return INITIALIZER_STATUS
diff --git a/Lib/test/test_concurrent_futures/test_interpreter_pool.py b/Lib/test/test_concurrent_futures/test_interpreter_pool.py
index f6c62ae4b20..844dfdd6fc9 100644
--- a/Lib/test/test_concurrent_futures/test_interpreter_pool.py
+++ b/Lib/test/test_concurrent_futures/test_interpreter_pool.py
@@ -2,35 +2,78 @@ import asyncio
import contextlib
import io
import os
-import pickle
+import sys
import time
import unittest
-from concurrent.futures.interpreter import (
- ExecutionFailed, BrokenInterpreterPool,
-)
+from concurrent.futures.interpreter import BrokenInterpreterPool
+from concurrent import interpreters
+from concurrent.interpreters import _queues as queues
import _interpreters
from test import support
+from test.support import os_helper
+from test.support import script_helper
import test.test_asyncio.utils as testasyncio_utils
-from test.support.interpreters import queues
from .executor import ExecutorTest, mul
from .util import BaseTestCase, InterpreterPoolMixin, setup_module
+WINDOWS = sys.platform.startswith('win')
+
+
+@contextlib.contextmanager
+def nonblocking(fd):
+ blocking = os.get_blocking(fd)
+ if blocking:
+ os.set_blocking(fd, False)
+ try:
+ yield
+ finally:
+ if blocking:
+ os.set_blocking(fd, blocking)
+
+
+def read_file_with_timeout(fd, nbytes, timeout):
+ with nonblocking(fd):
+ end = time.time() + timeout
+ try:
+ return os.read(fd, nbytes)
+ except BlockingIOError:
+ pass
+ while time.time() < end:
+ try:
+ return os.read(fd, nbytes)
+ except BlockingIOError:
+ continue
+ else:
+ raise TimeoutError('nothing to read')
+
+
+if not WINDOWS:
+ import select
+ def read_file_with_timeout(fd, nbytes, timeout):
+ r, _, _ = select.select([fd], [], [], timeout)
+ if fd not in r:
+ raise TimeoutError('nothing to read')
+ return os.read(fd, nbytes)
+
+
def noop():
pass
def write_msg(fd, msg):
+ import os
os.write(fd, msg + b'\0')
-def read_msg(fd):
+def read_msg(fd, timeout=10.0):
msg = b''
- while ch := os.read(fd, 1):
- if ch == b'\0':
- return msg
+ ch = read_file_with_timeout(fd, 1, timeout)
+ while ch != b'\0':
msg += ch
+ ch = os.read(fd, 1)
+ return msg
def get_current_name():
@@ -113,6 +156,38 @@ class InterpreterPoolExecutorTest(
self.assertEqual(before, b'\0')
self.assertEqual(after, msg)
+ def test_init_with___main___global(self):
+ # See https://github.com/python/cpython/pull/133957#issuecomment-2927415311.
+ text = """if True:
+ from concurrent.futures import InterpreterPoolExecutor
+
+ INITIALIZER_STATUS = 'uninitialized'
+
+ def init(x):
+ global INITIALIZER_STATUS
+ INITIALIZER_STATUS = x
+ INITIALIZER_STATUS
+
+ def get_init_status():
+ return INITIALIZER_STATUS
+
+ if __name__ == "__main__":
+ exe = InterpreterPoolExecutor(initializer=init,
+ initargs=('initialized',))
+ fut = exe.submit(get_init_status)
+ print(fut.result()) # 'initialized'
+ exe.shutdown(wait=True)
+ print(INITIALIZER_STATUS) # 'uninitialized'
+ """
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ self.assertEqual(stdout.splitlines(), [
+ 'initialized',
+ 'uninitialized',
+ ])
+
def test_init_closure(self):
count = 0
def init1():
@@ -121,10 +196,19 @@ class InterpreterPoolExecutorTest(
nonlocal count
count += 1
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=init1)
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=init2)
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=init1) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
+
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=init2) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
def test_init_instance_method(self):
class Spam:
@@ -132,26 +216,12 @@ class InterpreterPoolExecutorTest(
raise NotImplementedError
spam = Spam()
- with self.assertRaises(pickle.PicklingError):
- self.executor_type(initializer=spam.initializer)
-
- def test_init_shared(self):
- msg = b'eggs'
- r, w = self.pipe()
- script = f"""if True:
- import os
- if __name__ != '__main__':
- import __main__
- spam = __main__.spam
- os.write({w}, spam + b'\\0')
- """
-
- executor = self.executor_type(shared={'spam': msg})
- fut = executor.submit(exec, script)
- fut.result()
- after = read_msg(r)
-
- self.assertEqual(after, msg)
+ with contextlib.redirect_stderr(io.StringIO()) as stderr:
+ with self.executor_type(initializer=spam.initializer) as executor:
+ fut = executor.submit(lambda: None)
+ self.assertIn('NotShareableError', stderr.getvalue())
+ with self.assertRaises(BrokenInterpreterPool):
+ fut.result()
@unittest.expectedFailure
def test_init_exception_in_script(self):
@@ -178,8 +248,6 @@ class InterpreterPoolExecutorTest(
stderr = stderr.getvalue()
self.assertIn('ExecutionFailed: Exception: spam', stderr)
self.assertIn('Uncaught in the interpreter:', stderr)
- self.assertIn('The above exception was the direct cause of the following exception:',
- stderr)
@unittest.expectedFailure
def test_submit_script(self):
@@ -208,10 +276,14 @@ class InterpreterPoolExecutorTest(
return spam
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(task1)
- with self.assertRaises(pickle.PicklingError):
- executor.submit(task2)
+
+ fut = executor.submit(task1)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
+
+ fut = executor.submit(task2)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_local_instance(self):
class Spam:
@@ -219,8 +291,9 @@ class InterpreterPoolExecutorTest(
self.value = True
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(Spam)
+ fut = executor.submit(Spam)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_instance_method(self):
class Spam:
@@ -229,8 +302,9 @@ class InterpreterPoolExecutorTest(
spam = Spam()
executor = self.executor_type()
- with self.assertRaises(pickle.PicklingError):
- executor.submit(spam.run)
+ fut = executor.submit(spam.run)
+ with self.assertRaises(_interpreters.NotShareableError):
+ fut.result()
def test_submit_func_globals(self):
executor = self.executor_type()
@@ -242,13 +316,14 @@ class InterpreterPoolExecutorTest(
@unittest.expectedFailure
def test_submit_exception_in_script(self):
+ # Scripts are not supported currently.
fut = self.executor.submit('raise Exception("spam")')
with self.assertRaises(Exception) as captured:
fut.result()
self.assertIs(type(captured.exception), Exception)
self.assertEqual(str(captured.exception), 'spam')
cause = captured.exception.__cause__
- self.assertIs(type(cause), ExecutionFailed)
+ self.assertIs(type(cause), interpreters.ExecutionFailed)
for attr in ('__name__', '__qualname__', '__module__'):
self.assertEqual(getattr(cause.excinfo.type, attr),
getattr(Exception, attr))
@@ -261,7 +336,7 @@ class InterpreterPoolExecutorTest(
self.assertIs(type(captured.exception), Exception)
self.assertEqual(str(captured.exception), 'spam')
cause = captured.exception.__cause__
- self.assertIs(type(cause), ExecutionFailed)
+ self.assertIs(type(cause), interpreters.ExecutionFailed)
for attr in ('__name__', '__qualname__', '__module__'):
self.assertEqual(getattr(cause.excinfo.type, attr),
getattr(Exception, attr))
@@ -269,16 +344,93 @@ class InterpreterPoolExecutorTest(
def test_saturation(self):
blocker = queues.create()
- executor = self.executor_type(4, shared=dict(blocker=blocker))
+ executor = self.executor_type(4)
for i in range(15 * executor._max_workers):
- executor.submit(exec, 'import __main__; __main__.blocker.get()')
- #executor.submit('blocker.get()')
+ executor.submit(blocker.get)
self.assertEqual(len(executor._threads), executor._max_workers)
for i in range(15 * executor._max_workers):
blocker.put_nowait(None)
executor.shutdown(wait=True)
+ def test_blocking(self):
+ # There is no guarantee that a worker will be created for every
+ # submitted task. That's because there's a race between:
+ #
+ # * a new worker thread, created when task A was just submitted,
+ # becoming non-idle when it picks up task A
+ # * after task B is added to the queue, a new worker thread
+ # is started only if there are no idle workers
+ # (the check in ThreadPoolExecutor._adjust_thread_count())
+ #
+ # That means we must not block waiting for *all* tasks to report
+ # "ready" before we unblock the known-ready workers.
+ ready = queues.create()
+ blocker = queues.create()
+
+ def run(taskid, ready, blocker):
+ # There can't be any globals here.
+ ready.put_nowait(taskid)
+ blocker.get() # blocking
+
+ numtasks = 10
+ futures = []
+ with self.executor_type() as executor:
+ # Request the jobs.
+ for i in range(numtasks):
+ fut = executor.submit(run, i, ready, blocker)
+ futures.append(fut)
+ pending = numtasks
+ while pending > 0:
+ # Wait for any to be ready.
+ done = 0
+ for _ in range(pending):
+ try:
+ ready.get(timeout=1) # blocking
+ except interpreters.QueueEmpty:
+ pass
+ else:
+ done += 1
+ pending -= done
+ # Unblock the workers.
+ for _ in range(done):
+ blocker.put_nowait(None)
+
+ def test_blocking_with_limited_workers(self):
+ # This is essentially the same as test_blocking,
+ # but we explicitly force a limited number of workers,
+ # instead of it happening implicitly sometimes due to a race.
+ ready = queues.create()
+ blocker = queues.create()
+
+ def run(taskid, ready, blocker):
+ # There can't be any globals here.
+ ready.put_nowait(taskid)
+ blocker.get() # blocking
+
+ numtasks = 10
+ futures = []
+ with self.executor_type(4) as executor:
+ # Request the jobs.
+ for i in range(numtasks):
+ fut = executor.submit(run, i, ready, blocker)
+ futures.append(fut)
+ pending = numtasks
+ while pending > 0:
+ # Wait for any to be ready.
+ done = 0
+ for _ in range(pending):
+ try:
+ ready.get(timeout=1) # blocking
+ except interpreters.QueueEmpty:
+ pass
+ else:
+ done += 1
+ pending -= done
+ # Unblock the workers.
+ for _ in range(done):
+ blocker.put_nowait(None)
+
@support.requires_gil_enabled("gh-117344: test is flaky without the GIL")
def test_idle_thread_reuse(self):
executor = self.executor_type()
@@ -289,12 +441,21 @@ class InterpreterPoolExecutorTest(
executor.shutdown(wait=True)
def test_pickle_errors_propagate(self):
- # GH-125864: Pickle errors happen before the script tries to execute, so the
- # queue used to wait infinitely.
-
+ # GH-125864: Pickle errors happen before the script tries to execute,
+ # so the queue used to wait infinitely.
fut = self.executor.submit(PickleShenanigans(0))
- with self.assertRaisesRegex(RuntimeError, "gotcha"):
+ expected = interpreters.NotShareableError
+ with self.assertRaisesRegex(expected, 'args not shareable') as cm:
fut.result()
+ self.assertRegex(str(cm.exception.__cause__), 'unpickled')
+
+ def test_no_stale_references(self):
+ # Weak references don't cross between interpreters.
+ raise unittest.SkipTest('not applicable')
+
+ def test_free_reference(self):
+ # Weak references don't cross between interpreters.
+ raise unittest.SkipTest('not applicable')
class AsyncioTest(InterpretersMixin, testasyncio_utils.TestCase):
diff --git a/Lib/test/test_concurrent_futures/test_shutdown.py b/Lib/test/test_concurrent_futures/test_shutdown.py
index 7a4065afd46..99b315b47e2 100644
--- a/Lib/test/test_concurrent_futures/test_shutdown.py
+++ b/Lib/test/test_concurrent_futures/test_shutdown.py
@@ -330,6 +330,64 @@ class ProcessPoolShutdownTest(ExecutorShutdownTest):
# shutdown.
assert all([r == abs(v) for r, v in zip(res, range(-5, 5))])
+ @classmethod
+ def _failing_task_gh_132969(cls, n):
+ raise ValueError("failing task")
+
+ @classmethod
+ def _good_task_gh_132969(cls, n):
+ time.sleep(0.1 * n)
+ return n
+
+ def _run_test_issue_gh_132969(self, max_workers):
+ # max_workers=2 will repro exception
+ # max_workers=4 will repro exception and then hang
+
+ # Repro conditions
+ # max_tasks_per_child=1
+ # a task ends abnormally
+ # shutdown(wait=False) is called
+ start_method = self.get_context().get_start_method()
+ if (start_method == "fork" or
+ (start_method == "forkserver" and sys.platform.startswith("win"))):
+ self.skipTest(f"Skipping test for {start_method = }")
+ executor = futures.ProcessPoolExecutor(
+ max_workers=max_workers,
+ max_tasks_per_child=1,
+ mp_context=self.get_context())
+ f1 = executor.submit(ProcessPoolShutdownTest._good_task_gh_132969, 1)
+ f2 = executor.submit(ProcessPoolShutdownTest._failing_task_gh_132969, 2)
+ f3 = executor.submit(ProcessPoolShutdownTest._good_task_gh_132969, 3)
+ result = 0
+ try:
+ result += f1.result()
+ result += f2.result()
+ result += f3.result()
+ except ValueError:
+ # stop processing results upon first exception
+ pass
+
+ # Ensure that the executor cleans up after called
+ # shutdown with wait=False
+ executor_manager_thread = executor._executor_manager_thread
+ executor.shutdown(wait=False)
+ time.sleep(0.2)
+ executor_manager_thread.join()
+ return result
+
+ def test_shutdown_gh_132969_case_1(self):
+ # gh-132969: test that exception "object of type 'NoneType' has no len()"
+ # is not raised when shutdown(wait=False) is called.
+ result = self._run_test_issue_gh_132969(2)
+ self.assertEqual(result, 1)
+
+ def test_shutdown_gh_132969_case_2(self):
+ # gh-132969: test that process does not hang and
+ # exception "object of type 'NoneType' has no len()" is not raised
+ # when shutdown(wait=False) is called.
+ result = self._run_test_issue_gh_132969(4)
+ self.assertEqual(result, 1)
+
create_executor_tests(globals(), ProcessPoolShutdownTest,
executor_mixins=(ProcessPoolForkMixin,
diff --git a/Lib/test/test_configparser.py b/Lib/test/test_configparser.py
index 23904d17d32..e7364e18742 100644
--- a/Lib/test/test_configparser.py
+++ b/Lib/test/test_configparser.py
@@ -986,12 +986,12 @@ class ConfigParserTestCase(BasicTestCase, unittest.TestCase):
def test_defaults_keyword(self):
"""bpo-23835 fix for ConfigParser"""
- cf = self.newconfig(defaults={1: 2.4})
- self.assertEqual(cf[self.default_section]['1'], '2.4')
- self.assertAlmostEqual(cf[self.default_section].getfloat('1'), 2.4)
- cf = self.newconfig(defaults={"A": 5.2})
- self.assertEqual(cf[self.default_section]['a'], '5.2')
- self.assertAlmostEqual(cf[self.default_section].getfloat('a'), 5.2)
+ cf = self.newconfig(defaults={1: 2.5})
+ self.assertEqual(cf[self.default_section]['1'], '2.5')
+ self.assertAlmostEqual(cf[self.default_section].getfloat('1'), 2.5)
+ cf = self.newconfig(defaults={"A": 5.25})
+ self.assertEqual(cf[self.default_section]['a'], '5.25')
+ self.assertAlmostEqual(cf[self.default_section].getfloat('a'), 5.25)
class ConfigParserTestCaseNoInterpolation(BasicTestCase, unittest.TestCase):
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index cf651959803..6a3329fa5aa 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -48,23 +48,23 @@ class TestAbstractContextManager(unittest.TestCase):
def __exit__(self, exc_type, exc_value, traceback):
return None
- self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
+ self.assertIsSubclass(ManagerFromScratch, AbstractContextManager)
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
- self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
+ self.assertIsSubclass(DefaultEnter, AbstractContextManager)
class NoEnter(ManagerFromScratch):
__enter__ = None
- self.assertFalse(issubclass(NoEnter, AbstractContextManager))
+ self.assertNotIsSubclass(NoEnter, AbstractContextManager)
class NoExit(ManagerFromScratch):
__exit__ = None
- self.assertFalse(issubclass(NoExit, AbstractContextManager))
+ self.assertNotIsSubclass(NoExit, AbstractContextManager)
class ContextManagerTestCase(unittest.TestCase):
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index 7750186e56a..dcd00720379 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -77,23 +77,23 @@ class TestAbstractAsyncContextManager(unittest.TestCase):
async def __aexit__(self, exc_type, exc_value, traceback):
return None
- self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
+ self.assertIsSubclass(ManagerFromScratch, AbstractAsyncContextManager)
class DefaultEnter(AbstractAsyncContextManager):
async def __aexit__(self, *args):
await super().__aexit__(*args)
- self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
+ self.assertIsSubclass(DefaultEnter, AbstractAsyncContextManager)
class NoneAenter(ManagerFromScratch):
__aenter__ = None
- self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
+ self.assertNotIsSubclass(NoneAenter, AbstractAsyncContextManager)
class NoneAexit(ManagerFromScratch):
__aexit__ = None
- self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
+ self.assertNotIsSubclass(NoneAexit, AbstractAsyncContextManager)
class AsyncContextManagerTestCase(unittest.TestCase):
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index d76341417e9..467ec09d99e 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -19,7 +19,7 @@ class TestCopy(unittest.TestCase):
def test_exceptions(self):
self.assertIs(copy.Error, copy.error)
- self.assertTrue(issubclass(copy.Error, Exception))
+ self.assertIsSubclass(copy.Error, Exception)
# The copy() method
@@ -372,6 +372,7 @@ class TestCopy(unittest.TestCase):
self.assertIsNot(x[0], y[0])
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_list(self):
x = []
x.append(x)
@@ -400,6 +401,7 @@ class TestCopy(unittest.TestCase):
self.assertIs(x, y)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_tuple(self):
x = ([],)
x[0].append(x)
@@ -418,6 +420,7 @@ class TestCopy(unittest.TestCase):
self.assertIsNot(x["foo"], y["foo"])
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_deepcopy_reflexive_dict(self):
x = {}
x['foo'] = x
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index 761cb230277..4755046fe19 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -527,7 +527,7 @@ class CoroutineTest(unittest.TestCase):
def test_gen_1(self):
def gen(): yield
- self.assertFalse(hasattr(gen, '__await__'))
+ self.assertNotHasAttr(gen, '__await__')
def test_func_1(self):
async def foo():
diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py
index 192c8eab26e..57e818b1c68 100644
--- a/Lib/test/test_cprofile.py
+++ b/Lib/test/test_cprofile.py
@@ -125,21 +125,22 @@ class CProfileTest(ProfileTest):
"""
gh-106152
generator.throw() should trigger a call in cProfile
- In the any() call below, there should be two entries for the generator:
- * one for the call to __next__ which gets a True and terminates any
- * one when the generator is garbage collected which will effectively
- do a throw.
"""
+
+ def gen():
+ yield
+
pr = self.profilerclass()
pr.enable()
- any(a == 1 for a in (1, 2))
+ g = gen()
+ try:
+ g.throw(SyntaxError)
+ except SyntaxError:
+ pass
pr.disable()
pr.create_stats()
- for func, (cc, nc, _, _, _) in pr.stats.items():
- if func[2] == "<genexpr>":
- self.assertEqual(cc, 1)
- self.assertEqual(nc, 1)
+ self.assertTrue(any("throw" in func[2] for func in pr.stats.keys())),
def test_bad_descriptor(self):
# gh-132250
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 32d6fd4e94b..2fa0077a09b 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -1,10 +1,9 @@
import contextlib
-import importlib
-import importlib.util
import itertools
import sys
import types
import unittest
+import warnings
from test.support import import_helper
@@ -16,13 +15,281 @@ from test import _code_definitions as code_defs
from test import _crossinterp_definitions as defs
-BUILTIN_TYPES = [o for _, o in __builtins__.items()
- if isinstance(o, type)]
-EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES
+@contextlib.contextmanager
+def ignore_byteswarning():
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=BytesWarning)
+ yield
+
+
+# builtin types
+
+BUILTINS_TYPES = [o for _, o in __builtins__.items() if isinstance(o, type)]
+EXCEPTION_TYPES = [cls for cls in BUILTINS_TYPES
if issubclass(cls, BaseException)]
OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
- n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+ n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+BUILTIN_TYPES = [
+ *BUILTINS_TYPES,
+ *OTHER_TYPES,
+]
+
+# builtin exceptions
+
+try:
+ raise Exception
+except Exception as exc:
+ CAUGHT = exc
+EXCEPTIONS_WITH_SPECIAL_SIG = {
+ BaseExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ ExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ UnicodeError: (lambda msg: (None, msg, None, None, None)),
+ UnicodeEncodeError: (lambda msg: ('utf-8', '', 1, 3, msg)),
+ UnicodeDecodeError: (lambda msg: ('utf-8', b'', 1, 3, msg)),
+ UnicodeTranslateError: (lambda msg: ('', 1, 3, msg)),
+}
+BUILTIN_EXCEPTIONS = [
+ *(cls(*sig('error!')) for cls, sig in EXCEPTIONS_WITH_SPECIAL_SIG.items()),
+ *(cls('error!') for cls in EXCEPTION_TYPES
+ if cls not in EXCEPTIONS_WITH_SPECIAL_SIG),
+]
+
+# other builtin objects
+
+METHOD = defs.SpamOkay().okay
+BUILTIN_METHOD = [].append
+METHOD_DESCRIPTOR_WRAPPER = str.join
+METHOD_WRAPPER = object().__str__
+WRAPPER_DESCRIPTOR = object.__init__
+BUILTIN_WRAPPERS = {
+ METHOD: types.MethodType,
+ BUILTIN_METHOD: types.BuiltinMethodType,
+ dict.__dict__['fromkeys']: types.ClassMethodDescriptorType,
+ types.FunctionType.__code__: types.GetSetDescriptorType,
+ types.FunctionType.__globals__: types.MemberDescriptorType,
+ METHOD_DESCRIPTOR_WRAPPER: types.MethodDescriptorType,
+ METHOD_WRAPPER: types.MethodWrapperType,
+ WRAPPER_DESCRIPTOR: types.WrapperDescriptorType,
+ staticmethod(defs.SpamOkay.okay): None,
+ classmethod(defs.SpamOkay.okay): None,
+ property(defs.SpamOkay.okay): None,
+}
+BUILTIN_FUNCTIONS = [
+ # types.BuiltinFunctionType
+ len,
+ sys.is_finalizing,
+ sys.exit,
+ _testinternalcapi.get_crossinterp_data,
+]
+assert 'emptymod' not in sys.modules
+with import_helper.ready_to_import('emptymod', ''):
+ import emptymod as EMPTYMOD
+MODULES = [
+ sys,
+ defs,
+ unittest,
+ EMPTYMOD,
+]
+OBJECT = object()
+EXCEPTION = Exception()
+LAMBDA = (lambda: None)
+BUILTIN_SIMPLE = [
+ OBJECT,
+ # singletons
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ # bytes
+ *(i.to_bytes(2, 'little', signed=True)
+ for i in range(-1, 258)),
+ # str
+ 'hello world',
+ '你好世界',
+ '',
+ # int
+ sys.maxsize + 1,
+ sys.maxsize,
+ -sys.maxsize - 1,
+ -sys.maxsize - 2,
+ *range(-1, 258),
+ 2**1000,
+ # float
+ 0.0,
+ 1.1,
+ -1.0,
+ 0.12345678,
+ -0.12345678,
+]
+TUPLE_EXCEPTION = (0, 1.0, EXCEPTION)
+TUPLE_OBJECT = (0, 1.0, OBJECT)
+TUPLE_NESTED_EXCEPTION = (0, 1.0, (EXCEPTION,))
+TUPLE_NESTED_OBJECT = (0, 1.0, (OBJECT,))
+MEMORYVIEW_EMPTY = memoryview(b'')
+MEMORYVIEW_NOT_EMPTY = memoryview(b'spam'*42)
+MAPPING_PROXY_EMPTY = types.MappingProxyType({})
+BUILTIN_CONTAINERS = [
+ # tuple (flat)
+ (),
+ (1,),
+ ("hello", "world", ),
+ (1, True, "hello"),
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ # tuple (nested)
+ ((1,),),
+ ((1, 2), (3, 4)),
+ ((1, 2), (3, 4), (5, 6)),
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+ # buffer
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ # list
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ # dict
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ # set
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ # frozenset
+ frozenset([]),
+ frozenset({frozenset({1}), (2,)}),
+ # bytearray
+ bytearray(b''),
+ # other
+ MAPPING_PROXY_EMPTY,
+ types.SimpleNamespace(),
+]
+ns = {}
+exec("""
+try:
+ raise Exception
+except Exception as exc:
+ TRACEBACK = exc.__traceback__
+ FRAME = TRACEBACK.tb_frame
+""", ns, ns)
+BUILTIN_OTHER = [
+ # types.CellType
+ types.CellType(),
+ # types.FrameType
+ ns['FRAME'],
+ # types.TracebackType
+ ns['TRACEBACK'],
+]
+del ns
+
+# user-defined objects
+
+USER_TOP_INSTANCES = [c(*a) for c, a in defs.TOP_CLASSES.items()]
+USER_NESTED_INSTANCES = [c(*a) for c, a in defs.NESTED_CLASSES.items()]
+USER_INSTANCES = [
+ *USER_TOP_INSTANCES,
+ *USER_NESTED_INSTANCES,
+]
+USER_EXCEPTIONS = [
+ defs.MimimalError('error!'),
+]
+
+# shareable objects
+
+TUPLES_WITHOUT_EQUALITY = [
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+]
+_UNSHAREABLE_SIMPLE = [
+ Ellipsis,
+ NotImplemented,
+ OBJECT,
+ sys.maxsize + 1,
+ -sys.maxsize - 2,
+ 2**1000,
+]
+with ignore_byteswarning():
+ _SHAREABLE_SIMPLE = [o for o in BUILTIN_SIMPLE
+ if o not in _UNSHAREABLE_SIMPLE]
+ _SHAREABLE_CONTAINERS = [
+ *(o for o in BUILTIN_CONTAINERS if type(o) is memoryview),
+ *(o for o in BUILTIN_CONTAINERS
+ if type(o) is tuple and o not in TUPLES_WITHOUT_EQUALITY),
+ ]
+ _UNSHAREABLE_CONTAINERS = [o for o in BUILTIN_CONTAINERS
+ if o not in _SHAREABLE_CONTAINERS]
+SHAREABLE = [
+ *_SHAREABLE_SIMPLE,
+ *_SHAREABLE_CONTAINERS,
+]
+NOT_SHAREABLE = [
+ *_UNSHAREABLE_SIMPLE,
+ *_UNSHAREABLE_CONTAINERS,
+ *BUILTIN_TYPES,
+ *BUILTIN_WRAPPERS,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *MODULES,
+ *BUILTIN_OTHER,
+ # types.CodeType
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ # types.FunctionType
+ *defs.FUNCTIONS,
+ defs.SpamOkay.okay,
+ LAMBDA,
+ *defs.FUNCTION_LIKE,
+ # coroutines and generators
+ *defs.FUNCTION_LIKE_APPLIED,
+ # user classes
+ *defs.CLASSES,
+ *USER_INSTANCES,
+ # user exceptions
+ *USER_EXCEPTIONS,
+]
+
+# pickleable objects
+
+PICKLEABLE = [
+ *BUILTIN_SIMPLE,
+ *(o for o in BUILTIN_CONTAINERS if o not in [
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ MAPPING_PROXY_EMPTY,
+ ] or type(o) is dict),
+ *BUILTINS_TYPES,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *defs.TOP_FUNCTIONS,
+ defs.SpamOkay.okay,
+ *defs.FUNCTION_LIKE,
+ *defs.TOP_CLASSES,
+ *USER_TOP_INSTANCES,
+ *USER_EXCEPTIONS,
+ # from OTHER_TYPES
+ types.NoneType,
+ types.EllipsisType,
+ types.NotImplementedType,
+ types.GenericAlias,
+ types.UnionType,
+ types.SimpleNamespace,
+ # from BUILTIN_WRAPPERS
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_DESCRIPTOR_WRAPPER,
+ METHOD_WRAPPER,
+ WRAPPER_DESCRIPTOR,
+]
+assert not any(isinstance(o, types.MappingProxyType) for o in PICKLEABLE)
+
+
+# helpers
DEFS = defs
with open(code_defs.__file__) as infile:
@@ -111,6 +378,77 @@ class _GetXIDataTests(unittest.TestCase):
MODE = None
+ def assert_functions_equal(self, func1, func2):
+ assert type(func1) is types.FunctionType, repr(func1)
+ assert type(func2) is types.FunctionType, repr(func2)
+ self.assertEqual(func1.__name__, func2.__name__)
+ self.assertEqual(func1.__code__, func2.__code__)
+ self.assertEqual(func1.__defaults__, func2.__defaults__)
+ self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__)
+ # We don't worry about __globals__ for now.
+
+ def assert_exc_args_equal(self, exc1, exc2):
+ args1 = exc1.args
+ args2 = exc2.args
+ if isinstance(exc1, ExceptionGroup):
+ self.assertIs(type(args1), type(args2))
+ self.assertEqual(len(args1), 2)
+ self.assertEqual(len(args1), len(args2))
+ self.assertEqual(args1[0], args2[0])
+ group1 = args1[1]
+ group2 = args2[1]
+ self.assertEqual(len(group1), len(group2))
+ for grouped1, grouped2 in zip(group1, group2):
+ # Currently the "extra" attrs are not preserved
+ # (via __reduce__).
+ self.assertIs(type(exc1), type(exc2))
+ self.assert_exc_equal(grouped1, grouped2)
+ else:
+ self.assertEqual(args1, args2)
+
+ def assert_exc_equal(self, exc1, exc2):
+ self.assertIs(type(exc1), type(exc2))
+
+ if type(exc1).__eq__ is not object.__eq__:
+ self.assertEqual(exc1, exc2)
+
+ self.assert_exc_args_equal(exc1, exc2)
+ # XXX For now we do not preserve tracebacks.
+ if exc1.__traceback__ is not None:
+ self.assertEqual(exc1.__traceback__, exc2.__traceback__)
+ self.assertEqual(
+ getattr(exc1, '__notes__', None),
+ getattr(exc2, '__notes__', None),
+ )
+ # We assume there are no cycles.
+ if exc1.__cause__ is None:
+ self.assertIs(exc1.__cause__, exc2.__cause__)
+ else:
+ self.assert_exc_equal(exc1.__cause__, exc2.__cause__)
+ if exc1.__context__ is None:
+ self.assertIs(exc1.__context__, exc2.__context__)
+ else:
+ self.assert_exc_equal(exc1.__context__, exc2.__context__)
+
+ def assert_equal_or_equalish(self, obj, expected):
+ cls = type(expected)
+ if cls.__eq__ is not object.__eq__:
+ self.assertEqual(obj, expected)
+ elif cls is types.FunctionType:
+ self.assert_functions_equal(obj, expected)
+ elif isinstance(expected, BaseException):
+ self.assert_exc_equal(obj, expected)
+ elif cls is types.MethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.BuiltinMethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.MethodWrapperType:
+ raise NotImplementedError(cls)
+ elif cls.__bases__ == (object,):
+ self.assertEqual(obj.__dict__, expected.__dict__)
+ else:
+ raise NotImplementedError(cls)
+
def get_xidata(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
return _testinternalcapi.get_crossinterp_data(obj, mode)
@@ -126,35 +464,37 @@ class _GetXIDataTests(unittest.TestCase):
def assert_roundtrip_identical(self, values, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
- self.assertEqual(got, obj)
+ if got is obj:
+ continue
self.assertIs(type(got),
type(obj) if expecttype is None else expecttype)
+ self.assert_equal_or_equalish(got, obj)
def assert_roundtrip_equal_not_identical(self, values, *,
mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIsNot(got, obj)
self.assertIs(type(got),
type(obj) if expecttype is None else expecttype)
- self.assertEqual(got, obj)
+ self.assert_equal_or_equalish(got, obj)
def assert_roundtrip_not_equal(self, values, *,
mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIsNot(got, obj)
self.assertIs(type(got),
@@ -164,7 +504,7 @@ class _GetXIDataTests(unittest.TestCase):
def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
with self.assertRaises(NotShareableError) as cm:
_testinternalcapi.get_crossinterp_data(obj, mode)
if exctype is not None:
@@ -182,49 +522,26 @@ class PickleTests(_GetXIDataTests):
MODE = 'pickle'
def test_shareable(self):
- self.assert_roundtrip_equal([
- # singletons
- None,
- True,
- False,
- # bytes
- *(i.to_bytes(2, 'little', signed=True)
- for i in range(-1, 258)),
- # str
- 'hello world',
- '你好世界',
- '',
- # int
- sys.maxsize,
- -sys.maxsize - 1,
- *range(-1, 258),
- # float
- 0.0,
- 1.1,
- -1.0,
- 0.12345678,
- -0.12345678,
- # tuple
- (),
- (1,),
- ("hello", "world", ),
- (1, True, "hello"),
- ((1,),),
- ((1, 2), (3, 4)),
- ((1, 2), (3, 4), (5, 6)),
- ])
- # not shareable using xidata
- self.assert_roundtrip_equal([
- # int
- sys.maxsize + 1,
- -sys.maxsize - 2,
- 2**1000,
- # tuple
- (0, 1.0, []),
- (0, 1.0, {}),
- (0, 1.0, ([],)),
- (0, 1.0, ({},)),
- ])
+ with ignore_byteswarning():
+ for obj in SHAREABLE:
+ if obj in PICKLEABLE:
+ self.assert_roundtrip_equal([obj])
+ else:
+ self.assert_not_shareable([obj])
+
+ def test_not_shareable(self):
+ with ignore_byteswarning():
+ for obj in NOT_SHAREABLE:
+ if type(obj) is types.MappingProxyType:
+ self.assert_not_shareable([obj])
+ elif obj in PICKLEABLE:
+ with self.subTest(repr(obj)):
+ # We don't worry about checking the actual value.
+ # The other tests should cover that well enough.
+ got = self.get_roundtrip(obj)
+ self.assertIs(type(got), type(obj))
+ else:
+ self.assert_not_shareable([obj])
def test_list(self):
self.assert_roundtrip_equal_not_identical([
@@ -266,7 +583,7 @@ class PickleTests(_GetXIDataTests):
if cls not in defs.CLASSES_WITHOUT_EQUALITY:
continue
instances.append(cls(*args))
- self.assert_roundtrip_not_equal(instances)
+ self.assert_roundtrip_equal(instances)
def assert_class_defs_other_pickle(self, defs, mod):
# Pickle relative to a different module than the original.
@@ -286,7 +603,7 @@ class PickleTests(_GetXIDataTests):
instances = []
for cls, args in defs.TOP_CLASSES.items():
- with self.subTest(cls):
+ with self.subTest(repr(cls)):
setattr(mod, cls.__name__, cls)
xid = self.get_xidata(cls)
inst = cls(*args)
@@ -295,7 +612,7 @@ class PickleTests(_GetXIDataTests):
(cls, xid, inst, instxid))
for cls, xid, inst, instxid in instances:
- with self.subTest(cls):
+ with self.subTest(repr(cls)):
delattr(mod, cls.__name__)
if fail:
with self.assertRaises(NotShareableError):
@@ -403,13 +720,13 @@ class PickleTests(_GetXIDataTests):
def assert_func_defs_other_pickle(self, defs, mod):
# Pickle relative to a different module than the original.
for func in defs.TOP_FUNCTIONS:
- assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
self.assert_not_shareable(defs.TOP_FUNCTIONS)
def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
# Unpickle relative to a different module than the original.
for func in defs.TOP_FUNCTIONS:
- assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
captured = []
for func in defs.TOP_FUNCTIONS:
@@ -434,7 +751,7 @@ class PickleTests(_GetXIDataTests):
self.assert_not_shareable(defs.TOP_FUNCTIONS)
def test_user_function_normal(self):
-# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+ self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
self.assert_func_defs_same(defs)
def test_user_func_in___main__(self):
@@ -505,7 +822,7 @@ class PickleTests(_GetXIDataTests):
# exceptions
def test_user_exception_normal(self):
- self.assert_roundtrip_not_equal([
+ self.assert_roundtrip_equal([
defs.MimimalError('error!'),
])
self.assert_roundtrip_equal_not_identical([
@@ -521,7 +838,7 @@ class PickleTests(_GetXIDataTests):
special = {
BaseExceptionGroup: (msg, [caught]),
ExceptionGroup: (msg, [caught]),
-# UnicodeError: (None, msg, None, None, None),
+ UnicodeError: (None, msg, None, None, None),
UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
UnicodeTranslateError: ('', 1, 3, msg),
@@ -531,7 +848,7 @@ class PickleTests(_GetXIDataTests):
args = special.get(cls) or (msg,)
exceptions.append(cls(*args))
- self.assert_roundtrip_not_equal(exceptions)
+ self.assert_roundtrip_equal(exceptions)
class MarshalTests(_GetXIDataTests):
@@ -576,7 +893,7 @@ class MarshalTests(_GetXIDataTests):
'',
])
self.assert_not_shareable([
- object(),
+ OBJECT,
types.SimpleNamespace(),
])
@@ -647,10 +964,7 @@ class MarshalTests(_GetXIDataTests):
shareable = [
StopIteration,
]
- types = [
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ]
+ types = BUILTIN_TYPES
self.assert_not_shareable(cls for cls in types
if cls not in shareable)
self.assert_roundtrip_identical(cls for cls in types
@@ -725,10 +1039,236 @@ class MarshalTests(_GetXIDataTests):
])
+class CodeTests(_GetXIDataTests):
+
+ MODE = 'code'
+
+ def test_function_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
+ def test_functions(self):
+ self.assert_not_shareable([
+ *defs.FUNCTIONS,
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ 9999,
+ 'spam',
+ b'spam',
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class ShareableFuncTests(_GetXIDataTests):
+
+ MODE = 'func'
+
+ def test_stateless(self):
+ self.assert_roundtrip_equal([
+ *defs.STATELESS_FUNCTIONS,
+ # Generators can be stateless too.
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_not_stateless(self):
+ self.assert_not_shareable([
+ *(f for f in defs.FUNCTIONS
+ if f not in defs.STATELESS_FUNCTIONS),
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ 9999,
+ 'spam',
+ b'spam',
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class PureShareableScriptTests(_GetXIDataTests):
+
+ MODE = 'script-pure'
+
+ VALID_SCRIPTS = [
+ '',
+ 'spam',
+ '# a comment',
+ 'print("spam")',
+ 'raise Exception("spam")',
+ """if True:
+ do_something()
+ """,
+ """if True:
+ def spam(x):
+ return x
+ class Spam:
+ def eggs(self):
+ return 42
+ x = Spam().eggs()
+ raise ValueError(spam(x))
+ """,
+ ]
+ INVALID_SCRIPTS = [
+ ' pass', # IndentationError
+ '----', # SyntaxError
+ """if True:
+ def spam():
+ # no body
+ spam()
+ """, # IndentationError
+ ]
+
+ def test_valid_str(self):
+ self.assert_roundtrip_not_equal([
+ *self.VALID_SCRIPTS,
+ ], expecttype=types.CodeType)
+
+ def test_invalid_str(self):
+ self.assert_not_shareable([
+ *self.INVALID_SCRIPTS,
+ ])
+
+ def test_valid_bytes(self):
+ self.assert_roundtrip_not_equal([
+ *(s.encode('utf8') for s in self.VALID_SCRIPTS),
+ ], expecttype=types.CodeType)
+
+ def test_invalid_bytes(self):
+ self.assert_not_shareable([
+ *(s.encode('utf8') for s in self.INVALID_SCRIPTS),
+ ])
+
+ def test_pure_script_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_impure_script_code(self):
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_other_code(self):
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.FUNCTIONS
+ if f not in defs.SCRIPT_FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
+ def test_pure_script_function(self):
+ self.assert_roundtrip_not_equal([
+ *defs.PURE_SCRIPT_FUNCTIONS,
+ ], expecttype=types.CodeType)
+
+ def test_impure_script_function(self):
+ self.assert_not_shareable([
+ *(f for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_other_function(self):
+ self.assert_not_shareable([
+ *(f for f in defs.FUNCTIONS
+ if f not in defs.SCRIPT_FUNCTIONS),
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
+class ShareableScriptTests(PureShareableScriptTests):
+
+ MODE = 'script'
+
+ def test_impure_script_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ])
+
+ def test_impure_script_function(self):
+ self.assert_roundtrip_not_equal([
+ *(f for f in defs.SCRIPT_FUNCTIONS
+ if f not in defs.PURE_SCRIPT_FUNCTIONS),
+ ], expecttype=types.CodeType)
+
+
+class ShareableFallbackTests(_GetXIDataTests):
+
+ MODE = 'fallback'
+
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
+ def test_not_shareable(self):
+ okay = [
+ *PICKLEABLE,
+ *defs.STATELESS_FUNCTIONS,
+ LAMBDA,
+ ]
+ ignored = [
+ *TUPLES_WITHOUT_EQUALITY,
+ OBJECT,
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_WRAPPER,
+ ]
+ with ignore_byteswarning():
+ self.assert_roundtrip_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in okay and o not in ignored
+ and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_roundtrip_not_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in ignored and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_not_shareable([
+ *(o for o in NOT_SHAREABLE if o not in okay),
+ MAPPING_PROXY_EMPTY,
+ ])
+
+
class ShareableTypeTests(_GetXIDataTests):
MODE = 'xidata'
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
def test_singletons(self):
self.assert_roundtrip_identical([
None,
@@ -796,8 +1336,8 @@ class ShareableTypeTests(_GetXIDataTests):
def test_tuples_containing_non_shareable_types(self):
non_shareables = [
- Exception(),
- object(),
+ EXCEPTION,
+ OBJECT,
]
for s in non_shareables:
value = tuple([0, 1.0, s])
@@ -812,21 +1352,31 @@ class ShareableTypeTests(_GetXIDataTests):
# The rest are not shareable.
+ def test_not_shareable(self):
+ self.assert_not_shareable(NOT_SHAREABLE)
+
def test_object(self):
self.assert_not_shareable([
object(),
])
+ def test_code(self):
+ # types.CodeType
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
def test_function_object(self):
for func in defs.FUNCTIONS:
assert type(func) is types.FunctionType, func
assert type(defs.SpamOkay.okay) is types.FunctionType, func
- assert type(lambda: None) is types.LambdaType
+ assert type(LAMBDA) is types.LambdaType
self.assert_not_shareable([
*defs.FUNCTIONS,
defs.SpamOkay.okay,
- (lambda: None),
+ LAMBDA,
])
def test_builtin_function(self):
@@ -891,10 +1441,7 @@ class ShareableTypeTests(_GetXIDataTests):
self.assert_not_shareable(instances)
def test_builtin_type(self):
- self.assert_not_shareable([
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ])
+ self.assert_not_shareable(BUILTIN_TYPES)
def test_exception(self):
self.assert_not_shareable([
@@ -933,14 +1480,8 @@ class ShareableTypeTests(_GetXIDataTests):
""", ns, ns)
self.assert_not_shareable([
- types.MappingProxyType({}),
+ MAPPING_PROXY_EMPTY,
types.SimpleNamespace(),
- # types.CodeType
- defs.spam_minimal.__code__,
- defs.spam_full.__code__,
- defs.spam_CC.__code__,
- defs.eggs_closure_C.__code__,
- defs.ham_C_closure.__code__,
# types.CellType
types.CellType(),
# types.FrameType
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 4af8f7f480e..60feab225a1 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -10,7 +10,8 @@ import csv
import gc
import pickle
from test import support
-from test.support import import_helper, check_disallow_instantiation
+from test.support import cpython_only, import_helper, check_disallow_instantiation
+from test.support.import_helper import ensure_lazy_imports
from itertools import permutations
from textwrap import dedent
from collections import OrderedDict
@@ -1121,19 +1122,22 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be a 1-character string')
+ '"quotechar" must be a unicode character or None, '
+ 'not a string of length 0')
mydialect.quotechar = "''"
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be a 1-character string')
+ '"quotechar" must be a unicode character or None, '
+ 'not a string of length 2')
mydialect.quotechar = 4
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"quotechar" must be string or None, not int')
+ '"quotechar" must be a unicode character or None, '
+ 'not int')
def test_delimiter(self):
class mydialect(csv.Dialect):
@@ -1150,31 +1154,32 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be a 1-character string')
+ '"delimiter" must be a unicode character, '
+ 'not a string of length 3')
mydialect.delimiter = ""
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be a 1-character string')
+ '"delimiter" must be a unicode character, not a string of length 0')
mydialect.delimiter = b","
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not bytes')
+ '"delimiter" must be a unicode character, not bytes')
mydialect.delimiter = 4
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not int')
+ '"delimiter" must be a unicode character, not int')
mydialect.delimiter = None
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"delimiter" must be string, not NoneType')
+ '"delimiter" must be a unicode character, not NoneType')
def test_escapechar(self):
class mydialect(csv.Dialect):
@@ -1188,20 +1193,32 @@ class TestDialectValidity(unittest.TestCase):
self.assertEqual(d.escapechar, "\\")
mydialect.escapechar = ""
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not a string of length 0')
mydialect.escapechar = "**"
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not a string of length 2')
mydialect.escapechar = b"*"
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not bytes'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not bytes')
mydialect.escapechar = 4
- with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not int'):
+ with self.assertRaises(csv.Error) as cm:
mydialect()
+ self.assertEqual(str(cm.exception),
+ '"escapechar" must be a unicode character or None, '
+ 'not int')
def test_lineterminator(self):
class mydialect(csv.Dialect):
@@ -1222,7 +1239,13 @@ class TestDialectValidity(unittest.TestCase):
with self.assertRaises(csv.Error) as cm:
mydialect()
self.assertEqual(str(cm.exception),
- '"lineterminator" must be a string')
+ '"lineterminator" must be a string, not int')
+
+ mydialect.lineterminator = None
+ with self.assertRaises(csv.Error) as cm:
+ mydialect()
+ self.assertEqual(str(cm.exception),
+ '"lineterminator" must be a string, not NoneType')
def test_invalid_chars(self):
def create_invalid(field_name, value, **kwargs):
@@ -1565,6 +1588,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, csv, ('csv', '_csv'))
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("csv", {"re"})
+
def test_subclassable(self):
# issue 44089
class Foo(csv.Error): ...
diff --git a/Lib/test/test_ctypes/_support.py b/Lib/test/test_ctypes/_support.py
index 946d654a19a..700657a4e41 100644
--- a/Lib/test/test_ctypes/_support.py
+++ b/Lib/test/test_ctypes/_support.py
@@ -3,7 +3,6 @@
import ctypes
from _ctypes import Structure, Union, _Pointer, Array, _SimpleCData, CFuncPtr
import sys
-from test import support
_CData = Structure.__base__
diff --git a/Lib/test/test_ctypes/test_aligned_structures.py b/Lib/test/test_ctypes/test_aligned_structures.py
index 0c563ab8055..50b4d729b9d 100644
--- a/Lib/test/test_ctypes/test_aligned_structures.py
+++ b/Lib/test/test_ctypes/test_aligned_structures.py
@@ -316,6 +316,7 @@ class TestAlignedStructures(unittest.TestCase, StructCheckMixin):
class Main(sbase):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", c_ubyte),
("b", Inner),
diff --git a/Lib/test/test_ctypes/test_bitfields.py b/Lib/test/test_ctypes/test_bitfields.py
index dc81e752567..518f838219e 100644
--- a/Lib/test/test_ctypes/test_bitfields.py
+++ b/Lib/test/test_ctypes/test_bitfields.py
@@ -430,6 +430,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_84039(self):
class Bad(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -443,9 +444,9 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
("b1", c_uint16, 12),
]
-
class GoodA(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -460,6 +461,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
class Good(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", GoodA),
("b0", c_uint16, 4),
@@ -475,6 +477,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_73939(self):
class MyStructure(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("P", c_uint16),
("L", c_uint16, 9),
diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py
index 9f9904282e4..f14e1aa32e1 100644
--- a/Lib/test/test_ctypes/test_byteswap.py
+++ b/Lib/test/test_ctypes/test_byteswap.py
@@ -1,5 +1,4 @@
import binascii
-import ctypes
import math
import struct
import sys
@@ -269,6 +268,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(base):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
@@ -296,6 +296,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
diff --git a/Lib/test/test_ctypes/test_generated_structs.py b/Lib/test/test_ctypes/test_generated_structs.py
index 9a8102219d8..1cb46a82701 100644
--- a/Lib/test/test_ctypes/test_generated_structs.py
+++ b/Lib/test/test_ctypes/test_generated_structs.py
@@ -10,7 +10,7 @@ Run this module to regenerate the files:
"""
import unittest
-from test.support import import_helper, verbose
+from test.support import import_helper
import re
from dataclasses import dataclass
from functools import cached_property
@@ -125,18 +125,21 @@ class Nested(Structure):
class Packed1(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 1
+ _layout_ = 'ms'
@register()
class Packed2(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 2
+ _layout_ = 'ms'
@register()
class Packed3(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 4
+ _layout_ = 'ms'
@register()
@@ -155,6 +158,7 @@ class Packed4(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 8
+ _layout_ = 'ms'
@register()
class X86_32EdgeCase(Structure):
@@ -366,6 +370,7 @@ class Example_gh_95496(Structure):
@register()
class Example_gh_84039_bad(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -380,6 +385,7 @@ class Example_gh_84039_bad(Structure):
@register()
class Example_gh_84039_good_a(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -392,6 +398,7 @@ class Example_gh_84039_good_a(Structure):
@register()
class Example_gh_84039_good(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", Example_gh_84039_good_a),
("b0", c_uint16, 4),
("b1", c_uint16, 12)]
@@ -399,6 +406,7 @@ class Example_gh_84039_good(Structure):
@register()
class Example_gh_73939(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("P", c_uint16),
("L", c_uint16, 9),
("Pro", c_uint16, 1),
@@ -419,6 +427,7 @@ class Example_gh_86098(Structure):
@register()
class Example_gh_86098_pack(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", c_uint8, 8),
("b", c_uint8, 8),
("c", c_uint32, 16)]
@@ -528,7 +537,7 @@ def dump_ctype(tp, struct_or_union_tag='', variable_name='', semi=''):
pushes.append(f'#pragma pack(push, {pack})')
pops.append(f'#pragma pack(pop)')
layout = getattr(tp, '_layout_', None)
- if layout == 'ms' or pack:
+ if layout == 'ms':
# The 'ms_struct' attribute only works on x86 and PowerPC
requires.add(
'defined(MS_WIN32) || ('
diff --git a/Lib/test/test_ctypes/test_incomplete.py b/Lib/test/test_ctypes/test_incomplete.py
index fefdfe9102e..3189fcd1bd1 100644
--- a/Lib/test/test_ctypes/test_incomplete.py
+++ b/Lib/test/test_ctypes/test_incomplete.py
@@ -1,6 +1,5 @@
import ctypes
import unittest
-import warnings
from ctypes import Structure, POINTER, pointer, c_char_p
# String-based "incomplete pointers" were implemented in ctypes 0.6.3 (2003, when
@@ -21,9 +20,7 @@ class TestSetPointerType(unittest.TestCase):
_fields_ = [("name", c_char_p),
("next", lpcell)]
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- ctypes.SetPointerType(lpcell, cell)
+ lpcell.set_type(cell)
self.assertIs(POINTER(cell), lpcell)
@@ -50,10 +47,9 @@ class TestSetPointerType(unittest.TestCase):
_fields_ = [("name", c_char_p),
("next", lpcell)]
- with self.assertWarns(DeprecationWarning):
- ctypes.SetPointerType(lpcell, cell)
-
+ lpcell.set_type(cell)
self.assertIs(POINTER(cell), lpcell)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_ctypes/test_parameters.py b/Lib/test/test_ctypes/test_parameters.py
index f89521cf8b3..46f8ff93efa 100644
--- a/Lib/test/test_ctypes/test_parameters.py
+++ b/Lib/test/test_ctypes/test_parameters.py
@@ -1,3 +1,4 @@
+import sys
import unittest
import test.support
from ctypes import (CDLL, PyDLL, ArgumentError,
@@ -240,7 +241,8 @@ class SimpleTypesTestCase(unittest.TestCase):
self.assertRegex(repr(c_ulonglong.from_param(20000)), r"^<cparam '[LIQ]' \(20000\)>$")
self.assertEqual(repr(c_float.from_param(1.5)), "<cparam 'f' (1.5)>")
self.assertEqual(repr(c_double.from_param(1.5)), "<cparam 'd' (1.5)>")
- self.assertEqual(repr(c_double.from_param(1e300)), "<cparam 'd' (1e+300)>")
+ if sys.float_repr_style == 'short':
+ self.assertEqual(repr(c_double.from_param(1e300)), "<cparam 'd' (1e+300)>")
self.assertRegex(repr(c_longdouble.from_param(1.5)), r"^<cparam ('d' \(1.5\)|'g' at 0x[A-Fa-f0-9]+)>$")
self.assertRegex(repr(c_char_p.from_param(b'hihi')), r"^<cparam 'z' \(0x[A-Fa-f0-9]+\)>$")
self.assertRegex(repr(c_wchar_p.from_param('hihi')), r"^<cparam 'Z' \(0x[A-Fa-f0-9]+\)>$")
diff --git a/Lib/test/test_ctypes/test_pep3118.py b/Lib/test/test_ctypes/test_pep3118.py
index 06b2ccecade..11a0744f5a8 100644
--- a/Lib/test/test_ctypes/test_pep3118.py
+++ b/Lib/test/test_ctypes/test_pep3118.py
@@ -81,6 +81,7 @@ class Point(Structure):
class PackedPoint(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_long), ("y", c_long)]
class PointMidPad(Structure):
@@ -88,6 +89,7 @@ class PointMidPad(Structure):
class PackedPointMidPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_byte), ("y", c_uint64)]
class PointEndPad(Structure):
@@ -95,6 +97,7 @@ class PointEndPad(Structure):
class PackedPointEndPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_uint64), ("y", c_byte)]
class Point2(Structure):
diff --git a/Lib/test/test_ctypes/test_structunion.py b/Lib/test/test_ctypes/test_structunion.py
index 8d8b7e5e995..5b21d48d99c 100644
--- a/Lib/test/test_ctypes/test_structunion.py
+++ b/Lib/test/test_ctypes/test_structunion.py
@@ -11,6 +11,8 @@ from ._support import (_CData, PyCStructType, UnionType,
Py_TPFLAGS_DISALLOW_INSTANTIATION,
Py_TPFLAGS_IMMUTABLETYPE)
from struct import calcsize
+import contextlib
+from test.support import MS_WINDOWS
class StructUnionTestBase:
@@ -335,6 +337,22 @@ class StructUnionTestBase:
self.assertIn("from_address", dir(type(self.cls)))
self.assertIn("in_dll", dir(type(self.cls)))
+ def test_pack_layout_switch(self):
+ # Setting _pack_ implicitly sets default layout to MSVC;
+ # this is deprecated on non-Windows platforms.
+ if MS_WINDOWS:
+ warn_context = contextlib.nullcontext()
+ else:
+ warn_context = self.assertWarns(DeprecationWarning)
+ with warn_context:
+ class X(self.cls):
+ _pack_ = 1
+ # _layout_ missing
+ _fields_ = [('a', c_int8, 1), ('b', c_int16, 2)]
+
+ # Check MSVC layout (bitfields of different types aren't combined)
+ self.check_sizeof(X, struct_size=3, union_size=2)
+
class StructureTestCase(unittest.TestCase, StructUnionTestBase):
cls = Structure
diff --git a/Lib/test/test_ctypes/test_structures.py b/Lib/test/test_ctypes/test_structures.py
index 221319642e8..92d4851d739 100644
--- a/Lib/test/test_ctypes/test_structures.py
+++ b/Lib/test/test_ctypes/test_structures.py
@@ -25,6 +25,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 1
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 9)
@@ -34,6 +35,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 2
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 10)
self.assertEqual(X.b.offset, 2)
@@ -45,6 +47,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 4
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(4, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(4, longlong_align))
@@ -53,27 +56,33 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 8
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(8, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(8, longlong_align))
-
- d = {"_fields_": [("a", "b"),
- ("b", "q")],
- "_pack_": -1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", "b"), ("b", "q")]
+ _pack_ = -1
+ _layout_ = "ms"
@support.cpython_only
def test_packed_c_limits(self):
# Issue 15989
import _testcapi
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.INT_MAX + 1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.UINT_MAX + 2}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.INT_MAX + 1
+ _layout_ = "ms"
+
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.UINT_MAX + 2
+ _layout_ = "ms"
def test_initializers(self):
class Person(Structure):
diff --git a/Lib/test/test_ctypes/test_unaligned_structures.py b/Lib/test/test_ctypes/test_unaligned_structures.py
index 58a00597ef5..b5fb4c0df77 100644
--- a/Lib/test/test_ctypes/test_unaligned_structures.py
+++ b/Lib/test/test_ctypes/test_unaligned_structures.py
@@ -19,10 +19,12 @@ for typ in [c_short, c_int, c_long, c_longlong,
c_ushort, c_uint, c_ulong, c_ulonglong]:
class X(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
class Y(SwappedStructure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
structures.append(X)
diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py
index 6fe0e7fd4b7..d5ca7f2ca1a 100644
--- a/Lib/test/test_curses.py
+++ b/Lib/test/test_curses.py
@@ -8,7 +8,8 @@ import unittest
from unittest.mock import MagicMock
from test.support import (requires, verbose, SaveSignals, cpython_only,
- check_disallow_instantiation, MISSING_C_DOCSTRINGS)
+ check_disallow_instantiation, MISSING_C_DOCSTRINGS,
+ gc_collect)
from test.support.import_helper import import_module
# Optionally test curses module. This currently requires that the
@@ -51,12 +52,6 @@ def requires_colors(test):
term = os.environ.get('TERM')
SHORT_MAX = 0x7fff
-DEFAULT_PAIR_CONTENTS = [
- (curses.COLOR_WHITE, curses.COLOR_BLACK),
- (0, 0),
- (-1, -1),
- (15, 0), # for xterm-256color (15 is for BRIGHT WHITE)
-]
# If newterm was supported we could use it instead of initscr and not exit
@unittest.skipIf(not term or term == 'unknown',
@@ -135,6 +130,9 @@ class TestCurses(unittest.TestCase):
curses.use_env(False)
curses.use_env(True)
+ def test_error(self):
+ self.assertIsSubclass(curses.error, Exception)
+
def test_create_windows(self):
win = curses.newwin(5, 10)
self.assertEqual(win.getbegyx(), (0, 0))
@@ -187,6 +185,14 @@ class TestCurses(unittest.TestCase):
self.assertEqual(win3.getparyx(), (2, 1))
self.assertEqual(win3.getmaxyx(), (6, 11))
+ def test_subwindows_references(self):
+ win = curses.newwin(5, 10)
+ win2 = win.subwin(3, 7)
+ del win
+ gc_collect()
+ del win2
+ gc_collect()
+
def test_move_cursor(self):
stdscr = self.stdscr
win = stdscr.subwin(10, 15, 2, 5)
@@ -948,8 +954,6 @@ class TestCurses(unittest.TestCase):
@requires_colors
def test_pair_content(self):
- if not hasattr(curses, 'use_default_colors'):
- self.assertIn(curses.pair_content(0), DEFAULT_PAIR_CONTENTS)
curses.pair_content(0)
maxpair = self.get_pair_limit() - 1
if maxpair > 0:
@@ -994,13 +998,27 @@ class TestCurses(unittest.TestCase):
@requires_curses_func('use_default_colors')
@requires_colors
def test_use_default_colors(self):
- old = curses.pair_content(0)
try:
curses.use_default_colors()
except curses.error:
self.skipTest('cannot change color (use_default_colors() failed)')
self.assertEqual(curses.pair_content(0), (-1, -1))
- self.assertIn(old, DEFAULT_PAIR_CONTENTS)
+
+ @requires_curses_func('assume_default_colors')
+ @requires_colors
+ def test_assume_default_colors(self):
+ try:
+ curses.assume_default_colors(-1, -1)
+ except curses.error:
+ self.skipTest('cannot change color (assume_default_colors() failed)')
+ self.assertEqual(curses.pair_content(0), (-1, -1))
+ curses.assume_default_colors(curses.COLOR_YELLOW, curses.COLOR_BLUE)
+ self.assertEqual(curses.pair_content(0), (curses.COLOR_YELLOW, curses.COLOR_BLUE))
+ curses.assume_default_colors(curses.COLOR_RED, -1)
+ self.assertEqual(curses.pair_content(0), (curses.COLOR_RED, -1))
+ curses.assume_default_colors(-1, curses.COLOR_GREEN)
+ self.assertEqual(curses.pair_content(0), (-1, curses.COLOR_GREEN))
+ curses.assume_default_colors(-1, -1)
def test_keyname(self):
# TODO: key_name()
@@ -1242,7 +1260,7 @@ class TestAscii(unittest.TestCase):
def test_controlnames(self):
for name in curses.ascii.controlnames:
- self.assertTrue(hasattr(curses.ascii, name), name)
+ self.assertHasAttr(curses.ascii, name)
def test_ctypes(self):
def check(func, expected):
diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py
index 99fefb57fd0..e98a8f284ce 100644
--- a/Lib/test/test_dataclasses/__init__.py
+++ b/Lib/test/test_dataclasses/__init__.py
@@ -5,6 +5,7 @@
from dataclasses import *
import abc
+import annotationlib
import io
import pickle
import inspect
@@ -12,6 +13,7 @@ import builtins
import types
import weakref
import traceback
+import sys
import textwrap
import unittest
from unittest.mock import Mock
@@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
from test import support
+from test.support import import_helper
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -117,7 +120,7 @@ class TestCase(unittest.TestCase):
for param in inspect.signature(dataclass).parameters:
if param == 'cls':
continue
- self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param)
+ self.assertHasAttr(Some.__dataclass_params__, param)
def test_named_init_params(self):
@dataclass
@@ -668,7 +671,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(the_fields[0].name, 'x')
self.assertEqual(the_fields[0].type, int)
- self.assertFalse(hasattr(C, 'x'))
+ self.assertNotHasAttr(C, 'x')
self.assertTrue (the_fields[0].init)
self.assertTrue (the_fields[0].repr)
self.assertEqual(the_fields[1].name, 'y')
@@ -678,7 +681,7 @@ class TestCase(unittest.TestCase):
self.assertTrue (the_fields[1].repr)
self.assertEqual(the_fields[2].name, 'z')
self.assertEqual(the_fields[2].type, str)
- self.assertFalse(hasattr(C, 'z'))
+ self.assertNotHasAttr(C, 'z')
self.assertTrue (the_fields[2].init)
self.assertFalse(the_fields[2].repr)
@@ -729,8 +732,8 @@ class TestCase(unittest.TestCase):
z: object = default
t: int = field(default=100)
- self.assertFalse(hasattr(C, 'x'))
- self.assertFalse(hasattr(C, 'y'))
+ self.assertNotHasAttr(C, 'x')
+ self.assertNotHasAttr(C, 'y')
self.assertIs (C.z, default)
self.assertEqual(C.t, 100)
@@ -2909,10 +2912,10 @@ class TestFrozen(unittest.TestCase):
pass
c = C()
- self.assertFalse(hasattr(c, 'i'))
+ self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
c.i = 5
- self.assertFalse(hasattr(c, 'i'))
+ self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
del c.i
@@ -3141,7 +3144,7 @@ class TestFrozen(unittest.TestCase):
del s.y
self.assertEqual(s.y, 10)
del s.cached
- self.assertFalse(hasattr(s, 'cached'))
+ self.assertNotHasAttr(s, 'cached')
with self.assertRaises(AttributeError) as cm:
del s.cached
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3155,12 +3158,12 @@ class TestFrozen(unittest.TestCase):
pass
s = S()
- self.assertFalse(hasattr(s, 'x'))
+ self.assertNotHasAttr(s, 'x')
s.x = 5
self.assertEqual(s.x, 5)
del s.x
- self.assertFalse(hasattr(s, 'x'))
+ self.assertNotHasAttr(s, 'x')
with self.assertRaises(AttributeError) as cm:
del s.x
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3390,8 +3393,8 @@ class TestSlots(unittest.TestCase):
B = dataclass(A, slots=True)
self.assertIsNot(A, B)
- self.assertFalse(hasattr(A, "__slots__"))
- self.assertTrue(hasattr(B, "__slots__"))
+ self.assertNotHasAttr(A, "__slots__")
+ self.assertHasAttr(B, "__slots__")
# Can't be local to test_frozen_pickle.
@dataclass(frozen=True, slots=True)
@@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase):
@support.cpython_only
def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935
- from test.support import import_helper
# Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi')
@@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
- 'y': 'typing.Any',
- 'z': 'typing.Any'})
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
+ 'y': typing.Any,
+ 'z': typing.Any})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': int,
- 'z': 'typing.Any'})
+ 'z': typing.Any})
+
+ def test_no_types_get_annotations(self):
+ C = make_dataclass('C', ['x', ('y', int), 'z'])
+
+ self.assertEqual(
+ annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.FORWARDREF),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.STRING),
+ {'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
+ )
+
+ def test_no_types_no_typing_import(self):
+ with import_helper.CleanImport('typing'):
+ self.assertNotIn('typing', sys.modules)
+ C = make_dataclass('C', ['x', ('y', int)])
+
+ self.assertNotIn('typing', sys.modules)
+ self.assertEqual(
+ C.__annotate__(annotationlib.Format.FORWARDREF),
+ {
+ 'x': annotationlib.ForwardRef('Any', module='typing'),
+ 'y': int,
+ },
+ )
+ self.assertNotIn('typing', sys.modules)
+
+ for field in fields(C):
+ if field.name == "x":
+ self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
+ else:
+ self.assertEqual(field.name, "y")
+ self.assertIs(field.type, int)
def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py
index 4be7c5649da..7e8d78b8940 100644
--- a/Lib/test/test_dbm.py
+++ b/Lib/test/test_dbm.py
@@ -66,7 +66,7 @@ class AnyDBMTestCase:
return keys
def test_error(self):
- self.assertTrue(issubclass(self.module.error, OSError))
+ self.assertIsSubclass(self.module.error, OSError)
def test_anydbm_not_existing(self):
self.assertRaises(dbm.error, dbm.open, _fname)
@@ -135,6 +135,67 @@ class AnyDBMTestCase:
assert(f[key] == b"Python:")
f.close()
+ def test_anydbm_readonly_reorganize(self):
+ self.init_db()
+ with dbm.open(_fname, 'r') as d:
+ # Early stopping.
+ if not hasattr(d, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ self.assertRaises(dbm.error, lambda: d.reorganize())
+
+ def test_anydbm_reorganize_not_changed_content(self):
+ self.init_db()
+ with dbm.open(_fname, 'c') as d:
+ # Early stopping.
+ if not hasattr(d, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ keys_before = sorted(d.keys())
+ values_before = [d[k] for k in keys_before]
+ d.reorganize()
+ keys_after = sorted(d.keys())
+ values_after = [d[k] for k in keys_before]
+ self.assertEqual(keys_before, keys_after)
+ self.assertEqual(values_before, values_after)
+
+ def test_anydbm_reorganize_decreased_size(self):
+
+ def _calculate_db_size(db_path):
+ if os.path.isfile(db_path):
+ return os.path.getsize(db_path)
+ total_size = 0
+ for root, _, filenames in os.walk(db_path):
+ for filename in filenames:
+ file_path = os.path.join(root, filename)
+ total_size += os.path.getsize(file_path)
+ return total_size
+
+ # This test requires relatively large databases to reliably show difference in size before and after reorganizing.
+ with dbm.open(_fname, 'n') as f:
+ # Early stopping.
+ if not hasattr(f, 'reorganize'):
+ self.skipTest("method reorganize not available this dbm submodule")
+
+ for k in self._dict:
+ f[k.encode('ascii')] = self._dict[k] * 100000
+ db_keys = list(f.keys())
+
+ # Make sure to calculate size of database only after file is closed to ensure file content are flushed to disk.
+ size_before = _calculate_db_size(os.path.dirname(_fname))
+
+ # Delete some elements from the start of the database.
+ keys_to_delete = db_keys[:len(db_keys) // 2]
+ with dbm.open(_fname, 'c') as f:
+ for k in keys_to_delete:
+ del f[k]
+ f.reorganize()
+
+ # Make sure to calculate size of database only after file is closed to ensure file content are flushed to disk.
+ size_after = _calculate_db_size(os.path.dirname(_fname))
+
+ self.assertLess(size_after, size_before)
+
def test_open_with_bytes(self):
dbm.open(os.fsencode(_fname), "c").close()
diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py
index 66268c42a30..e0b988b7b95 100644
--- a/Lib/test/test_dbm_gnu.py
+++ b/Lib/test/test_dbm_gnu.py
@@ -74,12 +74,12 @@ class TestGdbm(unittest.TestCase):
# Test the flag parameter open() by trying all supported flag modes.
all = set(gdbm.open_flags)
# Test standard flags (presumably "crwn").
- modes = all - set('fsu')
+ modes = all - set('fsum')
for mode in sorted(modes): # put "c" mode first
self.g = gdbm.open(filename, mode)
self.g.close()
- # Test additional flags (presumably "fsu").
+ # Test additional flags (presumably "fsum").
flags = all - set('crwn')
for mode in modes:
for flag in flags:
@@ -217,6 +217,29 @@ class TestGdbm(unittest.TestCase):
create_empty_file(os.path.join(d, 'test'))
self.assertRaises(gdbm.error, gdbm.open, filename, 'r')
+ @unittest.skipUnless('m' in gdbm.open_flags, "requires 'm' in open_flags")
+ def test_nommap_no_crash(self):
+ self.g = g = gdbm.open(filename, 'nm')
+ os.truncate(filename, 0)
+
+ g.get(b'a', b'c')
+ g.keys()
+ g.firstkey()
+ g.nextkey(b'a')
+ with self.assertRaises(KeyError):
+ g[b'a']
+ with self.assertRaises(gdbm.error):
+ len(g)
+
+ with self.assertRaises(gdbm.error):
+ g[b'a'] = b'c'
+ with self.assertRaises(gdbm.error):
+ del g[b'a']
+ with self.assertRaises(gdbm.error):
+ g.setdefault(b'a', b'c')
+ with self.assertRaises(gdbm.error):
+ g.reorganize()
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_dbm_sqlite3.py b/Lib/test/test_dbm_sqlite3.py
index 2e1f2d32924..9216da8a63f 100644
--- a/Lib/test/test_dbm_sqlite3.py
+++ b/Lib/test/test_dbm_sqlite3.py
@@ -36,7 +36,7 @@ class URI(unittest.TestCase):
)
for path, normalized in dataset:
with self.subTest(path=path, normalized=normalized):
- self.assertTrue(_normalize_uri(path).endswith(normalized))
+ self.assertEndsWith(_normalize_uri(path), normalized)
@unittest.skipUnless(sys.platform == "win32", "requires Windows")
def test_uri_windows(self):
@@ -55,7 +55,7 @@ class URI(unittest.TestCase):
with self.subTest(path=path, normalized=normalized):
if not Path(path).is_absolute():
self.skipTest(f"skipping relative path: {path!r}")
- self.assertTrue(_normalize_uri(path).endswith(normalized))
+ self.assertEndsWith(_normalize_uri(path), normalized)
class ReadOnly(_SQLiteDbmTests):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 9e298401dc3..ef64b878805 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -28,7 +28,6 @@ import logging
import math
import os, sys
import operator
-import warnings
import pickle, copy
import unittest
import numbers
@@ -982,6 +981,7 @@ class FormatTest:
('.0f', '0e-2', '0'),
('.0f', '3.14159265', '3'),
('.1f', '3.14159265', '3.1'),
+ ('.01f', '3.14159265', '3.1'), # leading zero in precision
('.4f', '3.14159265', '3.1416'),
('.6f', '3.14159265', '3.141593'),
('.7f', '3.14159265', '3.1415926'), # round-half-even!
@@ -1067,6 +1067,7 @@ class FormatTest:
('8,', '123456', ' 123,456'),
('08,', '123456', '0,123,456'), # special case: extra 0 needed
('+08,', '123456', '+123,456'), # but not if there's a sign
+ ('008,', '123456', '0,123,456'), # leading zero in width
(' 08,', '123456', ' 123,456'),
('08,', '-123456', '-123,456'),
('+09,', '123456', '+0,123,456'),
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
index 4679f297fd7..4e1a489205a 100644
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -838,7 +838,7 @@ class TestSubclass(unittest.TestCase):
self.assertEqual(list(d), list(e))
self.assertEqual(e.x, d.x)
self.assertEqual(e.z, d.z)
- self.assertFalse(hasattr(e, 'y'))
+ self.assertNotHasAttr(e, 'y')
def test_pickle_recursive(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index 76937432a43..f6ec2cf5ce8 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -409,7 +409,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
def test_python_dicts(self):
# Testing Python subclass of dict...
- self.assertTrue(issubclass(dict, dict))
+ self.assertIsSubclass(dict, dict)
self.assertIsInstance({}, dict)
d = dict()
self.assertEqual(d, {})
@@ -433,7 +433,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
self.state = state
def getstate(self):
return self.state
- self.assertTrue(issubclass(C, dict))
+ self.assertIsSubclass(C, dict)
a1 = C(12)
self.assertEqual(a1.state, 12)
a2 = C(foo=1, bar=2)
@@ -1048,15 +1048,15 @@ class ClassPropertiesAndMethods(unittest.TestCase):
m = types.ModuleType("m")
self.assertTrue(m.__class__ is types.ModuleType)
- self.assertFalse(hasattr(m, "a"))
+ self.assertNotHasAttr(m, "a")
m.__class__ = SubType
self.assertTrue(m.__class__ is SubType)
- self.assertTrue(hasattr(m, "a"))
+ self.assertHasAttr(m, "a")
m.__class__ = types.ModuleType
self.assertTrue(m.__class__ is types.ModuleType)
- self.assertFalse(hasattr(m, "a"))
+ self.assertNotHasAttr(m, "a")
# Make sure that builtin immutable objects don't support __class__
# assignment, because the object instances may be interned.
@@ -1780,7 +1780,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class E: # *not* subclassing from C
foo = C.foo
self.assertEqual(E().foo.__func__, C.foo) # i.e., unbound
- self.assertTrue(repr(C.foo.__get__(C())).startswith("<bound method "))
+ self.assertStartsWith(repr(C.foo.__get__(C())), "<bound method ")
def test_compattr(self):
# Testing computed attributes...
@@ -2058,7 +2058,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
class E(object):
foo = C.foo
self.assertEqual(E().foo.__func__, C.foo) # i.e., unbound
- self.assertTrue(repr(C.foo.__get__(C(1))).startswith("<bound method "))
+ self.assertStartsWith(repr(C.foo.__get__(C(1))), "<bound method ")
@support.impl_detail("testing error message from implementation")
def test_methods_in_c(self):
@@ -3943,6 +3943,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
del C.__del__
@unittest.skipIf(support.is_emscripten, "Seems to works in Pyodide?")
+ @support.skip_wasi_stack_overflow()
def test_slots_trash(self):
# Testing slot trash...
# Deallocating deeply nested slotted trash caused stack overflows
@@ -4113,6 +4114,34 @@ class ClassPropertiesAndMethods(unittest.TestCase):
else:
self.fail("shouldn't be able to create inheritance cycles")
+ def test_assign_bases_many_subclasses(self):
+ # This is intended to check that typeobject.c:queue_slot_update() can
+ # handle updating many subclasses when a slot method is re-assigned.
+ class A:
+ x = 'hello'
+ def __call__(self):
+ return 123
+ def __getitem__(self, index):
+ return None
+
+ class X:
+ x = 'bye'
+
+ class B(A):
+ pass
+
+ subclasses = []
+ for i in range(1000):
+ sc = type(f'Sub{i}', (B,), {})
+ subclasses.append(sc)
+
+ self.assertEqual(subclasses[0]()(), 123)
+ self.assertEqual(subclasses[0]().x, 'hello')
+ B.__bases__ = (X,)
+ with self.assertRaises(TypeError):
+ subclasses[0]()()
+ self.assertEqual(subclasses[0]().x, 'bye')
+
def test_builtin_bases(self):
# Make sure all the builtin types can have their base queried without
# segfaulting. See issue #5787.
@@ -4523,6 +4552,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
del o
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@support.requires_resource('cpu')
def test_wrapper_segfault(self):
# SF 927248: deeply nested wrappers could cause stack overflow
@@ -4867,6 +4897,7 @@ class ClassPropertiesAndMethods(unittest.TestCase):
deque.append(thing, thing)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_repr_as_str(self):
# Issue #11603: crash or infinite loop when rebinding __str__ as
# __repr__.
@@ -5194,8 +5225,8 @@ class DictProxyTests(unittest.TestCase):
# We can't blindly compare with the repr of another dict as ordering
# of keys and values is arbitrary and may differ.
r = repr(self.C.__dict__)
- self.assertTrue(r.startswith('mappingproxy('), r)
- self.assertTrue(r.endswith(')'), r)
+ self.assertStartsWith(r, 'mappingproxy(')
+ self.assertEndsWith(r, ')')
for k, v in self.C.__dict__.items():
self.assertIn('{!r}: {!r}'.format(k, v), r)
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index 3104cbc66cb..60c62430370 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -290,6 +290,38 @@ class DictTest(unittest.TestCase):
['Cannot convert dictionary update sequence element #0 to a sequence'],
)
+ def test_update_shared_keys(self):
+ class MyClass: pass
+
+ # Subclass str to enable us to create an object during the
+ # dict.update() call.
+ class MyStr(str):
+ def __hash__(self):
+ return super().__hash__()
+
+ def __eq__(self, other):
+ # Create an object that shares the same PyDictKeysObject as
+ # obj.__dict__.
+ obj2 = MyClass()
+ obj2.a = "a"
+ obj2.b = "b"
+ obj2.c = "c"
+ return super().__eq__(other)
+
+ obj = MyClass()
+ obj.a = "a"
+ obj.b = "b"
+
+ x = {}
+ x[MyStr("a")] = MyStr("a")
+
+ # gh-132617: this previously raised "dict mutated during update" error
+ x.update(obj.__dict__)
+
+ self.assertEqual(x, {
+ MyStr("a"): "a",
+ "b": "b",
+ })
def test_fromkeys(self):
self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
@@ -338,17 +370,34 @@ class DictTest(unittest.TestCase):
self.assertRaises(Exc, baddict2.fromkeys, [1])
# test fast path for dictionary inputs
+ res = dict(zip(range(6), [0]*6))
d = dict(zip(range(6), range(6)))
- self.assertEqual(dict.fromkeys(d, 0), dict(zip(range(6), [0]*6)))
-
+ self.assertEqual(dict.fromkeys(d, 0), res)
+ # test fast path for set inputs
+ d = set(range(6))
+ self.assertEqual(dict.fromkeys(d, 0), res)
+ # test slow path for other iterable inputs
+ d = list(range(6))
+ self.assertEqual(dict.fromkeys(d, 0), res)
+
+ # test fast path when object's constructor returns large non-empty dict
class baddict3(dict):
def __new__(cls):
return d
- d = {i : i for i in range(10)}
+ d = {i : i for i in range(1000)}
res = d.copy()
res.update(a=None, b=None, c=None)
self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
+ # test slow path when object is a proper subclass of dict
+ class baddict4(dict):
+ def __init__(self):
+ dict.__init__(self, d)
+ d = {i : i for i in range(1000)}
+ res = d.copy()
+ res.update(a=None, b=None, c=None)
+ self.assertEqual(baddict4.fromkeys({"a", "b", "c"}), res)
+
def test_copy(self):
d = {1: 1, 2: 2, 3: 3}
self.assertIsNot(d.copy(), d)
@@ -770,8 +819,8 @@ class DictTest(unittest.TestCase):
def test_missing(self):
# Make sure dict doesn't have a __missing__ method
- self.assertFalse(hasattr(dict, "__missing__"))
- self.assertFalse(hasattr({}, "__missing__"))
+ self.assertNotHasAttr(dict, "__missing__")
+ self.assertNotHasAttr({}, "__missing__")
# Test several cases:
# (D) subclass defines __missing__ method returning a value
# (E) subclass defines __missing__ method raising RuntimeError
@@ -1022,10 +1071,8 @@ class DictTest(unittest.TestCase):
a = C()
a.x = 1
d = a.__dict__
- before_resize = sys.getsizeof(d)
d[2] = 2 # split table is resized to a generic combined table
- self.assertGreater(sys.getsizeof(d), before_resize)
self.assertEqual(list(d), ['x', 2])
def test_iterator_pickling(self):
diff --git a/Lib/test/test_difflib.py b/Lib/test/test_difflib.py
index 9e217249be7..6ac584a08d1 100644
--- a/Lib/test/test_difflib.py
+++ b/Lib/test/test_difflib.py
@@ -255,21 +255,21 @@ class TestSFpatches(unittest.TestCase):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_from1.splitlines(),
patch914575_to1.splitlines())
- self.assertIn('content="text/html; charset=utf-8"', output)
+ self.assertIn('charset="utf-8"', output)
def test_make_file_iso88591_charset(self):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_from1.splitlines(),
patch914575_to1.splitlines(),
charset='iso-8859-1')
- self.assertIn('content="text/html; charset=iso-8859-1"', output)
+ self.assertIn('charset="iso-8859-1"', output)
def test_make_file_usascii_charset_with_nonascii_input(self):
html_diff = difflib.HtmlDiff()
output = html_diff.make_file(patch914575_nonascii_from1.splitlines(),
patch914575_nonascii_to1.splitlines(),
charset='us-ascii')
- self.assertIn('content="text/html; charset=us-ascii"', output)
+ self.assertIn('charset="us-ascii"', output)
self.assertIn('&#305;mpl&#305;c&#305;t', output)
class TestDiffer(unittest.TestCase):
diff --git a/Lib/test/test_difflib_expect.html b/Lib/test/test_difflib_expect.html
index 9f33a9e9c9c..2346a6f9f8d 100644
--- a/Lib/test/test_difflib_expect.html
+++ b/Lib/test/test_difflib_expect.html
@@ -1,22 +1,42 @@
-<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
- "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
-
-<html>
-
+<!DOCTYPE html>
+<html lang="en">
<head>
- <meta http-equiv="Content-Type"
- content="text/html; charset=utf-8" />
- <title></title>
- <style type="text/css">
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ <title>Diff comparison</title>
+ <style>
:root {color-scheme: light dark}
- table.diff {font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace; border:medium}
- .diff_header {background-color:#e0e0e0}
- td.diff_header {text-align:right}
- .diff_next {background-color:#c0c0c0}
+ table.diff {
+ font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace;
+ border: medium;
+ }
+ .diff_header {
+ background-color: #e0e0e0;
+ font-weight: bold;
+ }
+ td.diff_header {
+ text-align: right;
+ padding: 0 8px;
+ }
+ .diff_next {
+ background-color: #c0c0c0;
+ padding: 4px 0;
+ }
.diff_add {background-color:palegreen}
.diff_chg {background-color:#ffff77}
.diff_sub {background-color:#ffaaaa}
+ table.diff[summary="Legends"] {
+ margin-top: 20px;
+ border: 1px solid #ccc;
+ }
+ table.diff[summary="Legends"] th {
+ background-color: #e0e0e0;
+ padding: 4px 8px;
+ }
+ table.diff[summary="Legends"] td {
+ padding: 4px 8px;
+ }
@media (prefers-color-scheme: dark) {
.diff_header {background-color:#666}
@@ -24,6 +44,8 @@
.diff_add {background-color:darkgreen}
.diff_chg {background-color:#847415}
.diff_sub {background-color:darkred}
+ table.diff[summary="Legends"] {border-color:#555}
+ table.diff[summary="Legends"] th{background-color:#666}
}
</style>
</head>
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index f2586fcee57..355990ed58e 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -606,7 +606,7 @@ dis_asyncwith = """\
POP_TOP
L1: RESUME 0
-%4d LOAD_FAST_BORROW 0 (c)
+%4d LOAD_FAST 0 (c)
COPY 1
LOAD_SPECIAL 3 (__aexit__)
SWAP 2
@@ -851,7 +851,7 @@ Disassembly of <code object <genexpr> at 0x..., file "%s", line %d>:
%4d RETURN_GENERATOR
POP_TOP
L1: RESUME 0
- LOAD_FAST_BORROW 0 (.0)
+ LOAD_FAST 0 (.0)
GET_ITER
L2: FOR_ITER 14 (to L3)
STORE_FAST 1 (z)
@@ -902,7 +902,7 @@ dis_loop_test_quickened_code = """\
%3d RESUME_CHECK 0
%3d BUILD_LIST 0
- LOAD_CONST_MORTAL 2 ((1, 2, 3))
+ LOAD_CONST 2 ((1, 2, 3))
LIST_EXTEND 1
LOAD_SMALL_INT 3
BINARY_OP 5 (*)
@@ -918,7 +918,7 @@ dis_loop_test_quickened_code = """\
%3d L2: END_FOR
POP_ITER
- LOAD_CONST_IMMORTAL 1 (None)
+ LOAD_CONST 1 (None)
RETURN_VALUE
""" % (loop_test.__code__.co_firstlineno,
loop_test.__code__.co_firstlineno + 1,
@@ -1304,7 +1304,7 @@ class DisTests(DisTestBase):
load_attr_quicken = """\
0 RESUME_CHECK 0
- 1 LOAD_CONST_IMMORTAL 0 ('a')
+ 1 LOAD_CONST 0 ('a')
LOAD_ATTR_SLOT 0 (__class__)
RETURN_VALUE
"""
@@ -1336,7 +1336,7 @@ class DisTests(DisTestBase):
# Loop can trigger a quicken where the loop is located
self.code_quicken(loop_test)
got = self.get_disassembly(loop_test, adaptive=True)
- jit = import_helper.import_module("_testinternalcapi").jit_enabled()
+ jit = sys._jit.is_enabled()
expected = dis_loop_test_quickened_code.format("JIT" if jit else "NO_JIT")
self.do_disassembly_compare(got, expected)
@@ -1821,7 +1821,7 @@ expected_opinfo_jumpy = [
make_inst(opname='LOAD_SMALL_INT', arg=10, argval=10, argrepr='', offset=12, start_offset=12, starts_line=False, line_number=3),
make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=14, start_offset=14, starts_line=False, line_number=3, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
make_inst(opname='GET_ITER', arg=None, argval=None, argrepr='', offset=22, start_offset=22, starts_line=False, line_number=3),
- make_inst(opname='FOR_ITER', arg=32, argval=92, argrepr='to L4', offset=24, start_offset=24, starts_line=False, line_number=3, label=1, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='FOR_ITER', arg=33, argval=94, argrepr='to L4', offset=24, start_offset=24, starts_line=False, line_number=3, label=1, cache_info=[('counter', 1, b'\x00\x00')]),
make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=28, start_offset=28, starts_line=False, line_number=3),
make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=30, start_offset=30, starts_line=True, line_number=4, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=40, start_offset=40, starts_line=False, line_number=4),
@@ -1840,110 +1840,111 @@ expected_opinfo_jumpy = [
make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=82, start_offset=82, starts_line=False, line_number=7),
make_inst(opname='JUMP_BACKWARD', arg=32, argval=24, argrepr='to L1', offset=84, start_offset=84, starts_line=False, line_number=7, cache_info=[('counter', 1, b'\x00\x00')]),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=88, start_offset=88, starts_line=True, line_number=8, label=3),
- make_inst(opname='JUMP_FORWARD', arg=13, argval=118, argrepr='to L5', offset=90, start_offset=90, starts_line=False, line_number=8),
- make_inst(opname='END_FOR', arg=None, argval=None, argrepr='', offset=92, start_offset=92, starts_line=True, line_number=3, label=4),
- make_inst(opname='POP_ITER', arg=None, argval=None, argrepr='', offset=94, start_offset=94, starts_line=False, line_number=3),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=96, start_offset=96, starts_line=True, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=1, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=106, start_offset=106, starts_line=False, line_number=10),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=108, start_offset=108, starts_line=False, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=116, start_offset=116, starts_line=False, line_number=10),
- make_inst(opname='LOAD_FAST_CHECK', arg=0, argval='i', argrepr='i', offset=118, start_offset=118, starts_line=True, line_number=11, label=5),
- make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=120, start_offset=120, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=40, argval=212, argrepr='to L8', offset=128, start_offset=128, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=132, start_offset=132, starts_line=False, line_number=11),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=134, start_offset=134, starts_line=True, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=144, start_offset=144, starts_line=False, line_number=12),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=146, start_offset=146, starts_line=False, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=154, start_offset=154, starts_line=False, line_number=12),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=156, start_offset=156, starts_line=True, line_number=13),
- make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=158, start_offset=158, starts_line=False, line_number=13),
- make_inst(opname='BINARY_OP', arg=23, argval=23, argrepr='-=', offset=160, start_offset=160, starts_line=False, line_number=13, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
- make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=172, start_offset=172, starts_line=False, line_number=13),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=174, start_offset=174, starts_line=True, line_number=14),
- make_inst(opname='LOAD_SMALL_INT', arg=6, argval=6, argrepr='', offset=176, start_offset=176, starts_line=False, line_number=14),
- make_inst(opname='COMPARE_OP', arg=148, argval='>', argrepr='bool(>)', offset=178, start_offset=178, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=3, argval=192, argrepr='to L6', offset=182, start_offset=182, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=186, start_offset=186, starts_line=False, line_number=14),
- make_inst(opname='JUMP_BACKWARD', arg=37, argval=118, argrepr='to L5', offset=188, start_offset=188, starts_line=True, line_number=15, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=192, start_offset=192, starts_line=True, line_number=16, label=6),
- make_inst(opname='LOAD_SMALL_INT', arg=4, argval=4, argrepr='', offset=194, start_offset=194, starts_line=False, line_number=16),
- make_inst(opname='COMPARE_OP', arg=18, argval='<', argrepr='bool(<)', offset=196, start_offset=196, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_TRUE', arg=3, argval=210, argrepr='to L7', offset=200, start_offset=200, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=204, start_offset=204, starts_line=False, line_number=16),
- make_inst(opname='JUMP_BACKWARD', arg=46, argval=118, argrepr='to L5', offset=206, start_offset=206, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='JUMP_FORWARD', arg=11, argval=234, argrepr='to L9', offset=210, start_offset=210, starts_line=True, line_number=17, label=7),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=212, start_offset=212, starts_line=True, line_number=19, label=8, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=2, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=222, start_offset=222, starts_line=False, line_number=19),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=224, start_offset=224, starts_line=False, line_number=19, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=232, start_offset=232, starts_line=False, line_number=19),
- make_inst(opname='NOP', arg=None, argval=None, argrepr='', offset=234, start_offset=234, starts_line=True, line_number=20, label=9),
- make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=236, start_offset=236, starts_line=True, line_number=21),
- make_inst(opname='LOAD_SMALL_INT', arg=0, argval=0, argrepr='', offset=238, start_offset=238, starts_line=False, line_number=21),
- make_inst(opname='BINARY_OP', arg=11, argval=11, argrepr='/', offset=240, start_offset=240, starts_line=False, line_number=21, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=252, start_offset=252, starts_line=False, line_number=21),
- make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=254, start_offset=254, starts_line=True, line_number=25),
- make_inst(opname='COPY', arg=1, argval=1, argrepr='', offset=256, start_offset=256, starts_line=False, line_number=25),
- make_inst(opname='LOAD_SPECIAL', arg=1, argval=1, argrepr='__exit__', offset=258, start_offset=258, starts_line=False, line_number=25),
- make_inst(opname='SWAP', arg=2, argval=2, argrepr='', offset=260, start_offset=260, starts_line=False, line_number=25),
- make_inst(opname='SWAP', arg=3, argval=3, argrepr='', offset=262, start_offset=262, starts_line=False, line_number=25),
- make_inst(opname='LOAD_SPECIAL', arg=0, argval=0, argrepr='__enter__', offset=264, start_offset=264, starts_line=False, line_number=25),
- make_inst(opname='CALL', arg=0, argval=0, argrepr='', offset=266, start_offset=266, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='STORE_FAST', arg=1, argval='dodgy', argrepr='dodgy', offset=274, start_offset=274, starts_line=False, line_number=25),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=276, start_offset=276, starts_line=True, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=3, argval='Never reach this', argrepr="'Never reach this'", offset=286, start_offset=286, starts_line=False, line_number=26),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=288, start_offset=288, starts_line=False, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=296, start_offset=296, starts_line=False, line_number=26),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=298, start_offset=298, starts_line=True, line_number=25),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=300, start_offset=300, starts_line=False, line_number=25),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=90, start_offset=90, starts_line=False, line_number=8),
+ make_inst(opname='JUMP_FORWARD', arg=13, argval=120, argrepr='to L5', offset=92, start_offset=92, starts_line=False, line_number=8),
+ make_inst(opname='END_FOR', arg=None, argval=None, argrepr='', offset=94, start_offset=94, starts_line=True, line_number=3, label=4),
+ make_inst(opname='POP_ITER', arg=None, argval=None, argrepr='', offset=96, start_offset=96, starts_line=False, line_number=3),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=98, start_offset=98, starts_line=True, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=1, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=108, start_offset=108, starts_line=False, line_number=10),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=110, start_offset=110, starts_line=False, line_number=10, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=118, start_offset=118, starts_line=False, line_number=10),
+ make_inst(opname='LOAD_FAST_CHECK', arg=0, argval='i', argrepr='i', offset=120, start_offset=120, starts_line=True, line_number=11, label=5),
+ make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=122, start_offset=122, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=40, argval=214, argrepr='to L8', offset=130, start_offset=130, starts_line=False, line_number=11, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=134, start_offset=134, starts_line=False, line_number=11),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=136, start_offset=136, starts_line=True, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=146, start_offset=146, starts_line=False, line_number=12),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=148, start_offset=148, starts_line=False, line_number=12, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=156, start_offset=156, starts_line=False, line_number=12),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=158, start_offset=158, starts_line=True, line_number=13),
+ make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=160, start_offset=160, starts_line=False, line_number=13),
+ make_inst(opname='BINARY_OP', arg=23, argval=23, argrepr='-=', offset=162, start_offset=162, starts_line=False, line_number=13, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
+ make_inst(opname='STORE_FAST', arg=0, argval='i', argrepr='i', offset=174, start_offset=174, starts_line=False, line_number=13),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=176, start_offset=176, starts_line=True, line_number=14),
+ make_inst(opname='LOAD_SMALL_INT', arg=6, argval=6, argrepr='', offset=178, start_offset=178, starts_line=False, line_number=14),
+ make_inst(opname='COMPARE_OP', arg=148, argval='>', argrepr='bool(>)', offset=180, start_offset=180, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=3, argval=194, argrepr='to L6', offset=184, start_offset=184, starts_line=False, line_number=14, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=188, start_offset=188, starts_line=False, line_number=14),
+ make_inst(opname='JUMP_BACKWARD', arg=37, argval=120, argrepr='to L5', offset=190, start_offset=190, starts_line=True, line_number=15, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=194, start_offset=194, starts_line=True, line_number=16, label=6),
+ make_inst(opname='LOAD_SMALL_INT', arg=4, argval=4, argrepr='', offset=196, start_offset=196, starts_line=False, line_number=16),
+ make_inst(opname='COMPARE_OP', arg=18, argval='<', argrepr='bool(<)', offset=198, start_offset=198, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_TRUE', arg=3, argval=212, argrepr='to L7', offset=202, start_offset=202, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=206, start_offset=206, starts_line=False, line_number=16),
+ make_inst(opname='JUMP_BACKWARD', arg=46, argval=120, argrepr='to L5', offset=208, start_offset=208, starts_line=False, line_number=16, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='JUMP_FORWARD', arg=11, argval=236, argrepr='to L9', offset=212, start_offset=212, starts_line=True, line_number=17, label=7),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=214, start_offset=214, starts_line=True, line_number=19, label=8, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=2, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=224, start_offset=224, starts_line=False, line_number=19),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=226, start_offset=226, starts_line=False, line_number=19, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=234, start_offset=234, starts_line=False, line_number=19),
+ make_inst(opname='NOP', arg=None, argval=None, argrepr='', offset=236, start_offset=236, starts_line=True, line_number=20, label=9),
+ make_inst(opname='LOAD_SMALL_INT', arg=1, argval=1, argrepr='', offset=238, start_offset=238, starts_line=True, line_number=21),
+ make_inst(opname='LOAD_SMALL_INT', arg=0, argval=0, argrepr='', offset=240, start_offset=240, starts_line=False, line_number=21),
+ make_inst(opname='BINARY_OP', arg=11, argval=11, argrepr='/', offset=242, start_offset=242, starts_line=False, line_number=21, cache_info=[('counter', 1, b'\x00\x00'), ('descr', 4, b'\x00\x00\x00\x00\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=254, start_offset=254, starts_line=False, line_number=21),
+ make_inst(opname='LOAD_FAST_BORROW', arg=0, argval='i', argrepr='i', offset=256, start_offset=256, starts_line=True, line_number=25),
+ make_inst(opname='COPY', arg=1, argval=1, argrepr='', offset=258, start_offset=258, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_SPECIAL', arg=1, argval=1, argrepr='__exit__', offset=260, start_offset=260, starts_line=False, line_number=25),
+ make_inst(opname='SWAP', arg=2, argval=2, argrepr='', offset=262, start_offset=262, starts_line=False, line_number=25),
+ make_inst(opname='SWAP', arg=3, argval=3, argrepr='', offset=264, start_offset=264, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_SPECIAL', arg=0, argval=0, argrepr='__enter__', offset=266, start_offset=266, starts_line=False, line_number=25),
+ make_inst(opname='CALL', arg=0, argval=0, argrepr='', offset=268, start_offset=268, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='STORE_FAST', arg=1, argval='dodgy', argrepr='dodgy', offset=276, start_offset=276, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=278, start_offset=278, starts_line=True, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=3, argval='Never reach this', argrepr="'Never reach this'", offset=288, start_offset=288, starts_line=False, line_number=26),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=290, start_offset=290, starts_line=False, line_number=26, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=298, start_offset=298, starts_line=False, line_number=26),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=300, start_offset=300, starts_line=True, line_number=25),
make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=302, start_offset=302, starts_line=False, line_number=25),
- make_inst(opname='CALL', arg=3, argval=3, argrepr='', offset=304, start_offset=304, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=312, start_offset=312, starts_line=False, line_number=25),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=314, start_offset=314, starts_line=True, line_number=28, label=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=324, start_offset=324, starts_line=False, line_number=28),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=326, start_offset=326, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=334, start_offset=334, starts_line=False, line_number=28),
- make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=336, start_offset=336, starts_line=False, line_number=28),
- make_inst(opname='RETURN_VALUE', arg=None, argval=None, argrepr='', offset=338, start_offset=338, starts_line=False, line_number=28),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=340, start_offset=340, starts_line=True, line_number=25),
- make_inst(opname='WITH_EXCEPT_START', arg=None, argval=None, argrepr='', offset=342, start_offset=342, starts_line=False, line_number=25),
- make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=344, start_offset=344, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_JUMP_IF_TRUE', arg=2, argval=360, argrepr='to L11', offset=352, start_offset=352, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=356, start_offset=356, starts_line=False, line_number=25),
- make_inst(opname='RERAISE', arg=2, argval=2, argrepr='', offset=358, start_offset=358, starts_line=False, line_number=25),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=360, start_offset=360, starts_line=False, line_number=25, label=11),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=362, start_offset=362, starts_line=False, line_number=25),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=364, start_offset=364, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=304, start_offset=304, starts_line=False, line_number=25),
+ make_inst(opname='CALL', arg=3, argval=3, argrepr='', offset=306, start_offset=306, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=314, start_offset=314, starts_line=False, line_number=25),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=316, start_offset=316, starts_line=True, line_number=28, label=10, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=326, start_offset=326, starts_line=False, line_number=28),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=328, start_offset=328, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=336, start_offset=336, starts_line=False, line_number=28),
+ make_inst(opname='LOAD_CONST', arg=4, argval=None, argrepr='None', offset=338, start_offset=338, starts_line=False, line_number=28),
+ make_inst(opname='RETURN_VALUE', arg=None, argval=None, argrepr='', offset=340, start_offset=340, starts_line=False, line_number=28),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=342, start_offset=342, starts_line=True, line_number=25),
+ make_inst(opname='WITH_EXCEPT_START', arg=None, argval=None, argrepr='', offset=344, start_offset=344, starts_line=False, line_number=25),
+ make_inst(opname='TO_BOOL', arg=None, argval=None, argrepr='', offset=346, start_offset=346, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00'), ('version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_JUMP_IF_TRUE', arg=2, argval=362, argrepr='to L11', offset=354, start_offset=354, starts_line=False, line_number=25, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=358, start_offset=358, starts_line=False, line_number=25),
+ make_inst(opname='RERAISE', arg=2, argval=2, argrepr='', offset=360, start_offset=360, starts_line=False, line_number=25),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=362, start_offset=362, starts_line=False, line_number=25, label=11),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=364, start_offset=364, starts_line=False, line_number=25),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=366, start_offset=366, starts_line=False, line_number=25),
make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=368, start_offset=368, starts_line=False, line_number=25),
- make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=29, argval=314, argrepr='to L10', offset=370, start_offset=370, starts_line=False, line_number=25),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=372, start_offset=372, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=374, start_offset=374, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=376, start_offset=376, starts_line=False, line_number=None),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=378, start_offset=378, starts_line=False, line_number=None),
- make_inst(opname='LOAD_GLOBAL', arg=4, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=380, start_offset=380, starts_line=True, line_number=22, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='CHECK_EXC_MATCH', arg=None, argval=None, argrepr='', offset=390, start_offset=390, starts_line=False, line_number=22),
- make_inst(opname='POP_JUMP_IF_FALSE', arg=15, argval=426, argrepr='to L12', offset=392, start_offset=392, starts_line=False, line_number=22, cache_info=[('counter', 1, b'\x00\x00')]),
- make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=396, start_offset=396, starts_line=False, line_number=22),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=398, start_offset=398, starts_line=False, line_number=22),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=400, start_offset=400, starts_line=True, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=5, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=410, start_offset=410, starts_line=False, line_number=23),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=412, start_offset=412, starts_line=False, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=420, start_offset=420, starts_line=False, line_number=23),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=422, start_offset=422, starts_line=False, line_number=23),
- make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=56, argval=314, argrepr='to L10', offset=424, start_offset=424, starts_line=False, line_number=23),
- make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=426, start_offset=426, starts_line=True, line_number=22, label=12),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=428, start_offset=428, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=430, start_offset=430, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=432, start_offset=432, starts_line=False, line_number=None),
- make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=434, start_offset=434, starts_line=False, line_number=None),
- make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=436, start_offset=436, starts_line=True, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
- make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=446, start_offset=446, starts_line=False, line_number=28),
- make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=448, start_offset=448, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
- make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=456, start_offset=456, starts_line=False, line_number=28),
- make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=458, start_offset=458, starts_line=False, line_number=28),
- make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=460, start_offset=460, starts_line=True, line_number=None),
- make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=462, start_offset=462, starts_line=False, line_number=None),
- make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=464, start_offset=464, starts_line=False, line_number=None),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=370, start_offset=370, starts_line=False, line_number=25),
+ make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=29, argval=316, argrepr='to L10', offset=372, start_offset=372, starts_line=False, line_number=25),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=374, start_offset=374, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=376, start_offset=376, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=378, start_offset=378, starts_line=False, line_number=None),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=380, start_offset=380, starts_line=False, line_number=None),
+ make_inst(opname='LOAD_GLOBAL', arg=4, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=382, start_offset=382, starts_line=True, line_number=22, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='CHECK_EXC_MATCH', arg=None, argval=None, argrepr='', offset=392, start_offset=392, starts_line=False, line_number=22),
+ make_inst(opname='POP_JUMP_IF_FALSE', arg=15, argval=428, argrepr='to L12', offset=394, start_offset=394, starts_line=False, line_number=22, cache_info=[('counter', 1, b'\x00\x00')]),
+ make_inst(opname='NOT_TAKEN', arg=None, argval=None, argrepr='', offset=398, start_offset=398, starts_line=False, line_number=22),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=400, start_offset=400, starts_line=False, line_number=22),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=402, start_offset=402, starts_line=True, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=5, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=412, start_offset=412, starts_line=False, line_number=23),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=414, start_offset=414, starts_line=False, line_number=23, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=422, start_offset=422, starts_line=False, line_number=23),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=424, start_offset=424, starts_line=False, line_number=23),
+ make_inst(opname='JUMP_BACKWARD_NO_INTERRUPT', arg=56, argval=316, argrepr='to L10', offset=426, start_offset=426, starts_line=False, line_number=23),
+ make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=428, start_offset=428, starts_line=True, line_number=22, label=12),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=430, start_offset=430, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=432, start_offset=432, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=434, start_offset=434, starts_line=False, line_number=None),
+ make_inst(opname='PUSH_EXC_INFO', arg=None, argval=None, argrepr='', offset=436, start_offset=436, starts_line=False, line_number=None),
+ make_inst(opname='LOAD_GLOBAL', arg=3, argval='print', argrepr='print + NULL', offset=438, start_offset=438, starts_line=True, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('index', 1, b'\x00\x00'), ('module_keys_version', 1, b'\x00\x00'), ('builtin_keys_version', 1, b'\x00\x00')]),
+ make_inst(opname='LOAD_CONST', arg=6, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=448, start_offset=448, starts_line=False, line_number=28),
+ make_inst(opname='CALL', arg=1, argval=1, argrepr='', offset=450, start_offset=450, starts_line=False, line_number=28, cache_info=[('counter', 1, b'\x00\x00'), ('func_version', 2, b'\x00\x00\x00\x00')]),
+ make_inst(opname='POP_TOP', arg=None, argval=None, argrepr='', offset=458, start_offset=458, starts_line=False, line_number=28),
+ make_inst(opname='RERAISE', arg=0, argval=0, argrepr='', offset=460, start_offset=460, starts_line=False, line_number=28),
+ make_inst(opname='COPY', arg=3, argval=3, argrepr='', offset=462, start_offset=462, starts_line=True, line_number=None),
+ make_inst(opname='POP_EXCEPT', arg=None, argval=None, argrepr='', offset=464, start_offset=464, starts_line=False, line_number=None),
+ make_inst(opname='RERAISE', arg=1, argval=1, argrepr='', offset=466, start_offset=466, starts_line=False, line_number=None),
]
# One last piece of inspect fodder to check the default line number handling
diff --git a/Lib/test/test_doctest/sample_doctest_errors.py b/Lib/test/test_doctest/sample_doctest_errors.py
new file mode 100644
index 00000000000..4a6f07af2d4
--- /dev/null
+++ b/Lib/test/test_doctest/sample_doctest_errors.py
@@ -0,0 +1,46 @@
+"""This is a sample module used for testing doctest.
+
+This module includes various scenarios involving errors.
+
+>>> 2 + 2
+5
+>>> 1/0
+1
+"""
+
+def g():
+ [][0] # line 12
+
+def errors():
+ """
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ >>> def f():
+ ... 2 + '2'
+ ...
+ >>> f()
+ 1
+ >>> g()
+ 1
+ """
+
+def syntax_error():
+ """
+ >>> 2+*3
+ 5
+ """
+
+__test__ = {
+ 'bad': """
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ """,
+}
+
+def test_suite():
+ import doctest
+ return doctest.DocTestSuite()
diff --git a/Lib/test/test_doctest/test_doctest.py b/Lib/test/test_doctest/test_doctest.py
index a4a49298bab..72763d4a013 100644
--- a/Lib/test/test_doctest/test_doctest.py
+++ b/Lib/test/test_doctest/test_doctest.py
@@ -2267,14 +2267,24 @@ def test_DocTestSuite():
>>> import unittest
>>> import test.test_doctest.sample_doctest
>>> suite = doctest.DocTestSuite(test.test_doctest.sample_doctest)
- >>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=9 errors=2 failures=2>
+ >>> for tst, _ in result.failures:
+ ... print(tst)
+ bad (test.test_doctest.sample_doctest.__test__) [0]
+ foo (test.test_doctest.sample_doctest) [0]
+ >>> for tst, _ in result.errors:
+ ... print(tst)
+ test_silly_setup (test.test_doctest.sample_doctest) [1]
+ y_is_one (test.test_doctest.sample_doctest) [0]
We can also supply the module by name:
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest')
- >>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=9 errors=2 failures=2>
The module need not contain any doctest examples:
@@ -2296,13 +2306,26 @@ def test_DocTestSuite():
>>> result
<unittest.result.TestResult run=6 errors=0 failures=2>
>>> len(result.skipped)
- 2
+ 7
+ >>> for tst, _ in result.skipped:
+ ... print(tst)
+ double_skip (test.test_doctest.sample_doctest_skip) [0]
+ double_skip (test.test_doctest.sample_doctest_skip) [1]
+ double_skip (test.test_doctest.sample_doctest_skip)
+ partial_skip_fail (test.test_doctest.sample_doctest_skip) [0]
+ partial_skip_pass (test.test_doctest.sample_doctest_skip) [0]
+ single_skip (test.test_doctest.sample_doctest_skip) [0]
+ single_skip (test.test_doctest.sample_doctest_skip)
+ >>> for tst, _ in result.failures:
+ ... print(tst)
+ no_skip_fail (test.test_doctest.sample_doctest_skip) [0]
+ partial_skip_fail (test.test_doctest.sample_doctest_skip) [1]
We can use the current module:
>>> suite = test.test_doctest.sample_doctest.test_suite()
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
We can also provide a DocTestFinder:
@@ -2310,7 +2333,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... test_finder=finder)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
The DocTestFinder need not return any tests:
@@ -2326,7 +2349,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest', globs={})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=5>
+ <unittest.result.TestResult run=9 errors=3 failures=2>
Alternatively, we can provide extra globals. Here we'll make an
error go away by providing an extra global variable:
@@ -2334,7 +2357,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... extraglobs={'y': 1})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
You can pass option flags. Here we'll cause an extra error
by disabling the blank-line feature:
@@ -2342,7 +2365,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... optionflags=doctest.DONT_ACCEPT_BLANKLINE)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=5>
+ <unittest.result.TestResult run=9 errors=2 failures=3>
You can supply setUp and tearDown functions:
@@ -2359,7 +2382,7 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest',
... setUp=setUp, tearDown=tearDown)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
But the tearDown restores sanity:
@@ -2377,13 +2400,115 @@ def test_DocTestSuite():
>>> suite = doctest.DocTestSuite('test.test_doctest.sample_doctest', setUp=setUp)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=3>
+ <unittest.result.TestResult run=9 errors=1 failures=2>
Here, we didn't need to use a tearDown function because we
modified the test globals, which are a copy of the
sample_doctest module dictionary. The test globals are
automatically cleared for us after a test.
- """
+ """
+
+def test_DocTestSuite_errors():
+ """Tests for error reporting in DocTestSuite.
+
+ >>> import unittest
+ >>> import test.test_doctest.sample_doctest_errors as mod
+ >>> suite = doctest.DocTestSuite(mod)
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=4 errors=6 failures=3>
+ >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 5, in test.test_doctest.sample_doctest_errors
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.failures[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line None, in test.test_doctest.sample_doctest_errors.__test__.bad
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.failures[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 16, in test.test_doctest.sample_doctest_errors.errors
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.errors[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 7, in test.test_doctest.sample_doctest_errors
+ >...>> 1/0
+ File "<doctest test.test_doctest.sample_doctest_errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line None, in test.test_doctest.sample_doctest_errors.__test__.bad
+ File "<doctest test.test_doctest.sample_doctest_errors.__test__.bad[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 18, in test.test_doctest.sample_doctest_errors.errors
+ >...>> 1/0
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[3][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 23, in test.test_doctest.sample_doctest_errors.errors
+ >...>> f()
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ <BLANKLINE>
+ >>> print(result.errors[4][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 25, in test.test_doctest.sample_doctest_errors.errors
+ >...>> g()
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[4]>", line 1, in <module>
+ g()
+ ~^^
+ File "...sample_doctest_errors.py", line 12, in g
+ [][0] # line 12
+ ~~^^^
+ IndexError: list index out of range
+ <BLANKLINE>
+ >>> print(result.errors[5][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...sample_doctest_errors.py", line 31, in test.test_doctest.sample_doctest_errors.syntax_error
+ >...>> 2+*3
+ File "<doctest test.test_doctest.sample_doctest_errors.syntax_error[0]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ <BLANKLINE>
+ """
def test_DocFileSuite():
"""We can test tests found in text files using a DocFileSuite.
@@ -2396,7 +2521,7 @@ def test_DocFileSuite():
... 'test_doctest2.txt',
... 'test_doctest4.txt')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
The test files are looked for in the directory containing the
calling module. A package keyword argument can be provided to
@@ -2408,14 +2533,14 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... package='test.test_doctest')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
'/' should be used as a path separator. It will be converted
to a native separator at run time:
>>> suite = doctest.DocFileSuite('../test_doctest/test_doctest.txt')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=1 errors=0 failures=1>
+ <unittest.result.TestResult run=1 errors=1 failures=0>
If DocFileSuite is used from an interactive session, then files
are resolved relative to the directory of sys.argv[0]:
@@ -2441,7 +2566,7 @@ def test_DocFileSuite():
>>> suite = doctest.DocFileSuite(test_file, module_relative=False)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=1 errors=0 failures=1>
+ <unittest.result.TestResult run=1 errors=1 failures=0>
It is an error to specify `package` when `module_relative=False`:
@@ -2455,12 +2580,19 @@ def test_DocFileSuite():
>>> suite = doctest.DocFileSuite('test_doctest.txt',
... 'test_doctest4.txt',
- ... 'test_doctest_skip.txt')
+ ... 'test_doctest_skip.txt',
+ ... 'test_doctest_skip2.txt')
>>> result = suite.run(unittest.TestResult())
>>> result
- <unittest.result.TestResult run=3 errors=0 failures=1>
- >>> len(result.skipped)
- 1
+ <unittest.result.TestResult run=4 errors=1 failures=0>
+ >>> len(result.skipped)
+ 4
+ >>> for tst, _ in result.skipped: # doctest: +ELLIPSIS
+ ... print('=', tst)
+ = ...test_doctest_skip.txt [0]
+ = ...test_doctest_skip.txt [1]
+ = ...test_doctest_skip.txt
+ = ...test_doctest_skip2.txt [0]
You can specify initial global variables:
@@ -2469,7 +2601,7 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... globs={'favorite_color': 'blue'})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=1>
+ <unittest.result.TestResult run=3 errors=1 failures=0>
In this case, we supplied a missing favorite color. You can
provide doctest options:
@@ -2480,7 +2612,7 @@ def test_DocFileSuite():
... optionflags=doctest.DONT_ACCEPT_BLANKLINE,
... globs={'favorite_color': 'blue'})
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=1 failures=1>
And, you can provide setUp and tearDown functions:
@@ -2499,7 +2631,7 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... setUp=setUp, tearDown=tearDown)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=1>
+ <unittest.result.TestResult run=3 errors=1 failures=0>
But the tearDown restores sanity:
@@ -2541,9 +2673,60 @@ def test_DocFileSuite():
... 'test_doctest4.txt',
... encoding='utf-8')
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=3 errors=0 failures=2>
+ <unittest.result.TestResult run=3 errors=2 failures=0>
+ """
- """
+def test_DocFileSuite_errors():
+ """Tests for error reporting in DocTestSuite.
+
+ >>> import unittest
+ >>> suite = doctest.DocFileSuite('test_doctest_errors.txt')
+ >>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=3 failures=1>
+ >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 4, in test_doctest_errors.txt
+ >...>> 2 + 2
+ AssertionError: Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ <BLANKLINE>
+ >>> print(result.errors[0][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 6, in test_doctest_errors.txt
+ >...>> 1/0
+ File "<doctest test_doctest_errors.txt[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ <BLANKLINE>
+ >>> print(result.errors[1][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 11, in test_doctest_errors.txt
+ >...>> f()
+ File "<doctest test_doctest_errors.txt[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test_doctest_errors.txt[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ <BLANKLINE>
+ >>> print(result.errors[2][1]) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ File "...test_doctest_errors.txt", line 13, in test_doctest_errors.txt
+ >...>> 2+*3
+ File "<doctest test_doctest_errors.txt[4]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ <BLANKLINE>
+
+ """
def test_trailing_space_in_test():
"""
@@ -2612,14 +2795,26 @@ def test_unittest_reportflags():
... optionflags=doctest.DONT_ACCEPT_BLANKLINE)
>>> import unittest
>>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=1 failures=1>
>>> print(result.failures[0][1]) # doctest: +ELLIPSIS
- Traceback ...
- Failed example:
- favorite_color
- ...
- Failed example:
+ Traceback (most recent call last):
+ File ...
+ >...>> if 1:
+ AssertionError: Failed example:
if 1:
- ...
+ print('a')
+ print()
+ print('b')
+ Expected:
+ a
+ <BLANKLINE>
+ b
+ Got:
+ a
+ <BLANKLINE>
+ b
+ <BLANKLINE>
Note that we see both failures displayed.
@@ -2628,16 +2823,8 @@ def test_unittest_reportflags():
Now, when we run the test:
- >>> result = suite.run(unittest.TestResult())
- >>> print(result.failures[0][1]) # doctest: +ELLIPSIS
- Traceback ...
- Failed example:
- favorite_color
- Exception raised:
- ...
- NameError: name 'favorite_color' is not defined
- <BLANKLINE>
- <BLANKLINE>
+ >>> suite.run(unittest.TestResult())
+ <unittest.result.TestResult run=1 errors=1 failures=0>
We get only the first failure.
@@ -2647,19 +2834,20 @@ def test_unittest_reportflags():
>>> suite = doctest.DocFileSuite('test_doctest.txt',
... optionflags=doctest.DONT_ACCEPT_BLANKLINE | doctest.REPORT_NDIFF)
- Then the default eporting options are ignored:
+ Then the default reporting options are ignored:
>>> result = suite.run(unittest.TestResult())
+ >>> result
+ <unittest.result.TestResult run=1 errors=1 failures=1>
*NOTE*: These doctest are intentionally not placed in raw string to depict
the trailing whitespace using `\x20` in the diff below.
>>> print(result.failures[0][1]) # doctest: +ELLIPSIS
Traceback ...
- Failed example:
- favorite_color
- ...
- Failed example:
+ File ...
+ >...>> if 1:
+ AssertionError: Failed example:
if 1:
print('a')
print()
@@ -2670,7 +2858,6 @@ def test_unittest_reportflags():
+\x20
b
<BLANKLINE>
- <BLANKLINE>
Test runners can restore the formatting flags after they run:
@@ -2860,6 +3047,57 @@ Test the verbose output:
>>> _colorize.COLORIZE = save_colorize
"""
+def test_testfile_errors(): r"""
+Tests for error reporting in the testfile() function.
+
+ >>> doctest.testfile('test_doctest_errors.txt', verbose=False) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 4, in test_doctest_errors.txt
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 6, in test_doctest_errors.txt
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test_doctest_errors.txt[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 11, in test_doctest_errors.txt
+ Failed example:
+ f()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test_doctest_errors.txt[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test_doctest_errors.txt[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ **********************************************************************
+ File "...test_doctest_errors.txt", line 13, in test_doctest_errors.txt
+ Failed example:
+ 2+*3
+ Exception raised:
+ File "<doctest test_doctest_errors.txt[4]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ **********************************************************************
+ 1 item had failures:
+ 4 of 5 in test_doctest_errors.txt
+ ***Test Failed*** 4 failures.
+ TestResults(failed=4, attempted=5)
+"""
+
class TestImporter(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
@@ -2990,6 +3228,110 @@ out of the binary module.
TestResults(failed=0, attempted=0)
"""
+def test_testmod_errors(): r"""
+Tests for error reporting in the testmod() function.
+
+ >>> import test.test_doctest.sample_doctest_errors as mod
+ >>> doctest.testmod(mod, verbose=False) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 5, in test.test_doctest.sample_doctest_errors
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 7, in test.test_doctest.sample_doctest_errors
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line ?, in test.test_doctest.sample_doctest_errors.__test__.bad
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line ?, in test.test_doctest.sample_doctest_errors.__test__.bad
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.__test__.bad[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 16, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ 2 + 2
+ Expected:
+ 5
+ Got:
+ 4
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 18, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ 1/0
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[1]>", line 1, in <module>
+ 1/0
+ ~^~
+ ZeroDivisionError: division by zero
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 23, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ f()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[3]>", line 1, in <module>
+ f()
+ ~^^
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[2]>", line 2, in f
+ 2 + '2'
+ ~~^~~~~
+ TypeError: ...
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 25, in test.test_doctest.sample_doctest_errors.errors
+ Failed example:
+ g()
+ Exception raised:
+ Traceback (most recent call last):
+ File "<doctest test.test_doctest.sample_doctest_errors.errors[4]>", line 1, in <module>
+ g()
+ ~^^
+ File "...sample_doctest_errors.py", line 12, in g
+ [][0] # line 12
+ ~~^^^
+ IndexError: list index out of range
+ **********************************************************************
+ File "...sample_doctest_errors.py", line 31, in test.test_doctest.sample_doctest_errors.syntax_error
+ Failed example:
+ 2+*3
+ Exception raised:
+ File "<doctest test.test_doctest.sample_doctest_errors.syntax_error[0]>", line 1
+ 2+*3
+ ^
+ SyntaxError: invalid syntax
+ **********************************************************************
+ 4 items had failures:
+ 2 of 2 in test.test_doctest.sample_doctest_errors
+ 2 of 2 in test.test_doctest.sample_doctest_errors.__test__.bad
+ 4 of 5 in test.test_doctest.sample_doctest_errors.errors
+ 1 of 1 in test.test_doctest.sample_doctest_errors.syntax_error
+ ***Test Failed*** 9 failures.
+ TestResults(failed=9, attempted=10)
+"""
+
try:
os.fsencode("foo-bär@baz.py")
supports_unicode = True
@@ -3021,11 +3363,6 @@ Check doctest with a non-ascii filename:
raise Exception('clé')
Exception raised:
Traceback (most recent call last):
- File ...
- exec(compile(example.source, filename, "single",
- ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- compileflags, True), test.globs)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<doctest foo-bär@baz[0]>", line 1, in <module>
raise Exception('clé')
Exception: clé
@@ -3318,9 +3655,9 @@ def test_run_doctestsuite_multiple_times():
>>> import test.test_doctest.sample_doctest
>>> suite = doctest.DocTestSuite(test.test_doctest.sample_doctest)
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
>>> suite.run(unittest.TestResult())
- <unittest.result.TestResult run=9 errors=0 failures=4>
+ <unittest.result.TestResult run=9 errors=2 failures=2>
"""
diff --git a/Lib/test/test_doctest/test_doctest_errors.txt b/Lib/test/test_doctest/test_doctest_errors.txt
new file mode 100644
index 00000000000..93c3c106e60
--- /dev/null
+++ b/Lib/test/test_doctest/test_doctest_errors.txt
@@ -0,0 +1,14 @@
+This is a sample doctest in a text file, in which all examples fail
+or raise an exception.
+
+ >>> 2 + 2
+ 5
+ >>> 1/0
+ 1
+ >>> def f():
+ ... 2 + '2'
+ ...
+ >>> f()
+ 1
+ >>> 2+*3
+ 5
diff --git a/Lib/test/test_doctest/test_doctest_skip.txt b/Lib/test/test_doctest/test_doctest_skip.txt
index f340e2b8141..06c23d06e60 100644
--- a/Lib/test/test_doctest/test_doctest_skip.txt
+++ b/Lib/test/test_doctest/test_doctest_skip.txt
@@ -2,3 +2,5 @@ This is a sample doctest in a text file, in which all examples are skipped.
>>> 2 + 2 # doctest: +SKIP
5
+ >>> 2 + 2 # doctest: +SKIP
+ 4
diff --git a/Lib/test/test_doctest/test_doctest_skip2.txt b/Lib/test/test_doctest/test_doctest_skip2.txt
new file mode 100644
index 00000000000..85e4938c346
--- /dev/null
+++ b/Lib/test/test_doctest/test_doctest_skip2.txt
@@ -0,0 +1,6 @@
+This is a sample doctest in a text file, in which some examples are skipped.
+
+ >>> 2 + 2 # doctest: +SKIP
+ 5
+ >>> 2 + 2
+ 4
diff --git a/Lib/test/test_dynamicclassattribute.py b/Lib/test/test_dynamicclassattribute.py
index 9f694d9eb46..b19be33c72f 100644
--- a/Lib/test/test_dynamicclassattribute.py
+++ b/Lib/test/test_dynamicclassattribute.py
@@ -104,8 +104,8 @@ class PropertyTests(unittest.TestCase):
self.assertEqual(base.spam, 10)
self.assertEqual(base._spam, 10)
delattr(base, "spam")
- self.assertTrue(not hasattr(base, "spam"))
- self.assertTrue(not hasattr(base, "_spam"))
+ self.assertNotHasAttr(base, "spam")
+ self.assertNotHasAttr(base, "_spam")
base.spam = 20
self.assertEqual(base.spam, 20)
self.assertEqual(base._spam, 20)
diff --git a/Lib/test/test_email/test__header_value_parser.py b/Lib/test/test_email/test__header_value_parser.py
index ac12c3b2306..179e236ecdf 100644
--- a/Lib/test/test_email/test__header_value_parser.py
+++ b/Lib/test/test_email/test__header_value_parser.py
@@ -463,6 +463,19 @@ class TestParser(TestParserMixin, TestEmailBase):
[errors.NonPrintableDefect], ')')
self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+ def test_get_qp_ctext_close_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ ')', '', ' ', [], ')')
+
+ def test_get_qp_ctext_open_paren_only(self):
+ self._test_get_x(parser.get_qp_ctext,
+ '(', '', ' ', [], '(')
+
+ def test_get_qp_ctext_no_end_char(self):
+ self._test_get_x(parser.get_qp_ctext,
+ '', '', ' ', [], '')
+
+
# get_qcontent
def test_get_qcontent_only(self):
@@ -503,6 +516,14 @@ class TestParser(TestParserMixin, TestEmailBase):
[errors.NonPrintableDefect], '"')
self.assertEqual(ptext.defects[0].non_printables[0], '\x00')
+ def test_get_qcontent_empty(self):
+ self._test_get_x(parser.get_qcontent,
+ '"', '', '', [], '"')
+
+ def test_get_qcontent_no_end_char(self):
+ self._test_get_x(parser.get_qcontent,
+ '', '', '', [], '')
+
# get_atext
def test_get_atext_only(self):
@@ -1283,6 +1304,18 @@ class TestParser(TestParserMixin, TestEmailBase):
self._test_get_x(parser.get_dtext,
'foo[bar', 'foo', 'foo', [], '[bar')
+ def test_get_dtext_open_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ '[', '', '', [], '[')
+
+ def test_get_dtext_close_bracket_only(self):
+ self._test_get_x(parser.get_dtext,
+ ']', '', '', [], ']')
+
+ def test_get_dtext_empty(self):
+ self._test_get_x(parser.get_dtext,
+ '', '', '', [], '')
+
# get_domain_literal
def test_get_domain_literal_only(self):
@@ -2458,6 +2491,38 @@ class TestParser(TestParserMixin, TestEmailBase):
self.assertEqual(address.all_mailboxes[0].domain, 'example.com')
self.assertEqual(address.all_mailboxes[0].addr_spec, '"example example"@example.com')
+ def test_get_address_with_invalid_domain(self):
+ address = self._test_get_x(parser.get_address,
+ '<T@[',
+ '<T@[]>',
+ '<T@[]>',
+ [errors.InvalidHeaderDefect, # missing trailing '>' on angle-addr
+ errors.InvalidHeaderDefect, # end of input inside domain-literal
+ ],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(len(address.all_mailboxes), 1)
+ self.assertEqual(address.all_mailboxes[0].domain, '[]')
+ self.assertEqual(address.all_mailboxes[0].local_part, 'T')
+ self.assertEqual(address.all_mailboxes[0].token_type, 'invalid-mailbox')
+ self.assertEqual(address[0].token_type, 'invalid-mailbox')
+
+ address = self._test_get_x(parser.get_address,
+ '!an??:=m==fr2@[C',
+ '!an??:=m==fr2@[C];',
+ '!an??:=m==fr2@[C];',
+ [errors.InvalidHeaderDefect, # end of header in group
+ errors.InvalidHeaderDefect, # end of input inside domain-literal
+ ],
+ '')
+ self.assertEqual(address.token_type, 'address')
+ self.assertEqual(len(address.mailboxes), 0)
+ self.assertEqual(len(address.all_mailboxes), 1)
+ self.assertEqual(address.all_mailboxes[0].domain, '[C]')
+ self.assertEqual(address.all_mailboxes[0].local_part, '=m==fr2')
+ self.assertEqual(address.all_mailboxes[0].token_type, 'invalid-mailbox')
+ self.assertEqual(address[0].token_type, 'group')
# get_address_list
@@ -2732,6 +2797,19 @@ class TestParser(TestParserMixin, TestEmailBase):
)
self.assertEqual(message_id.token_type, 'message-id')
+ def test_parse_message_id_with_invalid_domain(self):
+ message_id = self._test_parse_x(
+ parser.parse_message_id,
+ "<T@[",
+ "<T@[]>",
+ "<T@[]>",
+ [errors.ObsoleteHeaderDefect] + [errors.InvalidHeaderDefect] * 2,
+ [],
+ )
+ self.assertEqual(message_id.token_type, 'message-id')
+ self.assertEqual(str(message_id.all_defects[-1]),
+ "end of input inside domain-literal")
+
def test_parse_message_id_with_remaining(self):
message_id = self._test_parse_x(
parser.parse_message_id,
diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py
index 7b14305f997..b8116d073a2 100644
--- a/Lib/test/test_email/test_email.py
+++ b/Lib/test/test_email/test_email.py
@@ -389,6 +389,24 @@ class TestMessageAPI(TestEmailBase):
msg = email.message_from_string("Content-Type: blarg; baz; boo\n")
self.assertEqual(msg.get_param('baz'), '')
+ def test_continuation_sorting_part_order(self):
+ msg = email.message_from_string(
+ "Content-Disposition: attachment; "
+ "filename*=\"ignored\"; "
+ "filename*0*=\"utf-8''foo%20\"; "
+ "filename*1*=\"bar.txt\"\n"
+ )
+ filename = msg.get_filename()
+ self.assertEqual(filename, 'foo bar.txt')
+
+ def test_sorting_no_continuations(self):
+ msg = email.message_from_string(
+ "Content-Disposition: attachment; "
+ "filename*=\"bar.txt\"; "
+ )
+ filename = msg.get_filename()
+ self.assertEqual(filename, 'bar.txt')
+
def test_missing_filename(self):
msg = email.message_from_string("From: foo\n")
self.assertEqual(msg.get_filename(), None)
@@ -2550,6 +2568,18 @@ Re: =?mac-iceland?q?r=8Aksm=9Arg=8Cs?= baz foo bar =?mac-iceland?q?r=8Aksm?=
self.assertEqual(str(make_header(decode_header(s))),
'"Müller T" <T.Mueller@xxx.com>')
+ def test_unencoded_ascii(self):
+ # bpo-22833/gh-67022: returns [(str, None)] rather than [(bytes, None)]
+ s = 'header without encoded words'
+ self.assertEqual(decode_header(s),
+ [('header without encoded words', None)])
+
+ def test_unencoded_utf8(self):
+ # bpo-22833/gh-67022: returns [(str, None)] rather than [(bytes, None)]
+ s = 'header with unexpected non ASCII caract\xe8res'
+ self.assertEqual(decode_header(s),
+ [('header with unexpected non ASCII caract\xe8res', None)])
+
# Test the MIMEMessage class
class TestMIMEMessage(TestEmailBase):
diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py
index 4e6201e13c8..c9d09098b50 100644
--- a/Lib/test/test_email/test_utils.py
+++ b/Lib/test/test_email/test_utils.py
@@ -4,6 +4,16 @@ import test.support
import time
import unittest
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("email.utils", {"random", "socket"})
+
class DateTimeTests(unittest.TestCase):
diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py
index e06e684408c..89f4aebe28f 100644
--- a/Lib/test/test_embed.py
+++ b/Lib/test/test_embed.py
@@ -296,7 +296,7 @@ class EmbeddingTests(EmbeddingTestsMixin, unittest.TestCase):
if MS_WINDOWS:
expected_path = self.test_exe
else:
- expected_path = os.path.join(os.getcwd(), "spam")
+ expected_path = os.path.join(os.getcwd(), "_testembed")
expected_output = f"sys.executable: {expected_path}\n"
self.assertIn(expected_output, out)
self.assertEqual(err, '')
@@ -585,7 +585,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'faulthandler': False,
'tracemalloc': 0,
'perf_profiling': 0,
- 'import_time': False,
+ 'import_time': 0,
'thread_inherit_context': DEFAULT_THREAD_INHERIT_CONTEXT,
'context_aware_warnings': DEFAULT_CONTEXT_AWARE_WARNINGS,
'code_debug_ranges': True,
@@ -969,7 +969,6 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'utf8_mode': True,
}
config = {
- 'program_name': './globalvar',
'site_import': False,
'bytes_warning': True,
'warnoptions': ['default::BytesWarning'],
@@ -998,7 +997,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'hash_seed': 123,
'tracemalloc': 2,
'perf_profiling': 0,
- 'import_time': True,
+ 'import_time': 2,
'code_debug_ranges': False,
'show_ref_count': True,
'malloc_stats': True,
@@ -1064,7 +1063,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
@@ -1100,7 +1099,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
@@ -1916,6 +1915,10 @@ class AuditingTests(EmbeddingTestsMixin, unittest.TestCase):
self.run_embedded_interpreter("test_get_incomplete_frame")
+ def test_gilstate_after_finalization(self):
+ self.run_embedded_interpreter("test_gilstate_after_finalization")
+
+
class MiscTests(EmbeddingTestsMixin, unittest.TestCase):
def test_unicode_id_init(self):
# bpo-42882: Test that _PyUnicode_FromId() works
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 68cedc666a5..bbc7630fa83 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -19,7 +19,8 @@ from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support
from test.support import ALWAYS_EQ, REPO_ROOT
-from test.support import threading_helper
+from test.support import threading_helper, cpython_only
+from test.support.import_helper import ensure_lazy_imports
from datetime import timedelta
python_version = sys.version_info[:2]
@@ -35,7 +36,7 @@ def load_tests(loader, tests, ignore):
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
howto_tests = os.path.join(REPO_ROOT, 'Doc/howto/enum.rst')
- if os.path.exists(howto_tests):
+ if os.path.exists(howto_tests) and sys.float_repr_style == 'short':
tests.addTests(doctest.DocFileSuite(
howto_tests,
module_relative=False,
@@ -433,9 +434,9 @@ class _EnumTests:
def spam(cls):
pass
#
- self.assertTrue(hasattr(Season, 'spam'))
+ self.assertHasAttr(Season, 'spam')
del Season.spam
- self.assertFalse(hasattr(Season, 'spam'))
+ self.assertNotHasAttr(Season, 'spam')
#
with self.assertRaises(AttributeError):
del Season.SPRING
@@ -2651,12 +2652,12 @@ class TestSpecial(unittest.TestCase):
OneDay = day_1
OneWeek = week_1
OneMonth = month_1
- self.assertFalse(hasattr(Period, '_ignore_'))
- self.assertFalse(hasattr(Period, 'Period'))
- self.assertFalse(hasattr(Period, 'i'))
- self.assertTrue(isinstance(Period.day_1, timedelta))
- self.assertTrue(Period.month_1 is Period.day_30)
- self.assertTrue(Period.week_4 is Period.day_28)
+ self.assertNotHasAttr(Period, '_ignore_')
+ self.assertNotHasAttr(Period, 'Period')
+ self.assertNotHasAttr(Period, 'i')
+ self.assertIsInstance(Period.day_1, timedelta)
+ self.assertIs(Period.month_1, Period.day_30)
+ self.assertIs(Period.week_4, Period.day_28)
def test_nonhash_value(self):
class AutoNumberInAList(Enum):
@@ -2876,7 +2877,7 @@ class TestSpecial(unittest.TestCase):
self.assertEqual(str(ReformedColor.BLUE), 'blue')
self.assertEqual(ReformedColor.RED.behavior(), 'booyah')
self.assertEqual(ConfusedColor.RED.social(), "what's up?")
- self.assertTrue(issubclass(ReformedColor, int))
+ self.assertIsSubclass(ReformedColor, int)
def test_multiple_inherited_mixin(self):
@unique
@@ -5288,6 +5289,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("enum", {"functools", "warnings", "inspect", "re"})
+
def test_doc_1(self):
class Single(Enum):
ONE = 1
diff --git a/Lib/test/test_errno.py b/Lib/test/test_errno.py
index 5c437e9ccea..e7f185c6b1a 100644
--- a/Lib/test/test_errno.py
+++ b/Lib/test/test_errno.py
@@ -12,14 +12,12 @@ class ErrnoAttributeTests(unittest.TestCase):
def test_for_improper_attributes(self):
# No unexpected attributes should be on the module.
for error_code in std_c_errors:
- self.assertTrue(hasattr(errno, error_code),
- "errno is missing %s" % error_code)
+ self.assertHasAttr(errno, error_code)
def test_using_errorcode(self):
# Every key value in errno.errorcode should be on the module.
for value in errno.errorcode.values():
- self.assertTrue(hasattr(errno, value),
- 'no %s attr in errno' % value)
+ self.assertHasAttr(errno, value)
class ErrorcodeTests(unittest.TestCase):
diff --git a/Lib/test/test_exception_group.py b/Lib/test/test_exception_group.py
index 92bbf791764..5df2c41c6b5 100644
--- a/Lib/test/test_exception_group.py
+++ b/Lib/test/test_exception_group.py
@@ -1,13 +1,13 @@
import collections.abc
import types
import unittest
-from test.support import skip_emscripten_stack_overflow, exceeds_recursion_limit
+from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow, exceeds_recursion_limit
class TestExceptionGroupTypeHierarchy(unittest.TestCase):
def test_exception_group_types(self):
- self.assertTrue(issubclass(ExceptionGroup, Exception))
- self.assertTrue(issubclass(ExceptionGroup, BaseExceptionGroup))
- self.assertTrue(issubclass(BaseExceptionGroup, BaseException))
+ self.assertIsSubclass(ExceptionGroup, Exception)
+ self.assertIsSubclass(ExceptionGroup, BaseExceptionGroup)
+ self.assertIsSubclass(BaseExceptionGroup, BaseException)
def test_exception_is_not_generic_type(self):
with self.assertRaisesRegex(TypeError, 'Exception'):
@@ -465,12 +465,14 @@ class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
return e
@skip_emscripten_stack_overflow()
+ @skip_wasi_stack_overflow()
def test_deep_split(self):
e = self.make_deep_eg()
with self.assertRaises(RecursionError):
e.split(TypeError)
@skip_emscripten_stack_overflow()
+ @skip_wasi_stack_overflow()
def test_deep_subgroup(self):
e = self.make_deep_eg()
with self.assertRaises(RecursionError):
@@ -812,8 +814,8 @@ class NestedExceptionGroupSplitTest(ExceptionGroupSplitTestBase):
eg = ExceptionGroup("eg", [ValueError(1), TypeError(2)])
eg.__notes__ = 123
match, rest = eg.split(TypeError)
- self.assertFalse(hasattr(match, '__notes__'))
- self.assertFalse(hasattr(rest, '__notes__'))
+ self.assertNotHasAttr(match, '__notes__')
+ self.assertNotHasAttr(rest, '__notes__')
def test_drive_invalid_return_value(self):
class MyEg(ExceptionGroup):
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index d177e3dc0f5..57d0656487d 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -357,7 +357,7 @@ class ExceptionTests(unittest.TestCase):
except TypeError as err:
co = err.__traceback__.tb_frame.f_code
self.assertEqual(co.co_name, "test_capi1")
- self.assertTrue(co.co_filename.endswith('test_exceptions.py'))
+ self.assertEndsWith(co.co_filename, 'test_exceptions.py')
else:
self.fail("Expected exception")
@@ -369,7 +369,7 @@ class ExceptionTests(unittest.TestCase):
tb = err.__traceback__.tb_next
co = tb.tb_frame.f_code
self.assertEqual(co.co_name, "__init__")
- self.assertTrue(co.co_filename.endswith('test_exceptions.py'))
+ self.assertEndsWith(co.co_filename, 'test_exceptions.py')
co2 = tb.tb_frame.f_back.f_code
self.assertEqual(co2.co_name, "test_capi2")
else:
@@ -598,7 +598,7 @@ class ExceptionTests(unittest.TestCase):
def test_notes(self):
for e in [BaseException(1), Exception(2), ValueError(3)]:
with self.subTest(e=e):
- self.assertFalse(hasattr(e, '__notes__'))
+ self.assertNotHasAttr(e, '__notes__')
e.add_note("My Note")
self.assertEqual(e.__notes__, ["My Note"])
@@ -610,7 +610,7 @@ class ExceptionTests(unittest.TestCase):
self.assertEqual(e.__notes__, ["My Note", "Your Note"])
del e.__notes__
- self.assertFalse(hasattr(e, '__notes__'))
+ self.assertNotHasAttr(e, '__notes__')
e.add_note("Our Note")
self.assertEqual(e.__notes__, ["Our Note"])
@@ -1429,6 +1429,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn("maximum recursion depth exceeded", str(exc))
@support.skip_wasi_stack_overflow()
+ @support.skip_emscripten_stack_overflow()
@cpython_only
@support.requires_resource('cpu')
def test_trashcan_recursion(self):
@@ -1444,6 +1445,7 @@ class ExceptionTests(unittest.TestCase):
foo()
support.gc_collect()
+ @support.skip_emscripten_stack_overflow()
@cpython_only
def test_recursion_normalizing_exception(self):
import_module("_testinternalcapi")
@@ -1521,6 +1523,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn(b'Done.', out)
+ @support.skip_emscripten_stack_overflow()
def test_recursion_in_except_handler(self):
def set_relative_recursion_limit(n):
@@ -1626,7 +1629,7 @@ class ExceptionTests(unittest.TestCase):
# test basic usage of PyErr_NewException
error1 = _testcapi.make_exception_with_doc("_testcapi.error1")
self.assertIs(type(error1), type)
- self.assertTrue(issubclass(error1, Exception))
+ self.assertIsSubclass(error1, Exception)
self.assertIsNone(error1.__doc__)
# test with given docstring
@@ -1636,21 +1639,21 @@ class ExceptionTests(unittest.TestCase):
# test with explicit base (without docstring)
error3 = _testcapi.make_exception_with_doc("_testcapi.error3",
base=error2)
- self.assertTrue(issubclass(error3, error2))
+ self.assertIsSubclass(error3, error2)
# test with explicit base tuple
class C(object):
pass
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
(error3, C))
- self.assertTrue(issubclass(error4, error3))
- self.assertTrue(issubclass(error4, C))
+ self.assertIsSubclass(error4, error3)
+ self.assertIsSubclass(error4, C)
self.assertEqual(error4.__doc__, doc4)
# test with explicit dictionary
error5 = _testcapi.make_exception_with_doc("_testcapi.error5", "",
error4, {'a': 1})
- self.assertTrue(issubclass(error5, error4))
+ self.assertIsSubclass(error5, error4)
self.assertEqual(error5.a, 1)
self.assertEqual(error5.__doc__, "")
@@ -1743,7 +1746,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIn("<exception str() failed>", report)
else:
self.assertIn("test message", report)
- self.assertTrue(report.endswith("\n"))
+ self.assertEndsWith(report, "\n")
@cpython_only
# Python built with Py_TRACE_REFS fail with a fatal error in
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py
index aa05db972f0..0f31c225e68 100644
--- a/Lib/test/test_external_inspection.py
+++ b/Lib/test/test_external_inspection.py
@@ -4,7 +4,10 @@ import textwrap
import importlib
import sys
import socket
-from test.support import os_helper, SHORT_TIMEOUT, busy_retry
+import threading
+from asyncio import staggered, taskgroups, base_events, tasks
+from unittest.mock import ANY
+from test.support import os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled
from test.support.script_helper import make_script
from test.support.socket_helper import find_unused_port
@@ -13,33 +16,60 @@ import subprocess
PROCESS_VM_READV_SUPPORTED = False
try:
- from _testexternalinspection import PROCESS_VM_READV_SUPPORTED
- from _testexternalinspection import get_stack_trace
- from _testexternalinspection import get_async_stack_trace
- from _testexternalinspection import get_all_awaited_by
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ from _remote_debugging import RemoteUnwinder
+ from _remote_debugging import FrameInfo, CoroInfo, TaskInfo
except ImportError:
raise unittest.SkipTest(
- "Test only runs when _testexternalinspection is available")
+ "Test only runs when _remote_debugging is available"
+ )
+
def _make_test_script(script_dir, script_basename, source):
to_return = make_script(script_dir, script_basename, source)
importlib.invalidate_caches()
return to_return
-skip_if_not_supported = unittest.skipIf((sys.platform != "darwin"
- and sys.platform != "linux"
- and sys.platform != "win32"),
- "Test only runs on Linux, Windows and MacOS")
+
+skip_if_not_supported = unittest.skipIf(
+ (
+ sys.platform != "darwin"
+ and sys.platform != "linux"
+ and sys.platform != "win32"
+ ),
+ "Test only runs on Linux, Windows and MacOS",
+)
+
+
+def get_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
+ return unwinder.get_stack_trace()
+
+
+def get_async_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_async_stack_trace()
+
+
+def get_all_awaited_by(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_all_awaited_by()
+
+
class TestGetStackTrace(unittest.TestCase):
+ maxDiff = None
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
- import time, sys, socket
+ script = textwrap.dedent(
+ f"""\
+ import time, sys, socket, threading
# Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
@@ -48,15 +78,18 @@ class TestGetStackTrace(unittest.TestCase):
for x in range(100):
if x == 50:
baz()
+
def baz():
foo()
def foo():
- sock.sendall(b"ready")
- time.sleep(1000)
+ sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number
- bar()
- """)
+ t = threading.Thread(target=bar)
+ t.start()
+ sock.sendall(b"ready:main\\n"); t.join() # same line number
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -65,21 +98,27 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
+ response = b""
+ while (
+ b"ready:main" not in response
+ or b"ready:thread" not in response
+ ):
+ response += client_socket.recv(1024)
stack_trace = get_stack_trace(p.pid)
except PermissionError:
- self.skipTest("Insufficient permissions to read the stack trace")
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -87,22 +126,34 @@ class TestGetStackTrace(unittest.TestCase):
p.terminate()
p.wait(timeout=SHORT_TIMEOUT)
-
- expected_stack_trace = [
- 'foo',
- 'baz',
- 'bar',
- '<module>'
+ thread_expected_stack_trace = [
+ FrameInfo([script_name, 15, "foo"]),
+ FrameInfo([script_name, 12, "baz"]),
+ FrameInfo([script_name, 9, "bar"]),
+ FrameInfo([threading.__file__, ANY, "Thread.run"]),
]
- self.assertEqual(stack_trace, expected_stack_trace)
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ self.assertIn((ANY, thread_expected_stack_trace), stack_trace)
+
+ # Check that the main thread stack trace is in the result
+ frame = FrameInfo([script_name, 19, "<module>"])
+ for _, stack in stack_trace:
+ if frame in stack:
+ break
+ else:
+ self.fail("Main thread stack trace not found in result")
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -112,8 +163,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
def c5():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c4():
await asyncio.sleep(0)
@@ -142,7 +192,8 @@ class TestGetStackTrace(unittest.TestCase):
return loop
asyncio.run(main(), loop_factory={{TASK_FACTORY}})
- """)
+ """
+ )
stack_trace = None
for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop":
with (
@@ -151,19 +202,23 @@ class TestGetStackTrace(unittest.TestCase):
):
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket = socket.socket(
+ socket.AF_INET, socket.SOCK_STREAM
+ )
+ server_socket.setsockopt(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, 1
+ )
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
script_name = _make_test_script(
- script_dir, 'script',
- script.format(TASK_FACTORY=task_factory_variant))
+ script_dir,
+ "script",
+ script.format(TASK_FACTORY=task_factory_variant),
+ )
client_socket = None
try:
- p = subprocess.Popen(
- [sys.executable, script_name]
- )
+ p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
response = client_socket.recv(1024)
@@ -171,7 +226,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -182,25 +238,63 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
- root_task = "Task-1"
expected_stack_trace = [
- ["c5", "c4", "c3", "c2"],
+ [
+ FrameInfo([script_name, 10, "c5"]),
+ FrameInfo([script_name, 14, "c4"]),
+ FrameInfo([script_name, 17, "c3"]),
+ FrameInfo([script_name, 20, "c2"]),
+ ],
"c2_root",
[
- [["main"], root_task, []],
- [["c1"], "sub_main_1", [[["main"], root_task, []]]],
- [["c1"], "sub_main_2", [[["main"], root_task, []]]],
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo([script_name, 26, "main"]),
+ ],
+ "Task-1",
+ ]
+ ),
+ CoroInfo(
+ [
+ [FrameInfo([script_name, 23, "c1"])],
+ "sub_main_1",
+ ]
+ ),
+ CoroInfo(
+ [
+ [FrameInfo([script_name, 23, "c1"])],
+ "sub_main_2",
+ ]
+ ),
],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_asyncgen_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -210,8 +304,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
async def gen_nested_call():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def gen():
for num in range(2):
@@ -224,7 +317,8 @@ class TestGetStackTrace(unittest.TestCase):
pass
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -232,10 +326,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -245,7 +339,9 @@ class TestGetStackTrace(unittest.TestCase):
self.assertEqual(response, b"ready")
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
- self.skipTest("Insufficient permissions to read the stack trace")
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -257,17 +353,26 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace[2].sort(key=lambda x: x[1])
expected_stack_trace = [
- ['gen_nested_call', 'gen', 'main'], 'Task-1', []
+ [
+ FrameInfo([script_name, 10, "gen_nested_call"]),
+ FrameInfo([script_name, 16, "gen"]),
+ FrameInfo([script_name, 19, "main"]),
+ ],
+ "Task-1",
+ [],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_gather_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import time
import sys
@@ -278,8 +383,7 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
@@ -292,7 +396,8 @@ class TestGetStackTrace(unittest.TestCase):
await asyncio.gather(c1(), c2())
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -300,10 +405,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -314,7 +419,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -325,18 +431,26 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
- expected_stack_trace = [
- ['deep', 'c1'], 'Task-2', [[['main'], 'Task-1', []]]
+ expected_stack_trace = [
+ [
+ FrameInfo([script_name, 11, "deep"]),
+ FrameInfo([script_name, 15, "c1"]),
+ ],
+ "Task-2",
+ [CoroInfo([[FrameInfo([script_name, 21, "main"])], "Task-1"])],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_staggered_race_remote_stack_trace(self):
# Spawn a process with some realistic Python code
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio.staggered
import time
import sys
@@ -347,15 +461,14 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
await deep()
async def c2():
- await asyncio.sleep(10000)
+ await asyncio.sleep(10_000)
async def main():
await asyncio.staggered.staggered_race(
@@ -364,7 +477,8 @@ class TestGetStackTrace(unittest.TestCase):
)
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -372,10 +486,10 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
@@ -386,7 +500,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace = get_async_stack_trace(p.pid)
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -396,18 +511,44 @@ class TestGetStackTrace(unittest.TestCase):
# sets are unordered, so we want to sort "awaited_by"s
stack_trace[2].sort(key=lambda x: x[1])
-
- expected_stack_trace = [
- ['deep', 'c1', 'run_one_coro'], 'Task-2', [[['main'], 'Task-1', []]]
+ expected_stack_trace = [
+ [
+ FrameInfo([script_name, 11, "deep"]),
+ FrameInfo([script_name, 15, "c1"]),
+ FrameInfo(
+ [
+ staggered.__file__,
+ ANY,
+ "staggered_race.<locals>.run_one_coro",
+ ]
+ ),
+ ],
+ "Task-2",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [staggered.__file__, ANY, "staggered_race"]
+ ),
+ FrameInfo([script_name, 21, "main"]),
+ ],
+ "Task-1",
+ ]
+ )
+ ],
]
self.assertEqual(stack_trace, expected_stack_trace)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_async_global_awaited_by(self):
port = find_unused_port()
- script = textwrap.dedent(f"""\
+ script = textwrap.dedent(
+ f"""\
import asyncio
import os
import random
@@ -443,6 +584,8 @@ class TestGetStackTrace(unittest.TestCase):
assert message == data.decode()
writer.close()
await writer.wait_closed()
+ # Signal we are ready to sleep
+ sock.sendall(b"ready")
await asyncio.sleep(SHORT_TIMEOUT)
async def echo_client_spam(server):
@@ -452,8 +595,10 @@ class TestGetStackTrace(unittest.TestCase):
random.shuffle(msg)
tg.create_task(echo_client("".join(msg)))
await asyncio.sleep(0)
- # at least a 1000 tasks created
- sock.sendall(b"ready")
+ # at least a 1000 tasks created. Each task will signal
+ # when is ready to avoid the race caused by the fact that
+ # tasks are waited on tg.__exit__ and we cannot signal when
+ # that happens otherwise
# at this point all client tasks completed without assertion errors
# let's wrap up the test
server.close()
@@ -468,7 +613,8 @@ class TestGetStackTrace(unittest.TestCase):
tg.create_task(echo_client_spam(server), name="echo client spam")
asyncio.run(main())
- """)
+ """
+ )
stack_trace = None
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
@@ -476,17 +622,19 @@ class TestGetStackTrace(unittest.TestCase):
# Create a socket server to communicate with the target process
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(('localhost', port))
+ server_socket.bind(("localhost", port))
server_socket.settimeout(SHORT_TIMEOUT)
server_socket.listen(1)
- script_name = _make_test_script(script_dir, 'script', script)
+ script_name = _make_test_script(script_dir, "script", script)
client_socket = None
try:
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
+ for _ in range(1000):
+ expected_response = b"ready"
+ response = client_socket.recv(len(expected_response))
+ self.assertEqual(response, expected_response)
for _ in busy_retry(SHORT_TIMEOUT):
try:
all_awaited_by = get_all_awaited_by(p.pid)
@@ -497,7 +645,9 @@ class TestGetStackTrace(unittest.TestCase):
msg = str(re)
if msg.startswith("Task list appears corrupted"):
continue
- elif msg.startswith("Invalid linked list structure reading remote memory"):
+ elif msg.startswith(
+ "Invalid linked list structure reading remote memory"
+ ):
continue
elif msg.startswith("Unknown error reading memory"):
continue
@@ -516,22 +666,174 @@ class TestGetStackTrace(unittest.TestCase):
# expected: at least 1000 pending tasks
self.assertGreaterEqual(len(entries), 1000)
# the first three tasks stem from the code structure
- self.assertIn(('Task-1', []), entries)
- self.assertIn(('server task', [[['main'], 'Task-1', []]]), entries)
- self.assertIn(('echo client spam', [[['main'], 'Task-1', []]]), entries)
+ main_stack = [
+ FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]),
+ FrameInfo(
+ [taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
+ ),
+ FrameInfo([script_name, 60, "main"]),
+ ]
+ self.assertIn(
+ TaskInfo(
+ [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
+ [
+ ANY,
+ "server task",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ base_events.__file__,
+ ANY,
+ "Server.serve_forever",
+ ]
+ )
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, ANY, "main"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
+ [
+ ANY,
+ "Task-4",
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [tasks.__file__, ANY, "sleep"]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 38,
+ "echo_client",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 41,
+ "echo_client_spam",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
- expected_stack = [[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]]
- tasks_with_stack = [task for task in entries if task[1] == expected_stack]
- self.assertGreaterEqual(len(tasks_with_stack), 1000)
+ expected_awaited_by = [
+ CoroInfo(
+ [
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, 41, "echo_client_spam"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ]
+ tasks_with_awaited = [
+ task
+ for task in entries
+ if task.awaited_by == expected_awaited_by
+ ]
+ self.assertGreaterEqual(len(tasks_with_awaited), 1000)
# the final task will have some random number, but it should for
# sure be one of the echo client spam horde (In windows this is not true
# for some reason)
if sys.platform != "win32":
- self.assertEqual([[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]], entries[-1][1])
+ self.assertEqual(
+ tasks_with_awaited[-1].awaited_by,
+ entries[-1].awaited_by,
+ )
except PermissionError:
self.skipTest(
- "Insufficient permissions to read the stack trace")
+ "Insufficient permissions to read the stack trace"
+ )
finally:
if client_socket is not None:
client_socket.close()
@@ -540,12 +842,160 @@ class TestGetStackTrace(unittest.TestCase):
p.wait(timeout=SHORT_TIMEOUT)
@skip_if_not_supported
- @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
- "Test only runs on Linux with process_vm_readv support")
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
def test_self_trace(self):
stack_trace = get_stack_trace(os.getpid())
- print(stack_trace)
- self.assertEqual(stack_trace[0], "test_self_trace")
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ this_tread_stack = None
+ for thread_id, stack in stack_trace:
+ if thread_id == threading.get_native_id():
+ this_tread_stack = stack
+ break
+ self.assertIsNotNone(this_tread_stack)
+ self.assertEqual(
+ stack[:2],
+ [
+ FrameInfo(
+ [
+ __file__,
+ get_stack_trace.__code__.co_firstlineno + 2,
+ "get_stack_trace",
+ ]
+ ),
+ FrameInfo(
+ [
+ __file__,
+ self.test_self_trace.__code__.co_firstlineno + 6,
+ "TestGetStackTrace.test_self_trace",
+ ]
+ ),
+ ],
+ )
+
+ @skip_if_not_supported
+ @unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+ )
+ @requires_gil_enabled("Free threaded builds don't have an 'active thread'")
+ def test_only_active_thread(self):
+ # Test that only_active_thread parameter works correctly
+ port = find_unused_port()
+ script = textwrap.dedent(
+ f"""\
+ import time, sys, socket, threading
+
+ # Connect to the test process
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(('localhost', {port}))
+
+ def worker_thread(name, barrier, ready_event):
+ barrier.wait() # Synchronize thread start
+ ready_event.wait() # Wait for main thread signal
+ # Sleep to keep thread alive
+ time.sleep(10_000)
+
+ def main_work():
+ # Do busy work to hold the GIL
+ sock.sendall(b"working\\n")
+ count = 0
+ while count < 100000000:
+ count += 1
+ if count % 10000000 == 0:
+ pass # Keep main thread busy
+ sock.sendall(b"done\\n")
+
+ # Create synchronization primitives
+ num_threads = 3
+ barrier = threading.Barrier(num_threads + 1) # +1 for main thread
+ ready_event = threading.Event()
+
+ # Start worker threads
+ threads = []
+ for i in range(num_threads):
+ t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event))
+ t.start()
+ threads.append(t)
+
+ # Wait for all threads to be ready
+ barrier.wait()
+
+ # Signal ready to parent process
+ sock.sendall(b"ready\\n")
+
+ # Signal threads to start waiting
+ ready_event.set()
+
+ # Give threads time to start sleeping
+ time.sleep(0.1)
+
+ # Now do busy work to hold the GIL
+ main_work()
+ """
+ )
+
+ with os_helper.temp_dir() as work_dir:
+ script_dir = os.path.join(work_dir, "script_pkg")
+ os.mkdir(script_dir)
+
+ # Create a socket server to communicate with the target process
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server_socket.bind(("localhost", port))
+ server_socket.settimeout(SHORT_TIMEOUT)
+ server_socket.listen(1)
+
+ script_name = _make_test_script(script_dir, "script", script)
+ client_socket = None
+ try:
+ p = subprocess.Popen([sys.executable, script_name])
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+
+ # Wait for ready signal
+ response = b""
+ while b"ready" not in response:
+ response += client_socket.recv(1024)
+
+ # Wait for the main thread to start its busy work
+ while b"working" not in response:
+ response += client_socket.recv(1024)
+
+ # Get stack trace with all threads
+ unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
+ all_traces = unwinder_all.get_stack_trace()
+
+ # Get stack trace with only GIL holder
+ unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True)
+ gil_traces = unwinder_gil.get_stack_trace()
+
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ finally:
+ if client_socket is not None:
+ client_socket.close()
+ p.kill()
+ p.terminate()
+ p.wait(timeout=SHORT_TIMEOUT)
+
+ # Verify we got multiple threads in all_traces
+ self.assertGreater(len(all_traces), 1, "Should have multiple threads")
+
+ # Verify we got exactly one thread in gil_traces
+ self.assertEqual(len(gil_traces), 1, "Should have exactly one GIL holder")
+
+ # The GIL holder should be in the all_traces list
+ gil_thread_id = gil_traces[0][0]
+ all_thread_ids = [trace[0] for trace in all_traces]
+ self.assertIn(gil_thread_id, all_thread_ids,
+ "GIL holder should be among all threads")
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py
index 371c63adce9..2fb963f52e5 100644
--- a/Lib/test/test_faulthandler.py
+++ b/Lib/test/test_faulthandler.py
@@ -166,29 +166,6 @@ class FaultHandlerTests(unittest.TestCase):
fatal_error = 'Windows fatal exception: %s' % name_regex
self.check_error(code, line_number, fatal_error, **kw)
- @unittest.skipIf(sys.platform.startswith('aix'),
- "the first page of memory is a mapped read-only on AIX")
- def test_read_null(self):
- if not MS_WINDOWS:
- self.check_fatal_error("""
- import faulthandler
- faulthandler.enable()
- faulthandler._read_null()
- """,
- 3,
- # Issue #12700: Read NULL raises SIGILL on Mac OS X Lion
- '(?:Segmentation fault'
- '|Bus error'
- '|Illegal instruction)')
- else:
- self.check_windows_exception("""
- import faulthandler
- faulthandler.enable()
- faulthandler._read_null()
- """,
- 3,
- 'access violation')
-
@skip_segfault_on_android
def test_sigsegv(self):
self.check_fatal_error("""
diff --git a/Lib/test/test_fcntl.py b/Lib/test/test_fcntl.py
index b84c98ef3a2..7140a7b4f29 100644
--- a/Lib/test/test_fcntl.py
+++ b/Lib/test/test_fcntl.py
@@ -11,7 +11,7 @@ from test.support import (
cpython_only, get_pagesize, is_apple, requires_subprocess, verbose
)
from test.support.import_helper import import_module
-from test.support.os_helper import TESTFN, unlink
+from test.support.os_helper import TESTFN, unlink, make_bad_fd
# Skip test if no fcntl module.
@@ -228,6 +228,63 @@ class TestFcntl(unittest.TestCase):
os.close(test_pipe_r)
os.close(test_pipe_w)
+ def _check_fcntl_not_mutate_len(self, nbytes=None):
+ self.f = open(TESTFN, 'wb')
+ buf = struct.pack('ii', fcntl.F_OWNER_PID, os.getpid())
+ if nbytes is not None:
+ buf += b' ' * (nbytes - len(buf))
+ else:
+ nbytes = len(buf)
+ save_buf = bytes(buf)
+ r = fcntl.fcntl(self.f, fcntl.F_SETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(buf, save_buf)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ buf = b' ' * nbytes
+ r = fcntl.fcntl(self.f, fcntl.F_GETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(buf, b' ' * nbytes)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ buf = memoryview(b' ' * nbytes)
+ r = fcntl.fcntl(self.f, fcntl.F_GETOWN_EX, buf)
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(len(r), len(save_buf))
+ self.assertEqual(bytes(buf), b' ' * nbytes)
+ type, pid = memoryview(r).cast('i')[:2]
+ self.assertEqual(type, fcntl.F_OWNER_PID)
+ self.assertEqual(pid, os.getpid())
+
+ @unittest.skipUnless(
+ hasattr(fcntl, "F_SETOWN_EX") and hasattr(fcntl, "F_GETOWN_EX"),
+ "requires F_SETOWN_EX and F_GETOWN_EX")
+ def test_fcntl_small_buffer(self):
+ self._check_fcntl_not_mutate_len()
+
+ @unittest.skipUnless(
+ hasattr(fcntl, "F_SETOWN_EX") and hasattr(fcntl, "F_GETOWN_EX"),
+ "requires F_SETOWN_EX and F_GETOWN_EX")
+ def test_fcntl_large_buffer(self):
+ self._check_fcntl_not_mutate_len(2024)
+
+ @unittest.skipUnless(hasattr(fcntl, 'F_DUPFD'), 'need fcntl.F_DUPFD')
+ def test_bad_fd(self):
+ # gh-134744: Test error handling
+ fd = make_bad_fd()
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, 0)
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, b'\0' * 10)
+ with self.assertRaises(OSError):
+ fcntl.fcntl(fd, fcntl.F_DUPFD, b'\0' * 2048)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py
index b340ef7ed16..6524baabe7f 100644
--- a/Lib/test/test_fileinput.py
+++ b/Lib/test/test_fileinput.py
@@ -245,7 +245,7 @@ class FileInputTests(BaseTests, unittest.TestCase):
orig_stdin = sys.stdin
try:
sys.stdin = BytesIO(b'spam, bacon, sausage, and spam')
- self.assertFalse(hasattr(sys.stdin, 'buffer'))
+ self.assertNotHasAttr(sys.stdin, 'buffer')
fi = FileInput(files=['-'], mode='rb')
lines = list(fi)
self.assertEqual(lines, [b'spam, bacon, sausage, and spam'])
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
index 5a0f033ebb8..e3d54f6315a 100644
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -591,7 +591,7 @@ class OtherFileTests:
try:
f.write(b"abc")
f.close()
- with open(TESTFN_ASCII, "rb") as f:
+ with self.open(TESTFN_ASCII, "rb") as f:
self.assertEqual(f.read(), b"abc")
finally:
os.unlink(TESTFN_ASCII)
@@ -608,7 +608,7 @@ class OtherFileTests:
try:
f.write(b"abc")
f.close()
- with open(TESTFN_UNICODE, "rb") as f:
+ with self.open(TESTFN_UNICODE, "rb") as f:
self.assertEqual(f.read(), b"abc")
finally:
os.unlink(TESTFN_UNICODE)
@@ -692,13 +692,13 @@ class OtherFileTests:
def testAppend(self):
try:
- f = open(TESTFN, 'wb')
+ f = self.FileIO(TESTFN, 'wb')
f.write(b'spam')
f.close()
- f = open(TESTFN, 'ab')
+ f = self.FileIO(TESTFN, 'ab')
f.write(b'eggs')
f.close()
- f = open(TESTFN, 'rb')
+ f = self.FileIO(TESTFN, 'rb')
d = f.read()
f.close()
self.assertEqual(d, b'spameggs')
@@ -734,6 +734,7 @@ class OtherFileTests:
class COtherFileTests(OtherFileTests, unittest.TestCase):
FileIO = _io.FileIO
modulename = '_io'
+ open = _io.open
@cpython_only
def testInvalidFd_overflow(self):
@@ -755,6 +756,7 @@ class COtherFileTests(OtherFileTests, unittest.TestCase):
class PyOtherFileTests(OtherFileTests, unittest.TestCase):
FileIO = _pyio.FileIO
modulename = '_pyio'
+ open = _pyio.open
def test_open_code(self):
# Check that the default behaviour of open_code matches
diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py
index 237d7b5d35e..00518abcb11 100644
--- a/Lib/test/test_float.py
+++ b/Lib/test/test_float.py
@@ -795,6 +795,8 @@ class FormatTestCase(unittest.TestCase):
self.assertRaises(ValueError, format, x, '.6,n')
@support.requires_IEEE_754
+ @unittest.skipUnless(sys.float_repr_style == 'short',
+ "applies only when using short float repr style")
def test_format_testfile(self):
with open(format_testfile, encoding="utf-8") as testfile:
for line in testfile:
diff --git a/Lib/test/test_fnmatch.py b/Lib/test/test_fnmatch.py
index d4163cfe782..5daaf3b3fdd 100644
--- a/Lib/test/test_fnmatch.py
+++ b/Lib/test/test_fnmatch.py
@@ -218,24 +218,24 @@ class TranslateTestCase(unittest.TestCase):
def test_translate(self):
import re
- self.assertEqual(translate('*'), r'(?s:.*)\Z')
- self.assertEqual(translate('?'), r'(?s:.)\Z')
- self.assertEqual(translate('a?b*'), r'(?s:a.b.*)\Z')
- self.assertEqual(translate('[abc]'), r'(?s:[abc])\Z')
- self.assertEqual(translate('[]]'), r'(?s:[]])\Z')
- self.assertEqual(translate('[!x]'), r'(?s:[^x])\Z')
- self.assertEqual(translate('[^x]'), r'(?s:[\^x])\Z')
- self.assertEqual(translate('[x'), r'(?s:\[x)\Z')
+ self.assertEqual(translate('*'), r'(?s:.*)\z')
+ self.assertEqual(translate('?'), r'(?s:.)\z')
+ self.assertEqual(translate('a?b*'), r'(?s:a.b.*)\z')
+ self.assertEqual(translate('[abc]'), r'(?s:[abc])\z')
+ self.assertEqual(translate('[]]'), r'(?s:[]])\z')
+ self.assertEqual(translate('[!x]'), r'(?s:[^x])\z')
+ self.assertEqual(translate('[^x]'), r'(?s:[\^x])\z')
+ self.assertEqual(translate('[x'), r'(?s:\[x)\z')
# from the docs
- self.assertEqual(translate('*.txt'), r'(?s:.*\.txt)\Z')
+ self.assertEqual(translate('*.txt'), r'(?s:.*\.txt)\z')
# squash consecutive stars
- self.assertEqual(translate('*********'), r'(?s:.*)\Z')
- self.assertEqual(translate('A*********'), r'(?s:A.*)\Z')
- self.assertEqual(translate('*********A'), r'(?s:.*A)\Z')
- self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\Z')
+ self.assertEqual(translate('*********'), r'(?s:.*)\z')
+ self.assertEqual(translate('A*********'), r'(?s:A.*)\z')
+ self.assertEqual(translate('*********A'), r'(?s:.*A)\z')
+ self.assertEqual(translate('A*********?[?]?'), r'(?s:A.*.[?].)\z')
# fancy translation to prevent exponential-time match failure
t = translate('**a*a****a')
- self.assertEqual(t, r'(?s:(?>.*?a)(?>.*?a).*a)\Z')
+ self.assertEqual(t, r'(?s:(?>.*?a)(?>.*?a).*a)\z')
# and try pasting multiple translate results - it's an undocumented
# feature that this works
r1 = translate('**a**a**a*')
@@ -249,27 +249,27 @@ class TranslateTestCase(unittest.TestCase):
def test_translate_wildcards(self):
for pattern, expect in [
- ('ab*', r'(?s:ab.*)\Z'),
- ('ab*cd', r'(?s:ab.*cd)\Z'),
- ('ab*cd*', r'(?s:ab(?>.*?cd).*)\Z'),
- ('ab*cd*12', r'(?s:ab(?>.*?cd).*12)\Z'),
- ('ab*cd*12*', r'(?s:ab(?>.*?cd)(?>.*?12).*)\Z'),
- ('ab*cd*12*34', r'(?s:ab(?>.*?cd)(?>.*?12).*34)\Z'),
- ('ab*cd*12*34*', r'(?s:ab(?>.*?cd)(?>.*?12)(?>.*?34).*)\Z'),
+ ('ab*', r'(?s:ab.*)\z'),
+ ('ab*cd', r'(?s:ab.*cd)\z'),
+ ('ab*cd*', r'(?s:ab(?>.*?cd).*)\z'),
+ ('ab*cd*12', r'(?s:ab(?>.*?cd).*12)\z'),
+ ('ab*cd*12*', r'(?s:ab(?>.*?cd)(?>.*?12).*)\z'),
+ ('ab*cd*12*34', r'(?s:ab(?>.*?cd)(?>.*?12).*34)\z'),
+ ('ab*cd*12*34*', r'(?s:ab(?>.*?cd)(?>.*?12)(?>.*?34).*)\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
self.assertEqual(translated, expect, pattern)
for pattern, expect in [
- ('*ab', r'(?s:.*ab)\Z'),
- ('*ab*', r'(?s:(?>.*?ab).*)\Z'),
- ('*ab*cd', r'(?s:(?>.*?ab).*cd)\Z'),
- ('*ab*cd*', r'(?s:(?>.*?ab)(?>.*?cd).*)\Z'),
- ('*ab*cd*12', r'(?s:(?>.*?ab)(?>.*?cd).*12)\Z'),
- ('*ab*cd*12*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*)\Z'),
- ('*ab*cd*12*34', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*34)\Z'),
- ('*ab*cd*12*34*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12)(?>.*?34).*)\Z'),
+ ('*ab', r'(?s:.*ab)\z'),
+ ('*ab*', r'(?s:(?>.*?ab).*)\z'),
+ ('*ab*cd', r'(?s:(?>.*?ab).*cd)\z'),
+ ('*ab*cd*', r'(?s:(?>.*?ab)(?>.*?cd).*)\z'),
+ ('*ab*cd*12', r'(?s:(?>.*?ab)(?>.*?cd).*12)\z'),
+ ('*ab*cd*12*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*)\z'),
+ ('*ab*cd*12*34', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12).*34)\z'),
+ ('*ab*cd*12*34*', r'(?s:(?>.*?ab)(?>.*?cd)(?>.*?12)(?>.*?34).*)\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
@@ -277,28 +277,28 @@ class TranslateTestCase(unittest.TestCase):
def test_translate_expressions(self):
for pattern, expect in [
- ('[', r'(?s:\[)\Z'),
- ('[!', r'(?s:\[!)\Z'),
- ('[]', r'(?s:\[\])\Z'),
- ('[abc', r'(?s:\[abc)\Z'),
- ('[!abc', r'(?s:\[!abc)\Z'),
- ('[abc]', r'(?s:[abc])\Z'),
- ('[!abc]', r'(?s:[^abc])\Z'),
- ('[!abc][!def]', r'(?s:[^abc][^def])\Z'),
+ ('[', r'(?s:\[)\z'),
+ ('[!', r'(?s:\[!)\z'),
+ ('[]', r'(?s:\[\])\z'),
+ ('[abc', r'(?s:\[abc)\z'),
+ ('[!abc', r'(?s:\[!abc)\z'),
+ ('[abc]', r'(?s:[abc])\z'),
+ ('[!abc]', r'(?s:[^abc])\z'),
+ ('[!abc][!def]', r'(?s:[^abc][^def])\z'),
# with [[
- ('[[', r'(?s:\[\[)\Z'),
- ('[[a', r'(?s:\[\[a)\Z'),
- ('[[]', r'(?s:[\[])\Z'),
- ('[[]a', r'(?s:[\[]a)\Z'),
- ('[[]]', r'(?s:[\[]\])\Z'),
- ('[[]a]', r'(?s:[\[]a\])\Z'),
- ('[[a]', r'(?s:[\[a])\Z'),
- ('[[a]]', r'(?s:[\[a]\])\Z'),
- ('[[a]b', r'(?s:[\[a]b)\Z'),
+ ('[[', r'(?s:\[\[)\z'),
+ ('[[a', r'(?s:\[\[a)\z'),
+ ('[[]', r'(?s:[\[])\z'),
+ ('[[]a', r'(?s:[\[]a)\z'),
+ ('[[]]', r'(?s:[\[]\])\z'),
+ ('[[]a]', r'(?s:[\[]a\])\z'),
+ ('[[a]', r'(?s:[\[a])\z'),
+ ('[[a]]', r'(?s:[\[a]\])\z'),
+ ('[[a]b', r'(?s:[\[a]b)\z'),
# backslashes
- ('[\\', r'(?s:\[\\)\Z'),
- (r'[\]', r'(?s:[\\])\Z'),
- (r'[\\]', r'(?s:[\\\\])\Z'),
+ ('[\\', r'(?s:\[\\)\z'),
+ (r'[\]', r'(?s:[\\])\z'),
+ (r'[\\]', r'(?s:[\\\\])\z'),
]:
with self.subTest(pattern):
translated = translate(pattern)
diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py
index c7cc32e0949..1f626d87fa6 100644
--- a/Lib/test/test_format.py
+++ b/Lib/test/test_format.py
@@ -346,12 +346,12 @@ class FormatTest(unittest.TestCase):
testcommon(b"%s", memoryview(b"abc"), b"abc")
# %a will give the equivalent of
# repr(some_obj).encode('ascii', 'backslashreplace')
- testcommon(b"%a", 3.14, b"3.14")
+ testcommon(b"%a", 3.25, b"3.25")
testcommon(b"%a", b"ghi", b"b'ghi'")
testcommon(b"%a", "jkl", b"'jkl'")
testcommon(b"%a", "\u0544", b"'\\u0544'")
# %r is an alias for %a
- testcommon(b"%r", 3.14, b"3.14")
+ testcommon(b"%r", 3.25, b"3.25")
testcommon(b"%r", b"ghi", b"b'ghi'")
testcommon(b"%r", "jkl", b"'jkl'")
testcommon(b"%r", "\u0544", b"'\\u0544'")
@@ -407,19 +407,19 @@ class FormatTest(unittest.TestCase):
self.assertEqual(format("abc", "\u2007<5"), "abc\u2007\u2007")
self.assertEqual(format(123, "\u2007<5"), "123\u2007\u2007")
- self.assertEqual(format(12.3, "\u2007<6"), "12.3\u2007\u2007")
+ self.assertEqual(format(12.5, "\u2007<6"), "12.5\u2007\u2007")
self.assertEqual(format(0j, "\u2007<4"), "0j\u2007\u2007")
self.assertEqual(format(1+2j, "\u2007<8"), "(1+2j)\u2007\u2007")
self.assertEqual(format("abc", "\u2007>5"), "\u2007\u2007abc")
self.assertEqual(format(123, "\u2007>5"), "\u2007\u2007123")
- self.assertEqual(format(12.3, "\u2007>6"), "\u2007\u200712.3")
+ self.assertEqual(format(12.5, "\u2007>6"), "\u2007\u200712.5")
self.assertEqual(format(1+2j, "\u2007>8"), "\u2007\u2007(1+2j)")
self.assertEqual(format(0j, "\u2007>4"), "\u2007\u20070j")
self.assertEqual(format("abc", "\u2007^5"), "\u2007abc\u2007")
self.assertEqual(format(123, "\u2007^5"), "\u2007123\u2007")
- self.assertEqual(format(12.3, "\u2007^6"), "\u200712.3\u2007")
+ self.assertEqual(format(12.5, "\u2007^6"), "\u200712.5\u2007")
self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007")
self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007")
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index 84faa636064..d1d2739856c 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -1,7 +1,7 @@
"""Tests for Lib/fractions.py."""
from decimal import Decimal
-from test.support import requires_IEEE_754
+from test.support import requires_IEEE_754, adjust_int_max_str_digits
import math
import numbers
import operator
@@ -395,12 +395,14 @@ class FractionTest(unittest.TestCase):
def testFromString(self):
self.assertEqual((5, 1), _components(F("5")))
+ self.assertEqual((5, 1), _components(F("005")))
self.assertEqual((3, 2), _components(F("3/2")))
self.assertEqual((3, 2), _components(F("3 / 2")))
self.assertEqual((3, 2), _components(F(" \n +3/2")))
self.assertEqual((-3, 2), _components(F("-3/2 ")))
- self.assertEqual((13, 2), _components(F(" 013/02 \n ")))
+ self.assertEqual((13, 2), _components(F(" 0013/002 \n ")))
self.assertEqual((16, 5), _components(F(" 3.2 ")))
+ self.assertEqual((16, 5), _components(F("003.2")))
self.assertEqual((-16, 5), _components(F(" -3.2 ")))
self.assertEqual((-3, 1), _components(F(" -3. ")))
self.assertEqual((3, 5), _components(F(" .6 ")))
@@ -419,116 +421,102 @@ class FractionTest(unittest.TestCase):
self.assertRaisesMessage(
ZeroDivisionError, "Fraction(3, 0)",
F, "3/0")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '3/'",
- F, "3/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '/2'",
- F, "/2")
- self.assertRaisesMessage(
- # Denominators don't need a sign.
- ValueError, "Invalid literal for Fraction: '3/+2'",
- F, "3/+2")
- self.assertRaisesMessage(
- # Imitate float's parsing.
- ValueError, "Invalid literal for Fraction: '+ 3/2'",
- F, "+ 3/2")
- self.assertRaisesMessage(
- # Avoid treating '.' as a regex special character.
- ValueError, "Invalid literal for Fraction: '3a2'",
- F, "3a2")
- self.assertRaisesMessage(
- # Don't accept combinations of decimals and rationals.
- ValueError, "Invalid literal for Fraction: '3/7.2'",
- F, "3/7.2")
- self.assertRaisesMessage(
- # Don't accept combinations of decimals and rationals.
- ValueError, "Invalid literal for Fraction: '3.2/7'",
- F, "3.2/7")
- self.assertRaisesMessage(
- # Allow 3. and .3, but not .
- ValueError, "Invalid literal for Fraction: '.'",
- F, ".")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_'",
- F, "_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_1'",
- F, "_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1__2'",
- F, "1__2")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '/_'",
- F, "/_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1_/'",
- F, "1_/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '_1/'",
- F, "_1/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1__2/'",
- F, "1__2/")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/_'",
- F, "1/_")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/_1'",
- F, "1/_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/1__2'",
- F, "1/1__2")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1._111'",
- F, "1._111")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1__1'",
- F, "1.1__1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1e+_1'",
- F, "1.1e+_1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1e+1__1'",
- F, "1.1e+1__1")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '123.dd'",
- F, "123.dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '123.5_dd'",
- F, "123.5_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: 'dd.5'",
- F, "dd.5")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '7_dd'",
- F, "7_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/dd'",
- F, "1/dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/123_dd'",
- F, "1/123_dd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '789edd'",
- F, "789edd")
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '789e2_dd'",
- F, "789e2_dd")
+
+ def check_invalid(s):
+ msg = "Invalid literal for Fraction: " + repr(s)
+ self.assertRaisesMessage(ValueError, msg, F, s)
+
+ check_invalid("3/")
+ check_invalid("/2")
+ # Denominators don't need a sign.
+ check_invalid("3/+2")
+ check_invalid("3/-2")
+ # Imitate float's parsing.
+ check_invalid("+ 3/2")
+ check_invalid("- 3/2")
+ # Avoid treating '.' as a regex special character.
+ check_invalid("3a2")
+ # Don't accept combinations of decimals and rationals.
+ check_invalid("3/7.2")
+ check_invalid("3.2/7")
+ # No space around dot.
+ check_invalid("3 .2")
+ check_invalid("3. 2")
+ # No space around e.
+ check_invalid("3.2 e1")
+ check_invalid("3.2e 1")
+ # Fractional part don't need a sign.
+ check_invalid("3.+2")
+ check_invalid("3.-2")
+ # Only accept base 10.
+ check_invalid("0x10")
+ check_invalid("0x10/1")
+ check_invalid("1/0x10")
+ check_invalid("0x10.")
+ check_invalid("0x10.1")
+ check_invalid("1.0x10")
+ check_invalid("1.0e0x10")
+ # Only accept decimal digits.
+ check_invalid("³")
+ check_invalid("³/2")
+ check_invalid("3/²")
+ check_invalid("³.2")
+ check_invalid("3.²")
+ check_invalid("3.2e²")
+ check_invalid("¼")
+ # Allow 3. and .3, but not .
+ check_invalid(".")
+ check_invalid("_")
+ check_invalid("_1")
+ check_invalid("1__2")
+ check_invalid("/_")
+ check_invalid("1_/")
+ check_invalid("_1/")
+ check_invalid("1__2/")
+ check_invalid("1/_")
+ check_invalid("1/_1")
+ check_invalid("1/1__2")
+ check_invalid("1._111")
+ check_invalid("1.1__1")
+ check_invalid("1.1e+_1")
+ check_invalid("1.1e+1__1")
+ check_invalid("123.dd")
+ check_invalid("123.5_dd")
+ check_invalid("dd.5")
+ check_invalid("7_dd")
+ check_invalid("1/dd")
+ check_invalid("1/123_dd")
+ check_invalid("789edd")
+ check_invalid("789e2_dd")
# Test catastrophic backtracking.
val = "9"*50 + "_"
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '" + val + "'",
- F, val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1/" + val + "'",
- F, "1/" + val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1." + val + "'",
- F, "1." + val)
- self.assertRaisesMessage(
- ValueError, "Invalid literal for Fraction: '1.1+e" + val + "'",
- F, "1.1+e" + val)
+ check_invalid(val)
+ check_invalid("1/" + val)
+ check_invalid("1." + val)
+ check_invalid("." + val)
+ check_invalid("1.1+e" + val)
+ check_invalid("1.1e" + val)
+
+ def test_limit_int(self):
+ maxdigits = 5000
+ with adjust_int_max_str_digits(maxdigits):
+ msg = 'Exceeds the limit'
+ val = '1' * maxdigits
+ num = (10**maxdigits - 1)//9
+ self.assertEqual((num, 1), _components(F(val)))
+ self.assertRaisesRegex(ValueError, msg, F, val + '1')
+ self.assertEqual((num, 2), _components(F(val + '/2')))
+ self.assertRaisesRegex(ValueError, msg, F, val + '1/2')
+ self.assertEqual((1, num), _components(F('1/' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '1/1' + val)
+ self.assertEqual(((10**(maxdigits+1) - 1)//9, 10**maxdigits),
+ _components(F('1.' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '1.1' + val)
+ self.assertEqual((num, 10**maxdigits), _components(F('.' + val)))
+ self.assertRaisesRegex(ValueError, msg, F, '.1' + val)
+ self.assertRaisesRegex(ValueError, msg, F, '1.1e1' + val)
+ self.assertEqual((11, 10), _components(F('1.1e' + '0' * maxdigits)))
+ self.assertRaisesRegex(ValueError, msg, F, '1.1e' + '0' * (maxdigits+1))
def testImmutable(self):
r = F(7, 3)
@@ -1530,6 +1518,8 @@ class FractionTest(unittest.TestCase):
(F(51, 1000), '.1f', '0.1'),
(F(149, 1000), '.1f', '0.1'),
(F(151, 1000), '.1f', '0.2'),
+ (F(22, 7), '.02f', '3.14'), # issue gh-130662
+ (F(22, 7), '005.02f', '03.14'),
]
for fraction, spec, expected in testcases:
with self.subTest(fraction=fraction, spec=spec):
@@ -1628,12 +1618,6 @@ class FractionTest(unittest.TestCase):
'=010%',
'>00.2f',
'>00f',
- # Too many zeros - minimum width should not have leading zeros
- '006f',
- # Leading zeros in precision
- '.010f',
- '.02f',
- '.000f',
# Missing precision
'.e',
'.f',
diff --git a/Lib/test/test_free_threading/test_dict.py b/Lib/test/test_free_threading/test_dict.py
index 476cc3178d8..5d5d4e226ca 100644
--- a/Lib/test/test_free_threading/test_dict.py
+++ b/Lib/test/test_free_threading/test_dict.py
@@ -228,6 +228,22 @@ class TestDict(TestCase):
self.assertEqual(count, 0)
+ def test_racing_object_get_set_dict(self):
+ e = Exception()
+
+ def writer():
+ for i in range(10000):
+ e.__dict__ = {1:2}
+
+ def reader():
+ for i in range(10000):
+ e.__dict__
+
+ t1 = Thread(target=writer)
+ t2 = Thread(target=reader)
+
+ with threading_helper.start_threads([t1, t2]):
+ pass
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_free_threading/test_functools.py b/Lib/test/test_free_threading/test_functools.py
new file mode 100644
index 00000000000..a442fe056ce
--- /dev/null
+++ b/Lib/test/test_free_threading/test_functools.py
@@ -0,0 +1,75 @@
+import random
+import unittest
+
+from functools import lru_cache
+from threading import Barrier, Thread
+
+from test.support import threading_helper
+
+@threading_helper.requires_working_threading()
+class TestLRUCache(unittest.TestCase):
+
+ def _test_concurrent_operations(self, maxsize):
+ num_threads = 10
+ b = Barrier(num_threads)
+ @lru_cache(maxsize=maxsize)
+ def func(arg=0):
+ return object()
+
+
+ def thread_func():
+ b.wait()
+ for i in range(1000):
+ r = random.randint(0, 1000)
+ if i < 800:
+ func(i)
+ elif i < 900:
+ func.cache_info()
+ else:
+ func.cache_clear()
+
+ threads = []
+ for i in range(num_threads):
+ t = Thread(target=thread_func)
+ threads.append(t)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ def test_concurrent_operations_unbounded(self):
+ self._test_concurrent_operations(maxsize=None)
+
+ def test_concurrent_operations_bounded(self):
+ self._test_concurrent_operations(maxsize=128)
+
+ def _test_reentrant_cache_clear(self, maxsize):
+ num_threads = 10
+ b = Barrier(num_threads)
+ @lru_cache(maxsize=maxsize)
+ def func(arg=0):
+ func.cache_clear()
+ return object()
+
+
+ def thread_func():
+ b.wait()
+ for i in range(1000):
+ func(random.randint(0, 10000))
+
+ threads = []
+ for i in range(num_threads):
+ t = Thread(target=thread_func)
+ threads.append(t)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ def test_reentrant_cache_clear_unbounded(self):
+ self._test_reentrant_cache_clear(maxsize=None)
+
+ def test_reentrant_cache_clear_bounded(self):
+ self._test_reentrant_cache_clear(maxsize=128)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_generators.py b/Lib/test/test_free_threading/test_generators.py
new file mode 100644
index 00000000000..d01675eb38b
--- /dev/null
+++ b/Lib/test/test_free_threading/test_generators.py
@@ -0,0 +1,51 @@
+import concurrent.futures
+import unittest
+from threading import Barrier
+from unittest import TestCase
+import random
+import time
+
+from test.support import threading_helper, Py_GIL_DISABLED
+
+threading_helper.requires_working_threading(module=True)
+
+
+def random_sleep():
+ delay_us = random.randint(50, 100)
+ time.sleep(delay_us * 1e-6)
+
+def random_string():
+ return ''.join(random.choice('0123456789ABCDEF') for _ in range(10))
+
+def set_gen_name(g, b):
+ b.wait()
+ random_sleep()
+ g.__name__ = random_string()
+ return g.__name__
+
+def set_gen_qualname(g, b):
+ b.wait()
+ random_sleep()
+ g.__qualname__ = random_string()
+ return g.__qualname__
+
+
+@unittest.skipUnless(Py_GIL_DISABLED, "Enable only in FT build")
+class TestFTGenerators(TestCase):
+ NUM_THREADS = 4
+
+ def concurrent_write_with_func(self, func):
+ gen = (x for x in range(42))
+ for j in range(1000):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.NUM_THREADS) as executor:
+ b = Barrier(self.NUM_THREADS)
+ futures = {executor.submit(func, gen, b): i for i in range(self.NUM_THREADS)}
+ for fut in concurrent.futures.as_completed(futures):
+ gen_name = fut.result()
+ self.assertEqual(len(gen_name), 10)
+
+ def test_concurrent_write(self):
+ with self.subTest(func=set_gen_name):
+ self.concurrent_write_with_func(func=set_gen_name)
+ with self.subTest(func=set_gen_qualname):
+ self.concurrent_write_with_func(func=set_gen_qualname)
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py
new file mode 100644
index 00000000000..ee7adfb2b78
--- /dev/null
+++ b/Lib/test/test_free_threading/test_heapq.py
@@ -0,0 +1,267 @@
+import unittest
+
+import heapq
+
+from enum import Enum
+from threading import Thread, Barrier, Lock
+from random import shuffle, randint
+
+from test.support import threading_helper
+from test import test_heapq
+
+
+NTHREADS = 10
+OBJECT_COUNT = 5_000
+
+
+class Heap(Enum):
+ MIN = 1
+ MAX = 2
+
+
+@threading_helper.requires_working_threading()
+class TestHeapq(unittest.TestCase):
+ def setUp(self):
+ self.test_heapq = test_heapq.TestHeapPython()
+
+ def test_racing_heapify(self):
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappush(self):
+ heap = []
+
+ def heappush_func(heap):
+ for item in reversed(range(OBJECT_COUNT)):
+ heapq.heappush(heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heappop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_func(heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop(heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_ascending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_func,
+ args=(heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), 0)
+
+ def test_racing_heappushpop(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_func(heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop(heap, item)
+ self.assertTrue(popped_item <= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_func,
+ args=(heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapreplace(self):
+ heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_func(heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace(heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_func,
+ args=(heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(heap), OBJECT_COUNT)
+ self.test_heapq.check_invariant(heap)
+
+ def test_racing_heapify_max(self):
+ max_heap = list(range(OBJECT_COUNT))
+ shuffle(max_heap)
+
+ self.run_concurrently(
+ worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappush_max(self):
+ max_heap = []
+
+ def heappush_max_func(max_heap):
+ for item in range(OBJECT_COUNT):
+ heapq.heappush_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
+ )
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heappop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+
+ # Each thread pops (OBJECT_COUNT / NTHREADS) items
+ self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
+ per_thread_pop_count = OBJECT_COUNT // NTHREADS
+
+ def heappop_max_func(max_heap, pop_count):
+ local_list = []
+ for _ in range(pop_count):
+ item = heapq.heappop_max(max_heap)
+ local_list.append(item)
+
+ # Each local list should be sorted
+ self.assertTrue(self.is_sorted_descending(local_list))
+
+ self.run_concurrently(
+ worker_func=heappop_max_func,
+ args=(max_heap, per_thread_pop_count),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), 0)
+
+ def test_racing_heappushpop_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heappushpop_max_func(max_heap, pushpop_items):
+ for item in pushpop_items:
+ popped_item = heapq.heappushpop_max(max_heap, item)
+ self.assertTrue(popped_item >= item)
+
+ self.run_concurrently(
+ worker_func=heappushpop_max_func,
+ args=(max_heap, pushpop_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_racing_heapreplace_max(self):
+ max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
+ replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
+
+ def heapreplace_max_func(max_heap, replace_items):
+ for item in replace_items:
+ heapq.heapreplace_max(max_heap, item)
+
+ self.run_concurrently(
+ worker_func=heapreplace_max_func,
+ args=(max_heap, replace_items),
+ nthreads=NTHREADS,
+ )
+ self.assertEqual(len(max_heap), OBJECT_COUNT)
+ self.test_heapq.check_max_invariant(max_heap)
+
+ def test_lock_free_list_read(self):
+ n, n_threads = 1_000, 10
+ l = []
+ barrier = Barrier(n_threads * 2)
+
+ count = 0
+ lock = Lock()
+
+ def worker():
+ with lock:
+ nonlocal count
+ x = count
+ count += 1
+
+ barrier.wait()
+ for i in range(n):
+ if x % 2:
+ heapq.heappush(l, 1)
+ heapq.heappop(l)
+ else:
+ try:
+ l[0]
+ except IndexError:
+ pass
+
+ self.run_concurrently(worker, (), n_threads * 2)
+
+ @staticmethod
+ def is_sorted_ascending(lst):
+ """
+ Check if the list is sorted in ascending order (non-decreasing).
+ """
+ return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def is_sorted_descending(lst):
+ """
+ Check if the list is sorted in descending order (non-increasing).
+ """
+ return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
+
+ @staticmethod
+ def create_heap(size, heap_kind):
+ """
+ Create a min/max heap where elements are in the range (0, size - 1) and
+ shuffled before heapify.
+ """
+ heap = list(range(OBJECT_COUNT))
+ shuffle(heap)
+ if heap_kind == Heap.MIN:
+ heapq.heapify(heap)
+ else:
+ heapq.heapify_max(heap)
+
+ return heap
+
+ @staticmethod
+ def create_random_list(a, b, size):
+ """
+ Create a list of random numbers between a and b (inclusive).
+ """
+ return [randint(-a, b) for _ in range(size)]
+
+ def run_concurrently(self, worker_func, args, nthreads):
+ """
+ Run the worker function concurrently in multiple threads.
+ """
+ barrier = Barrier(nthreads)
+
+ def wrapper_func(*args):
+ # Wait for all threads to reach this point before proceeding.
+ barrier.wait()
+ worker_func(*args)
+
+ with threading_helper.catch_threading_exception() as cm:
+ workers = (
+ Thread(target=wrapper_func, args=args) for _ in range(nthreads)
+ )
+ with threading_helper.start_threads(workers):
+ pass
+
+ # Worker threads should not raise any exceptions
+ self.assertIsNone(cm.exc_value)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_io.py b/Lib/test/test_free_threading/test_io.py
new file mode 100644
index 00000000000..f9bec740ddf
--- /dev/null
+++ b/Lib/test/test_free_threading/test_io.py
@@ -0,0 +1,109 @@
+import threading
+from unittest import TestCase
+from test.support import threading_helper
+from random import randint
+from io import BytesIO
+from sys import getsizeof
+
+
+class TestBytesIO(TestCase):
+ # Test pretty much everything that can break under free-threading.
+ # Non-deterministic, but at least one of these things will fail if
+ # BytesIO object is not free-thread safe.
+
+ def check(self, funcs, *args):
+ barrier = threading.Barrier(len(funcs))
+ threads = []
+
+ for func in funcs:
+ thread = threading.Thread(target=func, args=(barrier, *args))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ @threading_helper.requires_working_threading()
+ @threading_helper.reap_threads
+ def test_free_threading(self):
+ """Test for segfaults and aborts."""
+
+ def write(barrier, b, *ignore):
+ barrier.wait()
+ try: b.write(b'0' * randint(100, 1000))
+ except ValueError: pass # ignore write fail to closed file
+
+ def writelines(barrier, b, *ignore):
+ barrier.wait()
+ b.write(b'0\n' * randint(100, 1000))
+
+ def truncate(barrier, b, *ignore):
+ barrier.wait()
+ try: b.truncate(0)
+ except: BufferError # ignore exported buffer
+
+ def read(barrier, b, *ignore):
+ barrier.wait()
+ b.read()
+
+ def read1(barrier, b, *ignore):
+ barrier.wait()
+ b.read1()
+
+ def readline(barrier, b, *ignore):
+ barrier.wait()
+ b.readline()
+
+ def readlines(barrier, b, *ignore):
+ barrier.wait()
+ b.readlines()
+
+ def readinto(barrier, b, into, *ignore):
+ barrier.wait()
+ b.readinto(into)
+
+ def close(barrier, b, *ignore):
+ barrier.wait()
+ b.close()
+
+ def getvalue(barrier, b, *ignore):
+ barrier.wait()
+ b.getvalue()
+
+ def getbuffer(barrier, b, *ignore):
+ barrier.wait()
+ b.getbuffer()
+
+ def iter(barrier, b, *ignore):
+ barrier.wait()
+ list(b)
+
+ def getstate(barrier, b, *ignore):
+ barrier.wait()
+ b.__getstate__()
+
+ def setstate(barrier, b, st, *ignore):
+ barrier.wait()
+ b.__setstate__(st)
+
+ def sizeof(barrier, b, *ignore):
+ barrier.wait()
+ getsizeof(b)
+
+ self.check([write] * 10, BytesIO())
+ self.check([writelines] * 10, BytesIO())
+ self.check([write] * 10 + [truncate] * 10, BytesIO())
+ self.check([truncate] + [read] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [read1] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [readline] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [readlines] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [readinto] * 10, BytesIO(b'0\n'*204800), bytearray(b'0\n'*204800))
+ self.check([close] + [write] * 10, BytesIO())
+ self.check([truncate] + [getvalue] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [getbuffer] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [iter] * 10, BytesIO(b'0\n'*20480))
+ self.check([truncate] + [getstate] * 10, BytesIO(b'0\n'*204800))
+ self.check([truncate] + [setstate] * 10, BytesIO(b'0\n'*204800), (b'123', 0, None))
+ self.check([truncate] + [sizeof] * 10, BytesIO(b'0\n'*204800))
+
+ # no tests for seek or tell because they don't break anything
diff --git a/Lib/test/test_free_threading/test_itertools.py b/Lib/test/test_free_threading/test_itertools.py
new file mode 100644
index 00000000000..9d366041917
--- /dev/null
+++ b/Lib/test/test_free_threading/test_itertools.py
@@ -0,0 +1,95 @@
+import unittest
+from threading import Thread, Barrier
+from itertools import batched, chain, cycle
+from test.support import threading_helper
+
+
+threading_helper.requires_working_threading(module=True)
+
+class ItertoolsThreading(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ def test_batched(self):
+ number_of_threads = 10
+ number_of_iterations = 20
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ while True:
+ try:
+ next(it)
+ except StopIteration:
+ break
+
+ data = tuple(range(1000))
+ for it in range(number_of_iterations):
+ batch_iterator = batched(data, 2)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[batch_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+ @threading_helper.reap_threads
+ def test_cycle(self):
+ number_of_threads = 6
+ number_of_iterations = 10
+ number_of_cycles = 400
+
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ for _ in range(number_of_cycles):
+ try:
+ next(it)
+ except StopIteration:
+ pass
+
+ data = (1, 2, 3, 4)
+ for it in range(number_of_iterations):
+ cycle_iterator = cycle(data)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[cycle_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+ @threading_helper.reap_threads
+ def test_chain(self):
+ number_of_threads = 6
+ number_of_iterations = 20
+
+ barrier = Barrier(number_of_threads)
+ def work(it):
+ barrier.wait()
+ while True:
+ try:
+ next(it)
+ except StopIteration:
+ break
+
+ data = [(1, )] * 200
+ for it in range(number_of_iterations):
+ chain_iterator = chain(*data)
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=work, args=[chain_iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_free_threading/test_itertools_batched.py b/Lib/test/test_free_threading/test_itertools_batched.py
deleted file mode 100644
index a754b4f9ea9..00000000000
--- a/Lib/test/test_free_threading/test_itertools_batched.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import unittest
-from threading import Thread, Barrier
-from itertools import batched
-from test.support import threading_helper
-
-
-threading_helper.requires_working_threading(module=True)
-
-class EnumerateThreading(unittest.TestCase):
-
- @threading_helper.reap_threads
- def test_threading(self):
- number_of_threads = 10
- number_of_iterations = 20
- barrier = Barrier(number_of_threads)
- def work(it):
- barrier.wait()
- while True:
- try:
- _ = next(it)
- except StopIteration:
- break
-
- data = tuple(range(1000))
- for it in range(number_of_iterations):
- batch_iterator = batched(data, 2)
- worker_threads = []
- for ii in range(number_of_threads):
- worker_threads.append(
- Thread(target=work, args=[batch_iterator]))
-
- with threading_helper.start_threads(worker_threads):
- pass
-
- barrier.reset()
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/Lib/test/test_free_threading/test_itertools_combinatoric.py b/Lib/test/test_free_threading/test_itertools_combinatoric.py
new file mode 100644
index 00000000000..5b3b88deedd
--- /dev/null
+++ b/Lib/test/test_free_threading/test_itertools_combinatoric.py
@@ -0,0 +1,51 @@
+import unittest
+from threading import Thread, Barrier
+from itertools import combinations, product
+from test.support import threading_helper
+
+
+threading_helper.requires_working_threading(module=True)
+
+def test_concurrent_iteration(iterator, number_of_threads):
+ barrier = Barrier(number_of_threads)
+ def iterator_worker(it):
+ barrier.wait()
+ while True:
+ try:
+ _ = next(it)
+ except StopIteration:
+ return
+
+ worker_threads = []
+ for ii in range(number_of_threads):
+ worker_threads.append(
+ Thread(target=iterator_worker, args=[iterator]))
+
+ with threading_helper.start_threads(worker_threads):
+ pass
+
+ barrier.reset()
+
+class ItertoolsThreading(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ def test_combinations(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = combinations((1, 2, 3, 4, 5), 2)
+ test_concurrent_iteration(iterator, number_of_threads)
+
+ @threading_helper.reap_threads
+ def test_product(self):
+ number_of_threads = 10
+ number_of_iterations = 24
+
+ for it in range(number_of_iterations):
+ iterator = product((1, 2, 3, 4, 5), (10, 20, 30))
+ test_concurrent_iteration(iterator, number_of_threads)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_fstring.py b/Lib/test/test_fstring.py
index dd58e032a8b..58a30c8e6ac 100644
--- a/Lib/test/test_fstring.py
+++ b/Lib/test/test_fstring.py
@@ -1336,9 +1336,9 @@ x = (
def test_conversions(self):
self.assertEqual(f'{3.14:10.10}', ' 3.14')
- self.assertEqual(f'{3.14!s:10.10}', '3.14 ')
- self.assertEqual(f'{3.14!r:10.10}', '3.14 ')
- self.assertEqual(f'{3.14!a:10.10}', '3.14 ')
+ self.assertEqual(f'{1.25!s:10.10}', '1.25 ')
+ self.assertEqual(f'{1.25!r:10.10}', '1.25 ')
+ self.assertEqual(f'{1.25!a:10.10}', '1.25 ')
self.assertEqual(f'{"a"}', 'a')
self.assertEqual(f'{"a"!r}', "'a'")
@@ -1347,7 +1347,7 @@ x = (
# Conversions can have trailing whitespace after them since it
# does not provide any significance
self.assertEqual(f"{3!s }", "3")
- self.assertEqual(f'{3.14!s :10.10}', '3.14 ')
+ self.assertEqual(f'{1.25!s :10.10}', '1.25 ')
# Not a conversion.
self.assertEqual(f'{"a!r"}', "a!r")
@@ -1380,7 +1380,7 @@ x = (
for conv in ' s', ' s ':
self.assertAllRaise(SyntaxError,
"f-string: conversion type must come right after the"
- " exclamanation mark",
+ " exclamation mark",
["f'{3!" + conv + "}'"])
self.assertAllRaise(SyntaxError,
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index e3b449f2d24..f7e09fd771e 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -21,8 +21,10 @@ from weakref import proxy
import contextlib
from inspect import Signature
+from test.support import ALWAYS_EQ
from test.support import import_helper
from test.support import threading_helper
+from test.support import cpython_only
from test.support import EqualToForwardRef
import functools
@@ -63,6 +65,14 @@ class BadTuple(tuple):
class MyDict(dict):
pass
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports(
+ "functools", {"os", "weakref", "typing", "annotationlib", "warnings"}
+ )
+
class TestPartial:
@@ -235,6 +245,13 @@ class TestPartial:
actual_args, actual_kwds = p('x', 'y')
self.assertEqual(actual_args, ('x', 0, 'y', 1))
self.assertEqual(actual_kwds, {})
+ # Checks via `is` and not `eq`
+ # thus ALWAYS_EQ isn't treated as Placeholder
+ p = self.partial(capture, ALWAYS_EQ)
+ actual_args, actual_kwds = p()
+ self.assertEqual(len(actual_args), 1)
+ self.assertIs(actual_args[0], ALWAYS_EQ)
+ self.assertEqual(actual_kwds, {})
def test_placeholders_optimization(self):
PH = self.module.Placeholder
@@ -251,6 +268,17 @@ class TestPartial:
self.assertEqual(p2.args, (PH, 0))
self.assertEqual(p2(1), ((1, 0), {}))
+ def test_placeholders_kw_restriction(self):
+ PH = self.module.Placeholder
+ with self.assertRaisesRegex(TypeError, "Placeholder"):
+ self.partial(capture, a=PH)
+ # Passes, as checks via `is` and not `eq`
+ p = self.partial(capture, a=ALWAYS_EQ)
+ actual_args, actual_kwds = p()
+ self.assertEqual(actual_args, ())
+ self.assertEqual(len(actual_kwds), 1)
+ self.assertIs(actual_kwds['a'], ALWAYS_EQ)
+
def test_construct_placeholder_singleton(self):
PH = self.module.Placeholder
tp = type(PH)
diff --git a/Lib/test/test_future_stmt/test_future.py b/Lib/test/test_future_stmt/test_future.py
index 42c6cb3fefa..71f1e616116 100644
--- a/Lib/test/test_future_stmt/test_future.py
+++ b/Lib/test/test_future_stmt/test_future.py
@@ -422,6 +422,11 @@ class AnnotationsFutureTestCase(unittest.TestCase):
eq('(((a)))', 'a')
eq('(((a, b)))', '(a, b)')
eq("1 + 2 + 3")
+ eq("t''")
+ eq("t'{a + b}'")
+ eq("t'{a!s}'")
+ eq("t'{a:b}'")
+ eq("t'{a:b=}'")
def test_fstring_debug_annotations(self):
# f-strings with '=' don't round trip very well, so set the expected
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index b5140057a69..b4cbfb6d774 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -7,7 +7,7 @@ from test.support import (verbose, refcount_test,
Py_GIL_DISABLED)
from test.support.import_helper import import_module
from test.support.os_helper import temp_dir, TESTFN, unlink
-from test.support.script_helper import assert_python_ok, make_script
+from test.support.script_helper import assert_python_ok, make_script, run_test_script
from test.support import threading_helper, gc_threshold
import gc
@@ -300,7 +300,7 @@ class GCTests(unittest.TestCase):
# We're mostly just checking that this doesn't crash.
rc, stdout, stderr = assert_python_ok("-c", code)
self.assertEqual(rc, 0)
- self.assertRegex(stdout, rb"""\A\s*func=<function at \S+>\s*\Z""")
+ self.assertRegex(stdout, rb"""\A\s*func=<function at \S+>\s*\z""")
self.assertFalse(stderr)
@refcount_test
@@ -914,7 +914,7 @@ class GCTests(unittest.TestCase):
gc.collect()
self.assertEqual(len(Lazarus.resurrected_instances), 1)
instance = Lazarus.resurrected_instances.pop()
- self.assertTrue(hasattr(instance, "cargo"))
+ self.assertHasAttr(instance, "cargo")
self.assertEqual(id(instance.cargo), cargo_id)
gc.collect()
@@ -1127,64 +1127,14 @@ class GCTests(unittest.TestCase):
class IncrementalGCTests(unittest.TestCase):
-
- def setUp(self):
- # Reenable GC as it is disabled module-wide
- gc.enable()
-
- def tearDown(self):
- gc.disable()
-
@unittest.skipIf(_testinternalcapi is None, "requires _testinternalcapi")
@requires_gil_enabled("Free threading does not support incremental GC")
- # Use small increments to emulate longer running process in a shorter time
- @gc_threshold(200, 10)
def test_incremental_gc_handles_fast_cycle_creation(self):
-
- class LinkedList:
-
- #Use slots to reduce number of implicit objects
- __slots__ = "next", "prev", "surprise"
-
- def __init__(self, next=None, prev=None):
- self.next = next
- if next is not None:
- next.prev = self
- self.prev = prev
- if prev is not None:
- prev.next = self
-
- def make_ll(depth):
- head = LinkedList()
- for i in range(depth):
- head = LinkedList(head, head.prev)
- return head
-
- head = make_ll(1000)
- count = 1000
-
- # There will be some objects we aren't counting,
- # e.g. the gc stats dicts. This test checks
- # that the counts don't grow, so we try to
- # correct for the uncounted objects
- # This is just an estimate.
- CORRECTION = 20
-
- enabled = gc.isenabled()
- gc.enable()
- olds = []
- initial_heap_size = _testinternalcapi.get_tracked_heap_size()
- for i in range(20_000):
- newhead = make_ll(20)
- count += 20
- newhead.surprise = head
- olds.append(newhead)
- if len(olds) == 20:
- new_objects = _testinternalcapi.get_tracked_heap_size() - initial_heap_size
- self.assertLess(new_objects, 27_000, f"Heap growing. Reached limit after {i} iterations")
- del olds[:]
- if not enabled:
- gc.disable()
+ # Run this test in a fresh process. The number of alive objects (which can
+ # be from unit tests run before this one) can influence how quickly cyclic
+ # garbage is found.
+ script = support.findfile("_test_gc_fast_cycles.py")
+ run_test_script(script)
class GCCallbackTests(unittest.TestCase):
diff --git a/Lib/test/test_generated_cases.py b/Lib/test/test_generated_cases.py
index d481cb07f75..eb01328b6ea 100644
--- a/Lib/test/test_generated_cases.py
+++ b/Lib/test/test_generated_cases.py
@@ -1,11 +1,9 @@
import contextlib
import os
-import re
import sys
import tempfile
import unittest
-from io import StringIO
from test import support
from test import test_tools
@@ -31,12 +29,11 @@ skip_if_different_mount_drives()
test_tools.skip_if_missing("cases_generator")
with test_tools.imports_under_tool("cases_generator"):
- from analyzer import analyze_forest, StackItem
+ from analyzer import StackItem
from cwriter import CWriter
import parser
from stack import Local, Stack
import tier1_generator
- import opcode_metadata_generator
import optimizer_generator
@@ -59,14 +56,14 @@ class TestEffects(unittest.TestCase):
def test_effect_sizes(self):
stack = Stack()
inputs = [
- x := StackItem("x", None, "1"),
- y := StackItem("y", None, "oparg"),
- z := StackItem("z", None, "oparg*2"),
+ x := StackItem("x", "1"),
+ y := StackItem("y", "oparg"),
+ z := StackItem("z", "oparg*2"),
]
outputs = [
- StackItem("x", None, "1"),
- StackItem("b", None, "oparg*4"),
- StackItem("c", None, "1"),
+ StackItem("x", "1"),
+ StackItem("b", "oparg*4"),
+ StackItem("c", "1"),
]
null = CWriter.null()
stack.pop(z, null)
@@ -1106,32 +1103,6 @@ class TestGeneratedCases(unittest.TestCase):
"""
self.run_cases_test(input, output)
- def test_pointer_to_stackref(self):
- input = """
- inst(OP, (arg: _PyStackRef * -- out)) {
- out = *arg;
- DEAD(arg);
- }
- """
- output = """
- TARGET(OP) {
- #if Py_TAIL_CALL_INTERP
- int opcode = OP;
- (void)(opcode);
- #endif
- frame->instr_ptr = next_instr;
- next_instr += 1;
- INSTRUCTION_STATS(OP);
- _PyStackRef *arg;
- _PyStackRef out;
- arg = (_PyStackRef *)stack_pointer[-1].bits;
- out = *arg;
- stack_pointer[-1] = out;
- DISPATCH();
- }
- """
- self.run_cases_test(input, output)
-
def test_unused_cached_value(self):
input = """
op(FIRST, (arg1 -- out)) {
@@ -2005,8 +1976,8 @@ class TestGeneratedAbstractCases(unittest.TestCase):
"""
output = """
case OP: {
- JitOptSymbol *arg1;
- JitOptSymbol *out;
+ JitOptRef arg1;
+ JitOptRef out;
arg1 = stack_pointer[-1];
out = EGGS(arg1);
stack_pointer[-1] = out;
@@ -2014,7 +1985,7 @@ class TestGeneratedAbstractCases(unittest.TestCase):
}
case OP2: {
- JitOptSymbol *out;
+ JitOptRef out;
out = sym_new_not_null(ctx);
stack_pointer[-1] = out;
break;
@@ -2039,14 +2010,14 @@ class TestGeneratedAbstractCases(unittest.TestCase):
"""
output = """
case OP: {
- JitOptSymbol *out;
+ JitOptRef out;
out = sym_new_not_null(ctx);
stack_pointer[-1] = out;
break;
}
case OP2: {
- JitOptSymbol *out;
+ JitOptRef out;
out = NULL;
stack_pointer[-1] = out;
break;
@@ -2069,6 +2040,386 @@ class TestGeneratedAbstractCases(unittest.TestCase):
with self.assertRaisesRegex(AssertionError, "All abstract uops"):
self.run_cases_test(input, input2, output)
+ def test_validate_uop_input_length_mismatch(self):
+ input = """
+ op(OP, (arg1 -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1, arg2 -- out)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Must have the same number of inputs"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_length_mismatch(self):
+ input = """
+ op(OP, (arg1 -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1 -- out1, out2)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Must have the same number of outputs"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_input_name_mismatch(self):
+ input = """
+ op(OP, (foo -- out)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (bar -- out)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal names"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_name_mismatch(self):
+ input = """
+ op(OP, (arg1 -- foo)) {
+ SPAM();
+ }
+ """
+ input2 = """
+ op(OP, (arg1 -- bar)) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Outputs must have equal names"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_input(self):
+ input = """
+ op(OP, (unused -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (foo -- )) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += -1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ input = """
+ op(OP, (foo -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (unused -- )) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += -1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_output(self):
+ input = """
+ op(OP, ( -- unused)) {
+ }
+ """
+ input2 = """
+ op(OP, ( -- foo)) {
+ foo = NULL;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ foo = NULL;
+ stack_pointer[0] = foo;
+ stack_pointer += 1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ input = """
+ op(OP, ( -- foo)) {
+ foo = NULL;
+ }
+ """
+ input2 = """
+ op(OP, ( -- unused)) {
+ }
+ """
+ output = """
+ case OP: {
+ stack_pointer += 1;
+ assert(WITHIN_STACK_BOUNDS());
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_input_size_mismatch(self):
+ input = """
+ op(OP, (arg1[2] -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (arg1[4] -- )) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_output_size_mismatch(self):
+ input = """
+ op(OP, ( -- out[2])) {
+ }
+ """
+ input2 = """
+ op(OP, ( -- out[4])) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Outputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_validate_uop_unused_size_mismatch(self):
+ input = """
+ op(OP, (foo[2] -- )) {
+ }
+ """
+ input2 = """
+ op(OP, (unused[4] -- )) {
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Inputs must have equal sizes"):
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ res = PyStackRef_IsNone(foo);
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_known(ctx, foo);
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ res_stackref = PyStackRef_IsNone(foo);
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = sym_new_known(ctx, foo);
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in_deopt(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ DEOPT_IF(PyStackRef_IsNull(foo));
+ res = foo;
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = foo;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (PyStackRef_IsNull(foo)) {
+ ctx->done = true;
+ break;
+ }
+ res_stackref = foo;
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = foo;
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_pure_uop_body_copied_in_error_if(self):
+ # Note: any non-escaping call works.
+ # In this case, we use PyStackRef_IsNone.
+ input = """
+ pure op(OP, (foo -- res)) {
+ ERROR_IF(PyStackRef_IsNull(foo));
+ res = foo;
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = foo;
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (PyStackRef_IsNull(foo)) {
+ goto error;
+ }
+ res_stackref = foo;
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = foo;
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+
+ def test_replace_opcode_uop_body_copied_in_complex(self):
+ input = """
+ pure op(OP, (foo -- res)) {
+ if (foo) {
+ res = PyStackRef_IsNone(foo);
+ }
+ else {
+ res = 1;
+ }
+ }
+ """
+ input2 = """
+ op(OP, (foo -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_known(ctx, foo);
+ }
+ """
+ output = """
+ case OP: {
+ JitOptRef foo;
+ JitOptRef res;
+ foo = stack_pointer[-1];
+ if (
+ sym_is_safe_const(ctx, foo)
+ ) {
+ JitOptRef foo_sym = foo;
+ _PyStackRef foo = sym_get_const_as_stackref(ctx, foo_sym);
+ _PyStackRef res_stackref;
+ /* Start of uop copied from bytecodes for constant evaluation */
+ if (foo) {
+ res_stackref = PyStackRef_IsNone(foo);
+ }
+ else {
+ res_stackref = 1;
+ }
+ /* End of uop copied from bytecodes for constant evaluation */
+ res = sym_new_const_steal(ctx, PyStackRef_AsPyObjectSteal(res_stackref));
+ stack_pointer[-1] = res;
+ break;
+ }
+ res = sym_new_known(ctx, foo);
+ stack_pointer[-1] = res;
+ break;
+ }
+ """
+ self.run_cases_test(input, input2, output)
+
+ def test_replace_opocode_uop_reject_array_effects(self):
+ input = """
+ pure op(OP, (foo[2] -- res)) {
+ if (foo) {
+ res = PyStackRef_IsNone(foo);
+ }
+ else {
+ res = 1;
+ }
+ }
+ """
+ input2 = """
+ op(OP, (foo[2] -- res)) {
+ REPLACE_OPCODE_IF_EVALUATES_PURE(foo);
+ res = sym_new_unknown(ctx);
+ }
+ """
+ output = """
+ """
+ with self.assertRaisesRegex(SyntaxError,
+ "Pure evaluation cannot take array-like inputs"):
+ self.run_cases_test(input, input2, output)
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_genericalias.py b/Lib/test/test_genericalias.py
index 5c13897b8d9..7601cb00ff6 100644
--- a/Lib/test/test_genericalias.py
+++ b/Lib/test/test_genericalias.py
@@ -61,6 +61,7 @@ try:
from tkinter import Event
except ImportError:
Event = None
+from string.templatelib import Template, Interpolation
from typing import TypeVar
T = TypeVar('T')
@@ -137,7 +138,12 @@ class BaseTest(unittest.TestCase):
Future, _WorkItem,
Morsel,
DictReader, DictWriter,
- array]
+ array,
+ staticmethod,
+ classmethod,
+ Template,
+ Interpolation,
+ ]
if ctypes is not None:
generic_types.extend((ctypes.Array, ctypes.LibraryLoader, ctypes.py_object))
if ValueProxy is not None:
@@ -230,13 +236,13 @@ class BaseTest(unittest.TestCase):
self.assertEqual(repr(x2), 'tuple[*tuple[int, str]]')
x3 = tuple[*tuple[int, ...]]
self.assertEqual(repr(x3), 'tuple[*tuple[int, ...]]')
- self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr.<locals>.MyList[int]'))
+ self.assertEndsWith(repr(MyList[int]), '.BaseTest.test_repr.<locals>.MyList[int]')
self.assertEqual(repr(list[str]()), '[]') # instances should keep their normal repr
# gh-105488
- self.assertTrue(repr(MyGeneric[int]).endswith('MyGeneric[int]'))
- self.assertTrue(repr(MyGeneric[[]]).endswith('MyGeneric[[]]'))
- self.assertTrue(repr(MyGeneric[[int, str]]).endswith('MyGeneric[[int, str]]'))
+ self.assertEndsWith(repr(MyGeneric[int]), 'MyGeneric[int]')
+ self.assertEndsWith(repr(MyGeneric[[]]), 'MyGeneric[[]]')
+ self.assertEndsWith(repr(MyGeneric[[int, str]]), 'MyGeneric[[int, str]]')
def test_exposed_type(self):
import types
@@ -356,7 +362,7 @@ class BaseTest(unittest.TestCase):
def test_issubclass(self):
class L(list): ...
- self.assertTrue(issubclass(L, list))
+ self.assertIsSubclass(L, list)
with self.assertRaises(TypeError):
issubclass(L, list[str])
diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py
index 6c3abe602f5..16c3268fefb 100644
--- a/Lib/test/test_genericpath.py
+++ b/Lib/test/test_genericpath.py
@@ -8,7 +8,7 @@ import sys
import unittest
import warnings
from test.support import (
- is_apple, is_emscripten, os_helper, warnings_helper
+ is_apple, os_helper, warnings_helper
)
from test.support.script_helper import assert_python_ok
from test.support.os_helper import FakePath
@@ -92,8 +92,8 @@ class GenericTest:
for s1 in testlist:
for s2 in testlist:
p = commonprefix([s1, s2])
- self.assertTrue(s1.startswith(p))
- self.assertTrue(s2.startswith(p))
+ self.assertStartsWith(s1, p)
+ self.assertStartsWith(s2, p)
if s1 != s2:
n = len(p)
self.assertNotEqual(s1[n:n+1], s2[n:n+1])
diff --git a/Lib/test/test_getpass.py b/Lib/test/test_getpass.py
index 80dda2caaa3..ab36535a1cf 100644
--- a/Lib/test/test_getpass.py
+++ b/Lib/test/test_getpass.py
@@ -161,6 +161,45 @@ class UnixGetpassTest(unittest.TestCase):
self.assertIn('Warning', stderr.getvalue())
self.assertIn('Password:', stderr.getvalue())
+ def test_echo_char_replaces_input_with_asterisks(self):
+ mock_result = '*************'
+ with mock.patch('os.open') as os_open, \
+ mock.patch('io.FileIO'), \
+ mock.patch('io.TextIOWrapper') as textio, \
+ mock.patch('termios.tcgetattr'), \
+ mock.patch('termios.tcsetattr'), \
+ mock.patch('getpass._raw_input') as mock_input:
+ os_open.return_value = 3
+ mock_input.return_value = mock_result
+
+ result = getpass.unix_getpass(echo_char='*')
+ mock_input.assert_called_once_with('Password: ', textio(),
+ input=textio(), echo_char='*')
+ self.assertEqual(result, mock_result)
+
+ def test_raw_input_with_echo_char(self):
+ passwd = 'my1pa$$word!'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, passwd)
+ self.assertEqual('Password: ************', mock_output.getvalue())
+
+ def test_control_chars_with_echo_char(self):
+ passwd = 'pass\twd\b'
+ expect_result = 'pass\tw'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, expect_result)
+ self.assertEqual('Password: *******\x08 \x08', mock_output.getvalue())
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 585ed08ea14..33b7d75e3ff 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -6,7 +6,8 @@ import unittest.mock
from functools import partial
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
+from test.support.import_helper import ensure_lazy_imports
# TODO:
@@ -931,6 +932,10 @@ class MiscTestCase(unittest.TestCase):
support.check__all__(self, gettext,
not_exported={'c2py', 'ENOENT'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("gettext", {"re", "warnings", "locale"})
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py
index 6e5fc2939c6..d0ed5129253 100644
--- a/Lib/test/test_glob.py
+++ b/Lib/test/test_glob.py
@@ -459,59 +459,59 @@ class GlobTests(unittest.TestCase):
def test_translate(self):
def fn(pat):
return glob.translate(pat, seps='/')
- self.assertEqual(fn('foo'), r'(?s:foo)\Z')
- self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z')
- self.assertEqual(fn('*'), r'(?s:[^/.][^/]*)\Z')
- self.assertEqual(fn('?'), r'(?s:(?!\.)[^/])\Z')
- self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('*a'), r'(?s:(?!\.)[^/]*a)\Z')
- self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z')
- self.assertEqual(fn('?aa'), r'(?s:(?!\.)[^/]aa)\Z')
- self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z')
- self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z')
- self.assertEqual(fn('**'), r'(?s:(?!\.)[^/]*)\Z')
- self.assertEqual(fn('***'), r'(?s:(?!\.)[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:(?!\.)[^/]*b)\Z')
+ self.assertEqual(fn('foo'), r'(?s:foo)\z')
+ self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\z')
+ self.assertEqual(fn('*'), r'(?s:[^/.][^/]*)\z')
+ self.assertEqual(fn('?'), r'(?s:(?!\.)[^/])\z')
+ self.assertEqual(fn('a*'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('*a'), r'(?s:(?!\.)[^/]*a)\z')
+ self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\z')
+ self.assertEqual(fn('?aa'), r'(?s:(?!\.)[^/]aa)\z')
+ self.assertEqual(fn('aa?'), r'(?s:aa[^/])\z')
+ self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\z')
+ self.assertEqual(fn('**'), r'(?s:(?!\.)[^/]*)\z')
+ self.assertEqual(fn('***'), r'(?s:(?!\.)[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:(?!\.)[^/]*b)\z')
self.assertEqual(fn('/**/*/*.*/**'),
- r'(?s:/(?!\.)[^/]*/[^/.][^/]*/(?!\.)[^/]*\.[^/]*/(?!\.)[^/]*)\Z')
+ r'(?s:/(?!\.)[^/]*/[^/.][^/]*/(?!\.)[^/]*\.[^/]*/(?!\.)[^/]*)\z')
def test_translate_include_hidden(self):
def fn(pat):
return glob.translate(pat, include_hidden=True, seps='/')
- self.assertEqual(fn('foo'), r'(?s:foo)\Z')
- self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\Z')
- self.assertEqual(fn('*'), r'(?s:[^/]+)\Z')
- self.assertEqual(fn('?'), r'(?s:[^/])\Z')
- self.assertEqual(fn('a*'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('*a'), r'(?s:[^/]*a)\Z')
- self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\Z')
- self.assertEqual(fn('?aa'), r'(?s:[^/]aa)\Z')
- self.assertEqual(fn('aa?'), r'(?s:aa[^/])\Z')
- self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\Z')
- self.assertEqual(fn('**'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('***'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z')
- self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/[^/]*/[^/]+/[^/]*\.[^/]*/[^/]*)\Z')
+ self.assertEqual(fn('foo'), r'(?s:foo)\z')
+ self.assertEqual(fn('foo/bar'), r'(?s:foo/bar)\z')
+ self.assertEqual(fn('*'), r'(?s:[^/]+)\z')
+ self.assertEqual(fn('?'), r'(?s:[^/])\z')
+ self.assertEqual(fn('a*'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('*a'), r'(?s:[^/]*a)\z')
+ self.assertEqual(fn('.*'), r'(?s:\.[^/]*)\z')
+ self.assertEqual(fn('?aa'), r'(?s:[^/]aa)\z')
+ self.assertEqual(fn('aa?'), r'(?s:aa[^/])\z')
+ self.assertEqual(fn('aa[ab]'), r'(?s:aa[ab])\z')
+ self.assertEqual(fn('**'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('***'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:[^/]*b)\z')
+ self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/[^/]*/[^/]+/[^/]*\.[^/]*/[^/]*)\z')
def test_translate_recursive(self):
def fn(pat):
return glob.translate(pat, recursive=True, include_hidden=True, seps='/')
- self.assertEqual(fn('*'), r'(?s:[^/]+)\Z')
- self.assertEqual(fn('?'), r'(?s:[^/])\Z')
- self.assertEqual(fn('**'), r'(?s:.*)\Z')
- self.assertEqual(fn('**/**'), r'(?s:.*)\Z')
- self.assertEqual(fn('***'), r'(?s:[^/]*)\Z')
- self.assertEqual(fn('a**'), r'(?s:a[^/]*)\Z')
- self.assertEqual(fn('**b'), r'(?s:[^/]*b)\Z')
- self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/(?:.+/)?[^/]+/[^/]*\.[^/]*/.*)\Z')
+ self.assertEqual(fn('*'), r'(?s:[^/]+)\z')
+ self.assertEqual(fn('?'), r'(?s:[^/])\z')
+ self.assertEqual(fn('**'), r'(?s:.*)\z')
+ self.assertEqual(fn('**/**'), r'(?s:.*)\z')
+ self.assertEqual(fn('***'), r'(?s:[^/]*)\z')
+ self.assertEqual(fn('a**'), r'(?s:a[^/]*)\z')
+ self.assertEqual(fn('**b'), r'(?s:[^/]*b)\z')
+ self.assertEqual(fn('/**/*/*.*/**'), r'(?s:/(?:.+/)?[^/]+/[^/]*\.[^/]*/.*)\z')
def test_translate_seps(self):
def fn(pat):
return glob.translate(pat, recursive=True, include_hidden=True, seps=['/', '\\'])
- self.assertEqual(fn('foo/bar\\baz'), r'(?s:foo[/\\]bar[/\\]baz)\Z')
- self.assertEqual(fn('**/*'), r'(?s:(?:.+[/\\])?[^/\\]+)\Z')
+ self.assertEqual(fn('foo/bar\\baz'), r'(?s:foo[/\\]bar[/\\]baz)\z')
+ self.assertEqual(fn('**/*'), r'(?s:(?:.+[/\\])?[^/\\]+)\z')
if __name__ == "__main__":
diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py
index c39565144bf..7f5d48b9c63 100644
--- a/Lib/test/test_grammar.py
+++ b/Lib/test/test_grammar.py
@@ -1,7 +1,7 @@
# Python test set -- part 1, grammar.
# This just tests whether the parser accepts them all.
-from test.support import check_syntax_error
+from test.support import check_syntax_error, skip_wasi_stack_overflow
from test.support import import_helper
import annotationlib
import inspect
@@ -249,6 +249,18 @@ the \'lazy\' dog.\n\
compile(s, "<test>", "exec")
self.assertIn("was never closed", str(cm.exception))
+ @skip_wasi_stack_overflow()
+ def test_max_level(self):
+ # Macro defined in Parser/lexer/state.h
+ MAXLEVEL = 200
+
+ result = eval("(" * MAXLEVEL + ")" * MAXLEVEL)
+ self.assertEqual(result, ())
+
+ with self.assertRaises(SyntaxError) as cm:
+ eval("(" * (MAXLEVEL + 1) + ")" * (MAXLEVEL + 1))
+ self.assertStartsWith(str(cm.exception), 'too many nested parentheses')
+
var_annot_global: int # a global annotated is necessary for test_var_annot
diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py
index fa5de7c190e..a12ff5662a7 100644
--- a/Lib/test/test_gzip.py
+++ b/Lib/test/test_gzip.py
@@ -9,7 +9,6 @@ import os
import struct
import sys
import unittest
-import warnings
from subprocess import PIPE, Popen
from test.support import catch_unraisable_exception
from test.support import import_helper
@@ -331,13 +330,13 @@ class TestGzip(BaseTest):
def test_1647484(self):
for mode in ('wb', 'rb'):
with gzip.GzipFile(self.filename, mode) as f:
- self.assertTrue(hasattr(f, "name"))
+ self.assertHasAttr(f, "name")
self.assertEqual(f.name, self.filename)
def test_paddedfile_getattr(self):
self.test_write()
with gzip.GzipFile(self.filename, 'rb') as f:
- self.assertTrue(hasattr(f.fileobj, "name"))
+ self.assertHasAttr(f.fileobj, "name")
self.assertEqual(f.fileobj.name, self.filename)
def test_mtime(self):
@@ -345,7 +344,7 @@ class TestGzip(BaseTest):
with gzip.GzipFile(self.filename, 'w', mtime = mtime) as fWrite:
fWrite.write(data1)
with gzip.GzipFile(self.filename) as fRead:
- self.assertTrue(hasattr(fRead, 'mtime'))
+ self.assertHasAttr(fRead, 'mtime')
self.assertIsNone(fRead.mtime)
dataRead = fRead.read()
self.assertEqual(dataRead, data1)
@@ -460,7 +459,7 @@ class TestGzip(BaseTest):
self.assertEqual(d, data1 * 50, "Incorrect data in file")
def test_gzip_BadGzipFile_exception(self):
- self.assertTrue(issubclass(gzip.BadGzipFile, OSError))
+ self.assertIsSubclass(gzip.BadGzipFile, OSError)
def test_bad_gzip_file(self):
with open(self.filename, 'wb') as file:
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index 5e3356a02f3..5bad483ae9d 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -12,12 +12,12 @@ import io
import itertools
import logging
import os
+import re
import sys
import sysconfig
import tempfile
import threading
import unittest
-import warnings
from test import support
from test.support import _4G, bigmemtest
from test.support import hashlib_helper
@@ -98,6 +98,14 @@ def read_vectors(hash_name):
yield parts
+DEPRECATED_STRING_PARAMETER = re.escape(
+ "the 'string' keyword parameter is deprecated since "
+ "Python 3.15 and slated for removal in Python 3.19; "
+ "use the 'data' keyword parameter or pass the data "
+ "to hash as a positional argument instead"
+)
+
+
class HashLibTestCase(unittest.TestCase):
supported_hash_names = ( 'md5', 'MD5', 'sha1', 'SHA1',
'sha224', 'SHA224', 'sha256', 'SHA256',
@@ -141,19 +149,18 @@ class HashLibTestCase(unittest.TestCase):
# of hashlib.new given the algorithm name.
for algorithm, constructors in self.constructors_to_test.items():
constructors.add(getattr(hashlib, algorithm))
- def _test_algorithm_via_hashlib_new(data=None, _alg=algorithm, **kwargs):
- if data is None:
- return hashlib.new(_alg, **kwargs)
- return hashlib.new(_alg, data, **kwargs)
- constructors.add(_test_algorithm_via_hashlib_new)
+ def c(*args, __algorithm_name=algorithm, **kwargs):
+ return hashlib.new(__algorithm_name, *args, **kwargs)
+ c.__name__ = f'do_test_algorithm_via_hashlib_new_{algorithm}'
+ constructors.add(c)
_hashlib = self._conditional_import_module('_hashlib')
self._hashlib = _hashlib
if _hashlib:
# These algorithms should always be present when this module
# is compiled. If not, something was compiled wrong.
- self.assertTrue(hasattr(_hashlib, 'openssl_md5'))
- self.assertTrue(hasattr(_hashlib, 'openssl_sha1'))
+ self.assertHasAttr(_hashlib, 'openssl_md5')
+ self.assertHasAttr(_hashlib, 'openssl_sha1')
for algorithm, constructors in self.constructors_to_test.items():
constructor = getattr(_hashlib, 'openssl_'+algorithm, None)
if constructor:
@@ -201,6 +208,11 @@ class HashLibTestCase(unittest.TestCase):
return itertools.chain.from_iterable(constructors)
@property
+ def shake_constructors(self):
+ for shake_name in self.shakes:
+ yield from self.constructors_to_test.get(shake_name, ())
+
+ @property
def is_fips_mode(self):
return get_fips_mode()
@@ -250,6 +262,85 @@ class HashLibTestCase(unittest.TestCase):
self._hashlib.new("md5", usedforsecurity=False)
self._hashlib.openssl_md5(usedforsecurity=False)
+ @unittest.skipIf(get_fips_mode(), "skip in FIPS mode")
+ def test_clinic_signature(self):
+ for constructor in self.hash_constructors:
+ with self.subTest(constructor.__name__):
+ constructor(b'')
+ constructor(data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ constructor(string=b'')
+
+ digest_name = constructor(b'').name
+ with self.subTest(digest_name):
+ hashlib.new(digest_name, b'')
+ hashlib.new(digest_name, data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ hashlib.new(digest_name, string=b'')
+ # Make sure that _hashlib contains the constructor
+ # to test when using a combination of libcrypto and
+ # interned hash implementations.
+ if self._hashlib and digest_name in self._hashlib._constructors:
+ self._hashlib.new(digest_name, b'')
+ self._hashlib.new(digest_name, data=b'')
+ with self.assertWarnsRegex(DeprecationWarning,
+ DEPRECATED_STRING_PARAMETER):
+ self._hashlib.new(digest_name, string=b'')
+
+ @unittest.skipIf(get_fips_mode(), "skip in FIPS mode")
+ def test_clinic_signature_errors(self):
+ nomsg = b''
+ mymsg = b'msg'
+ conflicting_call = re.escape(
+ "'data' and 'string' are mutually exclusive "
+ "and support for 'string' keyword parameter "
+ "is slated for removal in a future version."
+ )
+ duplicated_param = re.escape("given by name ('data') and position")
+ unexpected_param = re.escape("got an unexpected keyword argument '_'")
+ for args, kwds, errmsg in [
+ # Reject duplicated arguments before unknown keyword arguments.
+ ((nomsg,), dict(data=nomsg, _=nomsg), duplicated_param),
+ ((mymsg,), dict(data=nomsg, _=nomsg), duplicated_param),
+ # Reject duplicated arguments before conflicting ones.
+ *itertools.product(
+ [[nomsg], [mymsg]],
+ [dict(data=nomsg), dict(data=nomsg, string=nomsg)],
+ [duplicated_param]
+ ),
+ # Reject unknown keyword arguments before conflicting ones.
+ *itertools.product(
+ [()],
+ [
+ dict(_=None),
+ dict(data=nomsg, _=None),
+ dict(string=nomsg, _=None),
+ dict(string=nomsg, data=nomsg, _=None),
+ ],
+ [unexpected_param]
+ ),
+ ((nomsg,), dict(_=None), unexpected_param),
+ ((mymsg,), dict(_=None), unexpected_param),
+ # Reject conflicting arguments.
+ [(nomsg,), dict(string=nomsg), conflicting_call],
+ [(mymsg,), dict(string=nomsg), conflicting_call],
+ [(), dict(data=nomsg, string=nomsg), conflicting_call],
+ ]:
+ for constructor in self.hash_constructors:
+ digest_name = constructor(b'').name
+ with self.subTest(constructor.__name__, args=args, kwds=kwds):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ constructor(*args, **kwds)
+ with self.subTest(digest_name, args=args, kwds=kwds):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ hashlib.new(digest_name, *args, **kwds)
+ if (self._hashlib and
+ digest_name in self._hashlib._constructors):
+ with self.assertRaisesRegex(TypeError, errmsg):
+ self._hashlib.new(digest_name, *args, **kwds)
+
def test_unknown_hash(self):
self.assertRaises(ValueError, hashlib.new, 'spam spam spam spam spam')
self.assertRaises(TypeError, hashlib.new, 1)
@@ -284,6 +375,16 @@ class HashLibTestCase(unittest.TestCase):
self.assertIs(constructor, _md5.md5)
self.assertEqual(sorted(builtin_constructor_cache), ['MD5', 'md5'])
+ def test_copy(self):
+ for cons in self.hash_constructors:
+ h1 = cons(os.urandom(16), usedforsecurity=False)
+ h2 = h1.copy()
+ self.assertIs(type(h1), type(h2))
+ self.assertEqual(h1.name, h2.name)
+ size = (16,) if h1.name in self.shakes else ()
+ self.assertEqual(h1.digest(*size), h2.digest(*size))
+ self.assertEqual(h1.hexdigest(*size), h2.hexdigest(*size))
+
def test_hexdigest(self):
for cons in self.hash_constructors:
h = cons(usedforsecurity=False)
@@ -294,21 +395,50 @@ class HashLibTestCase(unittest.TestCase):
self.assertIsInstance(h.digest(), bytes)
self.assertEqual(hexstr(h.digest()), h.hexdigest())
- def test_digest_length_overflow(self):
- # See issue #34922
- large_sizes = (2**29, 2**32-10, 2**32+10, 2**61, 2**64-10, 2**64+10)
- for cons in self.hash_constructors:
- h = cons(usedforsecurity=False)
- if h.name not in self.shakes:
- continue
- if HASH is not None and isinstance(h, HASH):
- # _hashopenssl's take a size_t
- continue
- for digest in h.digest, h.hexdigest:
- self.assertRaises(ValueError, digest, -10)
- for length in large_sizes:
- with self.assertRaises((ValueError, OverflowError)):
- digest(length)
+ def test_shakes_zero_digest_length(self):
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(b'abcdef', usedforsecurity=False)
+ self.assertEqual(h.digest(0), b'')
+ self.assertEqual(h.hexdigest(0), '')
+
+ def test_shakes_invalid_digest_length(self):
+ # See https://github.com/python/cpython/issues/79103.
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(usedforsecurity=False)
+ # Note: digest() and hexdigest() take a signed input and
+ # raise if it is negative; the rationale is that we use
+ # internally PyBytes_FromStringAndSize() and _Py_strhex()
+ # which both take a Py_ssize_t.
+ for negative_size in (-1, -10, -(1 << 31), -sys.maxsize):
+ self.assertRaises(ValueError, h.digest, negative_size)
+ self.assertRaises(ValueError, h.hexdigest, negative_size)
+
+ def test_shakes_overflow_digest_length(self):
+ # See https://github.com/python/cpython/issues/135759.
+
+ exc_types = (OverflowError, ValueError)
+ # HACL* accepts an 'uint32_t' while OpenSSL accepts a 'size_t'.
+ openssl_overflown_sizes = (sys.maxsize + 1, 2 * sys.maxsize)
+ # https://github.com/python/cpython/issues/79103 restricts
+ # the accepted built-in lengths to 2 ** 29, even if OpenSSL
+ # accepts such lengths.
+ builtin_overflown_sizes = openssl_overflown_sizes + (
+ 2 ** 29, 2 ** 32 - 10, 2 ** 32, 2 ** 32 + 10,
+ 2 ** 61, 2 ** 64 - 10, 2 ** 64, 2 ** 64 + 10,
+ )
+
+ for constructor in self.shake_constructors:
+ with self.subTest(constructor=constructor):
+ h = constructor(usedforsecurity=False)
+ if HASH is not None and isinstance(h, HASH):
+ overflown_sizes = openssl_overflown_sizes
+ else:
+ overflown_sizes = builtin_overflown_sizes
+ for invalid_size in overflown_sizes:
+ self.assertRaises(exc_types, h.digest, invalid_size)
+ self.assertRaises(exc_types, h.hexdigest, invalid_size)
def test_name_attribute(self):
for cons in self.hash_constructors:
@@ -719,8 +849,6 @@ class HashLibTestCase(unittest.TestCase):
self.assertRaises(ValueError, constructor, node_offset=-1)
self.assertRaises(OverflowError, constructor, node_offset=max_offset+1)
- self.assertRaises(TypeError, constructor, data=b'')
- self.assertRaises(TypeError, constructor, string=b'')
self.assertRaises(TypeError, constructor, '')
constructor(
@@ -929,49 +1057,67 @@ class HashLibTestCase(unittest.TestCase):
def test_sha256_gil(self):
gil_minsize = hashlib_helper.find_gil_minsize(['_sha2', '_hashlib'])
+ data = b'1' + b'#' * gil_minsize + b'1'
+ expected = hashlib.sha256(data).hexdigest()
+
m = hashlib.sha256()
m.update(b'1')
m.update(b'#' * gil_minsize)
m.update(b'1')
- self.assertEqual(
- m.hexdigest(),
- '1cfceca95989f51f658e3f3ffe7f1cd43726c9e088c13ee10b46f57cef135b94'
- )
+ self.assertEqual(m.hexdigest(), expected)
- m = hashlib.sha256(b'1' + b'#' * gil_minsize + b'1')
- self.assertEqual(
- m.hexdigest(),
- '1cfceca95989f51f658e3f3ffe7f1cd43726c9e088c13ee10b46f57cef135b94'
- )
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_threaded_hashing_fast(self):
+ # Same as test_threaded_hashing_slow() but only tests some functions
+ # since otherwise test_hashlib.py becomes too slow during development.
+ for name in ['md5', 'sha1', 'sha256', 'sha3_256', 'blake2s']:
+ if constructor := getattr(hashlib, name, None):
+ with self.subTest(name):
+ self.do_test_threaded_hashing(constructor, is_shake=False)
+ if shake_128 := getattr(hashlib, 'shake_128', None):
+ self.do_test_threaded_hashing(shake_128, is_shake=True)
+ @requires_resource('cpu')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
- def test_threaded_hashing(self):
+ def test_threaded_hashing_slow(self):
+ for algorithm, constructors in self.constructors_to_test.items():
+ is_shake = algorithm in self.shakes
+ for constructor in constructors:
+ with self.subTest(constructor.__name__, is_shake=is_shake):
+ self.do_test_threaded_hashing(constructor, is_shake)
+
+ def do_test_threaded_hashing(self, constructor, is_shake):
# Updating the same hash object from several threads at once
# using data chunk sizes containing the same byte sequences.
#
# If the internal locks are working to prevent multiple
# updates on the same object from running at once, the resulting
# hash will be the same as doing it single threaded upfront.
- hasher = hashlib.sha1()
- num_threads = 5
- smallest_data = b'swineflu'
- data = smallest_data * 200000
- expected_hash = hashlib.sha1(data*num_threads).hexdigest()
-
- def hash_in_chunks(chunk_size):
- index = 0
- while index < len(data):
- hasher.update(data[index:index + chunk_size])
- index += chunk_size
+
+ # The data to hash has length s|M|q^N and the chunk size for the i-th
+ # thread is s|M|q^(N-i), where N is the number of threads, M is a fixed
+ # message of small length, and s >= 1 and q >= 2 are small integers.
+ smallest_size, num_threads, s, q = 8, 5, 2, 10
+
+ smallest_data = os.urandom(smallest_size)
+ data = s * smallest_data * (q ** num_threads)
+
+ h1 = constructor(usedforsecurity=False)
+ h2 = constructor(data * num_threads, usedforsecurity=False)
+
+ def update(chunk_size):
+ for index in range(0, len(data), chunk_size):
+ h1.update(data[index:index + chunk_size])
threads = []
- for threadnum in range(num_threads):
- chunk_size = len(data) // (10 ** threadnum)
+ for thread_num in range(num_threads):
+ # chunk_size = len(data) // (q ** thread_num)
+ chunk_size = s * smallest_size * q ** (num_threads - thread_num)
self.assertGreater(chunk_size, 0)
- self.assertEqual(chunk_size % len(smallest_data), 0)
- thread = threading.Thread(target=hash_in_chunks,
- args=(chunk_size,))
+ self.assertEqual(chunk_size % smallest_size, 0)
+ thread = threading.Thread(target=update, args=(chunk_size,))
threads.append(thread)
for thread in threads:
@@ -979,7 +1125,10 @@ class HashLibTestCase(unittest.TestCase):
for thread in threads:
thread.join()
- self.assertEqual(expected_hash, hasher.hexdigest())
+ if is_shake:
+ self.assertEqual(h1.hexdigest(16), h2.hexdigest(16))
+ else:
+ self.assertEqual(h1.hexdigest(), h2.hexdigest())
def test_get_fips_mode(self):
fips_mode = self.is_fips_mode
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 1aa8e4e2897..d6623fee9bb 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -13,8 +13,9 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
# _heapq is imported, so check them there
-func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
- '_heappop_max', '_heapreplace_max', '_heapify_max']
+func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace']
+# Add max-heap variants
+func_names += [func + '_max' for func in func_names]
class TestModules(TestCase):
def test_py_functions(self):
@@ -24,7 +25,7 @@ class TestModules(TestCase):
@skipUnless(c_heapq, 'requires _heapq')
def test_c_functions(self):
for fname in func_names:
- self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
+ self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq', fname)
def load_tests(loader, tests, ignore):
@@ -74,6 +75,34 @@ class TestHeap:
except AttributeError:
pass
+ def test_max_push_pop(self):
+ # 1) Push 256 random numbers and pop them off, verifying all's OK.
+ heap = []
+ data = []
+ self.check_max_invariant(heap)
+ for i in range(256):
+ item = random.random()
+ data.append(item)
+ self.module.heappush_max(heap, item)
+ self.check_max_invariant(heap)
+ results = []
+ while heap:
+ item = self.module.heappop_max(heap)
+ self.check_max_invariant(heap)
+ results.append(item)
+ data_sorted = data[:]
+ data_sorted.sort(reverse=True)
+
+ self.assertEqual(data_sorted, results)
+ # 2) Check that the invariant holds for a sorted array
+ self.check_max_invariant(results)
+
+ self.assertRaises(TypeError, self.module.heappush_max, [])
+
+ exc_types = (AttributeError, TypeError)
+ self.assertRaises(exc_types, self.module.heappush_max, None, None)
+ self.assertRaises(exc_types, self.module.heappop_max, None)
+
def check_invariant(self, heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
@@ -81,6 +110,11 @@ class TestHeap:
parentpos = (pos-1) >> 1
self.assertTrue(heap[parentpos] <= item)
+ def check_max_invariant(self, heap):
+ for pos, item in enumerate(heap[1:], start=1):
+ parentpos = (pos - 1) >> 1
+ self.assertGreaterEqual(heap[parentpos], item)
+
def test_heapify(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
@@ -89,6 +123,14 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapify, None)
+ def test_heapify_max(self):
+ for size in list(range(30)) + [20000]:
+ heap = [random.random() for dummy in range(size)]
+ self.module.heapify_max(heap)
+ self.check_max_invariant(heap)
+
+ self.assertRaises(TypeError, self.module.heapify_max, None)
+
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
heap = []
@@ -109,10 +151,7 @@ class TestHeap:
def test_nbest(self):
# Less-naive "N-best" algorithm, much faster (if len(data) is big
- # enough <wink>) than sorting all of data. However, if we had a max
- # heap instead of a min heap, it could go faster still via
- # heapify'ing all of data (linear time), then doing 10 heappops
- # (10 log-time steps).
+ # enough <wink>) than sorting all of data.
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
@@ -125,6 +164,17 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapreplace, None, None)
self.assertRaises(IndexError, self.module.heapreplace, [], None)
+ def test_nbest_maxheap(self):
+ # With a max heap instead of a min heap, the "N-best" algorithm can
+ # go even faster still via heapify'ing all of data (linear time), then
+ # doing 10 heappops (10 log-time steps).
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify_max(heap)
+ result = [self.module.heappop_max(heap) for _ in range(10)]
+ result.reverse()
+ self.assertEqual(result, sorted(data)[-10:])
+
def test_nbest_with_pushpop(self):
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
@@ -134,6 +184,62 @@ class TestHeap:
self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
self.assertEqual(self.module.heappushpop([], 'x'), 'x')
+ def test_naive_nworst(self):
+ # Max-heap variant of "test_naive_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ if len(heap) > 10:
+ self.module.heappop_max(heap)
+ heap.sort()
+ expected = sorted(data)[:10]
+ self.assertEqual(heap, expected)
+
+ def heapiter_max(self, heap):
+ # An iterator returning a max-heap's elements, largest-first.
+ try:
+ while 1:
+ yield self.module.heappop_max(heap)
+ except IndexError:
+ pass
+
+ def test_nworst(self):
+ # Max-heap variant of "test_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ if item < heap[0]: # this gets rarer the longer we run
+ self.module.heapreplace_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+
+ self.assertRaises(TypeError, self.module.heapreplace_max, None)
+ self.assertRaises(TypeError, self.module.heapreplace_max, None, None)
+ self.assertRaises(IndexError, self.module.heapreplace_max, [], None)
+
+ def test_nworst_minheap(self):
+ # Min-heap variant of "test_nbest_maxheap"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify(heap)
+ result = [self.module.heappop(heap) for _ in range(10)]
+ result.reverse()
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(result, expected)
+
+ def test_nworst_with_pushpop(self):
+ # Max-heap variant of "test_nbest_with_pushpop"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ self.module.heappushpop_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+ self.assertEqual(self.module.heappushpop_max([], 'x'), 'x')
+
def test_heappushpop(self):
h = []
x = self.module.heappushpop(h, 10)
@@ -153,12 +259,31 @@ class TestHeap:
x = self.module.heappushpop(h, 11)
self.assertEqual((h, x), ([11], 10))
+ def test_heappushpop_max(self):
+ h = []
+ x = self.module.heappushpop_max(h, 10)
+ self.assertTupleEqual((h, x), ([], 10))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 10.0)
+ self.assertTupleEqual((h, x), ([10], 10.0))
+ self.assertIsInstance(h[0], int)
+ self.assertIsInstance(x, float)
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 11)
+ self.assertTupleEqual((h, x), ([10], 11))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 9)
+ self.assertTupleEqual((h, x), ([9], 10))
+
def test_heappop_max(self):
- # _heapop_max has an optimization for one-item lists which isn't
+ # heapop_max has an optimization for one-item lists which isn't
# covered in other tests, so test that case explicitly here
h = [3, 2]
- self.assertEqual(self.module._heappop_max(h), 3)
- self.assertEqual(self.module._heappop_max(h), 2)
+ self.assertEqual(self.module.heappop_max(h), 3)
+ self.assertEqual(self.module.heappop_max(h), 2)
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
@@ -175,6 +300,20 @@ class TestHeap:
heap_sorted = [self.module.heappop(heap) for i in range(size)]
self.assertEqual(heap_sorted, sorted(data))
+ def test_heapsort_max(self):
+ for trial in range(100):
+ size = random.randrange(50)
+ data = [random.randrange(25) for i in range(size)]
+ if trial & 1: # Half of the time, use heapify_max
+ heap = data[:]
+ self.module.heapify_max(heap)
+ else: # The rest of the time, use heappush_max
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ heap_sorted = [self.module.heappop_max(heap) for i in range(size)]
+ self.assertEqual(heap_sorted, sorted(data, reverse=True))
+
def test_merge(self):
inputs = []
for i in range(random.randrange(25)):
@@ -377,16 +516,20 @@ class SideEffectLT:
class TestErrorHandling:
def test_non_sequence(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, 10)
for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10, 10)
def test_len_only(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly())
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(TypeError, f, 2, LenOnly())
@@ -395,7 +538,8 @@ class TestErrorHandling:
seq = [CmpErr(), CmpErr(), CmpErr()]
for f in (self.module.heapify, self.module.heappop):
self.assertRaises(ZeroDivisionError, f, seq)
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises(ZeroDivisionError, f, seq, 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(ZeroDivisionError, f, 2, seq)
@@ -403,6 +547,8 @@ class TestErrorHandling:
def test_arg_parsing(self):
for f in (self.module.heapify, self.module.heappop,
self.module.heappush, self.module.heapreplace,
+ self.module.heapify_max, self.module.heappop_max,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10)
@@ -424,6 +570,10 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappush(heap, SideEffectLT(5, heap))
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappush_max(heap, SideEffectLT(5, heap))
def test_heappop_mutating_heap(self):
heap = []
@@ -431,8 +581,12 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappop_max(heap)
- def test_comparison_operator_modifiying_heap(self):
+ def test_comparison_operator_modifying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
@@ -444,7 +598,7 @@ class TestErrorHandling:
self.module.heappush(heap, EvilClass(0))
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
- def test_comparison_operator_modifiying_heap_two_heaps(self):
+ def test_comparison_operator_modifying_heap_two_heaps(self):
class h(int):
def __lt__(self, o):
@@ -464,6 +618,17 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
+ list1, list2 = [], []
+
+ self.module.heappush_max(list1, h(0))
+ self.module.heappush_max(list2, g(0))
+ self.module.heappush_max(list1, g(1))
+ self.module.heappush_max(list2, h(1))
+
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list1, g(1))
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list2, h(1))
+
+
class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 70c79437722..ff6e1bce0ef 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -1,3 +1,21 @@
+"""Test suite for HMAC.
+
+Python provides three different implementations of HMAC:
+
+- OpenSSL HMAC using OpenSSL hash functions.
+- HACL* HMAC using HACL* hash functions.
+- Generic Python HMAC using user-defined hash functions.
+
+The generic Python HMAC implementation is able to use OpenSSL
+callables or names, HACL* named hash functions or arbitrary
+objects implementing PEP 247 interface.
+
+In the two first cases, Python HMAC wraps a C HMAC object (either OpenSSL
+or HACL*-based). As a last resort, HMAC is re-implemented in pure Python.
+It is however interesting to test the pure Python implementation against
+the OpenSSL and HACL* hash functions.
+"""
+
import binascii
import functools
import hmac
@@ -10,6 +28,12 @@ import unittest.mock as mock
import warnings
from _operator import _compare_digest as operator_compare_digest
from test.support import check_disallow_instantiation
+from test.support.hashlib_helper import (
+ BuiltinHashFunctionsTrait,
+ HashFunctionsTrait,
+ NamedHashFunctionsTrait,
+ OpenSSLHashFunctionsTrait,
+)
from test.support.import_helper import import_fresh_module, import_module
try:
@@ -382,50 +406,7 @@ class BuiltinAssertersMixin(ThroughBuiltinAPIMixin, AssertersMixin):
pass
-class HashFunctionsTrait:
- """Trait class for 'hashfunc' in hmac_new() and hmac_digest()."""
-
- ALGORITHMS = [
- 'md5', 'sha1',
- 'sha224', 'sha256', 'sha384', 'sha512',
- 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512',
- ]
-
- # By default, a missing algorithm skips the test that uses it.
- _ = property(lambda self: self.skipTest("missing hash function"))
- md5 = sha1 = _
- sha224 = sha256 = sha384 = sha512 = _
- sha3_224 = sha3_256 = sha3_384 = sha3_512 = _
- del _
-
-
-class WithOpenSSLHashFunctions(HashFunctionsTrait):
- """Test a HMAC implementation with an OpenSSL-based callable 'hashfunc'."""
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- for name in cls.ALGORITHMS:
- @property
- @hashlib_helper.requires_openssl_hashdigest(name)
- def func(self, *, __name=name): # __name needed to bind 'name'
- return getattr(_hashlib, f'openssl_{__name}')
- setattr(cls, name, func)
-
-
-class WithNamedHashFunctions(HashFunctionsTrait):
- """Test a HMAC implementation with a named 'hashfunc'."""
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- for name in cls.ALGORITHMS:
- setattr(cls, name, name)
-
-
-class RFCTestCaseMixin(AssertersMixin, HashFunctionsTrait):
+class RFCTestCaseMixin(HashFunctionsTrait, AssertersMixin):
"""Test HMAC implementations against RFC 2202/4231 and NIST test vectors.
- Test vectors for MD5 and SHA-1 are taken from RFC 2202.
@@ -739,26 +720,83 @@ class RFCTestCaseMixin(AssertersMixin, HashFunctionsTrait):
)
-class PyRFCTestCase(ThroughObjectMixin, PyAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
- unittest.TestCase):
+class PurePythonInitHMAC(PyModuleMixin, HashFunctionsTrait):
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ for meth in ['_init_openssl_hmac', '_init_builtin_hmac']:
+ fn = getattr(cls.hmac.HMAC, meth)
+ cm = mock.patch.object(cls.hmac.HMAC, meth, autospec=True, wraps=fn)
+ cls.enterClassContext(cm)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.hmac.HMAC._init_openssl_hmac.assert_not_called()
+ cls.hmac.HMAC._init_builtin_hmac.assert_not_called()
+ # Do not assert that HMAC._init_old() has been called as it's tricky
+ # to determine whether a test for a specific hash function has been
+ # executed or not. On regular builds, it will be called but if a
+ # hash function is not available, it's hard to detect for which
+ # test we should checj HMAC._init_old() or not.
+ super().tearDownClass()
+
+
+class PyRFCOpenSSLTestCase(ThroughObjectMixin,
+ PyAssertersMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
"""Python implementation of HMAC using hmac.HMAC().
- The underlying hash functions are OpenSSL-based.
+ The underlying hash functions are OpenSSL-based but
+ _init_old() is used instead of _init_openssl_hmac().
"""
-class PyDotNewRFCTestCase(ThroughModuleAPIMixin, PyAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
- unittest.TestCase):
+class PyRFCBuiltinTestCase(ThroughObjectMixin,
+ PyAssertersMixin,
+ BuiltinHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
+ """Python implementation of HMAC using hmac.HMAC().
+
+ The underlying hash functions are HACL*-based but
+ _init_old() is used instead of _init_builtin_hmac().
+ """
+
+
+class PyDotNewOpenSSLRFCTestCase(ThroughModuleAPIMixin,
+ PyAssertersMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
+ """Python implementation of HMAC using hmac.new().
+
+ The underlying hash functions are OpenSSL-based but
+ _init_old() is used instead of _init_openssl_hmac().
+ """
+
+
+class PyDotNewBuiltinRFCTestCase(ThroughModuleAPIMixin,
+ PyAssertersMixin,
+ BuiltinHashFunctionsTrait,
+ RFCTestCaseMixin,
+ PurePythonInitHMAC,
+ unittest.TestCase):
"""Python implementation of HMAC using hmac.new().
- The underlying hash functions are OpenSSL-based.
+ The underlying hash functions are HACL-based but
+ _init_old() is used instead of _init_openssl_hmac().
"""
class OpenSSLRFCTestCase(OpenSSLAssertersMixin,
- WithOpenSSLHashFunctions, RFCTestCaseMixin,
+ OpenSSLHashFunctionsTrait,
+ RFCTestCaseMixin,
unittest.TestCase):
"""OpenSSL implementation of HMAC.
@@ -767,7 +805,8 @@ class OpenSSLRFCTestCase(OpenSSLAssertersMixin,
class BuiltinRFCTestCase(BuiltinAssertersMixin,
- WithNamedHashFunctions, RFCTestCaseMixin,
+ NamedHashFunctionsTrait,
+ RFCTestCaseMixin,
unittest.TestCase):
"""Built-in HACL* implementation of HMAC.
@@ -784,12 +823,6 @@ class BuiltinRFCTestCase(BuiltinAssertersMixin,
self.check_hmac_hexdigest(key, msg, hexdigest, digest_size, func)
-# TODO(picnixz): once we have a HACL* HMAC, we should also test the Python
-# implementation of HMAC with a HACL*-based hash function. For now, we only
-# test it partially via the '_sha2' module, but for completeness we could
-# also test the RFC test vectors against all possible implementations.
-
-
class DigestModTestCaseMixin(CreatorMixin, DigestMixin):
"""Tests for the 'digestmod' parameter for hmac_new() and hmac_digest()."""
diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py
index b42a611c62c..65a4bee72b9 100644
--- a/Lib/test/test_htmlparser.py
+++ b/Lib/test/test_htmlparser.py
@@ -5,6 +5,7 @@ import pprint
import unittest
from unittest.mock import patch
+from test import support
class EventCollector(html.parser.HTMLParser):
@@ -317,6 +318,16 @@ text
("endtag", element_lower)],
collector=Collector(convert_charrefs=False))
+ def test_EOF_in_cdata(self):
+ content = """<!-- not a comment --> &not-an-entity-ref;
+ <a href="" /> </p><p> <span></span></style>
+ '</script' + '>'"""
+ s = f'<script>{content}'
+ self._run_check(s, [
+ ("starttag", 'script', []),
+ ("data", content)
+ ])
+
def test_comments(self):
html = ("<!-- I'm a valid comment -->"
'<!--me too!-->'
@@ -348,18 +359,16 @@ text
collector = lambda: EventCollectorCharrefs()
self.assertTrue(collector().convert_charrefs)
charrefs = ['&quot;', '&#34;', '&#x22;', '&quot', '&#34', '&#x22']
- # check charrefs in the middle of the text/attributes
- expected = [('starttag', 'a', [('href', 'foo"zar')]),
- ('data', 'a"z'), ('endtag', 'a')]
+ # check charrefs in the middle of the text
+ expected = [('starttag', 'a', []), ('data', 'a"z'), ('endtag', 'a')]
for charref in charrefs:
- self._run_check('<a href="foo{0}zar">a{0}z</a>'.format(charref),
+ self._run_check('<a>a{0}z</a>'.format(charref),
expected, collector=collector())
- # check charrefs at the beginning/end of the text/attributes
- expected = [('data', '"'),
- ('starttag', 'a', [('x', '"'), ('y', '"X'), ('z', 'X"')]),
+ # check charrefs at the beginning/end of the text
+ expected = [('data', '"'), ('starttag', 'a', []),
('data', '"'), ('endtag', 'a'), ('data', '"')]
for charref in charrefs:
- self._run_check('{0}<a x="{0}" y="{0}X" z="X{0}">'
+ self._run_check('{0}<a>'
'{0}</a>{0}'.format(charref),
expected, collector=collector())
# check charrefs in <script>/<style> elements
@@ -382,6 +391,35 @@ text
self._run_check('no charrefs here', [('data', 'no charrefs here')],
collector=collector())
+ def test_convert_charrefs_in_attribute_values(self):
+ # default value for convert_charrefs is now True
+ collector = lambda: EventCollectorCharrefs()
+ self.assertTrue(collector().convert_charrefs)
+
+ # always unescape terminated entity refs, numeric and hex char refs:
+ # - regardless whether they are at start, middle, end of attribute
+ # - or followed by alphanumeric, non-alphanumeric, or equals char
+ charrefs = ['&cent;', '&#xa2;', '&#xa2', '&#162;', '&#162']
+ expected = [('starttag', 'a',
+ [('x', '¢'), ('x', 'z¢'), ('x', '¢z'),
+ ('x', 'z¢z'), ('x', '¢ z'), ('x', '¢=z')]),
+ ('endtag', 'a')]
+ for charref in charrefs:
+ self._run_check('<a x="{0}" x="z{0}" x="{0}z" '
+ ' x="z{0}z" x="{0} z" x="{0}=z"></a>'
+ .format(charref), expected, collector=collector())
+
+ # only unescape unterminated entity matches if they are not followed by
+ # an alphanumeric or an equals sign
+ charref = '&cent'
+ expected = [('starttag', 'a',
+ [('x', '¢'), ('x', 'z¢'), ('x', '&centz'),
+ ('x', 'z&centz'), ('x', '¢ z'), ('x', '&cent=z')]),
+ ('endtag', 'a')]
+ self._run_check('<a x="{0}" x="z{0}" x="{0}z" '
+ ' x="z{0}z" x="{0} z" x="{0}=z"></a>'
+ .format(charref), expected, collector=collector())
+
# the remaining tests were for the "tolerant" parser (which is now
# the default), and check various kind of broken markup
def test_tolerant_parsing(self):
@@ -393,28 +431,34 @@ text
('data', '<'),
('starttag', 'bc<', [('a', None)]),
('endtag', 'html'),
- ('data', '\n<img src="URL>'),
- ('comment', '/img'),
- ('endtag', 'html<')])
+ ('data', '\n')])
def test_starttag_junk_chars(self):
+ self._run_check("<", [('data', '<')])
+ self._run_check("<>", [('data', '<>')])
+ self._run_check("< >", [('data', '< >')])
+ self._run_check("< ", [('data', '< ')])
self._run_check("</>", [])
+ self._run_check("<$>", [('data', '<$>')])
self._run_check("</$>", [('comment', '$')])
self._run_check("</", [('data', '</')])
- self._run_check("</a", [('data', '</a')])
+ self._run_check("</a", [])
+ self._run_check("</ a>", [('endtag', 'a')])
+ self._run_check("</ a", [('comment', ' a')])
self._run_check("<a<a>", [('starttag', 'a<a', [])])
self._run_check("</a<a>", [('endtag', 'a<a')])
- self._run_check("<!", [('data', '<!')])
- self._run_check("<a", [('data', '<a')])
- self._run_check("<a foo='bar'", [('data', "<a foo='bar'")])
- self._run_check("<a foo='bar", [('data', "<a foo='bar")])
- self._run_check("<a foo='>'", [('data', "<a foo='>'")])
- self._run_check("<a foo='>", [('data', "<a foo='>")])
+ self._run_check("<!", [('comment', '')])
+ self._run_check("<a", [])
+ self._run_check("<a foo='bar'", [])
+ self._run_check("<a foo='bar", [])
+ self._run_check("<a foo='>'", [])
+ self._run_check("<a foo='>", [])
self._run_check("<a$>", [('starttag', 'a$', [])])
self._run_check("<a$b>", [('starttag', 'a$b', [])])
self._run_check("<a$b/>", [('startendtag', 'a$b', [])])
self._run_check("<a$b >", [('starttag', 'a$b', [])])
self._run_check("<a$b />", [('startendtag', 'a$b', [])])
+ self._run_check("</a$b>", [('endtag', 'a$b')])
def test_slashes_in_starttag(self):
self._run_check('<a foo="var"/>', [('startendtag', 'a', [('foo', 'var')])])
@@ -539,52 +583,129 @@ text
for html, expected in data:
self._run_check(html, expected)
- def test_broken_comments(self):
- html = ('<! not really a comment >'
+ def test_eof_in_comments(self):
+ data = [
+ ('<!--', [('comment', '')]),
+ ('<!---', [('comment', '')]),
+ ('<!----', [('comment', '')]),
+ ('<!-----', [('comment', '-')]),
+ ('<!------', [('comment', '--')]),
+ ('<!----!', [('comment', '')]),
+ ('<!---!', [('comment', '-!')]),
+ ('<!---!>', [('comment', '-!>')]),
+ ('<!--foo', [('comment', 'foo')]),
+ ('<!--foo-', [('comment', 'foo')]),
+ ('<!--foo--', [('comment', 'foo')]),
+ ('<!--foo--!', [('comment', 'foo')]),
+ ('<!--<!--', [('comment', '<!')]),
+ ('<!--<!--!', [('comment', '<!')]),
+ ]
+ for html, expected in data:
+ self._run_check(html, expected)
+
+ def test_eof_in_declarations(self):
+ data = [
+ ('<!', [('comment', '')]),
+ ('<!-', [('comment', '-')]),
+ ('<![', [('comment', '[')]),
+ ('<![CDATA[', [('unknown decl', 'CDATA[')]),
+ ('<![CDATA[x', [('unknown decl', 'CDATA[x')]),
+ ('<![CDATA[x]', [('unknown decl', 'CDATA[x]')]),
+ ('<![CDATA[x]]', [('unknown decl', 'CDATA[x]]')]),
+ ('<!DOCTYPE', [('decl', 'DOCTYPE')]),
+ ('<!DOCTYPE ', [('decl', 'DOCTYPE ')]),
+ ('<!DOCTYPE html', [('decl', 'DOCTYPE html')]),
+ ('<!DOCTYPE html ', [('decl', 'DOCTYPE html ')]),
+ ('<!DOCTYPE html PUBLIC', [('decl', 'DOCTYPE html PUBLIC')]),
+ ('<!DOCTYPE html PUBLIC "foo', [('decl', 'DOCTYPE html PUBLIC "foo')]),
+ ('<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN" "foo',
+ [('decl', 'DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN" "foo')]),
+ ]
+ for html, expected in data:
+ self._run_check(html, expected)
+
+ def test_bogus_comments(self):
+ html = ('<!ELEMENT br EMPTY>'
+ '<! not really a comment >'
'<! not a comment either -->'
'<! -- close enough -->'
'<!><!<-- this was an empty comment>'
- '<!!! another bogus comment !!!>')
+ '<!!! another bogus comment !!!>'
+ # see #32876
+ '<![with square brackets]!>'
+ '<![\nmultiline\nbogusness\n]!>'
+ '<![more brackets]-[and a hyphen]!>'
+ '<![cdata[should be uppercase]]>'
+ '<![CDATA [whitespaces are not ignored]]>'
+ '<![CDATA]]>' # required '[' after CDATA
+ )
expected = [
+ ('comment', 'ELEMENT br EMPTY'),
('comment', ' not really a comment '),
('comment', ' not a comment either --'),
('comment', ' -- close enough --'),
('comment', ''),
('comment', '<-- this was an empty comment'),
('comment', '!! another bogus comment !!!'),
+ ('comment', '[with square brackets]!'),
+ ('comment', '[\nmultiline\nbogusness\n]!'),
+ ('comment', '[more brackets]-[and a hyphen]!'),
+ ('comment', '[cdata[should be uppercase]]'),
+ ('comment', '[CDATA [whitespaces are not ignored]]'),
+ ('comment', '[CDATA]]'),
]
self._run_check(html, expected)
def test_broken_condcoms(self):
# these condcoms are missing the '--' after '<!' and before the '>'
+ # and they are considered bogus comments according to
+ # "8.2.4.42. Markup declaration open state"
html = ('<![if !(IE)]>broken condcom<![endif]>'
'<![if ! IE]><link href="favicon.tiff"/><![endif]>'
'<![if !IE 6]><img src="firefox.png" /><![endif]>'
'<![if !ie 6]><b>foo</b><![endif]>'
'<![if (!IE)|(lt IE 9)]><img src="mammoth.bmp" /><![endif]>')
- # According to the HTML5 specs sections "8.2.4.44 Bogus comment state"
- # and "8.2.4.45 Markup declaration open state", comment tokens should
- # be emitted instead of 'unknown decl', but calling unknown_decl
- # provides more flexibility.
- # See also Lib/_markupbase.py:parse_declaration
expected = [
- ('unknown decl', 'if !(IE)'),
+ ('comment', '[if !(IE)]'),
('data', 'broken condcom'),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if ! IE'),
+ ('comment', '[endif]'),
+ ('comment', '[if ! IE]'),
('startendtag', 'link', [('href', 'favicon.tiff')]),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if !IE 6'),
+ ('comment', '[endif]'),
+ ('comment', '[if !IE 6]'),
('startendtag', 'img', [('src', 'firefox.png')]),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if !ie 6'),
+ ('comment', '[endif]'),
+ ('comment', '[if !ie 6]'),
('starttag', 'b', []),
('data', 'foo'),
('endtag', 'b'),
- ('unknown decl', 'endif'),
- ('unknown decl', 'if (!IE)|(lt IE 9)'),
+ ('comment', '[endif]'),
+ ('comment', '[if (!IE)|(lt IE 9)]'),
('startendtag', 'img', [('src', 'mammoth.bmp')]),
- ('unknown decl', 'endif')
+ ('comment', '[endif]')
+ ]
+ self._run_check(html, expected)
+
+ def test_cdata_declarations(self):
+ # More tests should be added. See also "8.2.4.42. Markup
+ # declaration open state", "8.2.4.69. CDATA section state",
+ # and issue 32876
+ html = ('<![CDATA[just some plain text]]>')
+ expected = [('unknown decl', 'CDATA[just some plain text')]
+ self._run_check(html, expected)
+
+ def test_cdata_declarations_multiline(self):
+ html = ('<code><![CDATA['
+ ' if (a < b && a > b) {'
+ ' printf("[<marquee>How?</marquee>]");'
+ ' }'
+ ']]></code>')
+ expected = [
+ ('starttag', 'code', []),
+ ('unknown decl',
+ 'CDATA[ if (a < b && a > b) { '
+ 'printf("[<marquee>How?</marquee>]"); }'),
+ ('endtag', 'code')
]
self._run_check(html, expected)
@@ -600,6 +721,26 @@ text
('endtag', 'a'), ('data', ' bar & baz')]
)
+ @support.requires_resource('cpu')
+ def test_eof_no_quadratic_complexity(self):
+ # Each of these examples used to take about an hour.
+ # Now they take a fraction of a second.
+ def check(source):
+ parser = html.parser.HTMLParser()
+ parser.feed(source)
+ parser.close()
+ n = 120_000
+ check("<a " * n)
+ check("<a a=" * n)
+ check("</a " * 14 * n)
+ check("</a a=" * 11 * n)
+ check("<!--" * 4 * n)
+ check("<!" * 60 * n)
+ check("<?" * 19 * n)
+ check("</$" * 15 * n)
+ check("<![CDATA[" * 9 * n)
+ check("<!doctype" * 35 * n)
+
class AttributesTestCase(TestCaseBase):
diff --git a/Lib/test/test_http_cookiejar.py b/Lib/test/test_http_cookiejar.py
index 6bc33b15ec3..04cb440cd4c 100644
--- a/Lib/test/test_http_cookiejar.py
+++ b/Lib/test/test_http_cookiejar.py
@@ -4,6 +4,7 @@ import os
import stat
import sys
import re
+from test import support
from test.support import os_helper
from test.support import warnings_helper
import time
@@ -105,8 +106,7 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(http2time(s.lower()), test_t, s.lower())
self.assertEqual(http2time(s.upper()), test_t, s.upper())
- def test_http2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Mandag 16. September 1996',
@@ -121,10 +121,9 @@ class DateTimeTests(unittest.TestCase):
'08-01-3697739',
'09 Feb 19942632 22:23:32 GMT',
'Wed, 09 Feb 1994834 22:23:32 GMT',
- ]:
- self.assertIsNone(http2time(test),
- "http2time(%s) is not None\n"
- "http2time(test) %s" % (test, http2time(test)))
+ ])
+ def test_http2time_garbage(self, test):
+ self.assertIsNone(http2time(test))
def test_http2time_redos_regression_actually_completes(self):
# LOOSE_HTTP_DATE_RE was vulnerable to malicious input which caused catastrophic backtracking (REDoS).
@@ -149,9 +148,7 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(parse_date("1994-02-03 19:45:29 +0530"),
(1994, 2, 3, 14, 15, 29))
- def test_iso2time_formats(self):
- # test iso2time for supported dates.
- tests = [
+ @support.subTests('s', [
'1994-02-03 00:00:00 -0000', # ISO 8601 format
'1994-02-03 00:00:00 +0000', # ISO 8601 format
'1994-02-03 00:00:00', # zone is optional
@@ -164,16 +161,15 @@ class DateTimeTests(unittest.TestCase):
# A few tests with extra space at various places
' 1994-02-03 ',
' 1994-02-03T00:00:00 ',
- ]
-
+ ])
+ def test_iso2time_formats(self, s):
+ # test iso2time for supported dates.
test_t = 760233600 # assume broken POSIX counting of seconds
- for s in tests:
- self.assertEqual(iso2time(s), test_t, s)
- self.assertEqual(iso2time(s.lower()), test_t, s.lower())
- self.assertEqual(iso2time(s.upper()), test_t, s.upper())
+ self.assertEqual(iso2time(s), test_t, s)
+ self.assertEqual(iso2time(s.lower()), test_t, s.lower())
+ self.assertEqual(iso2time(s.upper()), test_t, s.upper())
- def test_iso2time_garbage(self):
- for test in [
+ @support.subTests('test', [
'',
'Garbage',
'Thursday, 03-Feb-94 00:00:00 GMT',
@@ -186,9 +182,9 @@ class DateTimeTests(unittest.TestCase):
'01-01-1980 00:00:62',
'01-01-1980T00:00:62',
'19800101T250000Z',
- ]:
- self.assertIsNone(iso2time(test),
- "iso2time(%r)" % test)
+ ])
+ def test_iso2time_garbage(self, test):
+ self.assertIsNone(iso2time(test))
def test_iso2time_performance_regression(self):
# If ISO_DATE_RE regresses to quadratic complexity, this test will take a very long time to succeed.
@@ -199,24 +195,23 @@ class DateTimeTests(unittest.TestCase):
class HeaderTests(unittest.TestCase):
- def test_parse_ns_headers(self):
- # quotes should be stripped
- expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
- for hdr in [
+ @support.subTests('hdr', [
'foo=bar; expires=01 Jan 2040 22:23:32 GMT',
'foo=bar; expires="01 Jan 2040 22:23:32 GMT"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
-
- def test_parse_ns_headers_version(self):
-
+ ])
+ def test_parse_ns_headers(self, hdr):
# quotes should be stripped
- expected = [[('foo', 'bar'), ('version', '1')]]
- for hdr in [
+ expected = [[('foo', 'bar'), ('expires', 2209069412), ('version', '0')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
+
+ @support.subTests('hdr', [
'foo=bar; version="1"',
'foo=bar; Version="1"',
- ]:
- self.assertEqual(parse_ns_headers([hdr]), expected)
+ ])
+ def test_parse_ns_headers_version(self, hdr):
+ # quotes should be stripped
+ expected = [[('foo', 'bar'), ('version', '1')]]
+ self.assertEqual(parse_ns_headers([hdr]), expected)
def test_parse_ns_headers_special_names(self):
# names such as 'expires' are not special in first name=value pair
@@ -226,8 +221,7 @@ class HeaderTests(unittest.TestCase):
expected = [[("expires", "01 Jan 2040 22:23:32 GMT"), ("version", "0")]]
self.assertEqual(parse_ns_headers([hdr]), expected)
- def test_join_header_words(self):
- for src, expected in [
+ @support.subTests('src,expected', [
([[("foo", None), ("bar", "baz")]], "foo; bar=baz"),
(([]), ""),
(([[]]), ""),
@@ -237,12 +231,11 @@ class HeaderTests(unittest.TestCase):
'n; foo="foo;_", bar=foo_bar'),
([[("n", "m"), ("foo", None)], [("bar", "foo_bar")]],
'n=m; foo, bar=foo_bar'),
- ]:
- with self.subTest(src=src):
- self.assertEqual(join_header_words(src), expected)
+ ])
+ def test_join_header_words(self, src, expected):
+ self.assertEqual(join_header_words(src), expected)
- def test_split_header_words(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", [[("foo", None)]]),
("foo=bar", [[("foo", "bar")]]),
(" foo ", [[("foo", None)]]),
@@ -259,24 +252,22 @@ class HeaderTests(unittest.TestCase):
(r'foo; bar=baz, spam=, foo="\,\;\"", bar= ',
[[("foo", None), ("bar", "baz")],
[("spam", "")], [("foo", ',;"')], [("bar", "")]]),
- ]
-
- for arg, expect in tests:
- try:
- result = split_header_words([arg])
- except:
- import traceback, io
- f = io.StringIO()
- traceback.print_exc(None, f)
- result = "(error -- traceback follows)\n\n%s" % f.getvalue()
- self.assertEqual(result, expect, """
+ ])
+ def test_split_header_words(self, arg, expect):
+ try:
+ result = split_header_words([arg])
+ except:
+ import traceback, io
+ f = io.StringIO()
+ traceback.print_exc(None, f)
+ result = "(error -- traceback follows)\n\n%s" % f.getvalue()
+ self.assertEqual(result, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
""" % (arg, expect, result))
- def test_roundtrip(self):
- tests = [
+ @support.subTests('arg,expect', [
("foo", "foo"),
("foo=bar", "foo=bar"),
(" foo ", "foo"),
@@ -309,12 +300,11 @@ Got: '%s'
('n; foo="foo;_", bar="foo,_"',
'n; foo="foo;_", bar="foo,_"'),
- ]
-
- for arg, expect in tests:
- input = split_header_words([arg])
- res = join_header_words(input)
- self.assertEqual(res, expect, """
+ ])
+ def test_roundtrip(self, arg, expect):
+ input = split_header_words([arg])
+ res = join_header_words(input)
+ self.assertEqual(res, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
@@ -516,14 +506,7 @@ class CookieTests(unittest.TestCase):
## just the 7 special TLD's listed in their spec. And folks rely on
## that...
- def test_domain_return_ok(self):
- # test optimization: .domain_return_ok() should filter out most
- # domains in the CookieJar before we try to access them (because that
- # may require disk access -- in particular, with MSIECookieJar)
- # This is only a rough check for performance reasons, so it's not too
- # critical as long as it's sufficiently liberal.
- pol = DefaultCookiePolicy()
- for url, domain, ok in [
+ @support.subTests('url,domain,ok', [
("http://foo.bar.com/", "blah.com", False),
("http://foo.bar.com/", "rhubarb.blah.com", False),
("http://foo.bar.com/", "rhubarb.foo.bar.com", False),
@@ -543,11 +526,18 @@ class CookieTests(unittest.TestCase):
("http://foo/", ".local", True),
("http://barfoo.com", ".foo.com", False),
("http://barfoo.com", "foo.com", False),
- ]:
- request = urllib.request.Request(url)
- r = pol.domain_return_ok(domain, request)
- if ok: self.assertTrue(r)
- else: self.assertFalse(r)
+ ])
+ def test_domain_return_ok(self, url, domain, ok):
+ # test optimization: .domain_return_ok() should filter out most
+ # domains in the CookieJar before we try to access them (because that
+ # may require disk access -- in particular, with MSIECookieJar)
+ # This is only a rough check for performance reasons, so it's not too
+ # critical as long as it's sufficiently liberal.
+ pol = DefaultCookiePolicy()
+ request = urllib.request.Request(url)
+ r = pol.domain_return_ok(domain, request)
+ if ok: self.assertTrue(r)
+ else: self.assertFalse(r)
def test_missing_value(self):
# missing = sign in Cookie: header is regarded by Mozilla as a missing
@@ -581,10 +571,7 @@ class CookieTests(unittest.TestCase):
self.assertEqual(interact_netscape(c, "http://www.acme.com/foo/"),
'"spam"; eggs')
- def test_rfc2109_handling(self):
- # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
- # dependent on policy settings
- for rfc2109_as_netscape, rfc2965, version in [
+ @support.subTests('rfc2109_as_netscape,rfc2965,version', [
# default according to rfc2965 if not explicitly specified
(None, False, 0),
(None, True, 1),
@@ -593,24 +580,27 @@ class CookieTests(unittest.TestCase):
(False, True, 1),
(True, False, 0),
(True, True, 0),
- ]:
- policy = DefaultCookiePolicy(
- rfc2109_as_netscape=rfc2109_as_netscape,
- rfc2965=rfc2965)
- c = CookieJar(policy)
- interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
- try:
- cookie = c._cookies["www.example.com"]["/"]["ni"]
- except KeyError:
- self.assertIsNone(version) # didn't expect a stored cookie
- else:
- self.assertEqual(cookie.version, version)
- # 2965 cookies are unaffected
- interact_2965(c, "http://www.example.com/",
- "foo=bar; Version=1")
- if rfc2965:
- cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
- self.assertEqual(cookie2965.version, 1)
+ ])
+ def test_rfc2109_handling(self, rfc2109_as_netscape, rfc2965, version):
+ # RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
+ # dependent on policy settings
+ policy = DefaultCookiePolicy(
+ rfc2109_as_netscape=rfc2109_as_netscape,
+ rfc2965=rfc2965)
+ c = CookieJar(policy)
+ interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
+ try:
+ cookie = c._cookies["www.example.com"]["/"]["ni"]
+ except KeyError:
+ self.assertIsNone(version) # didn't expect a stored cookie
+ else:
+ self.assertEqual(cookie.version, version)
+ # 2965 cookies are unaffected
+ interact_2965(c, "http://www.example.com/",
+ "foo=bar; Version=1")
+ if rfc2965:
+ cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
+ self.assertEqual(cookie2965.version, 1)
def test_ns_parser(self):
c = CookieJar()
@@ -778,8 +768,7 @@ class CookieTests(unittest.TestCase):
# Cookie is sent back to the same URI.
self.assertEqual(interact_netscape(cj, uri), value)
- def test_escape_path(self):
- cases = [
+ @support.subTests('arg,result', [
# quoted safe
("/foo%2f/bar", "/foo%2F/bar"),
("/foo%2F/bar", "/foo%2F/bar"),
@@ -799,9 +788,9 @@ class CookieTests(unittest.TestCase):
("/foo/bar\u00fc", "/foo/bar%C3%BC"), # UTF-8 encoded
# unicode
("/foo/bar\uabcd", "/foo/bar%EA%AF%8D"), # UTF-8 encoded
- ]
- for arg, result in cases:
- self.assertEqual(escape_path(arg), result)
+ ])
+ def test_escape_path(self, arg, result):
+ self.assertEqual(escape_path(arg), result)
def test_request_path(self):
# with parameters
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 2cafa4e45a1..2548a7c5f29 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -3,16 +3,16 @@
Written by Cody A.W. Somerville <cody-somerville@ubuntu.com>,
Josip Dzolonga, and Michael Otteneder for the 2007/08 GHOP contest.
"""
-from collections import OrderedDict
+
from http.server import BaseHTTPRequestHandler, HTTPServer, HTTPSServer, \
- SimpleHTTPRequestHandler, CGIHTTPRequestHandler
+ SimpleHTTPRequestHandler
from http import server, HTTPStatus
+import contextlib
import os
import socket
import sys
import re
-import base64
import ntpath
import pathlib
import shutil
@@ -21,6 +21,7 @@ import email.utils
import html
import http, http.client
import urllib.parse
+import urllib.request
import tempfile
import time
import datetime
@@ -31,8 +32,10 @@ from io import BytesIO, StringIO
import unittest
from test import support
from test.support import (
- is_apple, import_helper, os_helper, requires_subprocess, threading_helper
+ is_apple, import_helper, os_helper, threading_helper
)
+from test.support.script_helper import kill_python, spawn_python
+from test.support.socket_helper import find_unused_port
try:
import ssl
@@ -522,42 +525,120 @@ class SimpleHTTPServerTestCase(BaseTestCase):
reader.close()
return body
+ def check_list_dir_dirname(self, dirname, quotedname=None):
+ fullpath = os.path.join(self.tempdir, dirname)
+ try:
+ os.mkdir(os.path.join(self.tempdir, dirname))
+ except (OSError, UnicodeEncodeError):
+ self.skipTest(f'Can not create directory {dirname!a} '
+ f'on current file system')
+
+ if quotedname is None:
+ quotedname = urllib.parse.quote(dirname, errors='surrogatepass')
+ response = self.request(self.base_url + '/' + quotedname + '/')
+ body = self.check_status_and_reason(response, HTTPStatus.OK)
+ displaypath = html.escape(f'{self.base_url}/{dirname}/', quote=False)
+ enc = sys.getfilesystemencoding()
+ prefix = f'listing for {displaypath}</'.encode(enc, 'surrogateescape')
+ self.assertIn(prefix + b'title>', body)
+ self.assertIn(prefix + b'h1>', body)
+
+ def check_list_dir_filename(self, filename):
+ fullpath = os.path.join(self.tempdir, filename)
+ content = ascii(fullpath).encode() + (os_helper.TESTFN_UNDECODABLE or b'\xff')
+ try:
+ with open(fullpath, 'wb') as f:
+ f.write(content)
+ except OSError:
+ self.skipTest(f'Can not create file {filename!a} '
+ f'on current file system')
+
+ response = self.request(self.base_url + '/')
+ body = self.check_status_and_reason(response, HTTPStatus.OK)
+ quotedname = urllib.parse.quote(filename, errors='surrogatepass')
+ enc = response.headers.get_content_charset()
+ self.assertIsNotNone(enc)
+ self.assertIn((f'href="{quotedname}"').encode('ascii'), body)
+ displayname = html.escape(filename, quote=False)
+ self.assertIn(f'>{displayname}<'.encode(enc, 'surrogateescape'), body)
+
+ response = self.request(self.base_url + '/' + quotedname)
+ self.check_status_and_reason(response, HTTPStatus.OK, data=content)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ def test_list_dir_nonascii_dirname(self):
+ dirname = os_helper.TESTFN_NONASCII + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_NONASCII,
+ 'need os_helper.TESTFN_NONASCII')
+ def test_list_dir_nonascii_filename(self):
+ filename = os_helper.TESTFN_NONASCII + '.txt'
+ self.check_list_dir_filename(filename)
+
@unittest.skipIf(is_apple,
'undecodable name cannot always be decoded on Apple platforms')
@unittest.skipIf(sys.platform == 'win32',
'undecodable name cannot be decoded on win32')
@unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
'need os_helper.TESTFN_UNDECODABLE')
- def test_undecodable_filename(self):
- enc = sys.getfilesystemencoding()
- filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
- with open(os.path.join(self.tempdir, filename), 'wb') as f:
- f.write(os_helper.TESTFN_UNDECODABLE)
- response = self.request(self.base_url + '/')
- if is_apple:
- # On Apple platforms the HFS+ filesystem replaces bytes that
- # aren't valid UTF-8 into a percent-encoded value.
- for name in os.listdir(self.tempdir):
- if name != 'test': # Ignore a filename created in setUp().
- filename = name
- break
- body = self.check_status_and_reason(response, HTTPStatus.OK)
- quotedname = urllib.parse.quote(filename, errors='surrogatepass')
- self.assertIn(('href="%s"' % quotedname)
- .encode(enc, 'surrogateescape'), body)
- self.assertIn(('>%s<' % html.escape(filename, quote=False))
- .encode(enc, 'surrogateescape'), body)
- response = self.request(self.base_url + '/' + quotedname)
- self.check_status_and_reason(response, HTTPStatus.OK,
- data=os_helper.TESTFN_UNDECODABLE)
+ def test_list_dir_undecodable_dirname(self):
+ dirname = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.dir'
+ self.check_list_dir_dirname(dirname)
- def test_undecodable_parameter(self):
- # sanity check using a valid parameter
+ @unittest.skipIf(is_apple,
+ 'undecodable name cannot always be decoded on Apple platforms')
+ @unittest.skipIf(sys.platform == 'win32',
+ 'undecodable name cannot be decoded on win32')
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE,
+ 'need os_helper.TESTFN_UNDECODABLE')
+ def test_list_dir_undecodable_filename(self):
+ filename = os.fsdecode(os_helper.TESTFN_UNDECODABLE) + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_undecodable_dirname2(self):
+ dirname = '\ufffd.dir'
+ self.check_list_dir_dirname(dirname, quotedname='%ff.dir')
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ def test_list_dir_unencodable_dirname(self):
+ dirname = os_helper.TESTFN_UNENCODABLE + '.dir'
+ self.check_list_dir_dirname(dirname)
+
+ @unittest.skipUnless(os_helper.TESTFN_UNENCODABLE,
+ 'need os_helper.TESTFN_UNENCODABLE')
+ def test_list_dir_unencodable_filename(self):
+ filename = os_helper.TESTFN_UNENCODABLE + '.txt'
+ self.check_list_dir_filename(filename)
+
+ def test_list_dir_escape_dirname(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&amp;', '&amp', '<i>', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ dirname = name + '.dir'
+ self.check_list_dir_dirname(dirname,
+ quotedname=urllib.parse.quote(dirname, safe='&<>\'"'))
+
+ def test_list_dir_escape_filename(self):
+ # Characters that need special treating in URL or HTML.
+ for name in ('q?', 'f#', '&amp;', '&amp', '<i>', '"dq"', "'sq'",
+ '%A4', '%E2%82%AC'):
+ with self.subTest(name=name):
+ filename = name + '.txt'
+ self.check_list_dir_filename(filename)
+ os_helper.unlink(os.path.join(self.tempdir, filename))
+
+ def test_list_dir_with_query_and_fragment(self):
+ prefix = f'listing for {self.base_url}/</'.encode('latin1')
+ response = self.request(self.base_url + '/#123').read()
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
response = self.request(self.base_url + '/?x=123').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=123'.encode('latin1'))
- # now the bogus encoding
- response = self.request(self.base_url + '/?x=%bb').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=\xef\xbf\xbd'.encode('latin1'))
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
def test_get_dir_redirect_location_domain_injection_bug(self):
"""Ensure //evil.co/..%2f../../X does not put //evil.co/ in Location.
@@ -615,10 +696,19 @@ class SimpleHTTPServerTestCase(BaseTestCase):
# check for trailing "/" which should return 404. See Issue17324
response = self.request(self.base_url + '/test/')
self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2f')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
+ response = self.request(self.base_url + '/test%2F')
+ self.check_status_and_reason(response, HTTPStatus.NOT_FOUND)
response = self.request(self.base_url + '/')
self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2f')
+ self.check_status_and_reason(response, HTTPStatus.OK)
+ response = self.request(self.base_url + '%2F')
+ self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.base_url)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"), self.base_url + "/")
self.assertEqual(response.getheader("Content-Length"), "0")
response = self.request(self.base_url + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
@@ -724,6 +814,8 @@ class SimpleHTTPServerTestCase(BaseTestCase):
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name)
self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY)
+ self.assertEqual(response.getheader("Location"),
+ self.tempdir_name + "/")
response = self.request(self.tempdir_name + '/?hi=2')
self.check_status_and_reason(response, HTTPStatus.OK)
response = self.request(self.tempdir_name + '?hi=1')
@@ -731,350 +823,6 @@ class SimpleHTTPServerTestCase(BaseTestCase):
self.assertEqual(response.getheader("Location"),
self.tempdir_name + "/?hi=1")
- def test_html_escape_filename(self):
- filename = '<test&>.txt'
- fullpath = os.path.join(self.tempdir, filename)
-
- try:
- open(fullpath, 'wb').close()
- except OSError:
- raise unittest.SkipTest('Can not create file %s on current file '
- 'system' % filename)
-
- try:
- response = self.request(self.base_url + '/')
- body = self.check_status_and_reason(response, HTTPStatus.OK)
- enc = response.headers.get_content_charset()
- finally:
- os.unlink(fullpath) # avoid affecting test_undecodable_filename
-
- self.assertIsNotNone(enc)
- html_text = '>%s<' % html.escape(filename, quote=False)
- self.assertIn(html_text.encode(enc), body)
-
-
-cgi_file1 = """\
-#!%s
-
-print("Content-type: text/html")
-print()
-print("Hello World")
-"""
-
-cgi_file2 = """\
-#!%s
-import os
-import sys
-import urllib.parse
-
-print("Content-type: text/html")
-print()
-
-content_length = int(os.environ["CONTENT_LENGTH"])
-query_string = sys.stdin.buffer.read(content_length)
-params = {key.decode("utf-8"): val.decode("utf-8")
- for key, val in urllib.parse.parse_qsl(query_string)}
-
-print("%%s, %%s, %%s" %% (params["spam"], params["eggs"], params["bacon"]))
-"""
-
-cgi_file4 = """\
-#!%s
-import os
-
-print("Content-type: text/html")
-print()
-
-print(os.environ["%s"])
-"""
-
-cgi_file6 = """\
-#!%s
-import os
-
-print("X-ambv: was here")
-print("Content-type: text/html")
-print()
-print("<pre>")
-for k, v in os.environ.items():
- try:
- k.encode('ascii')
- v.encode('ascii')
- except UnicodeEncodeError:
- continue # see: BPO-44647
- print(f"{k}={v}")
-print("</pre>")
-"""
-
-
-@unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
- "This test can't be run reliably as root (issue #13308).")
-@requires_subprocess()
-class CGIHTTPServerTestCase(BaseTestCase):
- class request_handler(NoLogRequestHandler, CGIHTTPRequestHandler):
- _test_case_self = None # populated by each setUp() method call.
-
- def __init__(self, *args, **kwargs):
- with self._test_case_self.assertWarnsRegex(
- DeprecationWarning,
- r'http\.server\.CGIHTTPRequestHandler'):
- # This context also happens to catch and silence the
- # threading DeprecationWarning from os.fork().
- super().__init__(*args, **kwargs)
-
- linesep = os.linesep.encode('ascii')
-
- def setUp(self):
- self.request_handler._test_case_self = self # practical, but yuck.
- BaseTestCase.setUp(self)
- self.cwd = os.getcwd()
- self.parent_dir = tempfile.mkdtemp()
- self.cgi_dir = os.path.join(self.parent_dir, 'cgi-bin')
- self.cgi_child_dir = os.path.join(self.cgi_dir, 'child-dir')
- self.sub_dir_1 = os.path.join(self.parent_dir, 'sub')
- self.sub_dir_2 = os.path.join(self.sub_dir_1, 'dir')
- self.cgi_dir_in_sub_dir = os.path.join(self.sub_dir_2, 'cgi-bin')
- os.mkdir(self.cgi_dir)
- os.mkdir(self.cgi_child_dir)
- os.mkdir(self.sub_dir_1)
- os.mkdir(self.sub_dir_2)
- os.mkdir(self.cgi_dir_in_sub_dir)
- self.nocgi_path = None
- self.file1_path = None
- self.file2_path = None
- self.file3_path = None
- self.file4_path = None
- self.file5_path = None
-
- # The shebang line should be pure ASCII: use symlink if possible.
- # See issue #7668.
- self._pythonexe_symlink = None
- if os_helper.can_symlink():
- self.pythonexe = os.path.join(self.parent_dir, 'python')
- self._pythonexe_symlink = support.PythonSymlink(self.pythonexe).__enter__()
- else:
- self.pythonexe = sys.executable
-
- try:
- # The python executable path is written as the first line of the
- # CGI Python script. The encoding cookie cannot be used, and so the
- # path should be encodable to the default script encoding (utf-8)
- self.pythonexe.encode('utf-8')
- except UnicodeEncodeError:
- self.tearDown()
- self.skipTest("Python executable path is not encodable to utf-8")
-
- self.nocgi_path = os.path.join(self.parent_dir, 'nocgi.py')
- with open(self.nocgi_path, 'w', encoding='utf-8') as fp:
- fp.write(cgi_file1 % self.pythonexe)
- os.chmod(self.nocgi_path, 0o777)
-
- self.file1_path = os.path.join(self.cgi_dir, 'file1.py')
- with open(self.file1_path, 'w', encoding='utf-8') as file1:
- file1.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file1_path, 0o777)
-
- self.file2_path = os.path.join(self.cgi_dir, 'file2.py')
- with open(self.file2_path, 'w', encoding='utf-8') as file2:
- file2.write(cgi_file2 % self.pythonexe)
- os.chmod(self.file2_path, 0o777)
-
- self.file3_path = os.path.join(self.cgi_child_dir, 'file3.py')
- with open(self.file3_path, 'w', encoding='utf-8') as file3:
- file3.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file3_path, 0o777)
-
- self.file4_path = os.path.join(self.cgi_dir, 'file4.py')
- with open(self.file4_path, 'w', encoding='utf-8') as file4:
- file4.write(cgi_file4 % (self.pythonexe, 'QUERY_STRING'))
- os.chmod(self.file4_path, 0o777)
-
- self.file5_path = os.path.join(self.cgi_dir_in_sub_dir, 'file5.py')
- with open(self.file5_path, 'w', encoding='utf-8') as file5:
- file5.write(cgi_file1 % self.pythonexe)
- os.chmod(self.file5_path, 0o777)
-
- self.file6_path = os.path.join(self.cgi_dir, 'file6.py')
- with open(self.file6_path, 'w', encoding='utf-8') as file6:
- file6.write(cgi_file6 % self.pythonexe)
- os.chmod(self.file6_path, 0o777)
-
- os.chdir(self.parent_dir)
-
- def tearDown(self):
- self.request_handler._test_case_self = None
- try:
- os.chdir(self.cwd)
- if self._pythonexe_symlink:
- self._pythonexe_symlink.__exit__(None, None, None)
- if self.nocgi_path:
- os.remove(self.nocgi_path)
- if self.file1_path:
- os.remove(self.file1_path)
- if self.file2_path:
- os.remove(self.file2_path)
- if self.file3_path:
- os.remove(self.file3_path)
- if self.file4_path:
- os.remove(self.file4_path)
- if self.file5_path:
- os.remove(self.file5_path)
- if self.file6_path:
- os.remove(self.file6_path)
- os.rmdir(self.cgi_child_dir)
- os.rmdir(self.cgi_dir)
- os.rmdir(self.cgi_dir_in_sub_dir)
- os.rmdir(self.sub_dir_2)
- os.rmdir(self.sub_dir_1)
- # The 'gmon.out' file can be written in the current working
- # directory if C-level code profiling with gprof is enabled.
- os_helper.unlink(os.path.join(self.parent_dir, 'gmon.out'))
- os.rmdir(self.parent_dir)
- finally:
- BaseTestCase.tearDown(self)
-
- def test_url_collapse_path(self):
- # verify tail is the last portion and head is the rest on proper urls
- test_vectors = {
- '': '//',
- '..': IndexError,
- '/.//..': IndexError,
- '/': '//',
- '//': '//',
- '/\\': '//\\',
- '/.//': '//',
- 'cgi-bin/file1.py': '/cgi-bin/file1.py',
- '/cgi-bin/file1.py': '/cgi-bin/file1.py',
- 'a': '//a',
- '/a': '//a',
- '//a': '//a',
- './a': '//a',
- './C:/': '/C:/',
- '/a/b': '/a/b',
- '/a/b/': '/a/b/',
- '/a/b/.': '/a/b/',
- '/a/b/c/..': '/a/b/',
- '/a/b/c/../d': '/a/b/d',
- '/a/b/c/../d/e/../f': '/a/b/d/f',
- '/a/b/c/../d/e/../../f': '/a/b/f',
- '/a/b/c/../d/e/.././././..//f': '/a/b/f',
- '../a/b/c/../d/e/.././././..//f': IndexError,
- '/a/b/c/../d/e/../../../f': '/a/f',
- '/a/b/c/../d/e/../../../../f': '//f',
- '/a/b/c/../d/e/../../../../../f': IndexError,
- '/a/b/c/../d/e/../../../../f/..': '//',
- '/a/b/c/../d/e/../../../../f/../.': '//',
- }
- for path, expected in test_vectors.items():
- if isinstance(expected, type) and issubclass(expected, Exception):
- self.assertRaises(expected,
- server._url_collapse_path, path)
- else:
- actual = server._url_collapse_path(path)
- self.assertEqual(expected, actual,
- msg='path = %r\nGot: %r\nWanted: %r' %
- (path, actual, expected))
-
- def test_headers_and_content(self):
- res = self.request('/cgi-bin/file1.py')
- self.assertEqual(
- (res.read(), res.getheader('Content-type'), res.status),
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK))
-
- def test_issue19435(self):
- res = self.request('///////////nocgi.py/../cgi-bin/nothere.sh')
- self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
-
- def test_post(self):
- params = urllib.parse.urlencode(
- {'spam' : 1, 'eggs' : 'python', 'bacon' : 123456})
- headers = {'Content-type' : 'application/x-www-form-urlencoded'}
- res = self.request('/cgi-bin/file2.py', 'POST', params, headers)
-
- self.assertEqual(res.read(), b'1, python, 123456' + self.linesep)
-
- def test_invaliduri(self):
- res = self.request('/cgi-bin/invalid')
- res.read()
- self.assertEqual(res.status, HTTPStatus.NOT_FOUND)
-
- def test_authorization(self):
- headers = {b'Authorization' : b'Basic ' +
- base64.b64encode(b'username:pass')}
- res = self.request('/cgi-bin/file1.py', 'GET', headers=headers)
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_no_leading_slash(self):
- # http://bugs.python.org/issue2254
- res = self.request('cgi-bin/file1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_os_environ_is_not_altered(self):
- signature = "Test CGI Server"
- os.environ['SERVER_SOFTWARE'] = signature
- res = self.request('/cgi-bin/file1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
- self.assertEqual(os.environ['SERVER_SOFTWARE'], signature)
-
- def test_urlquote_decoding_in_cgi_check(self):
- res = self.request('/cgi-bin%2ffile1.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_nested_cgi_path_issue21323(self):
- res = self.request('/cgi-bin/child-dir/file3.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_query_with_multiple_question_mark(self):
- res = self.request('/cgi-bin/file4.py?a=b?c=d')
- self.assertEqual(
- (b'a=b?c=d' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_query_with_continuous_slashes(self):
- res = self.request('/cgi-bin/file4.py?k=aa%2F%2Fbb&//q//p//=//a//b//')
- self.assertEqual(
- (b'k=aa%2F%2Fbb&//q//p//=//a//b//' + self.linesep,
- 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
-
- def test_cgi_path_in_sub_directories(self):
- try:
- CGIHTTPRequestHandler.cgi_directories.append('/sub/dir/cgi-bin')
- res = self.request('/sub/dir/cgi-bin/file5.py')
- self.assertEqual(
- (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK),
- (res.read(), res.getheader('Content-type'), res.status))
- finally:
- CGIHTTPRequestHandler.cgi_directories.remove('/sub/dir/cgi-bin')
-
- def test_accept(self):
- browser_accept = \
- 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8'
- tests = (
- ((('Accept', browser_accept),), browser_accept),
- ((), ''),
- # Hack case to get two values for the one header
- ((('Accept', 'text/html'), ('ACCEPT', 'text/plain')),
- 'text/html,text/plain'),
- )
- for headers, expected in tests:
- headers = OrderedDict(headers)
- with self.subTest(headers):
- res = self.request('/cgi-bin/file6.py', 'GET', headers=headers)
- self.assertEqual(http.HTTPStatus.OK, res.status)
- expected = f"HTTP_ACCEPT={expected}".encode('ascii')
- self.assertIn(expected, res.read())
-
class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, directory=None):
@@ -1095,6 +843,7 @@ class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def log_message(self, format, *args):
pass
+
class RejectingSocketlessRequestHandler(SocketlessRequestHandler):
def handle_expect_100(self):
self.send_error(HTTPStatus.EXPECTATION_FAILED)
@@ -1536,6 +1285,243 @@ class ScriptTestCase(unittest.TestCase):
self.assertEqual(mock_server.address_family, socket.AF_INET)
+class CommandLineTestCase(unittest.TestCase):
+ default_port = 8000
+ default_bind = None
+ default_protocol = 'HTTP/1.0'
+ default_handler = SimpleHTTPRequestHandler
+ default_server = unittest.mock.ANY
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = 'somepass'
+ tls_cert_options = ['--tls-cert']
+ tls_key_options = ['--tls-key']
+ tls_password_options = ['--tls-password-file']
+ args = {
+ 'HandlerClass': default_handler,
+ 'ServerClass': default_server,
+ 'protocol': default_protocol,
+ 'port': default_port,
+ 'bind': default_bind,
+ 'tls_cert': None,
+ 'tls_key': None,
+ 'tls_password': None,
+ }
+
+ def setUp(self):
+ super().setUp()
+ self.tls_password_file = tempfile.mktemp()
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password.encode())
+ self.addCleanup(os_helper.unlink, self.tls_password_file)
+
+ def invoke_httpd(self, *args, stdout=None, stderr=None):
+ stdout = StringIO() if stdout is None else stdout
+ stderr = StringIO() if stderr is None else stderr
+ with contextlib.redirect_stdout(stdout), \
+ contextlib.redirect_stderr(stderr):
+ server._main(args)
+ return stdout.getvalue(), stderr.getvalue()
+
+ @mock.patch('http.server.test')
+ def test_port_flag(self, mock_func):
+ ports = [8000, 65535]
+ for port in ports:
+ with self.subTest(port=port):
+ self.invoke_httpd(str(port))
+ call_args = self.args | dict(port=port)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_directory_flag(self, mock_func):
+ options = ['-d', '--directory']
+ directories = ['.', '/foo', '\\bar', '/',
+ 'C:\\', 'C:\\foo', 'C:\\bar',
+ '/home/user', './foo/foo2', 'D:\\foo\\bar']
+ for flag in options:
+ for directory in directories:
+ with self.subTest(flag=flag, directory=directory):
+ self.invoke_httpd(flag, directory)
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_bind_flag(self, mock_func):
+ options = ['-b', '--bind']
+ bind_addresses = ['localhost', '127.0.0.1', '::1',
+ '0.0.0.0', '8.8.8.8']
+ for flag in options:
+ for bind_address in bind_addresses:
+ with self.subTest(flag=flag, bind_address=bind_address):
+ self.invoke_httpd(flag, bind_address)
+ call_args = self.args | dict(bind=bind_address)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_protocol_flag(self, mock_func):
+ options = ['-p', '--protocol']
+ protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0']
+ for flag in options:
+ for protocol in protocols:
+ with self.subTest(flag=flag, protocol=protocol):
+ self.invoke_httpd(flag, protocol)
+ call_args = self.args | dict(protocol=protocol)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ self.invoke_httpd(tls_cert_option, self.tls_cert,
+ tls_key_option, self.tls_key)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_and_password_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ for tls_password_option in self.tls_password_options:
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_key_option,
+ self.tls_key,
+ tls_password_option,
+ self.tls_password_file)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ 'tls_password': self.tls_password,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_missing_tls_cert_flag(self, mock_func):
+ for tls_key_option in self.tls_key_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_key_option, self.tls_key)
+ mock_func.reset_mock()
+
+ for tls_password_option in self.tls_password_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_password_option, self.tls_password)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_invalid_password_file(self, mock_func):
+ non_existent_file = 'non_existent_file'
+ for tls_password_option in self.tls_password_options:
+ for tls_cert_option in self.tls_cert_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_password_option,
+ non_existent_file)
+
+ @mock.patch('http.server.test')
+ def test_no_arguments(self, mock_func):
+ self.invoke_httpd()
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_help_flag(self, _):
+ options = ['-h', '--help']
+ for option in options:
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(option, stdout=stdout, stderr=stderr)
+ self.assertIn('usage', stdout.getvalue())
+ self.assertEqual(stderr.getvalue(), '')
+
+ @mock.patch('http.server.test')
+ def test_unknown_flag(self, _):
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr)
+ self.assertEqual(stdout.getvalue(), '')
+ self.assertIn('error', stderr.getvalue())
+
+
+class CommandLineRunTimeTestCase(unittest.TestCase):
+ served_data = os.urandom(32)
+ served_filename = 'served_filename'
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = b'somepass'
+ tls_password_file = 'ssl_key_password'
+
+ def setUp(self):
+ super().setUp()
+ server_dir_context = os_helper.temp_cwd()
+ server_dir = self.enterContext(server_dir_context)
+ with open(self.served_filename, 'wb') as f:
+ f.write(self.served_data)
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password)
+
+ def fetch_file(self, path, context=None):
+ req = urllib.request.Request(path, method='GET')
+ with urllib.request.urlopen(req, context=context) as res:
+ return res.read()
+
+ def parse_cli_output(self, output):
+ match = re.search(r'Serving (HTTP|HTTPS) on (.+) port (\d+)', output)
+ if match is None:
+ return None, None, None
+ return match.group(1).lower(), match.group(2), int(match.group(3))
+
+ def wait_for_server(self, proc, protocol, bind, port):
+ """Check that the server has been successfully started."""
+ line = proc.stdout.readline().strip()
+ if support.verbose:
+ print()
+ print('python -m http.server: ', line)
+ return self.parse_cli_output(line) == (protocol, bind, port)
+
+ def test_http_client(self):
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'http', bind, port))
+ res = self.fetch_file(f'http://{bind}:{port}/{self.served_filename}')
+ self.assertEqual(res, self.served_data)
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ def test_https_client(self):
+ context = ssl.create_default_context()
+ # allow self-signed certificates
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ '--tls-cert', self.tls_cert,
+ '--tls-key', self.tls_key,
+ '--tls-password-file', self.tls_password_file,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'https', bind, port))
+ url = f'https://{bind}:{port}/{self.served_filename}'
+ res = self.fetch_file(url, context=context)
+ self.assertEqual(res, self.served_data)
+
+
def setUpModule():
unittest.addModuleCleanup(os.chdir, os.getcwd())
diff --git a/Lib/test/test_idle.py b/Lib/test/test_idle.py
index 3d8b7ecc0ec..ebf572ac5ca 100644
--- a/Lib/test/test_idle.py
+++ b/Lib/test/test_idle.py
@@ -16,7 +16,7 @@ idlelib.testing = True
# Unittest.main and test.libregrtest.runtest.runtest_inner
# call load_tests, when present here, to discover tests to run.
-from idlelib.idle_test import load_tests
+from idlelib.idle_test import load_tests # noqa: F401
if __name__ == '__main__':
tk.NoDefaultRoot()
diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py
index b5f4645847a..6e34094c5aa 100644
--- a/Lib/test/test_import/__init__.py
+++ b/Lib/test/test_import/__init__.py
@@ -1001,7 +1001,7 @@ from not_a_module import symbol
expected_error = error + (
rb" \(consider renaming '.*numpy.py' if it has the "
- rb"same name as a library you intended to import\)\s+\Z"
+ rb"same name as a library you intended to import\)\s+\z"
)
popen = script_helper.spawn_python(os.path.join(tmp, "numpy.py"))
@@ -1022,14 +1022,14 @@ from not_a_module import symbol
f.write("this_script_does_not_attempt_to_import_numpy = True")
expected_error = (
- rb"AttributeError: module 'numpy' has no attribute 'attr'\s+\Z"
+ rb"AttributeError: module 'numpy' has no attribute 'attr'\s+\z"
)
popen = script_helper.spawn_python('-c', 'import numpy; numpy.attr', cwd=tmp)
stdout, stderr = popen.communicate()
self.assertRegex(stdout, expected_error)
expected_error = (
- rb"ImportError: cannot import name 'attr' from 'numpy' \(.*\)\s+\Z"
+ rb"ImportError: cannot import name 'attr' from 'numpy' \(.*\)\s+\z"
)
popen = script_helper.spawn_python('-c', 'from numpy import attr', cwd=tmp)
stdout, stderr = popen.communicate()
diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py
index e535d119763..1549cbe96ce 100644
--- a/Lib/test/test_importlib/import_/test_relative_imports.py
+++ b/Lib/test/test_importlib/import_/test_relative_imports.py
@@ -223,6 +223,21 @@ class RelativeImports:
self.__import__('sys', {'__package__': '', '__spec__': None},
level=1)
+ def test_malicious_relative_import(self):
+ # https://github.com/python/cpython/issues/134100
+ # Test to make sure UAF bug with error msg doesn't come back to life
+ import sys
+ loooong = "".ljust(0x23000, "b")
+ name = f"a.{loooong}.c"
+
+ with util.uncache(name):
+ sys.modules[name] = {}
+ with self.assertRaisesRegex(
+ KeyError,
+ r"'a\.b+' not in sys\.modules as expected"
+ ):
+ __import__(f"{loooong}.c", {"__package__": "a"}, level=1)
+
(Frozen_RelativeImports,
Source_RelativeImports
diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py
index befac5d62b0..655e5881a15 100644
--- a/Lib/test/test_importlib/test_locks.py
+++ b/Lib/test/test_importlib/test_locks.py
@@ -34,6 +34,7 @@ class ModuleLockAsRLockTests:
# lock status in repr unsupported
test_repr = None
test_locked_repr = None
+ test_repr_count = None
def tearDown(self):
for splitinit in init.values():
diff --git a/Lib/test/test_importlib/test_threaded_import.py b/Lib/test/test_importlib/test_threaded_import.py
index 9af1e4d505c..f78dc399720 100644
--- a/Lib/test/test_importlib/test_threaded_import.py
+++ b/Lib/test/test_importlib/test_threaded_import.py
@@ -135,10 +135,12 @@ class ThreadedImportTests(unittest.TestCase):
if verbose:
print("OK.")
- def test_parallel_module_init(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_module_init(self, size):
self.check_parallel_module_init()
- def test_parallel_meta_path(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_meta_path(self, size):
finder = Finder()
sys.meta_path.insert(0, finder)
try:
@@ -148,7 +150,8 @@ class ThreadedImportTests(unittest.TestCase):
finally:
sys.meta_path.remove(finder)
- def test_parallel_path_hooks(self):
+ @support.bigmemtest(size=50, memuse=76*2**20, dry_run=False)
+ def test_parallel_path_hooks(self, size):
# Here the Finder instance is only used to check concurrent calls
# to path_hook().
finder = Finder()
@@ -242,13 +245,15 @@ class ThreadedImportTests(unittest.TestCase):
__import__(TESTFN)
del sys.modules[TESTFN]
- def test_concurrent_futures_circular_import(self):
+ @support.bigmemtest(size=1, memuse=1.8*2**30, dry_run=False)
+ def test_concurrent_futures_circular_import(self, size):
# Regression test for bpo-43515
fn = os.path.join(os.path.dirname(__file__),
'partial', 'cfimport.py')
script_helper.assert_python_ok(fn)
- def test_multiprocessing_pool_circular_import(self):
+ @support.bigmemtest(size=1, memuse=1.8*2**30, dry_run=False)
+ def test_multiprocessing_pool_circular_import(self, size):
# Regression test for bpo-41567
fn = os.path.join(os.path.dirname(__file__),
'partial', 'pool_in_threads.py')
diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py
index 4950af42cfe..79eb103224b 100644
--- a/Lib/test/test_inspect/test_inspect.py
+++ b/Lib/test/test_inspect/test_inspect.py
@@ -786,12 +786,12 @@ class TestRetrievingSourceCode(GetSourceBase):
def test_getfile_builtin_module(self):
with self.assertRaises(TypeError) as e:
inspect.getfile(sys)
- self.assertTrue(str(e.exception).startswith('<module'))
+ self.assertStartsWith(str(e.exception), '<module')
def test_getfile_builtin_class(self):
with self.assertRaises(TypeError) as e:
inspect.getfile(int)
- self.assertTrue(str(e.exception).startswith('<class'))
+ self.assertStartsWith(str(e.exception), '<class')
def test_getfile_builtin_function_or_method(self):
with self.assertRaises(TypeError) as e_abs:
@@ -2949,7 +2949,7 @@ class TestSignatureObject(unittest.TestCase):
pass
sig = inspect.signature(test)
- self.assertTrue(repr(sig).startswith('<Signature'))
+ self.assertStartsWith(repr(sig), '<Signature')
self.assertTrue('(po, /, pk' in repr(sig))
# We need two functions, because it is impossible to represent
@@ -2958,7 +2958,7 @@ class TestSignatureObject(unittest.TestCase):
pass
sig2 = inspect.signature(test2)
- self.assertTrue(repr(sig2).startswith('<Signature'))
+ self.assertStartsWith(repr(sig2), '<Signature')
self.assertTrue('(pod=42, /)' in repr(sig2))
po = sig.parameters['po']
@@ -4997,6 +4997,37 @@ class TestSignatureObject(unittest.TestCase):
with self.assertRaisesRegex(NameError, "undefined"):
signature_func(ida.f)
+ def test_signature_deferred_annotations(self):
+ def f(x: undef):
+ pass
+
+ class C:
+ x: undef
+
+ def __init__(self, x: undef):
+ self.x = x
+
+ sig = inspect.signature(f, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['x'])
+ sig = inspect.signature(C, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['x'])
+
+ class CallableWrapper:
+ def __init__(self, func):
+ self.func = func
+ self.__annotate__ = func.__annotate__
+
+ def __call__(self, *args, **kwargs):
+ return self.func(*args, **kwargs)
+
+ @property
+ def __annotations__(self):
+ return self.__annotate__(Format.VALUE)
+
+ cw = CallableWrapper(f)
+ sig = inspect.signature(cw, annotation_format=Format.FORWARDREF)
+ self.assertEqual(list(sig.parameters), ['args', 'kwargs'])
+
def test_signature_none_annotation(self):
class funclike:
# Has to be callable, and have correct
@@ -5102,7 +5133,7 @@ class TestParameterObject(unittest.TestCase):
with self.assertRaisesRegex(ValueError, 'cannot have default values'):
p.replace(kind=inspect.Parameter.VAR_POSITIONAL)
- self.assertTrue(repr(p).startswith('<Parameter'))
+ self.assertStartsWith(repr(p), '<Parameter')
self.assertTrue('"a=42"' in repr(p))
def test_signature_parameter_hashable(self):
@@ -5846,7 +5877,7 @@ class TestSignatureDefinitions(unittest.TestCase):
def test_os_module_has_signatures(self):
unsupported_signature = {'chmod', 'utime'}
unsupported_signature |= {name for name in
- ['get_terminal_size', 'posix_spawn', 'posix_spawnp',
+ ['get_terminal_size', 'link', 'posix_spawn', 'posix_spawnp',
'register_at_fork', 'startfile']
if hasattr(os, name)}
self._test_module_has_signatures(os, unsupported_signature=unsupported_signature)
@@ -6146,12 +6177,14 @@ class TestRepl(unittest.TestCase):
object.
"""
+ # TODO(picnixz): refactor this as it's used by test_repl.py
+
# To run the REPL without using a terminal, spawn python with the command
# line option '-i' and the process name set to '<stdin>'.
# The directory of argv[0] must match the directory of the Python
# executable for the Popen() call to python to succeed as the directory
- # path may be used by Py_GetPath() to build the default module search
- # path.
+ # path may be used by PyConfig_Get("module_search_paths") to build the
+ # default module search path.
stdin_fname = os.path.join(os.path.dirname(sys.executable), "<stdin>")
cmd_line = [stdin_fname, '-E', '-i']
cmd_line.extend(args)
diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py
index 245528ce57a..7a7cb73f673 100644
--- a/Lib/test/test_int.py
+++ b/Lib/test/test_int.py
@@ -836,7 +836,7 @@ class PyLongModuleTests(unittest.TestCase):
n = hibit | getrandbits(bits - 1)
assert n.bit_length() == bits
sn = str(n)
- self.assertFalse(sn.startswith('0'))
+ self.assertNotStartsWith(sn, '0')
self.assertEqual(n, int(sn))
bits <<= 1
diff --git a/Lib/test/test_interpreters/test_api.py b/Lib/test/test_interpreters/test_api.py
index 66c7afce88f..0ee4582b5d1 100644
--- a/Lib/test/test_interpreters/test_api.py
+++ b/Lib/test/test_interpreters/test_api.py
@@ -1,18 +1,23 @@
+import contextlib
import os
import pickle
+import sys
from textwrap import dedent
import threading
import types
import unittest
from test import support
+from test.support import os_helper
+from test.support import script_helper
from test.support import import_helper
# Raise SkipTest if subinterpreters not supported.
_interpreters = import_helper.import_module('_interpreters')
+from concurrent import interpreters
from test.support import Py_GIL_DISABLED
-from test.support import interpreters
from test.support import force_not_colorized
-from test.support.interpreters import (
+import test._crossinterp_definitions as defs
+from concurrent.interpreters import (
InterpreterError, InterpreterNotFoundError, ExecutionFailed,
)
from .utils import (
@@ -29,6 +34,59 @@ WHENCE_STR_XI = 'cross-interpreter C-API'
WHENCE_STR_STDLIB = '_interpreters module'
+def is_pickleable(obj):
+ try:
+ pickle.dumps(obj)
+ except Exception:
+ return False
+ return True
+
+
+@contextlib.contextmanager
+def defined_in___main__(name, script, *, remove=False):
+ import __main__ as mainmod
+ mainns = vars(mainmod)
+ assert name not in mainns
+ exec(script, mainns, mainns)
+ if remove:
+ yield mainns.pop(name)
+ else:
+ try:
+ yield mainns[name]
+ finally:
+ mainns.pop(name, None)
+
+
+def build_excinfo(exctype, msg=None, formatted=None, errdisplay=None):
+ if isinstance(exctype, type):
+ assert issubclass(exctype, BaseException), exctype
+ exctype = types.SimpleNamespace(
+ __name__=exctype.__name__,
+ __qualname__=exctype.__qualname__,
+ __module__=exctype.__module__,
+ )
+ elif isinstance(exctype, str):
+ module, _, name = exctype.rpartition(exctype)
+ if not module and name in __builtins__:
+ module = 'builtins'
+ exctype = types.SimpleNamespace(
+ __name__=name,
+ __qualname__=exctype,
+ __module__=module or None,
+ )
+ else:
+ assert isinstance(exctype, types.SimpleNamespace)
+ assert msg is None or isinstance(msg, str), msg
+ assert formatted is None or isinstance(formatted, str), formatted
+ assert errdisplay is None or isinstance(errdisplay, str), errdisplay
+ return types.SimpleNamespace(
+ type=exctype,
+ msg=msg,
+ formatted=formatted,
+ errdisplay=errdisplay,
+ )
+
+
class ModuleTests(TestBase):
def test_queue_aliases(self):
@@ -75,7 +133,7 @@ class CreateTests(TestBase):
main, = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.create()
print(interp.id)
"""))
@@ -138,7 +196,7 @@ class GetCurrentTests(TestBase):
main = interpreters.get_main()
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(cur.id)
"""))
@@ -155,7 +213,7 @@ class GetCurrentTests(TestBase):
with self.subTest('subinterpreter'):
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(id(cur))
cur = interpreters.get_current()
@@ -167,7 +225,7 @@ class GetCurrentTests(TestBase):
with self.subTest('per-interpreter'):
interp = interpreters.create()
out = _run_output(interp, dedent("""
- from test.support import interpreters
+ from concurrent import interpreters
cur = interpreters.get_current()
print(id(cur))
"""))
@@ -524,7 +582,7 @@ class TestInterpreterClose(TestBase):
main, = interpreters.list_all()
interp = interpreters.create()
out = _run_output(interp, dedent(f"""
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.Interpreter({interp.id})
try:
interp.close()
@@ -541,7 +599,7 @@ class TestInterpreterClose(TestBase):
self.assertEqual(set(interpreters.list_all()),
{main, interp1, interp2})
interp1.exec(dedent(f"""
- from test.support import interpreters
+ from concurrent import interpreters
interp2 = interpreters.Interpreter({interp2.id})
interp2.close()
interp3 = interpreters.create()
@@ -748,7 +806,7 @@ class TestInterpreterExec(TestBase):
ham()
""")
scriptfile = self.make_script('script.py', tempdir, text="""
- from test.support import interpreters
+ from concurrent import interpreters
def script():
import spam
@@ -769,7 +827,7 @@ class TestInterpreterExec(TestBase):
~~~~~~~~~~~^^^^^^^^
{interpmod_line.strip()}
raise ExecutionFailed(excinfo)
- test.support.interpreters.ExecutionFailed: RuntimeError: uh-oh!
+ concurrent.interpreters.ExecutionFailed: RuntimeError: uh-oh!
Uncaught in the interpreter:
@@ -839,9 +897,16 @@ class TestInterpreterExec(TestBase):
interp.exec(10)
def test_bytes_for_script(self):
+ r, w = self.pipe()
+ RAN = b'R'
+ DONE = b'D'
interp = interpreters.create()
- with self.assertRaises(TypeError):
- interp.exec(b'print("spam")')
+ interp.exec(f"""if True:
+ import os
+ os.write({w}, {RAN!r})
+ """)
+ os.write(w, DONE)
+ self.assertEqual(os.read(r, 1), RAN)
def test_with_background_threads_still_running(self):
r_interp, w_interp = self.pipe()
@@ -879,28 +944,46 @@ class TestInterpreterExec(TestBase):
with self.assertRaisesRegex(InterpreterError, 'unrecognized'):
interp.exec('raise Exception("it worked!")')
+ def test_list_comprehension(self):
+ # gh-135450: List comprehensions caused an assertion failure
+ # in _PyCode_CheckNoExternalState()
+ import string
+ r_interp, w_interp = self.pipe()
+
+ interp = interpreters.create()
+ interp.exec(f"""if True:
+ import os
+ comp = [str(i) for i in range(10)]
+ os.write({w_interp}, ''.join(comp).encode())
+ """)
+ self.assertEqual(os.read(r_interp, 10).decode(), string.digits)
+ interp.close()
+
+
# test__interpreters covers the remaining
# Interpreter.exec() behavior.
-def call_func_noop():
- pass
+call_func_noop = defs.spam_minimal
+call_func_ident = defs.spam_returns_arg
+call_func_failure = defs.spam_raises
def call_func_return_shareable():
return (1, None)
-def call_func_return_not_shareable():
- return [1, 2, 3]
+def call_func_return_stateless_func():
+ return (lambda x: x)
-def call_func_failure():
- raise Exception('spam!')
+def call_func_return_pickleable():
+ return [1, 2, 3]
-def call_func_ident(value):
- return value
+def call_func_return_unpickleable():
+ x = 42
+ return (lambda: x)
def get_call_func_closure(value):
@@ -909,6 +992,11 @@ def get_call_func_closure(value):
return call_func_closure
+def call_func_exec_wrapper(script, ns):
+ res = exec(script, ns, ns)
+ return res, ns, id(ns)
+
+
class Spam:
@staticmethod
@@ -1005,86 +1093,663 @@ class TestInterpreterCall(TestBase):
# - preserves info (e.g. SyntaxError)
# - matching error display
- def test_call(self):
+ @contextlib.contextmanager
+ def assert_fails(self, expected):
+ with self.assertRaises(ExecutionFailed) as cm:
+ yield cm
+ uncaught = cm.exception.excinfo
+ self.assertEqual(uncaught.type.__name__, expected.__name__)
+
+ def assert_fails_not_shareable(self):
+ return self.assert_fails(interpreters.NotShareableError)
+
+ def assert_code_equal(self, code1, code2):
+ if code1 == code2:
+ return
+ self.assertEqual(code1.co_name, code2.co_name)
+ self.assertEqual(code1.co_flags, code2.co_flags)
+ self.assertEqual(code1.co_consts, code2.co_consts)
+ self.assertEqual(code1.co_varnames, code2.co_varnames)
+ self.assertEqual(code1.co_cellvars, code2.co_cellvars)
+ self.assertEqual(code1.co_freevars, code2.co_freevars)
+ self.assertEqual(code1.co_names, code2.co_names)
+ self.assertEqual(
+ _testinternalcapi.get_code_var_counts(code1),
+ _testinternalcapi.get_code_var_counts(code2),
+ )
+ self.assertEqual(code1.co_code, code2.co_code)
+
+ def assert_funcs_equal(self, func1, func2):
+ if func1 == func2:
+ return
+ self.assertIs(type(func1), type(func2))
+ self.assertEqual(func1.__name__, func2.__name__)
+ self.assertEqual(func1.__defaults__, func2.__defaults__)
+ self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__)
+ self.assertEqual(func1.__closure__, func2.__closure__)
+ self.assert_code_equal(func1.__code__, func2.__code__)
+ self.assertEqual(
+ _testinternalcapi.get_code_var_counts(func1),
+ _testinternalcapi.get_code_var_counts(func2),
+ )
+
+ def assert_exceptions_equal(self, exc1, exc2):
+ assert isinstance(exc1, Exception)
+ assert isinstance(exc2, Exception)
+ if exc1 == exc2:
+ return
+ self.assertIs(type(exc1), type(exc2))
+ self.assertEqual(exc1.args, exc2.args)
+
+ def test_stateless_funcs(self):
interp = interpreters.create()
- for i, (callable, args, kwargs) in enumerate([
- (call_func_noop, (), {}),
- (call_func_return_shareable, (), {}),
- (call_func_return_not_shareable, (), {}),
- (Spam.noop, (), {}),
+ func = call_func_noop
+ with self.subTest('no args, no return'):
+ res = interp.call(func)
+ self.assertIsNone(res)
+
+ func = call_func_return_shareable
+ with self.subTest('no args, returns shareable'):
+ res = interp.call(func)
+ self.assertEqual(res, (1, None))
+
+ func = call_func_return_stateless_func
+ expected = (lambda x: x)
+ with self.subTest('no args, returns stateless func'):
+ res = interp.call(func)
+ self.assert_funcs_equal(res, expected)
+
+ func = call_func_return_pickleable
+ with self.subTest('no args, returns pickleable'):
+ res = interp.call(func)
+ self.assertEqual(res, [1, 2, 3])
+
+ func = call_func_return_unpickleable
+ with self.subTest('no args, returns unpickleable'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+
+ def test_stateless_func_returns_arg(self):
+ interp = interpreters.create()
+
+ for arg in [
+ None,
+ 10,
+ 'spam!',
+ b'spam!',
+ (1, 2, 'spam!'),
+ memoryview(b'spam!'),
+ ]:
+ with self.subTest(f'shareable {arg!r}'):
+ assert _interpreters.is_shareable(arg)
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assertEqual(res, arg)
+
+ for arg in defs.STATELESS_FUNCTIONS:
+ with self.subTest(f'stateless func {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assert_funcs_equal(res, arg)
+
+ for arg in defs.TOP_FUNCTIONS:
+ if arg in defs.STATELESS_FUNCTIONS:
+ continue
+ with self.subTest(f'stateful func {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ self.assert_funcs_equal(res, arg)
+ assert is_pickleable(arg)
+
+ for arg in [
+ Ellipsis,
+ NotImplemented,
+ object(),
+ 2**1000,
+ [1, 2, 3],
+ {'a': 1, 'b': 2},
+ types.SimpleNamespace(x=42),
+ # builtin types
+ object,
+ type,
+ Exception,
+ ModuleNotFoundError,
+ # builtin exceptions
+ Exception('uh-oh!'),
+ ModuleNotFoundError('mymodule'),
+ # builtin fnctions
+ len,
+ sys.exit,
+ # user classes
+ *defs.TOP_CLASSES,
+ *(c(*a) for c, a in defs.TOP_CLASSES.items()
+ if c not in defs.CLASSES_WITHOUT_EQUALITY),
+ ]:
+ with self.subTest(f'pickleable {arg!r}'):
+ res = interp.call(defs.spam_returns_arg, arg)
+ if type(arg) is object:
+ self.assertIs(type(res), object)
+ elif isinstance(arg, BaseException):
+ self.assert_exceptions_equal(res, arg)
+ else:
+ self.assertEqual(res, arg)
+ assert is_pickleable(arg)
+
+ for arg in [
+ types.MappingProxyType({}),
+ *(f for f in defs.NESTED_FUNCTIONS
+ if f not in defs.STATELESS_FUNCTIONS),
+ ]:
+ with self.subTest(f'unpickleable {arg!r}'):
+ assert not _interpreters.is_shareable(arg)
+ assert not is_pickleable(arg)
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ def test_full_args(self):
+ interp = interpreters.create()
+ expected = (1, 2, 3, 4, 5, 6, ('?',), {'g': 7, 'h': 8})
+ func = defs.spam_full_args
+ res = interp.call(func, 1, 2, 3, 4, '?', e=5, f=6, g=7, h=8)
+ self.assertEqual(res, expected)
+
+ def test_full_defaults(self):
+ # pickleable, but not stateless
+ interp = interpreters.create()
+ expected = (-1, -2, -3, -4, -5, -6, (), {'g': 8, 'h': 9})
+ res = interp.call(defs.spam_full_args_with_defaults, g=8, h=9)
+ self.assertEqual(res, expected)
+
+ def test_modified_arg(self):
+ interp = interpreters.create()
+ script = dedent("""
+ a = 7
+ b = 2
+ c = a ** b
+ """)
+ ns = {}
+ expected = {'a': 7, 'b': 2, 'c': 49}
+ res = interp.call(call_func_exec_wrapper, script, ns)
+ obj, resns, resid = res
+ del resns['__builtins__']
+ self.assertIsNone(obj)
+ self.assertEqual(ns, {})
+ self.assertEqual(resns, expected)
+ self.assertNotEqual(resid, id(ns))
+ self.assertNotEqual(resid, id(resns))
+
+ def test_func_in___main___valid(self):
+ # pickleable, already there'
+
+ with os_helper.temp_dir() as tempdir:
+ def new_mod(name, text):
+ script_helper.make_script(tempdir, name, dedent(text))
+
+ def run(text):
+ name = 'myscript'
+ text = dedent(f"""
+ import sys
+ sys.path.insert(0, {tempdir!r})
+
+ """) + dedent(text)
+ filename = script_helper.make_script(tempdir, name, text)
+ res = script_helper.assert_python_ok(filename)
+ return res.out.decode('utf-8').strip()
+
+ # no module indirection
+ with self.subTest('no indirection'):
+ text = run(f"""
+ from concurrent import interpreters
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ interp = interpreters.create()
+ res = interp.call(spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ # indirect as func, direct interp
+ new_mod('mymod', f"""
+ def run(interp, func):
+ return interp.call(func)
+ """)
+ with self.subTest('indirect as func, direct interp'):
+ text = run(f"""
+ from concurrent import interpreters
+ import mymod
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ interp = interpreters.create()
+ res = mymod.run(interp, spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ # indirect as func, indirect interp
+ new_mod('mymod', f"""
+ from concurrent import interpreters
+ def run(func):
+ interp = interpreters.create()
+ return interp.call(func)
+ """)
+ with self.subTest('indirect as func, indirect interp'):
+ text = run(f"""
+ import mymod
+
+ def spam():
+ # This a global var...
+ return __name__
+
+ if __name__ == '__main__':
+ res = mymod.run(spam)
+ print(res)
+ """)
+ self.assertEqual(text, '<fake __main__>')
+
+ def test_func_in___main___invalid(self):
+ interp = interpreters.create()
+
+ funcname = f'{__name__.replace(".", "_")}_spam_okay'
+ script = dedent(f"""
+ def {funcname}():
+ # This a global var...
+ return __name__
+ """)
+
+ with self.subTest('pickleable, added dynamically'):
+ with defined_in___main__(funcname, script) as arg:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ with self.subTest('lying about __main__'):
+ with defined_in___main__(funcname, script, remove=True) as arg:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(defs.spam_returns_arg, arg)
+
+ def test_func_in___main___hidden(self):
+ # When a top-level function that uses global variables is called
+ # through Interpreter.call(), it will be pickled, sent over,
+ # and unpickled. That requires that it be found in the other
+ # interpreter's __main__ module. However, the original script
+ # that defined the function is only run in the main interpreter,
+ # so pickle.loads() would normally fail.
+ #
+ # We work around this by running the script in the other
+ # interpreter. However, this is a one-off solution for the sake
+ # of unpickling, so we avoid modifying that interpreter's
+ # __main__ module by running the script in a hidden module.
+ #
+ # In this test we verify that the function runs with the hidden
+ # module as its __globals__ when called in the other interpreter,
+ # and that the interpreter's __main__ module is unaffected.
+ text = dedent("""
+ eggs = True
+
+ def spam(*, explicit=False):
+ if explicit:
+ import __main__
+ ns = __main__.__dict__
+ else:
+ # For now we have to have a LOAD_GLOBAL in the
+ # function in order for globals() to actually return
+ # spam.__globals__. Maybe it doesn't go through pickle?
+ # XXX We will fix this later.
+ spam
+ ns = globals()
+
+ func = ns.get('spam')
+ return [
+ id(ns),
+ ns.get('__name__'),
+ ns.get('__file__'),
+ id(func),
+ None if func is None else repr(func),
+ ns.get('eggs'),
+ ns.get('ham'),
+ ]
+
+ if __name__ == "__main__":
+ from concurrent import interpreters
+ interp = interpreters.create()
+
+ ham = True
+ print([
+ [
+ spam(explicit=True),
+ spam(),
+ ],
+ [
+ interp.call(spam, explicit=True),
+ interp.call(spam),
+ ],
+ ])
+ """)
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ local, remote = eval(stdout)
+
+ # In the main interpreter.
+ main, unpickled = local
+ nsid, _, _, funcid, func, _, _ = main
+ self.assertEqual(main, [
+ nsid,
+ '__main__',
+ filename,
+ funcid,
+ func,
+ True,
+ True,
+ ])
+ self.assertIsNot(func, None)
+ self.assertRegex(func, '^<function spam at 0x.*>$')
+ self.assertEqual(unpickled, main)
+
+ # In the subinterpreter.
+ main, unpickled = remote
+ nsid1, _, _, funcid1, _, _, _ = main
+ self.assertEqual(main, [
+ nsid1,
+ '__main__',
+ None,
+ funcid1,
+ None,
+ None,
+ None,
+ ])
+ nsid2, _, _, funcid2, func, _, _ = unpickled
+ self.assertEqual(unpickled, [
+ nsid2,
+ '<fake __main__>',
+ filename,
+ funcid2,
+ func,
+ True,
+ None,
+ ])
+ self.assertIsNot(func, None)
+ self.assertRegex(func, '^<function spam at 0x.*>$')
+ self.assertNotEqual(nsid2, nsid1)
+ self.assertNotEqual(funcid2, funcid1)
+
+ def test_func_in___main___uses_globals(self):
+ # See the note in test_func_in___main___hidden about pickle
+ # and the __main__ module.
+ #
+ # Additionally, the solution to that problem must provide
+ # for global variables on which a pickled function might rely.
+ #
+ # To check that, we run a script that has two global functions
+ # and a global variable in the __main__ module. One of the
+ # functions sets the global variable and the other returns
+ # the value.
+ #
+ # The script calls those functions multiple times in another
+ # interpreter, to verify the following:
+ #
+ # * the global variable is properly initialized
+ # * the global variable retains state between calls
+ # * the setter modifies that persistent variable
+ # * the getter uses the variable
+ # * the calls in the other interpreter do not modify
+ # the main interpreter
+ # * those calls don't modify the interpreter's __main__ module
+ # * the functions and variable do not actually show up in the
+ # other interpreter's __main__ module
+ text = dedent("""
+ count = 0
+
+ def inc(x=1):
+ global count
+ count += x
+
+ def get_count():
+ return count
+
+ if __name__ == "__main__":
+ counts = []
+ results = [count, counts]
+
+ from concurrent import interpreters
+ interp = interpreters.create()
+
+ val = interp.call(get_count)
+ counts.append(val)
+
+ interp.call(inc)
+ val = interp.call(get_count)
+ counts.append(val)
+
+ interp.call(inc, 3)
+ val = interp.call(get_count)
+ counts.append(val)
+
+ results.append(count)
+
+ modified = {name: interp.call(eval, f'{name!r} in vars()')
+ for name in ('count', 'inc', 'get_count')}
+ results.append(modified)
+
+ print(results)
+ """)
+ with os_helper.temp_dir() as tempdir:
+ filename = script_helper.make_script(tempdir, 'my-script', text)
+ res = script_helper.assert_python_ok(filename)
+ stdout = res.out.decode('utf-8').strip()
+ before, counts, after, modified = eval(stdout)
+ self.assertEqual(modified, {
+ 'count': False,
+ 'inc': False,
+ 'get_count': False,
+ })
+ self.assertEqual(before, 0)
+ self.assertEqual(after, 0)
+ self.assertEqual(counts, [0, 1, 4])
+
+ def test_raises(self):
+ interp = interpreters.create()
+ with self.assertRaises(ExecutionFailed):
+ interp.call(call_func_failure)
+
+ with self.assert_fails(ValueError):
+ interp.call(call_func_complex, '???', exc=ValueError('spam'))
+
+ def test_call_valid(self):
+ interp = interpreters.create()
+
+ for i, (callable, args, kwargs, expected) in enumerate([
+ (call_func_noop, (), {}, None),
+ (call_func_ident, ('spamspamspam',), {}, 'spamspamspam'),
+ (call_func_return_shareable, (), {}, (1, None)),
+ (call_func_return_pickleable, (), {}, [1, 2, 3]),
+ (Spam.noop, (), {}, None),
+ (Spam.from_values, (), {}, Spam(())),
+ (Spam.from_values, (1, 2, 3), {}, Spam((1, 2, 3))),
+ (Spam, ('???',), {}, Spam('???')),
+ (Spam(101), (), {}, (101, (), {})),
+ (Spam(10101).run, (), {}, (10101, (), {})),
+ (call_func_complex, ('ident', 'spam'), {}, 'spam'),
+ (call_func_complex, ('full-ident', 'spam'), {}, ('spam', (), {})),
+ (call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'},
+ ('spam', ('ham',), {'eggs': '!!!'})),
+ (call_func_complex, ('globals',), {}, __name__),
+ (call_func_complex, ('interpid',), {}, interp.id),
+ (call_func_complex, ('custom', 'spam!'), {}, Spam('spam!')),
]):
with self.subTest(f'success case #{i+1}'):
- res = interp.call(callable)
- self.assertIs(res, None)
+ res = interp.call(callable, *args, **kwargs)
+ self.assertEqual(res, expected)
+
+ def test_call_invalid(self):
+ interp = interpreters.create()
+
+ func = get_call_func_closure
+ with self.subTest(func):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, 42)
+
+ func = get_call_func_closure(42)
+ with self.subTest(func):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+
+ func = call_func_complex
+ op = 'closure'
+ with self.subTest(f'{func} ({op})'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, op, value='~~~')
+
+ op = 'custom-inner'
+ with self.subTest(f'{func} ({op})'):
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func, op, 'eggs!')
+
+ def test_callable_requires_frame(self):
+ # There are various functions that require a current frame.
+ interp = interpreters.create()
+ for call, expected in [
+ ((eval, '[1, 2, 3]'),
+ [1, 2, 3]),
+ ((eval, 'sum([1, 2, 3])'),
+ 6),
+ ((exec, '...'),
+ None),
+ ]:
+ with self.subTest(str(call)):
+ res = interp.call(*call)
+ self.assertEqual(res, expected)
+
+ result_not_pickleable = [
+ globals,
+ locals,
+ vars,
+ ]
+ for func, expectedtype in {
+ globals: dict,
+ locals: dict,
+ vars: dict,
+ dir: list,
+ }.items():
+ with self.subTest(str(func)):
+ if func in result_not_pickleable:
+ with self.assertRaises(interpreters.NotShareableError):
+ interp.call(func)
+ else:
+ res = interp.call(func)
+ self.assertIsInstance(res, expectedtype)
+ self.assertIn('__builtins__', res)
+
+ def test_globals_from_builtins(self):
+ # The builtins exec(), eval(), globals(), locals(), vars(),
+ # and dir() each runs relative to the target interpreter's
+ # __main__ module, when called directly. However,
+ # globals(), locals(), and vars() don't work when called
+ # directly so we don't check them.
+ from _frozen_importlib import BuiltinImporter
+ interp = interpreters.create()
+
+ names = interp.call(dir)
+ self.assertEqual(names, [
+ '__builtins__',
+ '__doc__',
+ '__loader__',
+ '__name__',
+ '__package__',
+ '__spec__',
+ ])
+
+ values = {name: interp.call(eval, name)
+ for name in names if name != '__builtins__'}
+ self.assertEqual(values, {
+ '__name__': '__main__',
+ '__doc__': None,
+ '__spec__': None, # It wasn't imported, so no module spec?
+ '__package__': None,
+ '__loader__': BuiltinImporter,
+ })
+ with self.assertRaises(ExecutionFailed):
+ interp.call(eval, 'spam'),
+
+ interp.call(exec, f'assert dir() == {names}')
+
+ # Update the interpreter's __main__.
+ interp.prepare_main(spam=42)
+ expected = names + ['spam']
+
+ names = interp.call(dir)
+ self.assertEqual(names, expected)
+
+ value = interp.call(eval, 'spam')
+ self.assertEqual(value, 42)
+
+ interp.call(exec, f'assert dir() == {expected}, dir()')
+
+ def test_globals_from_stateless_func(self):
+ # A stateless func, which doesn't depend on any globals,
+ # doesn't go through pickle, so it runs in __main__.
+ def set_global(name, value):
+ globals()[name] = value
+
+ def get_global(name):
+ return globals().get(name)
+
+ interp = interpreters.create()
+
+ modname = interp.call(get_global, '__name__')
+ self.assertEqual(modname, '__main__')
+
+ res = interp.call(get_global, 'spam')
+ self.assertIsNone(res)
+
+ interp.exec('spam = True')
+ res = interp.call(get_global, 'spam')
+ self.assertTrue(res)
+
+ interp.call(set_global, 'spam', 42)
+ res = interp.call(get_global, 'spam')
+ self.assertEqual(res, 42)
+
+ interp.exec('assert spam == 42, repr(spam)')
+
+ def test_call_in_thread(self):
+ interp = interpreters.create()
for i, (callable, args, kwargs) in enumerate([
- (call_func_ident, ('spamspamspam',), {}),
- (get_call_func_closure, (42,), {}),
- (get_call_func_closure(42), (), {}),
+ (call_func_noop, (), {}),
+ (call_func_return_shareable, (), {}),
+ (call_func_return_pickleable, (), {}),
(Spam.from_values, (), {}),
(Spam.from_values, (1, 2, 3), {}),
- (Spam, ('???'), {}),
(Spam(101), (), {}),
(Spam(10101).run, (), {}),
+ (Spam.noop, (), {}),
(call_func_complex, ('ident', 'spam'), {}),
(call_func_complex, ('full-ident', 'spam'), {}),
(call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'}),
(call_func_complex, ('globals',), {}),
(call_func_complex, ('interpid',), {}),
- (call_func_complex, ('closure',), {'value': '~~~'}),
(call_func_complex, ('custom', 'spam!'), {}),
- (call_func_complex, ('custom-inner', 'eggs!'), {}),
- (call_func_complex, ('???',), {'exc': ValueError('spam')}),
- ]):
- with self.subTest(f'invalid case #{i+1}'):
- with self.assertRaises(Exception):
- if args or kwargs:
- raise Exception((args, kwargs))
- interp.call(callable)
-
- with self.assertRaises(ExecutionFailed):
- interp.call(call_func_failure)
-
- def test_call_in_thread(self):
- interp = interpreters.create()
-
- for i, (callable, args, kwargs) in enumerate([
- (call_func_noop, (), {}),
- (call_func_return_shareable, (), {}),
- (call_func_return_not_shareable, (), {}),
- (Spam.noop, (), {}),
]):
with self.subTest(f'success case #{i+1}'):
with self.captured_thread_exception() as ctx:
- t = interp.call_in_thread(callable)
+ t = interp.call_in_thread(callable, *args, **kwargs)
t.join()
self.assertIsNone(ctx.caught)
for i, (callable, args, kwargs) in enumerate([
- (call_func_ident, ('spamspamspam',), {}),
(get_call_func_closure, (42,), {}),
(get_call_func_closure(42), (), {}),
- (Spam.from_values, (), {}),
- (Spam.from_values, (1, 2, 3), {}),
- (Spam, ('???'), {}),
- (Spam(101), (), {}),
- (Spam(10101).run, (), {}),
- (call_func_complex, ('ident', 'spam'), {}),
- (call_func_complex, ('full-ident', 'spam'), {}),
- (call_func_complex, ('full-ident', 'spam', 'ham'), {'eggs': '!!!'}),
- (call_func_complex, ('globals',), {}),
- (call_func_complex, ('interpid',), {}),
- (call_func_complex, ('closure',), {'value': '~~~'}),
- (call_func_complex, ('custom', 'spam!'), {}),
- (call_func_complex, ('custom-inner', 'eggs!'), {}),
- (call_func_complex, ('???',), {'exc': ValueError('spam')}),
]):
with self.subTest(f'invalid case #{i+1}'):
- if args or kwargs:
- continue
with self.captured_thread_exception() as ctx:
- t = interp.call_in_thread(callable)
+ t = interp.call_in_thread(callable, *args, **kwargs)
t.join()
self.assertIsNotNone(ctx.caught)
@@ -1452,6 +2117,14 @@ class LowLevelTests(TestBase):
self.assertFalse(
self.interp_exists(interpid))
+ with self.subTest('basic C-API'):
+ interpid = _testinternalcapi.create_interpreter()
+ self.assertTrue(
+ self.interp_exists(interpid))
+ _testinternalcapi.destroy_interpreter(interpid, basic=True)
+ self.assertFalse(
+ self.interp_exists(interpid))
+
def test_get_config(self):
# This test overlaps with
# test.test_capi.test_misc.InterpreterConfigTests.
@@ -1585,18 +2258,14 @@ class LowLevelTests(TestBase):
with results:
exc = _interpreters.exec(interpid, script)
out = results.stdout()
- self.assertEqual(out, '')
- self.assert_ns_equal(exc, types.SimpleNamespace(
- type=types.SimpleNamespace(
- __name__='Exception',
- __qualname__='Exception',
- __module__='builtins',
- ),
- msg='uh-oh!',
+ expected = build_excinfo(
+ Exception, 'uh-oh!',
# We check these in other tests.
formatted=exc.formatted,
errdisplay=exc.errdisplay,
- ))
+ )
+ self.assertEqual(out, '')
+ self.assert_ns_equal(exc, expected)
with self.subTest('from C-API'):
with self.interpreter_from_capi() as interpid:
@@ -1608,25 +2277,50 @@ class LowLevelTests(TestBase):
self.assertEqual(exc.msg, 'it worked!')
def test_call(self):
- with self.subTest('no args'):
- interpid = _interpreters.create()
- exc = _interpreters.call(interpid, call_func_return_shareable)
- self.assertIs(exc, None)
+ interpid = _interpreters.create()
+
+ # Here we focus on basic args and return values.
+ # See TestInterpreterCall for full operational coverage,
+ # including supported callables.
+
+ with self.subTest('no args, return None'):
+ func = defs.spam_minimal
+ res, exc = _interpreters.call(interpid, func)
+ self.assertIsNone(exc)
+ self.assertIsNone(res)
+
+ with self.subTest('empty args, return None'):
+ func = defs.spam_minimal
+ res, exc = _interpreters.call(interpid, func, (), {})
+ self.assertIsNone(exc)
+ self.assertIsNone(res)
+
+ with self.subTest('no args, return non-None'):
+ func = defs.script_with_return
+ res, exc = _interpreters.call(interpid, func)
+ self.assertIsNone(exc)
+ self.assertIs(res, True)
+
+ with self.subTest('full args, return non-None'):
+ expected = (1, 2, 3, 4, 5, 6, (7, 8), {'g': 9, 'h': 0})
+ func = defs.spam_full_args
+ args = (1, 2, 3, 4, 7, 8)
+ kwargs = dict(e=5, f=6, g=9, h=0)
+ res, exc = _interpreters.call(interpid, func, args, kwargs)
+ self.assertIsNone(exc)
+ self.assertEqual(res, expected)
with self.subTest('uncaught exception'):
- interpid = _interpreters.create()
- exc = _interpreters.call(interpid, call_func_failure)
- self.assertEqual(exc, types.SimpleNamespace(
- type=types.SimpleNamespace(
- __name__='Exception',
- __qualname__='Exception',
- __module__='builtins',
- ),
- msg='spam!',
+ func = defs.spam_raises
+ res, exc = _interpreters.call(interpid, func)
+ expected = build_excinfo(
+ Exception, 'spam!',
# We check these in other tests.
formatted=exc.formatted,
errdisplay=exc.errdisplay,
- ))
+ )
+ self.assertIsNone(res)
+ self.assertEqual(exc, expected)
@requires_test_modules
def test_set___main___attrs(self):
diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py
index eada18f99d0..109ddf34453 100644
--- a/Lib/test/test_interpreters/test_channels.py
+++ b/Lib/test/test_interpreters/test_channels.py
@@ -8,8 +8,8 @@ import time
from test.support import import_helper
# Raise SkipTest if subinterpreters not supported.
_channels = import_helper.import_module('_interpchannels')
-from test.support import interpreters
-from test.support.interpreters import channels
+from concurrent import interpreters
+from test.support import channels
from .utils import _run_output, TestBase
@@ -171,7 +171,7 @@ class TestSendRecv(TestBase):
def test_send_recv_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import channels
+ from test.support import channels
r, s = channels.create()
orig = b'spam'
s.send_nowait(orig)
@@ -244,7 +244,7 @@ class TestSendRecv(TestBase):
def test_send_recv_nowait_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import channels
+ from test.support import channels
r, s = channels.create()
orig = b'spam'
s.send_nowait(orig)
@@ -377,17 +377,17 @@ class TestSendRecv(TestBase):
if not unbound:
extraargs = ''
elif unbound is channels.UNBOUND:
- extraargs = ', unbound=channels.UNBOUND'
+ extraargs = ', unbounditems=channels.UNBOUND'
elif unbound is channels.UNBOUND_ERROR:
- extraargs = ', unbound=channels.UNBOUND_ERROR'
+ extraargs = ', unbounditems=channels.UNBOUND_ERROR'
elif unbound is channels.UNBOUND_REMOVE:
- extraargs = ', unbound=channels.UNBOUND_REMOVE'
+ extraargs = ', unbounditems=channels.UNBOUND_REMOVE'
else:
raise NotImplementedError(repr(unbound))
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
sch = channels.SendChannel({sch.id})
obj1 = b'spam'
obj2 = b'eggs'
@@ -454,11 +454,11 @@ class TestSendRecv(TestBase):
with self.assertRaises(channels.ChannelEmptyError):
rch.recv_nowait()
- sch.send_nowait(b'ham', unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(b'ham', unbounditems=channels.UNBOUND_REMOVE)
self.assertEqual(_channels.get_count(rch.id), 1)
interp = common(rch, sch, channels.UNBOUND_REMOVE, 1)
self.assertEqual(_channels.get_count(rch.id), 3)
- sch.send_nowait(42, unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(42, unbounditems=channels.UNBOUND_REMOVE)
self.assertEqual(_channels.get_count(rch.id), 4)
del interp
self.assertEqual(_channels.get_count(rch.id), 2)
@@ -482,13 +482,13 @@ class TestSendRecv(TestBase):
self.assertEqual(_channels.get_count(rch.id), 0)
_run_output(interp, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
sch = channels.SendChannel({sch.id})
- sch.send_nowait(1, unbound=channels.UNBOUND)
- sch.send_nowait(2, unbound=channels.UNBOUND_ERROR)
+ sch.send_nowait(1, unbounditems=channels.UNBOUND)
+ sch.send_nowait(2, unbounditems=channels.UNBOUND_ERROR)
sch.send_nowait(3)
- sch.send_nowait(4, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(5, unbound=channels.UNBOUND)
+ sch.send_nowait(4, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(5, unbounditems=channels.UNBOUND)
"""))
self.assertEqual(_channels.get_count(rch.id), 5)
@@ -518,15 +518,15 @@ class TestSendRecv(TestBase):
sch.send_nowait(1)
_run_output(interp1, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
rch = channels.RecvChannel({rch.id})
sch = channels.SendChannel({sch.id})
obj1 = rch.recv()
- sch.send_nowait(2, unbound=channels.UNBOUND)
- sch.send_nowait(obj1, unbound=channels.UNBOUND_REMOVE)
+ sch.send_nowait(2, unbounditems=channels.UNBOUND)
+ sch.send_nowait(obj1, unbounditems=channels.UNBOUND_REMOVE)
"""))
_run_output(interp2, dedent(f"""
- from test.support.interpreters import channels
+ from test.support import channels
rch = channels.RecvChannel({rch.id})
sch = channels.SendChannel({sch.id})
obj2 = rch.recv()
@@ -535,21 +535,21 @@ class TestSendRecv(TestBase):
self.assertEqual(_channels.get_count(rch.id), 0)
sch.send_nowait(3)
_run_output(interp1, dedent("""
- sch.send_nowait(4, unbound=channels.UNBOUND)
+ sch.send_nowait(4, unbounditems=channels.UNBOUND)
# interp closed here
- sch.send_nowait(5, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(6, unbound=channels.UNBOUND)
+ sch.send_nowait(5, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(6, unbounditems=channels.UNBOUND)
"""))
_run_output(interp2, dedent("""
- sch.send_nowait(7, unbound=channels.UNBOUND_ERROR)
+ sch.send_nowait(7, unbounditems=channels.UNBOUND_ERROR)
# interp closed here
- sch.send_nowait(obj1, unbound=channels.UNBOUND_ERROR)
- sch.send_nowait(obj2, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(8, unbound=channels.UNBOUND)
+ sch.send_nowait(obj1, unbounditems=channels.UNBOUND_ERROR)
+ sch.send_nowait(obj2, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(8, unbounditems=channels.UNBOUND)
"""))
_run_output(interp1, dedent("""
- sch.send_nowait(9, unbound=channels.UNBOUND_REMOVE)
- sch.send_nowait(10, unbound=channels.UNBOUND)
+ sch.send_nowait(9, unbounditems=channels.UNBOUND_REMOVE)
+ sch.send_nowait(10, unbounditems=channels.UNBOUND)
"""))
self.assertEqual(_channels.get_count(rch.id), 10)
diff --git a/Lib/test/test_interpreters/test_lifecycle.py b/Lib/test/test_interpreters/test_lifecycle.py
index ac24f6568ac..15537ac6cc8 100644
--- a/Lib/test/test_interpreters/test_lifecycle.py
+++ b/Lib/test/test_interpreters/test_lifecycle.py
@@ -119,7 +119,7 @@ class StartupTests(TestBase):
# The main interpreter's sys.path[0] should be used by subinterpreters.
script = '''
import sys
- from test.support import interpreters
+ from concurrent import interpreters
orig = sys.path[0]
@@ -170,7 +170,7 @@ class FinalizationTests(TestBase):
# is reported, even when subinterpreters get cleaned up at the end.
import subprocess
argv = [sys.executable, '-c', '''if True:
- from test.support import interpreters
+ from concurrent import interpreters
interp = interpreters.create()
raise Exception
''']
diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py
index 18f83d097eb..cb17340f581 100644
--- a/Lib/test/test_interpreters/test_queues.py
+++ b/Lib/test/test_interpreters/test_queues.py
@@ -7,8 +7,8 @@ import unittest
from test.support import import_helper, Py_DEBUG
# Raise SkipTest if subinterpreters not supported.
_queues = import_helper.import_module('_interpqueues')
-from test.support import interpreters
-from test.support.interpreters import queues, _crossinterp
+from concurrent import interpreters
+from concurrent.interpreters import _queues as queues, _crossinterp
from .utils import _run_output, TestBase as _TestBase
@@ -42,7 +42,7 @@ class LowLevelTests(TestBase):
importlib.reload(queues)
def test_create_destroy(self):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.destroy(qid)
self.assertEqual(get_num_queues(), 0)
with self.assertRaises(queues.QueueNotFoundError):
@@ -56,7 +56,7 @@ class LowLevelTests(TestBase):
'-c',
dedent(f"""
import {_queues.__name__} as _queues
- _queues.create(2, 0, {REPLACE})
+ _queues.create(2, {REPLACE}, -1)
"""),
)
self.assertEqual(stdout, '')
@@ -67,13 +67,13 @@ class LowLevelTests(TestBase):
def test_bind_release(self):
with self.subTest('typical'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.release(qid)
self.assertEqual(get_num_queues(), 0)
with self.subTest('bind too much'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
@@ -81,7 +81,7 @@ class LowLevelTests(TestBase):
self.assertEqual(get_num_queues(), 0)
with self.subTest('nested'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
_queues.bind(qid)
_queues.bind(qid)
_queues.release(qid)
@@ -89,7 +89,7 @@ class LowLevelTests(TestBase):
self.assertEqual(get_num_queues(), 0)
with self.subTest('release without binding'):
- qid = _queues.create(2, 0, REPLACE)
+ qid = _queues.create(2, REPLACE, -1)
with self.assertRaises(queues.QueueError):
_queues.release(qid)
@@ -126,19 +126,19 @@ class QueueTests(TestBase):
interp = interpreters.create()
interp.exec(dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue1 = queues.Queue({queue1.id})
"""));
with self.subTest('same interpreter'):
queue2 = queues.create()
- queue1.put(queue2, syncobj=True)
+ queue1.put(queue2)
queue3 = queue1.get()
self.assertIs(queue3, queue2)
with self.subTest('from current interpreter'):
queue4 = queues.create()
- queue1.put(queue4, syncobj=True)
+ queue1.put(queue4)
out = _run_output(interp, dedent("""
queue4 = queue1.get()
print(queue4.id)
@@ -149,7 +149,7 @@ class QueueTests(TestBase):
with self.subTest('from subinterpreter'):
out = _run_output(interp, dedent("""
queue5 = queues.create()
- queue1.put(queue5, syncobj=True)
+ queue1.put(queue5)
print(queue5.id)
"""))
qid = int(out)
@@ -198,7 +198,7 @@ class TestQueueOps(TestBase):
def test_empty(self):
queue = queues.create()
before = queue.empty()
- queue.put(None, syncobj=True)
+ queue.put(None)
during = queue.empty()
queue.get()
after = queue.empty()
@@ -208,18 +208,64 @@ class TestQueueOps(TestBase):
self.assertIs(after, True)
def test_full(self):
- expected = [False, False, False, True, False, False, False]
- actual = []
- queue = queues.create(3)
- for _ in range(3):
- actual.append(queue.full())
- queue.put(None, syncobj=True)
- actual.append(queue.full())
- for _ in range(3):
- queue.get()
- actual.append(queue.full())
+ for maxsize in [1, 3, 11]:
+ with self.subTest(f'maxsize={maxsize}'):
+ num_to_add = maxsize
+ expected = [False] * (num_to_add * 2 + 3)
+ expected[maxsize] = True
+ expected[maxsize + 1] = True
+
+ queue = queues.create(maxsize)
+ actual = []
+ empty = [queue.empty()]
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.put_nowait(None)
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueFull):
+ queue.put_nowait(None)
+ empty.append(queue.empty())
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.get_nowait()
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueEmpty):
+ queue.get_nowait()
+ actual.append(queue.full())
+ empty.append(queue.empty())
- self.assertEqual(actual, expected)
+ self.assertEqual(actual, expected)
+ self.assertEqual(empty, [True, False, True])
+
+ # no max size
+ for args in [(), (0,), (-1,), (-10,)]:
+ with self.subTest(f'maxsize={args[0]}' if args else '<default>'):
+ num_to_add = 13
+ expected = [False] * (num_to_add * 2 + 3)
+
+ queue = queues.create(*args)
+ actual = []
+ empty = [queue.empty()]
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.put_nowait(None)
+ actual.append(queue.full())
+ empty.append(queue.empty())
+
+ for _ in range(num_to_add):
+ actual.append(queue.full())
+ queue.get_nowait()
+ actual.append(queue.full())
+ with self.assertRaises(queues.QueueEmpty):
+ queue.get_nowait()
+ actual.append(queue.full())
+ empty.append(queue.empty())
+
+ self.assertEqual(actual, expected)
+ self.assertEqual(empty, [True, False, True])
def test_qsize(self):
expected = [0, 1, 2, 3, 2, 3, 2, 1, 0, 1, 0]
@@ -227,16 +273,16 @@ class TestQueueOps(TestBase):
queue = queues.create()
for _ in range(3):
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
queue.get()
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
for _ in range(3):
queue.get()
actual.append(queue.qsize())
- queue.put(None, syncobj=True)
+ queue.put(None)
actual.append(queue.qsize())
queue.get()
actual.append(queue.qsize())
@@ -245,70 +291,32 @@ class TestQueueOps(TestBase):
def test_put_get_main(self):
expected = list(range(20))
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create()
- for i in range(20):
- queue.put(i, **kwds)
- actual = [queue.get() for _ in range(20)]
+ queue = queues.create()
+ for i in range(20):
+ queue.put(i)
+ actual = [queue.get() for _ in range(20)]
- self.assertEqual(actual, expected)
+ self.assertEqual(actual, expected)
def test_put_timeout(self):
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create(2)
- queue.put(None, **kwds)
- queue.put(None, **kwds)
- with self.assertRaises(queues.QueueFull):
- queue.put(None, timeout=0.1, **kwds)
- queue.get()
- queue.put(None, **kwds)
+ queue = queues.create(2)
+ queue.put(None)
+ queue.put(None)
+ with self.assertRaises(queues.QueueFull):
+ queue.put(None, timeout=0.1)
+ queue.get()
+ queue.put(None)
def test_put_nowait(self):
- for syncobj in (True, False):
- kwds = dict(syncobj=syncobj)
- with self.subTest(f'syncobj={syncobj}'):
- queue = queues.create(2)
- queue.put_nowait(None, **kwds)
- queue.put_nowait(None, **kwds)
- with self.assertRaises(queues.QueueFull):
- queue.put_nowait(None, **kwds)
- queue.get()
- queue.put_nowait(None, **kwds)
-
- def test_put_syncobj(self):
- for obj in [
- None,
- True,
- 10,
- 'spam',
- b'spam',
- (0, 'a'),
- ]:
- with self.subTest(repr(obj)):
- queue = queues.create()
-
- queue.put(obj, syncobj=True)
- obj2 = queue.get()
- self.assertEqual(obj2, obj)
-
- queue.put(obj, syncobj=True)
- obj2 = queue.get_nowait()
- self.assertEqual(obj2, obj)
-
- for obj in [
- [1, 2, 3],
- {'a': 13, 'b': 17},
- ]:
- with self.subTest(repr(obj)):
- queue = queues.create()
- with self.assertRaises(interpreters.NotShareableError):
- queue.put(obj, syncobj=True)
+ queue = queues.create(2)
+ queue.put_nowait(None)
+ queue.put_nowait(None)
+ with self.assertRaises(queues.QueueFull):
+ queue.put_nowait(None)
+ queue.get()
+ queue.put_nowait(None)
- def test_put_not_syncobj(self):
+ def test_put_full_fallback(self):
for obj in [
None,
True,
@@ -323,11 +331,11 @@ class TestQueueOps(TestBase):
with self.subTest(repr(obj)):
queue = queues.create()
- queue.put(obj, syncobj=False)
+ queue.put(obj)
obj2 = queue.get()
self.assertEqual(obj2, obj)
- queue.put(obj, syncobj=False)
+ queue.put(obj)
obj2 = queue.get_nowait()
self.assertEqual(obj2, obj)
@@ -341,24 +349,9 @@ class TestQueueOps(TestBase):
with self.assertRaises(queues.QueueEmpty):
queue.get_nowait()
- def test_put_get_default_syncobj(self):
- expected = list(range(20))
- queue = queues.create(syncobj=True)
- for methname in ('get', 'get_nowait'):
- with self.subTest(f'{methname}()'):
- get = getattr(queue, methname)
- for i in range(20):
- queue.put(i)
- actual = [get() for _ in range(20)]
- self.assertEqual(actual, expected)
-
- obj = [1, 2, 3] # lists are not shareable
- with self.assertRaises(interpreters.NotShareableError):
- queue.put(obj)
-
- def test_put_get_default_not_syncobj(self):
+ def test_put_get_full_fallback(self):
expected = list(range(20))
- queue = queues.create(syncobj=False)
+ queue = queues.create()
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
get = getattr(queue, methname)
@@ -377,14 +370,14 @@ class TestQueueOps(TestBase):
def test_put_get_same_interpreter(self):
interp = interpreters.create()
interp.exec(dedent("""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.create()
"""))
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
interp.exec(dedent(f"""
orig = b'spam'
- queue.put(orig, syncobj=True)
+ queue.put(orig)
obj = queue.{methname}()
assert obj == orig, 'expected: obj == orig'
assert obj is not orig, 'expected: obj is not orig'
@@ -399,12 +392,12 @@ class TestQueueOps(TestBase):
for methname in ('get', 'get_nowait'):
with self.subTest(f'{methname}()'):
obj1 = b'spam'
- queue1.put(obj1, syncobj=True)
+ queue1.put(obj1)
out = _run_output(
interp,
dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue1 = queues.Queue({queue1.id})
queue2 = queues.Queue({queue2.id})
assert queue1.qsize() == 1, 'expected: queue1.qsize() == 1'
@@ -416,7 +409,7 @@ class TestQueueOps(TestBase):
obj2 = b'eggs'
print(id(obj2))
assert queue2.qsize() == 0, 'expected: queue2.qsize() == 0'
- queue2.put(obj2, syncobj=True)
+ queue2.put(obj2)
assert queue2.qsize() == 1, 'expected: queue2.qsize() == 1'
"""))
self.assertEqual(len(queues.list_all()), 2)
@@ -433,22 +426,22 @@ class TestQueueOps(TestBase):
if not unbound:
extraargs = ''
elif unbound is queues.UNBOUND:
- extraargs = ', unbound=queues.UNBOUND'
+ extraargs = ', unbounditems=queues.UNBOUND'
elif unbound is queues.UNBOUND_ERROR:
- extraargs = ', unbound=queues.UNBOUND_ERROR'
+ extraargs = ', unbounditems=queues.UNBOUND_ERROR'
elif unbound is queues.UNBOUND_REMOVE:
- extraargs = ', unbound=queues.UNBOUND_REMOVE'
+ extraargs = ', unbounditems=queues.UNBOUND_REMOVE'
else:
raise NotImplementedError(repr(unbound))
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj1 = b'spam'
obj2 = b'eggs'
- queue.put(obj1, syncobj=True{extraargs})
- queue.put(obj2, syncobj=True{extraargs})
+ queue.put(obj1{extraargs})
+ queue.put(obj2{extraargs})
"""))
self.assertEqual(queue.qsize(), presize + 2)
@@ -501,11 +494,11 @@ class TestQueueOps(TestBase):
with self.assertRaises(queues.QueueEmpty):
queue.get_nowait()
- queue.put(b'ham', unbound=queues.UNBOUND_REMOVE)
+ queue.put(b'ham', unbounditems=queues.UNBOUND_REMOVE)
self.assertEqual(queue.qsize(), 1)
interp = common(queue, queues.UNBOUND_REMOVE, 1)
self.assertEqual(queue.qsize(), 3)
- queue.put(42, unbound=queues.UNBOUND_REMOVE)
+ queue.put(42, unbounditems=queues.UNBOUND_REMOVE)
self.assertEqual(queue.qsize(), 4)
del interp
self.assertEqual(queue.qsize(), 2)
@@ -521,13 +514,13 @@ class TestQueueOps(TestBase):
queue = queues.create()
interp = interpreters.create()
_run_output(interp, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
- queue.put(1, syncobj=True, unbound=queues.UNBOUND)
- queue.put(2, syncobj=True, unbound=queues.UNBOUND_ERROR)
- queue.put(3, syncobj=True)
- queue.put(4, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(5, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(1, unbounditems=queues.UNBOUND)
+ queue.put(2, unbounditems=queues.UNBOUND_ERROR)
+ queue.put(3)
+ queue.put(4, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(5, unbounditems=queues.UNBOUND)
"""))
self.assertEqual(queue.qsize(), 5)
@@ -555,16 +548,16 @@ class TestQueueOps(TestBase):
interp1 = interpreters.create()
interp2 = interpreters.create()
- queue.put(1, syncobj=True)
+ queue.put(1)
_run_output(interp1, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj1 = queue.get()
- queue.put(2, syncobj=True, unbound=queues.UNBOUND)
- queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_REMOVE)
+ queue.put(2, unbounditems=queues.UNBOUND)
+ queue.put(obj1, unbounditems=queues.UNBOUND_REMOVE)
"""))
_run_output(interp2, dedent(f"""
- from test.support.interpreters import queues
+ from concurrent.interpreters import _queues as queues
queue = queues.Queue({queue.id})
obj2 = queue.get()
obj1 = queue.get()
@@ -572,21 +565,21 @@ class TestQueueOps(TestBase):
self.assertEqual(queue.qsize(), 0)
queue.put(3)
_run_output(interp1, dedent("""
- queue.put(4, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(4, unbounditems=queues.UNBOUND)
# interp closed here
- queue.put(5, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(6, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(5, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(6, unbounditems=queues.UNBOUND)
"""))
_run_output(interp2, dedent("""
- queue.put(7, syncobj=True, unbound=queues.UNBOUND_ERROR)
+ queue.put(7, unbounditems=queues.UNBOUND_ERROR)
# interp closed here
- queue.put(obj1, syncobj=True, unbound=queues.UNBOUND_ERROR)
- queue.put(obj2, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(8, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(obj1, unbounditems=queues.UNBOUND_ERROR)
+ queue.put(obj2, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(8, unbounditems=queues.UNBOUND)
"""))
_run_output(interp1, dedent("""
- queue.put(9, syncobj=True, unbound=queues.UNBOUND_REMOVE)
- queue.put(10, syncobj=True, unbound=queues.UNBOUND)
+ queue.put(9, unbounditems=queues.UNBOUND_REMOVE)
+ queue.put(10, unbounditems=queues.UNBOUND)
"""))
self.assertEqual(queue.qsize(), 10)
@@ -642,12 +635,12 @@ class TestQueueOps(TestBase):
break
except queues.QueueEmpty:
continue
- queue2.put(obj, syncobj=True)
+ queue2.put(obj)
t = threading.Thread(target=f)
t.start()
orig = b'spam'
- queue1.put(orig, syncobj=True)
+ queue1.put(orig)
obj = queue2.get()
t.join()
diff --git a/Lib/test/test_interpreters/test_stress.py b/Lib/test/test_interpreters/test_stress.py
index 56bfc172199..e25e67a0d4f 100644
--- a/Lib/test/test_interpreters/test_stress.py
+++ b/Lib/test/test_interpreters/test_stress.py
@@ -6,7 +6,7 @@ from test.support import import_helper
from test.support import threading_helper
# Raise SkipTest if subinterpreters not supported.
import_helper.import_module('_interpreters')
-from test.support import interpreters
+from concurrent import interpreters
from .utils import TestBase
@@ -21,21 +21,29 @@ class StressTests(TestBase):
for _ in range(100):
interp = interpreters.create()
alive.append(interp)
+ del alive
+ support.gc_collect()
- @support.requires_resource('cpu')
- @threading_helper.requires_working_threading()
- def test_create_many_threaded(self):
+ @support.bigmemtest(size=200, memuse=32*2**20, dry_run=False)
+ def test_create_many_threaded(self, size):
alive = []
+ start = threading.Event()
def task():
+ # try to create all interpreters simultaneously
+ if not start.wait(support.SHORT_TIMEOUT):
+ raise TimeoutError
interp = interpreters.create()
alive.append(interp)
- threads = (threading.Thread(target=task) for _ in range(200))
+ threads = [threading.Thread(target=task) for _ in range(size)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
+ del alive
+ support.gc_collect()
- @support.requires_resource('cpu')
@threading_helper.requires_working_threading()
- def test_many_threads_running_interp_in_other_interp(self):
+ @support.bigmemtest(size=200, memuse=34*2**20, dry_run=False)
+ def test_many_threads_running_interp_in_other_interp(self, size):
+ start = threading.Event()
interp = interpreters.create()
script = f"""if True:
@@ -47,6 +55,9 @@ class StressTests(TestBase):
interp = interpreters.create()
alreadyrunning = (f'{interpreters.InterpreterError}: '
'interpreter already running')
+ # try to run all interpreters simultaneously
+ if not start.wait(support.SHORT_TIMEOUT):
+ raise TimeoutError
success = False
while not success:
try:
@@ -58,9 +69,10 @@ class StressTests(TestBase):
else:
success = True
- threads = (threading.Thread(target=run) for _ in range(200))
+ threads = [threading.Thread(target=run) for _ in range(size)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
+ support.gc_collect()
if __name__ == '__main__':
diff --git a/Lib/test/test_interpreters/utils.py b/Lib/test/test_interpreters/utils.py
index fc4ad662e03..ae09aa457b4 100644
--- a/Lib/test/test_interpreters/utils.py
+++ b/Lib/test/test_interpreters/utils.py
@@ -12,7 +12,6 @@ from textwrap import dedent
import threading
import types
import unittest
-import warnings
from test import support
@@ -22,7 +21,7 @@ try:
import _interpreters
except ImportError as exc:
raise unittest.SkipTest(str(exc))
-from test.support import interpreters
+from concurrent import interpreters
try:
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 545643aa455..0c921ffbc25 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -572,7 +572,7 @@ class IOTest(unittest.TestCase):
for [test, abilities] in tests:
with self.subTest(test):
if test == pipe_writer and not threading_helper.can_start_thread:
- skipTest()
+ self.skipTest("Need threads")
with test() as obj:
do_test(test, obj, abilities)
@@ -902,7 +902,7 @@ class IOTest(unittest.TestCase):
self.BytesIO()
)
for obj in test:
- self.assertTrue(hasattr(obj, "__dict__"))
+ self.assertHasAttr(obj, "__dict__")
def test_opener(self):
with self.open(os_helper.TESTFN, "w", encoding="utf-8") as f:
@@ -918,7 +918,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -1
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -1')
def test_bad_opener_other_negative(self):
@@ -926,7 +926,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -2
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -2')
def test_opener_invalid_fd(self):
@@ -1062,6 +1062,37 @@ class IOTest(unittest.TestCase):
# Silence destructor error
R.flush = lambda self: None
+ @threading_helper.requires_working_threading()
+ def test_write_readline_races(self):
+ # gh-134908: Concurrent iteration over a file caused races
+ thread_count = 2
+ write_count = 100
+ read_count = 100
+
+ def writer(file, barrier):
+ barrier.wait()
+ for _ in range(write_count):
+ file.write("x")
+
+ def reader(file, barrier):
+ barrier.wait()
+ for _ in range(read_count):
+ for line in file:
+ self.assertEqual(line, "")
+
+ with self.open(os_helper.TESTFN, "w+") as f:
+ barrier = threading.Barrier(thread_count + 1)
+ reader = threading.Thread(target=reader, args=(f, barrier))
+ writers = [threading.Thread(target=writer, args=(f, barrier))
+ for _ in range(thread_count)]
+ with threading_helper.catch_threading_exception() as cm:
+ with threading_helper.start_threads(writers + [reader]):
+ pass
+ self.assertIsNone(cm.exc_type)
+
+ self.assertEqual(os.stat(os_helper.TESTFN).st_size,
+ write_count * thread_count)
+
class CIOTest(IOTest):
@@ -1117,7 +1148,7 @@ class TestIOCTypes(unittest.TestCase):
def check_subs(types, base):
for tp in types:
with self.subTest(tp=tp, base=base):
- self.assertTrue(issubclass(tp, base))
+ self.assertIsSubclass(tp, base)
def recursive_check(d):
for k, v in d.items():
@@ -1373,6 +1404,28 @@ class CommonBufferedTests:
with self.assertRaises(AttributeError):
buf.raw = x
+ def test_pickling_subclass(self):
+ global MyBufferedIO
+ class MyBufferedIO(self.tp):
+ def __init__(self, raw, tag):
+ super().__init__(raw)
+ self.tag = tag
+ def __getstate__(self):
+ return self.tag, self.raw.getvalue()
+ def __setstate__(slf, state):
+ tag, value = state
+ slf.__init__(self.BytesIO(value), tag)
+
+ raw = self.BytesIO(b'data')
+ buf = MyBufferedIO(raw, tag='ham')
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(protocol=proto):
+ pickled = pickle.dumps(buf, proto)
+ newbuf = pickle.loads(pickled)
+ self.assertEqual(newbuf.raw.getvalue(), b'data')
+ self.assertEqual(newbuf.tag, 'ham')
+ del MyBufferedIO
+
class SizeofTest:
@@ -1848,7 +1901,7 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
flushed = b"".join(writer._write_stack)
# At least (total - 8) bytes were implicitly flushed, perhaps more
# depending on the implementation.
- self.assertTrue(flushed.startswith(contents[:-8]), flushed)
+ self.assertStartsWith(flushed, contents[:-8])
def check_writes(self, intermediate_func):
# Lots of writes, test the flushed output is as expected.
@@ -1918,7 +1971,7 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
self.assertEqual(bufio.write(b"ABCDEFGHI"), 9)
s = raw.pop_written()
# Previously buffered bytes were flushed
- self.assertTrue(s.startswith(b"01234567A"), s)
+ self.assertStartsWith(s, b"01234567A")
def test_write_and_rewind(self):
raw = self.BytesIO()
@@ -2214,7 +2267,7 @@ class BufferedRWPairTest(unittest.TestCase):
def test_peek(self):
pair = self.tp(self.BytesIO(b"abcdef"), self.MockRawIO())
- self.assertTrue(pair.peek(3).startswith(b"abc"))
+ self.assertStartsWith(pair.peek(3), b"abc")
self.assertEqual(pair.read(3), b"abc")
def test_readable(self):
@@ -3950,6 +4003,28 @@ class TextIOWrapperTest(unittest.TestCase):
f.write(res)
self.assertEqual(res + f.readline(), 'foo\nbar\n')
+ def test_pickling_subclass(self):
+ global MyTextIO
+ class MyTextIO(self.TextIOWrapper):
+ def __init__(self, raw, tag):
+ super().__init__(raw)
+ self.tag = tag
+ def __getstate__(self):
+ return self.tag, self.buffer.getvalue()
+ def __setstate__(slf, state):
+ tag, value = state
+ slf.__init__(self.BytesIO(value), tag)
+
+ raw = self.BytesIO(b'data')
+ txt = MyTextIO(raw, 'ham')
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(protocol=proto):
+ pickled = pickle.dumps(txt, proto)
+ newtxt = pickle.loads(pickled)
+ self.assertEqual(newtxt.buffer.getvalue(), b'data')
+ self.assertEqual(newtxt.tag, 'ham')
+ del MyTextIO
+
@unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
def test_read_non_blocking(self):
import os
@@ -4373,7 +4448,7 @@ class MiscIOTest(unittest.TestCase):
self._check_abc_inheritance(io)
def _check_warn_on_dealloc(self, *args, **kwargs):
- f = open(*args, **kwargs)
+ f = self.open(*args, **kwargs)
r = repr(f)
with self.assertWarns(ResourceWarning) as cm:
f = None
@@ -4402,7 +4477,7 @@ class MiscIOTest(unittest.TestCase):
r, w = os.pipe()
fds += r, w
with warnings_helper.check_no_resource_warning(self):
- open(r, *args, closefd=False, **kwargs)
+ self.open(r, *args, closefd=False, **kwargs)
@unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
def test_warn_on_dealloc_fd(self):
@@ -4574,10 +4649,8 @@ class MiscIOTest(unittest.TestCase):
proc = assert_python_ok('-X', 'warn_default_encoding', '-c', code)
warnings = proc.err.splitlines()
self.assertEqual(len(warnings), 2)
- self.assertTrue(
- warnings[0].startswith(b"<string>:5: EncodingWarning: "))
- self.assertTrue(
- warnings[1].startswith(b"<string>:8: EncodingWarning: "))
+ self.assertStartsWith(warnings[0], b"<string>:5: EncodingWarning: ")
+ self.assertStartsWith(warnings[1], b"<string>:8: EncodingWarning: ")
def test_text_encoding(self):
# PEP 597, bpo-47000. io.text_encoding() returns "locale" or "utf-8"
@@ -4790,7 +4863,7 @@ class SignalsTest(unittest.TestCase):
os.read(r, len(data) * 100)
exc = cm.exception
if isinstance(exc, RuntimeError):
- self.assertTrue(str(exc).startswith("reentrant call"), str(exc))
+ self.assertStartsWith(str(exc), "reentrant call")
finally:
signal.alarm(0)
wio.close()
diff --git a/Lib/test/test_ioctl.py b/Lib/test/test_ioctl.py
index 7a986048bda..277d2fc99ea 100644
--- a/Lib/test/test_ioctl.py
+++ b/Lib/test/test_ioctl.py
@@ -5,7 +5,7 @@ import sys
import threading
import unittest
from test import support
-from test.support import threading_helper
+from test.support import os_helper, threading_helper
from test.support.import_helper import import_module
fcntl = import_module('fcntl')
termios = import_module('termios')
@@ -127,9 +127,8 @@ class IoctlTestsTty(unittest.TestCase):
self._check_ioctl_not_mutate_len(1024)
def test_ioctl_mutate_2048(self):
- # Test with a larger buffer, just for the record.
self._check_ioctl_mutate_len(2048)
- self.assertRaises(ValueError, self._check_ioctl_not_mutate_len, 2048)
+ self._check_ioctl_not_mutate_len(1024)
@unittest.skipUnless(hasattr(os, 'openpty'), "need os.openpty()")
@@ -202,6 +201,17 @@ class IoctlTestsPty(unittest.TestCase):
new_winsz = struct.unpack("HHHH", result)
self.assertEqual(new_winsz[:2], (20, 40))
+ @unittest.skipUnless(hasattr(fcntl, 'FICLONE'), 'need fcntl.FICLONE')
+ def test_bad_fd(self):
+ # gh-134744: Test error handling
+ fd = os_helper.make_bad_fd()
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, fd)
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, b'\0' * 10)
+ with self.assertRaises(OSError):
+ fcntl.ioctl(fd, fcntl.FICLONE, b'\0' * 2048)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
index d04012d1afd..db1c38243e2 100644
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -397,6 +397,19 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6):
# A trailing IPv4 address is two parts
assertBadSplit("10:9:8:7:6:5:4:3:42.42.42.42%scope")
+ def test_bad_address_split_v6_too_long(self):
+ def assertBadSplit(addr):
+ msg = r"At most 45 characters expected in '%s"
+ with self.assertAddressError(msg, re.escape(addr[:45])):
+ ipaddress.IPv6Address(addr)
+
+ # Long IPv6 address
+ long_addr = ("0:" * 10000) + "0"
+ assertBadSplit(long_addr)
+ assertBadSplit(long_addr + "%zoneid")
+ assertBadSplit(long_addr + ":255.255.255.255")
+ assertBadSplit(long_addr + ":ffff:255.255.255.255")
+
def test_bad_address_split_v6_too_many_parts(self):
def assertBadSplit(addr):
msg = "Exactly 8 parts expected without '::' in %r"
@@ -2178,6 +2191,11 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(ipaddress.ip_address('FFFF::192.0.2.1'),
ipaddress.ip_address('FFFF::c000:201'))
+ self.assertEqual(ipaddress.ip_address('0000:0000:0000:0000:0000:FFFF:192.168.255.255'),
+ ipaddress.ip_address('::ffff:c0a8:ffff'))
+ self.assertEqual(ipaddress.ip_address('FFFF:0000:0000:0000:0000:0000:192.168.255.255'),
+ ipaddress.ip_address('ffff::c0a8:ffff'))
+
self.assertEqual(ipaddress.ip_address('::FFFF:192.0.2.1%scope'),
ipaddress.ip_address('::FFFF:c000:201%scope'))
self.assertEqual(ipaddress.ip_address('FFFF::192.0.2.1%scope'),
@@ -2190,6 +2208,10 @@ class IpaddrUnitTest(unittest.TestCase):
ipaddress.ip_address('::FFFF:c000:201%scope'))
self.assertNotEqual(ipaddress.ip_address('FFFF::192.0.2.1'),
ipaddress.ip_address('FFFF::c000:201%scope'))
+ self.assertEqual(ipaddress.ip_address('0000:0000:0000:0000:0000:FFFF:192.168.255.255%scope'),
+ ipaddress.ip_address('::ffff:c0a8:ffff%scope'))
+ self.assertEqual(ipaddress.ip_address('FFFF:0000:0000:0000:0000:0000:192.168.255.255%scope'),
+ ipaddress.ip_address('ffff::c0a8:ffff%scope'))
def testIPVersion(self):
self.assertEqual(ipaddress.IPv4Address.version, 4)
@@ -2599,6 +2621,10 @@ class IpaddrUnitTest(unittest.TestCase):
'::7:6:5:4:3:2:0': '0:7:6:5:4:3:2:0/128',
'7:6:5:4:3:2:1::': '7:6:5:4:3:2:1:0/128',
'0:6:5:4:3:2:1::': '0:6:5:4:3:2:1:0/128',
+ '0000:0000:0000:0000:0000:0000:255.255.255.255': '::ffff:ffff/128',
+ '0000:0000:0000:0000:0000:ffff:255.255.255.255': '::ffff:255.255.255.255/128',
+ 'ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255':
+ 'ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128',
}
for uncompressed, compressed in list(test_addresses.items()):
self.assertEqual(compressed, str(ipaddress.IPv6Interface(
@@ -2762,6 +2788,34 @@ class IpaddrUnitTest(unittest.TestCase):
ipv6_address2 = ipaddress.IPv6Interface("2001:658:22a:cafe:200:0:0:2")
self.assertNotEqual(ipv6_address1.__hash__(), ipv6_address2.__hash__())
+ # issue 134062 Hash collisions in IPv4Network and IPv6Network
+ def testNetworkV4HashCollisions(self):
+ self.assertNotEqual(
+ ipaddress.IPv4Network("192.168.1.255/32").__hash__(),
+ ipaddress.IPv4Network("192.168.1.0/24").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("172.24.255.0/24").__hash__(),
+ ipaddress.IPv4Network("172.24.0.0/16").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("192.168.1.87/32").__hash__(),
+ ipaddress.IPv4Network("192.168.1.86/31").__hash__()
+ )
+
+ # issue 134062 Hash collisions in IPv4Network and IPv6Network
+ def testNetworkV6HashCollisions(self):
+ self.assertNotEqual(
+ ipaddress.IPv6Network("fe80::/64").__hash__(),
+ ipaddress.IPv6Network("fe80::ffff:ffff:ffff:0/112").__hash__()
+ )
+ self.assertNotEqual(
+ ipaddress.IPv4Network("10.0.0.0/8").__hash__(),
+ ipaddress.IPv6Network(
+ "ffff:ffff:ffff:ffff:ffff:ffff:aff:0/112"
+ ).__hash__()
+ )
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_isinstance.py b/Lib/test/test_isinstance.py
index daad00e8643..f440fc28ee7 100644
--- a/Lib/test/test_isinstance.py
+++ b/Lib/test/test_isinstance.py
@@ -318,6 +318,7 @@ class TestIsInstanceIsSubclass(unittest.TestCase):
self.assertRaises(RecursionError, isinstance, 1, X())
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_infinite_recursion_via_bases_tuple(self):
"""Regression test for bpo-30570."""
class Failure(object):
@@ -328,6 +329,7 @@ class TestIsInstanceIsSubclass(unittest.TestCase):
issubclass(Failure(), int)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_infinite_cycle_in_bases(self):
"""Regression test for bpo-30570."""
class X:
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index 1b9f3cf7624..18e4b676c53 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -1147,7 +1147,7 @@ class TestCase(unittest.TestCase):
def test_exception_locations(self):
# The location of an exception raised from __init__ or
- # __next__ should should be the iterator expression
+ # __next__ should be the iterator expression
def init_raises():
try:
diff --git a/Lib/test/test_json/test_dump.py b/Lib/test/test_json/test_dump.py
index 13b40020781..39470754003 100644
--- a/Lib/test/test_json/test_dump.py
+++ b/Lib/test/test_json/test_dump.py
@@ -22,6 +22,14 @@ class TestDump:
self.assertIn('valid_key', o)
self.assertNotIn(b'invalid_key', o)
+ def test_dump_skipkeys_indent_empty(self):
+ v = {b'invalid_key': False}
+ self.assertEqual(self.json.dumps(v, skipkeys=True, indent=4), '{}')
+
+ def test_skipkeys_indent(self):
+ v = {b'invalid_key': False, 'valid_key': True}
+ self.assertEqual(self.json.dumps(v, skipkeys=True, indent=4), '{\n "valid_key": true\n}')
+
def test_encode_truefalse(self):
self.assertEqual(self.dumps(
{True: False, False: True}, sort_keys=True),
diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py
index 7c1696cc66d..79c44af2fbf 100644
--- a/Lib/test/test_json/test_fail.py
+++ b/Lib/test/test_json/test_fail.py
@@ -102,7 +102,7 @@ class TestFail:
with self.assertRaisesRegex(TypeError,
'Object of type module is not JSON serializable') as cm:
self.dumps(sys)
- self.assertFalse(hasattr(cm.exception, '__notes__'))
+ self.assertNotHasAttr(cm.exception, '__notes__')
with self.assertRaises(TypeError) as cm:
self.dumps([1, [2, 3, sys]])
diff --git a/Lib/test/test_json/test_recursion.py b/Lib/test/test_json/test_recursion.py
index d82093f3895..5d7b56ff9ad 100644
--- a/Lib/test/test_json/test_recursion.py
+++ b/Lib/test/test_json/test_recursion.py
@@ -69,6 +69,7 @@ class TestRecursion:
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_highly_nested_objects_decoding(self):
very_deep = 200000
# test that loading highly-nested objects doesn't segfault when C
@@ -85,6 +86,7 @@ class TestRecursion:
@support.skip_wasi_stack_overflow()
@support.skip_emscripten_stack_overflow()
+ @support.requires_resource('cpu')
def test_highly_nested_objects_encoding(self):
# See #12051
l, d = [], {}
@@ -98,6 +100,7 @@ class TestRecursion:
self.dumps(d)
@support.skip_emscripten_stack_overflow()
+ @support.skip_wasi_stack_overflow()
def test_endless_recursion(self):
# See #12051
class EndlessJSONEncoder(self.json.JSONEncoder):
diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py
index ba9c42f758e..30f9bb33316 100644
--- a/Lib/test/test_json/test_tool.py
+++ b/Lib/test/test_json/test_tool.py
@@ -6,9 +6,11 @@ import unittest
import subprocess
from test import support
-from test.support import force_not_colorized, os_helper
+from test.support import force_colorized, force_not_colorized, os_helper
from test.support.script_helper import assert_python_ok
+from _colorize import get_theme
+
@support.requires_subprocess()
class TestMain(unittest.TestCase):
@@ -158,7 +160,7 @@ class TestMain(unittest.TestCase):
rc, out, err = assert_python_ok('-m', self.module, '-h',
PYTHON_COLORS='0')
self.assertEqual(rc, 0)
- self.assertTrue(out.startswith(b'usage: '))
+ self.assertStartsWith(out, b'usage: ')
self.assertEqual(err, b'')
def test_sort_keys_flag(self):
@@ -246,34 +248,39 @@ class TestMain(unittest.TestCase):
proc.communicate(b'"{}"')
self.assertEqual(proc.returncode, errno.EPIPE)
+ @force_colorized
def test_colors(self):
infile = os_helper.TESTFN
self.addCleanup(os.remove, infile)
+ t = get_theme().syntax
+ ob = "{"
+ cb = "}"
+
cases = (
- ('{}', b'{}'),
- ('[]', b'[]'),
- ('null', b'\x1b[1;36mnull\x1b[0m'),
- ('true', b'\x1b[1;36mtrue\x1b[0m'),
- ('false', b'\x1b[1;36mfalse\x1b[0m'),
- ('NaN', b'NaN'),
- ('Infinity', b'Infinity'),
- ('-Infinity', b'-Infinity'),
- ('"foo"', b'\x1b[1;32m"foo"\x1b[0m'),
- (r'" \"foo\" "', b'\x1b[1;32m" \\"foo\\" "\x1b[0m'),
- ('"α"', b'\x1b[1;32m"\\u03b1"\x1b[0m'),
- ('123', b'123'),
- ('-1.2345e+23', b'-1.2345e+23'),
+ ('{}', '{}'),
+ ('[]', '[]'),
+ ('null', f'{t.keyword}null{t.reset}'),
+ ('true', f'{t.keyword}true{t.reset}'),
+ ('false', f'{t.keyword}false{t.reset}'),
+ ('NaN', f'{t.number}NaN{t.reset}'),
+ ('Infinity', f'{t.number}Infinity{t.reset}'),
+ ('-Infinity', f'{t.number}-Infinity{t.reset}'),
+ ('"foo"', f'{t.string}"foo"{t.reset}'),
+ (r'" \"foo\" "', f'{t.string}" \\"foo\\" "{t.reset}'),
+ ('"α"', f'{t.string}"\\u03b1"{t.reset}'),
+ ('123', f'{t.number}123{t.reset}'),
+ ('-1.25e+23', f'{t.number}-1.25e+23{t.reset}'),
(r'{"\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
(r'{"\\\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
('''\
{
"foo": "bar",
@@ -281,30 +288,32 @@ class TestMain(unittest.TestCase):
"qux": [true, false, null],
"xyz": [NaN, -Infinity, Infinity]
}''',
- b'''\
-{
- \x1b[94m"foo"\x1b[0m: \x1b[1;32m"bar"\x1b[0m,
- \x1b[94m"baz"\x1b[0m: 1234,
- \x1b[94m"qux"\x1b[0m: [
- \x1b[1;36mtrue\x1b[0m,
- \x1b[1;36mfalse\x1b[0m,
- \x1b[1;36mnull\x1b[0m
+ f'''\
+{ob}
+ {t.definition}"foo"{t.reset}: {t.string}"bar"{t.reset},
+ {t.definition}"baz"{t.reset}: {t.number}1234{t.reset},
+ {t.definition}"qux"{t.reset}: [
+ {t.keyword}true{t.reset},
+ {t.keyword}false{t.reset},
+ {t.keyword}null{t.reset}
],
- \x1b[94m"xyz"\x1b[0m: [
- NaN,
- -Infinity,
- Infinity
+ {t.definition}"xyz"{t.reset}: [
+ {t.number}NaN{t.reset},
+ {t.number}-Infinity{t.reset},
+ {t.number}Infinity{t.reset}
]
-}'''),
+{cb}'''),
)
for input_, expected in cases:
with self.subTest(input=input_):
with open(infile, "w", encoding="utf-8") as fp:
fp.write(input_)
- _, stdout, _ = assert_python_ok('-m', self.module, infile,
- PYTHON_COLORS='1')
- stdout = stdout.replace(b'\r\n', b'\n') # normalize line endings
+ _, stdout_b, _ = assert_python_ok(
+ '-m', self.module, infile, FORCE_COLOR='1', __isolated='1'
+ )
+ stdout = stdout_b.decode()
+ stdout = stdout.replace('\r\n', '\n') # normalize line endings
stdout = stdout.strip()
self.assertEqual(stdout, expected)
diff --git a/Lib/test/test_launcher.py b/Lib/test/test_launcher.py
index 173fc743cf6..caa1603c78e 100644
--- a/Lib/test/test_launcher.py
+++ b/Lib/test/test_launcher.py
@@ -443,7 +443,7 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
except subprocess.CalledProcessError:
raise unittest.SkipTest("requires at least one Python 3.x install")
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("3."), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "3.")
def test_search_major_3_32(self):
try:
@@ -453,8 +453,8 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
raise unittest.SkipTest("requires at least one 32-bit Python 3.x install")
raise
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("3."), data["env.tag"])
- self.assertTrue(data["env.tag"].endswith("-32"), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "3.")
+ self.assertEndsWith(data["env.tag"], "-32")
def test_search_major_2(self):
try:
@@ -463,7 +463,7 @@ class TestLauncher(unittest.TestCase, RunPyMixin):
if not is_installed("2.7"):
raise unittest.SkipTest("requires at least one Python 2.x install")
self.assertEqual("PythonCore", data["env.company"])
- self.assertTrue(data["env.tag"].startswith("2."), data["env.tag"])
+ self.assertStartsWith(data["env.tag"], "2.")
def test_py_default(self):
with self.py_ini(TEST_PY_DEFAULTS):
diff --git a/Lib/test/test_linecache.py b/Lib/test/test_linecache.py
index e4aa41ebb43..02f65338428 100644
--- a/Lib/test/test_linecache.py
+++ b/Lib/test/test_linecache.py
@@ -4,10 +4,12 @@ import linecache
import unittest
import os.path
import tempfile
+import threading
import tokenize
from importlib.machinery import ModuleSpec
from test import support
from test.support import os_helper
+from test.support import threading_helper
from test.support.script_helper import assert_python_ok
@@ -374,5 +376,40 @@ class LineCacheInvalidationTests(unittest.TestCase):
self.assertIn(self.unchanged_file, linecache.cache)
+class MultiThreadingTest(unittest.TestCase):
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_read_write_safety(self):
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ filenames = []
+ for i in range(10):
+ name = os.path.join(tmpdirname, f"test_{i}.py")
+ with open(name, "w") as h:
+ h.write("import time\n")
+ h.write("import system\n")
+ filenames.append(name)
+
+ def linecache_get_line(b):
+ b.wait()
+ for _ in range(100):
+ for name in filenames:
+ linecache.getline(name, 1)
+
+ def check(funcs):
+ barrier = threading.Barrier(len(funcs))
+ threads = []
+
+ for func in funcs:
+ thread = threading.Thread(target=func, args=(barrier,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ check([linecache_get_line] * 20)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py
index 6894fba2ad1..223f34fb696 100644
--- a/Lib/test/test_list.py
+++ b/Lib/test/test_list.py
@@ -365,5 +365,20 @@ class ListTest(list_tests.CommonTest):
rc, _, _ = assert_python_ok("-c", code)
self.assertEqual(rc, 0)
+ def test_list_overwrite_local(self):
+ """Test that overwriting the last reference to the
+ iterable doesn't prematurely free the iterable"""
+
+ def foo(x):
+ self.assertEqual(sys.getrefcount(x), 1)
+ r = 0
+ for i in x:
+ r += i
+ x = None
+ return r
+
+ self.assertEqual(foo(list(range(10))), 45)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_listcomps.py b/Lib/test/test_listcomps.py
index cffdeeacc5d..70148dc30fc 100644
--- a/Lib/test/test_listcomps.py
+++ b/Lib/test/test_listcomps.py
@@ -716,7 +716,7 @@ class ListComprehensionTest(unittest.TestCase):
def test_exception_locations(self):
# The location of an exception raised from __init__ or
- # __next__ should should be the iterator expression
+ # __next__ should be the iterator expression
def init_raises():
try:
diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py
index 528ceef5281..55b502e52ca 100644
--- a/Lib/test/test_locale.py
+++ b/Lib/test/test_locale.py
@@ -1,13 +1,18 @@
from decimal import Decimal
-from test.support import verbose, is_android, linked_to_musl, os_helper
+from test.support import cpython_only, verbose, is_android, linked_to_musl, os_helper
from test.support.warnings_helper import check_warnings
-from test.support.import_helper import import_fresh_module
+from test.support.import_helper import ensure_lazy_imports, import_fresh_module
from unittest import mock
import unittest
import locale
import sys
import codecs
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("locale", {"re", "warnings"})
+
class BaseLocalizedTest(unittest.TestCase):
#
@@ -382,6 +387,10 @@ class NormalizeTest(unittest.TestCase):
self.check('c', 'C')
self.check('posix', 'C')
+ def test_c_utf8(self):
+ self.check('c.utf8', 'C.UTF-8')
+ self.check('C.UTF-8', 'C.UTF-8')
+
def test_english(self):
self.check('en', 'en_US.ISO8859-1')
self.check('EN', 'en_US.ISO8859-1')
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index de9108288a7..3819965ed2c 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -61,7 +61,7 @@ import warnings
import weakref
from http.server import HTTPServer, BaseHTTPRequestHandler
-from unittest.mock import patch
+from unittest.mock import call, Mock, patch
from urllib.parse import urlparse, parse_qs
from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
ThreadingTCPServer, StreamRequestHandler)
@@ -1036,7 +1036,7 @@ class TestTCPServer(ControlMixin, ThreadingTCPServer):
"""
allow_reuse_address = True
- allow_reuse_port = True
+ allow_reuse_port = False
def __init__(self, addr, handler, poll_interval=0.5,
bind_and_activate=True):
@@ -5572,12 +5572,19 @@ class BasicConfigTest(unittest.TestCase):
assertRaises = self.assertRaises
handlers = [logging.StreamHandler()]
stream = sys.stderr
+ formatter = logging.Formatter()
assertRaises(ValueError, logging.basicConfig, filename='test.log',
stream=stream)
assertRaises(ValueError, logging.basicConfig, filename='test.log',
handlers=handlers)
assertRaises(ValueError, logging.basicConfig, stream=stream,
handlers=handlers)
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ format='%(message)s')
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ datefmt='%H:%M:%S')
+ assertRaises(ValueError, logging.basicConfig, formatter=formatter,
+ style='%')
# Issue 23207: test for invalid kwargs
assertRaises(ValueError, logging.basicConfig, loglevel=logging.INFO)
# Should pop both filename and filemode even if filename is None
@@ -5712,6 +5719,20 @@ class BasicConfigTest(unittest.TestCase):
# didn't write anything due to the encoding error
self.assertEqual(data, r'')
+ def test_formatter_given(self):
+ mock_formatter = Mock()
+ mock_handler = Mock(formatter=None)
+ with patch("logging.Formatter") as mock_formatter_init:
+ logging.basicConfig(formatter=mock_formatter, handlers=[mock_handler])
+ self.assertEqual(mock_handler.setFormatter.call_args_list, [call(mock_formatter)])
+ self.assertEqual(mock_formatter_init.call_count, 0)
+
+ def test_formatter_not_given(self):
+ mock_handler = Mock(formatter=None)
+ with patch("logging.Formatter") as mock_formatter_init:
+ logging.basicConfig(handlers=[mock_handler])
+ self.assertEqual(mock_formatter_init.call_count, 1)
+
@support.requires_working_socket()
def test_log_taskName(self):
async def log_record():
@@ -6740,7 +6761,7 @@ class TimedRotatingFileHandlerTest(BaseFileTest):
rotator = rotators[i]
candidates = rotator.getFilesToDelete()
self.assertEqual(len(candidates), n_files - backupCount, candidates)
- matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\Z")
+ matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\z")
for c in candidates:
d, fn = os.path.split(c)
self.assertStartsWith(fn, prefix+'.')
diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py
index 9ffb93e797d..e93c3c37354 100644
--- a/Lib/test/test_lzma.py
+++ b/Lib/test/test_lzma.py
@@ -1025,12 +1025,12 @@ class FileTestCase(unittest.TestCase):
with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
result = f.peek()
self.assertGreater(len(result), 0)
- self.assertTrue(INPUT.startswith(result))
+ self.assertStartsWith(INPUT, result)
self.assertEqual(f.read(), INPUT)
with LZMAFile(BytesIO(COMPRESSED_XZ)) as f:
result = f.peek(10)
self.assertGreater(len(result), 0)
- self.assertTrue(INPUT.startswith(result))
+ self.assertStartsWith(INPUT, result)
self.assertEqual(f.read(), INPUT)
def test_peek_bad_args(self):
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 913a60bf9e0..46cb54647b1 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -475,6 +475,19 @@ class MathTests(unittest.TestCase):
# similarly, copysign(2., NAN) could be 2. or -2.
self.assertEqual(abs(math.copysign(2., NAN)), 2.)
+ def test_signbit(self):
+ self.assertRaises(TypeError, math.signbit)
+ self.assertRaises(TypeError, math.signbit, '1.0')
+
+ # C11, §7.12.3.6 requires signbit() to return a nonzero value
+ # if and only if the sign of its argument value is negative,
+ # but in practice, we are only interested in a boolean value.
+ self.assertIsInstance(math.signbit(1.0), bool)
+
+ for arg in [0., 1., INF, NAN]:
+ self.assertFalse(math.signbit(arg))
+ self.assertTrue(math.signbit(-arg))
+
def testCos(self):
self.assertRaises(TypeError, math.cos)
self.ftest('cos(-pi/2)', math.cos(-math.pi/2), 0, abs_tol=math.ulp(1))
@@ -1214,6 +1227,12 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.ldexp(NINF, n), NINF)
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
+ @requires_IEEE_754
+ def testLdexp_denormal(self):
+ # Denormal output incorrectly rounded (truncated)
+ # on some Windows.
+ self.assertEqual(math.ldexp(6993274598585239, -1126), 1e-323)
+
def testLog(self):
self.assertRaises(TypeError, math.log)
self.assertRaises(TypeError, math.log, 1, 2, 3)
@@ -1381,7 +1400,6 @@ class MathTests(unittest.TestCase):
args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
self.assertEqual(sumprod(*args), 0.0)
-
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sumprod() accuracy not guaranteed on machines with double rounding")
@@ -1967,6 +1985,28 @@ class MathTests(unittest.TestCase):
self.assertFalse(math.isfinite(float("inf")))
self.assertFalse(math.isfinite(float("-inf")))
+ def testIsnormal(self):
+ self.assertTrue(math.isnormal(1.25))
+ self.assertTrue(math.isnormal(-1.0))
+ self.assertFalse(math.isnormal(0.0))
+ self.assertFalse(math.isnormal(-0.0))
+ self.assertFalse(math.isnormal(INF))
+ self.assertFalse(math.isnormal(NINF))
+ self.assertFalse(math.isnormal(NAN))
+ self.assertFalse(math.isnormal(FLOAT_MIN/2))
+ self.assertFalse(math.isnormal(-FLOAT_MIN/2))
+
+ def testIssubnormal(self):
+ self.assertFalse(math.issubnormal(1.25))
+ self.assertFalse(math.issubnormal(-1.0))
+ self.assertFalse(math.issubnormal(0.0))
+ self.assertFalse(math.issubnormal(-0.0))
+ self.assertFalse(math.issubnormal(INF))
+ self.assertFalse(math.issubnormal(NINF))
+ self.assertFalse(math.issubnormal(NAN))
+ self.assertTrue(math.issubnormal(FLOAT_MIN/2))
+ self.assertTrue(math.issubnormal(-FLOAT_MIN/2))
+
def testIsnan(self):
self.assertTrue(math.isnan(float("nan")))
self.assertTrue(math.isnan(float("-nan")))
@@ -2458,7 +2498,6 @@ class MathTests(unittest.TestCase):
with self.assertRaises(ValueError):
math.nextafter(1.0, INF, steps=-1)
-
@requires_IEEE_754
def test_ulp(self):
self.assertEqual(math.ulp(1.0), sys.float_info.epsilon)
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index 95629ed862d..63998a86c45 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -265,8 +265,8 @@ class MemoryTestMixin:
memio = self.ioclass(buf * 10)
self.assertEqual(iter(memio), memio)
- self.assertTrue(hasattr(memio, '__iter__'))
- self.assertTrue(hasattr(memio, '__next__'))
+ self.assertHasAttr(memio, '__iter__')
+ self.assertHasAttr(memio, '__next__')
i = 0
for line in memio:
self.assertEqual(line, buf)
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index 61b068c630c..64f440f180b 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -743,19 +743,21 @@ class RacingTest(unittest.TestCase):
from multiprocessing.managers import SharedMemoryManager
except ImportError:
self.skipTest("Test requires multiprocessing")
- from threading import Thread
+ from threading import Thread, Event
- n = 100
+ start = Event()
with SharedMemoryManager() as smm:
obj = smm.ShareableList(range(100))
- threads = []
- for _ in range(n):
- # Issue gh-127085, the `ShareableList.count` is just a convenient way to mess the `exports`
- # counter of `memoryview`, this issue has no direct relation with `ShareableList`.
- threads.append(Thread(target=obj.count, args=(1,)))
-
+ def test():
+ # Issue gh-127085, the `ShareableList.count` is just a
+ # convenient way to mess the `exports` counter of `memoryview`,
+ # this issue has no direct relation with `ShareableList`.
+ start.wait(support.SHORT_TIMEOUT)
+ for i in range(10):
+ obj.count(1)
+ threads = [Thread(target=test) for _ in range(10)]
with threading_helper.start_threads(threads):
- pass
+ start.set()
del obj
diff --git a/Lib/test/test_mimetypes.py b/Lib/test/test_mimetypes.py
index dad5dbde7cd..fb57d5e5544 100644
--- a/Lib/test/test_mimetypes.py
+++ b/Lib/test/test_mimetypes.py
@@ -6,7 +6,8 @@ import sys
import unittest.mock
from platform import win32_edition
from test import support
-from test.support import os_helper
+from test.support import cpython_only, force_not_colorized, os_helper
+from test.support.import_helper import ensure_lazy_imports
try:
import _winapi
@@ -435,8 +436,13 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, mimetypes)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("mimetypes", {"os", "posixpath", "urllib.parse", "argparse"})
+
class CommandLineTest(unittest.TestCase):
+ @force_not_colorized
def test_parse_args(self):
args, help_text = mimetypes._parse_args("-h")
self.assertTrue(help_text.startswith("usage: "))
diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py
index 6679c0a4fbe..4f25e9c2a03 100644
--- a/Lib/test/test_minidom.py
+++ b/Lib/test/test_minidom.py
@@ -102,41 +102,38 @@ class MinidomTest(unittest.TestCase):
elem = root.childNodes[0]
nelem = dom.createElement("element")
root.insertBefore(nelem, elem)
- self.confirm(len(root.childNodes) == 2
- and root.childNodes.length == 2
- and root.childNodes[0] is nelem
- and root.childNodes.item(0) is nelem
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.firstChild is nelem
- and root.lastChild is elem
- and root.toxml() == "<doc><element/><foo/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 2)
+ self.assertEqual(root.childNodes.length, 2)
+ self.assertIs(root.childNodes[0], nelem)
+ self.assertIs(root.childNodes.item(0), nelem)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.firstChild, nelem)
+ self.assertIs(root.lastChild, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/></doc>")
nelem = dom.createElement("element")
root.insertBefore(nelem, None)
- self.confirm(len(root.childNodes) == 3
- and root.childNodes.length == 3
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.childNodes[2] is nelem
- and root.childNodes.item(2) is nelem
- and root.lastChild is nelem
- and nelem.previousSibling is elem
- and root.toxml() == "<doc><element/><foo/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 3)
+ self.assertEqual(root.childNodes.length, 3)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.childNodes[2], nelem)
+ self.assertIs(root.childNodes.item(2), nelem)
+ self.assertIs(root.lastChild, nelem)
+ self.assertIs(nelem.previousSibling, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/><element/></doc>")
nelem2 = dom.createElement("bar")
root.insertBefore(nelem2, nelem)
- self.confirm(len(root.childNodes) == 4
- and root.childNodes.length == 4
- and root.childNodes[2] is nelem2
- and root.childNodes.item(2) is nelem2
- and root.childNodes[3] is nelem
- and root.childNodes.item(3) is nelem
- and nelem2.nextSibling is nelem
- and nelem.previousSibling is nelem2
- and root.toxml() ==
- "<doc><element/><foo/><bar/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 4)
+ self.assertEqual(root.childNodes.length, 4)
+ self.assertIs(root.childNodes[2], nelem2)
+ self.assertIs(root.childNodes.item(2), nelem2)
+ self.assertIs(root.childNodes[3], nelem)
+ self.assertIs(root.childNodes.item(3), nelem)
+ self.assertIs(nelem2.nextSibling, nelem)
+ self.assertIs(nelem.previousSibling, nelem2)
+ self.assertEqual(root.toxml(),
+ "<doc><element/><foo/><bar/><element/></doc>")
dom.unlink()
def _create_fragment_test_nodes(self):
@@ -342,8 +339,8 @@ class MinidomTest(unittest.TestCase):
self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode,
None)
self.assertIs(node, child.removeAttributeNode(node))
- self.confirm(len(child.attributes) == 0
- and child.getAttributeNode("spam") is None)
+ self.assertEqual(len(child.attributes), 0)
+ self.assertIsNone(child.getAttributeNode("spam"))
dom2 = Document()
child2 = dom2.appendChild(dom2.createElement("foo"))
node2 = child2.getAttributeNode("spam")
@@ -366,33 +363,34 @@ class MinidomTest(unittest.TestCase):
# Set this attribute to be an ID and make sure that doesn't change
# when changing the value:
el.setIdAttribute("spam")
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "bam"
- and el.attributes["spam"].nodeValue == "bam"
- and el.getAttribute("spam") == "bam"
- and el.getAttributeNode("spam").isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "bam")
+ self.assertEqual(el.attributes["spam"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam"), "bam")
+ self.assertTrue(el.getAttributeNode("spam").isId)
el.attributes["spam"] = "ham"
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam"].isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertTrue(el.attributes["spam"].isId)
el.setAttribute("spam2", "bam")
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam"
- and el.attributes["spam2"].nodeValue == "bam"
- and el.getAttribute("spam2") == "bam")
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam2"), "bam")
el.attributes["spam2"] = "bam2"
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam2"
- and el.attributes["spam2"].nodeValue == "bam2"
- and el.getAttribute("spam2") == "bam2")
+
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam2")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam2")
+ self.assertEqual(el.getAttribute("spam2"), "bam2")
dom.unlink()
def testGetAttrList(self):
@@ -448,12 +446,12 @@ class MinidomTest(unittest.TestCase):
dom = parseString(d)
elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom",
"myelem")
- self.confirm(len(elems) == 1
- and elems[0].namespaceURI == "http://pyxml.sf.net/minidom"
- and elems[0].localName == "myelem"
- and elems[0].prefix == "minidom"
- and elems[0].tagName == "minidom:myelem"
- and elems[0].nodeName == "minidom:myelem")
+ self.assertEqual(len(elems), 1)
+ self.assertEqual(elems[0].namespaceURI, "http://pyxml.sf.net/minidom")
+ self.assertEqual(elems[0].localName, "myelem")
+ self.assertEqual(elems[0].prefix, "minidom")
+ self.assertEqual(elems[0].tagName, "minidom:myelem")
+ self.assertEqual(elems[0].nodeName, "minidom:myelem")
dom.unlink()
def get_empty_nodelist_from_elements_by_tagName_ns_helper(self, doc, nsuri,
@@ -602,17 +600,17 @@ class MinidomTest(unittest.TestCase):
def testProcessingInstruction(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
pi = dom.documentElement.firstChild
- self.confirm(pi.target == "mypi"
- and pi.data == "data \t\n "
- and pi.nodeName == "mypi"
- and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
- and pi.attributes is None
- and not pi.hasChildNodes()
- and len(pi.childNodes) == 0
- and pi.firstChild is None
- and pi.lastChild is None
- and pi.localName is None
- and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE)
+ self.assertEqual(pi.target, "mypi")
+ self.assertEqual(pi.data, "data \t\n ")
+ self.assertEqual(pi.nodeName, "mypi")
+ self.assertEqual(pi.nodeType, Node.PROCESSING_INSTRUCTION_NODE)
+ self.assertIsNone(pi.attributes)
+ self.assertFalse(pi.hasChildNodes())
+ self.assertEqual(len(pi.childNodes), 0)
+ self.assertIsNone(pi.firstChild)
+ self.assertIsNone(pi.lastChild)
+ self.assertIsNone(pi.localName)
+ self.assertEqual(pi.namespaceURI, xml.dom.EMPTY_NAMESPACE)
def testProcessingInstructionRepr(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
@@ -718,19 +716,16 @@ class MinidomTest(unittest.TestCase):
keys2 = list(attrs2.keys())
keys1.sort()
keys2.sort()
- self.assertEqual(keys1, keys2,
- "clone of element has same attribute keys")
+ self.assertEqual(keys1, keys2)
for i in range(len(keys1)):
a1 = attrs1.item(i)
a2 = attrs2.item(i)
- self.confirm(a1 is not a2
- and a1.value == a2.value
- and a1.nodeValue == a2.nodeValue
- and a1.namespaceURI == a2.namespaceURI
- and a1.localName == a2.localName
- , "clone of attribute node has proper attribute values")
- self.assertIs(a2.ownerElement, e2,
- "clone of attribute node correctly owned")
+ self.assertIsNot(a1, a2)
+ self.assertEqual(a1.value, a2.value)
+ self.assertEqual(a1.nodeValue, a2.nodeValue)
+ self.assertEqual(a1.namespaceURI,a2.namespaceURI)
+ self.assertEqual(a1.localName, a2.localName)
+ self.assertIs(a2.ownerElement, e2)
def _setupCloneElement(self, deep):
dom = parseString("<doc attr='value'><foo/></doc>")
@@ -746,20 +741,19 @@ class MinidomTest(unittest.TestCase):
def testCloneElementShallow(self):
dom, clone = self._setupCloneElement(0)
- self.confirm(len(clone.childNodes) == 0
- and clone.childNodes.length == 0
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"/>'
- , "testCloneElementShallow")
+ self.assertEqual(len(clone.childNodes), 0)
+ self.assertEqual(clone.childNodes.length, 0)
+ self.assertIsNone(clone.parentNode)
+ self.assertEqual(clone.toxml(), '<doc attr="value"/>')
+
dom.unlink()
def testCloneElementDeep(self):
dom, clone = self._setupCloneElement(1)
- self.confirm(len(clone.childNodes) == 1
- and clone.childNodes.length == 1
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"><foo/></doc>'
- , "testCloneElementDeep")
+ self.assertEqual(len(clone.childNodes), 1)
+ self.assertEqual(clone.childNodes.length, 1)
+ self.assertIsNone(clone.parentNode)
+ self.assertTrue(clone.toxml(), '<doc attr="value"><foo/></doc>')
dom.unlink()
def testCloneDocumentShallow(self):
diff --git a/Lib/test/test_monitoring.py b/Lib/test/test_monitoring.py
index 263e4e6f394..a932ac80117 100644
--- a/Lib/test/test_monitoring.py
+++ b/Lib/test/test_monitoring.py
@@ -2157,6 +2157,21 @@ class TestRegressions(MonitoringTestBase, unittest.TestCase):
sys.monitoring.restart_events()
sys.monitoring.set_events(0, 0)
+ def test_134879(self):
+ # gh-134789
+ # Specialized FOR_ITER not incrementing index
+ def foo():
+ t = 0
+ for i in [1,2,3,4]:
+ t += i
+ self.assertEqual(t, 10)
+
+ sys.monitoring.use_tool_id(0, "test")
+ self.addCleanup(sys.monitoring.free_tool_id, 0)
+ sys.monitoring.set_local_events(0, foo.__code__, E.BRANCH_LEFT | E.BRANCH_RIGHT)
+ foo()
+ sys.monitoring.set_local_events(0, foo.__code__, 0)
+
class TestOptimizer(MonitoringTestBase, unittest.TestCase):
diff --git a/Lib/test/test_netrc.py b/Lib/test/test_netrc.py
index 81e11a293cc..9d720f62710 100644
--- a/Lib/test/test_netrc.py
+++ b/Lib/test/test_netrc.py
@@ -1,11 +1,7 @@
import netrc, os, unittest, sys, textwrap
+from test import support
from test.support import os_helper
-try:
- import pwd
-except ImportError:
- pwd = None
-
temp_filename = os_helper.TESTFN
class NetrcTestCase(unittest.TestCase):
@@ -269,9 +265,14 @@ class NetrcTestCase(unittest.TestCase):
machine bar.domain.com login foo password pass
""", '#pass')
+ @unittest.skipUnless(support.is_wasi, 'WASI only test')
+ def test_security_on_WASI(self):
+ self.assertFalse(netrc._can_security_check())
+ self.assertEqual(netrc._getpwuid(0), 'uid 0')
+ self.assertEqual(netrc._getpwuid(123456), 'uid 123456')
@unittest.skipUnless(os.name == 'posix', 'POSIX only test')
- @unittest.skipIf(pwd is None, 'security check requires pwd module')
+ @unittest.skipUnless(hasattr(os, 'getuid'), "os.getuid is required")
@os_helper.skip_unless_working_chmod
def test_security(self):
# This test is incomplete since we are normally not run as root and
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index c10387b58e3..22f6403d482 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -6,8 +6,9 @@ import subprocess
import sys
import unittest
import warnings
-from test.support import cpython_only, os_helper
-from test.support import TestFailed, is_emscripten
+from ntpath import ALLOW_MISSING
+from test import support
+from test.support import TestFailed, cpython_only, os_helper
from test.support.os_helper import FakePath
from test import test_genericpath
from tempfile import TemporaryFile
@@ -77,6 +78,10 @@ def tester(fn, wantResult):
%(str(fn), str(wantResult), repr(gotResult)))
+def _parameterize(*parameters):
+ return support.subTests('kwargs', parameters, _do_cleanups=True)
+
+
class NtpathTestCase(unittest.TestCase):
def assertPathEqual(self, path1, path2):
if path1 == path2 or _norm(path1) == _norm(path2):
@@ -124,6 +129,22 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.splitdrive("//?/UNC/server/share/dir")',
("//?/UNC/server/share", "/dir"))
+ def test_splitdrive_invalid_paths(self):
+ splitdrive = ntpath.splitdrive
+ self.assertEqual(splitdrive('\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ ('\\\\ser\x00ver\\sha\x00re', '\\di\x00r'))
+ self.assertEqual(splitdrive(b'\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ (b'\\\\ser\x00ver\\sha\x00re', b'\\di\x00r'))
+ self.assertEqual(splitdrive("\\\\\udfff\\\udffe\\\udffd"),
+ ('\\\\\udfff\\\udffe', '\\\udffd'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\\xff\\share\\dir')
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\server\\\xff\\dir')
+ self.assertRaises(UnicodeDecodeError, splitdrive, b'\\\\server\\share\\\xff')
+ else:
+ self.assertEqual(splitdrive(b'\\\\\xff\\\xfe\\\xfd'),
+ (b'\\\\\xff\\\xfe', b'\\\xfd'))
+
def test_splitroot(self):
tester("ntpath.splitroot('')", ('', '', ''))
tester("ntpath.splitroot('foo')", ('', '', 'foo'))
@@ -214,6 +235,22 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.splitroot(" :/foo")', (" :", "/", "foo"))
tester('ntpath.splitroot("/:/foo")', ("", "/", ":/foo"))
+ def test_splitroot_invalid_paths(self):
+ splitroot = ntpath.splitroot
+ self.assertEqual(splitroot('\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ ('\\\\ser\x00ver\\sha\x00re', '\\', 'di\x00r'))
+ self.assertEqual(splitroot(b'\\\\ser\x00ver\\sha\x00re\\di\x00r'),
+ (b'\\\\ser\x00ver\\sha\x00re', b'\\', b'di\x00r'))
+ self.assertEqual(splitroot("\\\\\udfff\\\udffe\\\udffd"),
+ ('\\\\\udfff\\\udffe', '\\', '\udffd'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\\xff\\share\\dir')
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\server\\\xff\\dir')
+ self.assertRaises(UnicodeDecodeError, splitroot, b'\\\\server\\share\\\xff')
+ else:
+ self.assertEqual(splitroot(b'\\\\\xff\\\xfe\\\xfd'),
+ (b'\\\\\xff\\\xfe', b'\\', b'\xfd'))
+
def test_split(self):
tester('ntpath.split("c:\\foo\\bar")', ('c:\\foo', 'bar'))
tester('ntpath.split("\\\\conky\\mountpoint\\foo\\bar")',
@@ -226,6 +263,21 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.split("c:/")', ('c:/', ''))
tester('ntpath.split("//conky/mountpoint/")', ('//conky/mountpoint/', ''))
+ def test_split_invalid_paths(self):
+ split = ntpath.split
+ self.assertEqual(split('c:\\fo\x00o\\ba\x00r'),
+ ('c:\\fo\x00o', 'ba\x00r'))
+ self.assertEqual(split(b'c:\\fo\x00o\\ba\x00r'),
+ (b'c:\\fo\x00o', b'ba\x00r'))
+ self.assertEqual(split('c:\\\udfff\\\udffe'),
+ ('c:\\\udfff', '\udffe'))
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, split, b'c:\\\xff\\bar')
+ self.assertRaises(UnicodeDecodeError, split, b'c:\\foo\\\xff')
+ else:
+ self.assertEqual(split(b'c:\\\xff\\\xfe'),
+ (b'c:\\\xff', b'\xfe'))
+
def test_isabs(self):
tester('ntpath.isabs("foo\\bar")', 0)
tester('ntpath.isabs("foo/bar")', 0)
@@ -333,6 +385,30 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.join('D:a', './c:b')", 'D:a\\.\\c:b')
tester("ntpath.join('D:/a', './c:b')", 'D:\\a\\.\\c:b')
+ def test_normcase(self):
+ normcase = ntpath.normcase
+ self.assertEqual(normcase(''), '')
+ self.assertEqual(normcase(b''), b'')
+ self.assertEqual(normcase('ABC'), 'abc')
+ self.assertEqual(normcase(b'ABC'), b'abc')
+ self.assertEqual(normcase('\xc4\u0141\u03a8'), '\xe4\u0142\u03c8')
+ expected = '\u03c9\u2126' if sys.platform == 'win32' else '\u03c9\u03c9'
+ self.assertEqual(normcase('\u03a9\u2126'), expected)
+ if sys.platform == 'win32' or sys.getfilesystemencoding() == 'utf-8':
+ self.assertEqual(normcase('\xc4\u0141\u03a8'.encode()),
+ '\xe4\u0142\u03c8'.encode())
+ self.assertEqual(normcase('\u03a9\u2126'.encode()),
+ expected.encode())
+
+ def test_normcase_invalid_paths(self):
+ normcase = ntpath.normcase
+ self.assertEqual(normcase('abc\x00def'), 'abc\x00def')
+ self.assertEqual(normcase(b'abc\x00def'), b'abc\x00def')
+ self.assertEqual(normcase('\udfff'), '\udfff')
+ if sys.platform == 'win32':
+ path = b'ABC' + bytes(range(128, 256))
+ self.assertEqual(normcase(path), path.lower())
+
def test_normpath(self):
tester("ntpath.normpath('A//////././//.//B')", r'A\B')
tester("ntpath.normpath('A/./B')", r'A\B')
@@ -381,6 +457,21 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.normpath('\\\\')", '\\\\')
tester("ntpath.normpath('//?/UNC/server/share/..')", '\\\\?\\UNC\\server\\share\\')
+ def test_normpath_invalid_paths(self):
+ normpath = ntpath.normpath
+ self.assertEqual(normpath('fo\x00o'), 'fo\x00o')
+ self.assertEqual(normpath(b'fo\x00o'), b'fo\x00o')
+ self.assertEqual(normpath('fo\x00o\\..\\bar'), 'bar')
+ self.assertEqual(normpath(b'fo\x00o\\..\\bar'), b'bar')
+ self.assertEqual(normpath('\udfff'), '\udfff')
+ self.assertEqual(normpath('\udfff\\..\\foo'), 'foo')
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, normpath, b'\xff')
+ self.assertRaises(UnicodeDecodeError, normpath, b'\xff\\..\\foo')
+ else:
+ self.assertEqual(normpath(b'\xff'), b'\xff')
+ self.assertEqual(normpath(b'\xff\\..\\foo'), b'foo')
+
def test_realpath_curdir(self):
expected = ntpath.normpath(os.getcwd())
tester("ntpath.realpath('.')", expected)
@@ -389,6 +480,27 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.realpath('.\\.')", expected)
tester("ntpath.realpath('\\'.join(['.'] * 100))", expected)
+ def test_realpath_curdir_strict(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('.', strict=True)", expected)
+ tester("ntpath.realpath('./.', strict=True)", expected)
+ tester("ntpath.realpath('/'.join(['.'] * 100), strict=True)", expected)
+ tester("ntpath.realpath('.\\.', strict=True)", expected)
+ tester("ntpath.realpath('\\'.join(['.'] * 100), strict=True)", expected)
+
+ def test_realpath_curdir_missing_ok(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('./.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('/'.join(['.'] * 100), strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('.\\.', strict=ALLOW_MISSING)",
+ expected)
+ tester("ntpath.realpath('\\'.join(['.'] * 100), strict=ALLOW_MISSING)",
+ expected)
+
def test_realpath_pardir(self):
expected = ntpath.normpath(os.getcwd())
tester("ntpath.realpath('..')", ntpath.dirname(expected))
@@ -401,28 +513,59 @@ class TestNtpath(NtpathTestCase):
tester("ntpath.realpath('\\'.join(['..'] * 50))",
ntpath.splitdrive(expected)[0] + '\\')
+ def test_realpath_pardir_strict(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('..', strict=True)", ntpath.dirname(expected))
+ tester("ntpath.realpath('../..', strict=True)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('/'.join(['..'] * 50), strict=True)",
+ ntpath.splitdrive(expected)[0] + '\\')
+ tester("ntpath.realpath('..\\..', strict=True)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('\\'.join(['..'] * 50), strict=True)",
+ ntpath.splitdrive(expected)[0] + '\\')
+
+ def test_realpath_pardir_missing_ok(self):
+ expected = ntpath.normpath(os.getcwd())
+ tester("ntpath.realpath('..', strict=ALLOW_MISSING)",
+ ntpath.dirname(expected))
+ tester("ntpath.realpath('../..', strict=ALLOW_MISSING)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('/'.join(['..'] * 50), strict=ALLOW_MISSING)",
+ ntpath.splitdrive(expected)[0] + '\\')
+ tester("ntpath.realpath('..\\..', strict=ALLOW_MISSING)",
+ ntpath.dirname(ntpath.dirname(expected)))
+ tester("ntpath.realpath('\\'.join(['..'] * 50), strict=ALLOW_MISSING)",
+ ntpath.splitdrive(expected)[0] + '\\')
+
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_basic(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_basic(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
open(ABSTFN, "wb").close()
self.addCleanup(os_helper.unlink, ABSTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "1")
os.symlink(ABSTFN, ABSTFN + "1")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "1"), ABSTFN)
- self.assertPathEqual(ntpath.realpath(os.fsencode(ABSTFN + "1")),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1", **kwargs), ABSTFN)
+ self.assertPathEqual(ntpath.realpath(os.fsencode(ABSTFN + "1"), **kwargs),
os.fsencode(ABSTFN))
# gh-88013: call ntpath.realpath with binary drive name may raise a
# TypeError. The drive should not exist to reproduce the bug.
drives = {f"{c}:\\" for c in string.ascii_uppercase} - set(os.listdrives())
d = drives.pop().encode()
- self.assertEqual(ntpath.realpath(d), d)
+ self.assertEqual(ntpath.realpath(d, strict=False), d)
# gh-106242: Embedded nulls and non-strict fallback to abspath
- self.assertEqual(ABSTFN + "\0spam",
- ntpath.realpath(os_helper.TESTFN + "\0spam", strict=False))
+ if kwargs:
+ with self.assertRaises(OSError):
+ ntpath.realpath(os_helper.TESTFN + "\0spam",
+ **kwargs)
+ else:
+ self.assertEqual(ABSTFN + "\0spam",
+ ntpath.realpath(os_helper.TESTFN + "\0spam", **kwargs))
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@@ -434,19 +577,77 @@ class TestNtpath(NtpathTestCase):
self.addCleanup(os_helper.unlink, ABSTFN)
self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN, strict=True)
self.assertRaises(FileNotFoundError, ntpath.realpath, ABSTFN + "2", strict=True)
+
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ def test_realpath_invalid_paths(self):
+ realpath = ntpath.realpath
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ ABSTFNb = os.fsencode(ABSTFN)
+ path = ABSTFN + '\x00'
+ # gh-106242: Embedded nulls and non-strict fallback to abspath
+ self.assertEqual(realpath(path, strict=False), path)
# gh-106242: Embedded nulls should raise OSError (not ValueError)
- self.assertRaises(OSError, ntpath.realpath, ABSTFN + "\0spam", strict=True)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFNb + b'\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFN + '\\nonexistent\\x\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFNb + b'\\nonexistent\\x\x00'
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ path = ABSTFN + '\x00\\..'
+ self.assertEqual(realpath(path, strict=False), os.getcwd())
+ self.assertEqual(realpath(path, strict=True), os.getcwd())
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), os.getcwd())
+ path = ABSTFNb + b'\x00\\..'
+ self.assertEqual(realpath(path, strict=False), os.getcwdb())
+ self.assertEqual(realpath(path, strict=True), os.getcwdb())
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), os.getcwdb())
+ path = ABSTFN + '\\nonexistent\\x\x00\\..'
+ self.assertEqual(realpath(path, strict=False), ABSTFN + '\\nonexistent')
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), ABSTFN + '\\nonexistent')
+ path = ABSTFNb + b'\\nonexistent\\x\x00\\..'
+ self.assertEqual(realpath(path, strict=False), ABSTFNb + b'\\nonexistent')
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), ABSTFNb + b'\\nonexistent')
+
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_invalid_unicode_paths(self, kwargs):
+ realpath = ntpath.realpath
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ ABSTFNb = os.fsencode(ABSTFN)
+ path = ABSTFNb + b'\xff'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\\nonexistent\\\xff'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\xff\\..'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ path = ABSTFNb + b'\\nonexistent\\\xff\\..'
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
+ self.assertRaises(UnicodeDecodeError, realpath, path, **kwargs)
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_relative(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_relative(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
open(ABSTFN, "wb").close()
self.addCleanup(os_helper.unlink, ABSTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "1")
os.symlink(ABSTFN, ntpath.relpath(ABSTFN + "1"))
- self.assertPathEqual(ntpath.realpath(ABSTFN + "1"), ABSTFN)
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1", **kwargs), ABSTFN)
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@@ -598,7 +799,62 @@ class TestNtpath(NtpathTestCase):
@os_helper.skip_unless_symlink
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
- def test_realpath_symlink_prefix(self):
+ def test_realpath_symlink_loops_raise(self):
+ # Symlink loops raise OSError in ALLOW_MISSING mode
+ ABSTFN = ntpath.abspath(os_helper.TESTFN)
+ self.addCleanup(os_helper.unlink, ABSTFN)
+ self.addCleanup(os_helper.unlink, ABSTFN + "1")
+ self.addCleanup(os_helper.unlink, ABSTFN + "2")
+ self.addCleanup(os_helper.unlink, ABSTFN + "y")
+ self.addCleanup(os_helper.unlink, ABSTFN + "c")
+ self.addCleanup(os_helper.unlink, ABSTFN + "a")
+ self.addCleanup(os_helper.unlink, ABSTFN + "x")
+
+ os.symlink(ABSTFN, ABSTFN)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN, strict=ALLOW_MISSING)
+
+ os.symlink(ABSTFN + "1", ABSTFN + "2")
+ os.symlink(ABSTFN + "2", ABSTFN + "1")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "1",
+ strict=ALLOW_MISSING)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "2",
+ strict=ALLOW_MISSING)
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "1\\x",
+ strict=ALLOW_MISSING)
+
+ # Windows eliminates '..' components before resolving links;
+ # realpath is not expected to raise if this removes the loop.
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\.."),
+ ntpath.dirname(ABSTFN))
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\..\\x"),
+ ntpath.dirname(ABSTFN) + "\\x")
+
+ os.symlink(ABSTFN + "x", ABSTFN + "y")
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "1\\..\\"
+ + ntpath.basename(ABSTFN) + "y"),
+ ABSTFN + "x")
+ self.assertRaises(
+ OSError, ntpath.realpath,
+ ABSTFN + "1\\..\\" + ntpath.basename(ABSTFN) + "1",
+ strict=ALLOW_MISSING)
+
+ os.symlink(ntpath.basename(ABSTFN) + "a\\b", ABSTFN + "a")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "a",
+ strict=ALLOW_MISSING)
+
+ os.symlink("..\\" + ntpath.basename(ntpath.dirname(ABSTFN))
+ + "\\" + ntpath.basename(ABSTFN) + "c", ABSTFN + "c")
+ self.assertRaises(OSError, ntpath.realpath, ABSTFN + "c",
+ strict=ALLOW_MISSING)
+
+ # Test using relative path as well.
+ self.assertRaises(OSError, ntpath.realpath, ntpath.basename(ABSTFN),
+ strict=ALLOW_MISSING)
+
+ @os_helper.skip_unless_symlink
+ @unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_symlink_prefix(self, kwargs):
ABSTFN = ntpath.abspath(os_helper.TESTFN)
self.addCleanup(os_helper.unlink, ABSTFN + "3")
self.addCleanup(os_helper.unlink, "\\\\?\\" + ABSTFN + "3.")
@@ -613,9 +869,9 @@ class TestNtpath(NtpathTestCase):
f.write(b'1')
os.symlink("\\\\?\\" + ABSTFN + "3.", ABSTFN + "3.link")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "3link"),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "3link", **kwargs),
ABSTFN + "3")
- self.assertPathEqual(ntpath.realpath(ABSTFN + "3.link"),
+ self.assertPathEqual(ntpath.realpath(ABSTFN + "3.link", **kwargs),
"\\\\?\\" + ABSTFN + "3.")
# Resolved paths should be usable to open target files
@@ -625,14 +881,17 @@ class TestNtpath(NtpathTestCase):
self.assertEqual(f.read(), b'1')
# When the prefix is included, it is not stripped
- self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3link"),
+ self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3link", **kwargs),
"\\\\?\\" + ABSTFN + "3")
- self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3.link"),
+ self.assertPathEqual(ntpath.realpath("\\\\?\\" + ABSTFN + "3.link", **kwargs),
"\\\\?\\" + ABSTFN + "3.")
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
def test_realpath_nul(self):
tester("ntpath.realpath('NUL')", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=False)", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=True)", r'\\.\NUL')
+ tester("ntpath.realpath('NUL', strict=ALLOW_MISSING)", r'\\.\NUL')
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
@unittest.skipUnless(HAVE_GETSHORTPATHNAME, 'need _getshortpathname')
@@ -656,12 +915,20 @@ class TestNtpath(NtpathTestCase):
self.assertPathEqual(test_file_long, ntpath.realpath(test_file_short))
- with os_helper.change_cwd(test_dir_long):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
- with os_helper.change_cwd(test_dir_long.lower()):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
- with os_helper.change_cwd(test_dir_short):
- self.assertPathEqual(test_file_long, ntpath.realpath("file.txt"))
+ for kwargs in {}, {'strict': True}, {'strict': ALLOW_MISSING}:
+ with self.subTest(**kwargs):
+ with os_helper.change_cwd(test_dir_long):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
+ with os_helper.change_cwd(test_dir_long.lower()):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
+ with os_helper.change_cwd(test_dir_short):
+ self.assertPathEqual(
+ test_file_long,
+ ntpath.realpath("file.txt", **kwargs))
@unittest.skipUnless(HAVE_GETFINALPATHNAME, 'need _getfinalpathname')
def test_realpath_permission(self):
@@ -682,12 +949,15 @@ class TestNtpath(NtpathTestCase):
# Automatic generation of short names may be disabled on
# NTFS volumes for the sake of performance.
# They're not supported at all on ReFS and exFAT.
- subprocess.run(
+ p = subprocess.run(
# Try to set the short name manually.
['fsutil.exe', 'file', 'setShortName', test_file, 'LONGFI~1.TXT'],
creationflags=subprocess.DETACHED_PROCESS
)
+ if p.returncode:
+ raise unittest.SkipTest('failed to set short name')
+
try:
self.assertPathEqual(test_file, ntpath.realpath(test_file_short))
except AssertionError:
@@ -812,8 +1082,6 @@ class TestNtpath(NtpathTestCase):
tester('ntpath.abspath("C:/nul")', "\\\\.\\nul")
tester('ntpath.abspath("C:\\nul")', "\\\\.\\nul")
self.assertTrue(ntpath.isabs(ntpath.abspath("C:spam")))
- self.assertEqual(ntpath.abspath("C:\x00"), ntpath.join(ntpath.abspath("C:"), "\x00"))
- self.assertEqual(ntpath.abspath("\x00:spam"), "\x00:\\spam")
tester('ntpath.abspath("//..")', "\\\\")
tester('ntpath.abspath("//../")', "\\\\..\\")
tester('ntpath.abspath("//../..")', "\\\\..\\")
@@ -847,6 +1115,26 @@ class TestNtpath(NtpathTestCase):
drive, _ = ntpath.splitdrive(cwd_dir)
tester('ntpath.abspath("/abc/")', drive + "\\abc")
+ def test_abspath_invalid_paths(self):
+ abspath = ntpath.abspath
+ if sys.platform == 'win32':
+ self.assertEqual(abspath("C:\x00"), ntpath.join(abspath("C:"), "\x00"))
+ self.assertEqual(abspath(b"C:\x00"), ntpath.join(abspath(b"C:"), b"\x00"))
+ self.assertEqual(abspath("\x00:spam"), "\x00:\\spam")
+ self.assertEqual(abspath(b"\x00:spam"), b"\x00:\\spam")
+ self.assertEqual(abspath('c:\\fo\x00o'), 'c:\\fo\x00o')
+ self.assertEqual(abspath(b'c:\\fo\x00o'), b'c:\\fo\x00o')
+ self.assertEqual(abspath('c:\\fo\x00o\\..\\bar'), 'c:\\bar')
+ self.assertEqual(abspath(b'c:\\fo\x00o\\..\\bar'), b'c:\\bar')
+ self.assertEqual(abspath('c:\\\udfff'), 'c:\\\udfff')
+ self.assertEqual(abspath('c:\\\udfff\\..\\foo'), 'c:\\foo')
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, abspath, b'c:\\\xff')
+ self.assertRaises(UnicodeDecodeError, abspath, b'c:\\\xff\\..\\foo')
+ else:
+ self.assertEqual(abspath(b'c:\\\xff'), b'c:\\\xff')
+ self.assertEqual(abspath(b'c:\\\xff\\..\\foo'), b'c:\\foo')
+
def test_relpath(self):
tester('ntpath.relpath("a")', 'a')
tester('ntpath.relpath(ntpath.abspath("a"))', 'a')
@@ -989,6 +1277,18 @@ class TestNtpath(NtpathTestCase):
self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$"))
self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$\\"))
+ def test_ismount_invalid_paths(self):
+ ismount = ntpath.ismount
+ self.assertFalse(ismount("c:\\\udfff"))
+ if sys.platform == 'win32':
+ self.assertRaises(ValueError, ismount, "c:\\\x00")
+ self.assertRaises(ValueError, ismount, b"c:\\\x00")
+ self.assertRaises(UnicodeDecodeError, ismount, b"c:\\\xff")
+ else:
+ self.assertFalse(ismount("c:\\\x00"))
+ self.assertFalse(ismount(b"c:\\\x00"))
+ self.assertFalse(ismount(b"c:\\\xff"))
+
def test_isreserved(self):
self.assertFalse(ntpath.isreserved(''))
self.assertFalse(ntpath.isreserved('.'))
@@ -1095,6 +1395,13 @@ class TestNtpath(NtpathTestCase):
self.assertFalse(ntpath.isjunction('tmpdir'))
self.assertPathEqual(ntpath.realpath('testjunc'), ntpath.realpath('tmpdir'))
+ def test_isfile_invalid_paths(self):
+ isfile = ntpath.isfile
+ self.assertIs(isfile('/tmp\udfffabcds'), False)
+ self.assertIs(isfile(b'/tmp\xffabcds'), False)
+ self.assertIs(isfile('/tmp\x00abcds'), False)
+ self.assertIs(isfile(b'/tmp\x00abcds'), False)
+
@unittest.skipIf(sys.platform != 'win32', "drive letters are a windows concept")
def test_isfile_driveletter(self):
drive = os.environ.get('SystemDrive')
@@ -1195,9 +1502,6 @@ class PathLikeTests(NtpathTestCase):
def test_path_normcase(self):
self._check_function(self.path.normcase)
- if sys.platform == 'win32':
- self.assertEqual(ntpath.normcase('\u03a9\u2126'), 'ωΩ')
- self.assertEqual(ntpath.normcase('abc\x00def'), 'abc\x00def')
def test_path_isabs(self):
self._check_function(self.path.isabs)
diff --git a/Lib/test/test_opcache.py b/Lib/test/test_opcache.py
index a45aafc63fa..30baa090486 100644
--- a/Lib/test/test_opcache.py
+++ b/Lib/test/test_opcache.py
@@ -1810,20 +1810,6 @@ class TestSpecializer(TestBase):
self.assert_specialized(compare_op_str, "COMPARE_OP_STR")
self.assert_no_opcode(compare_op_str, "COMPARE_OP")
- @cpython_only
- @requires_specialization_ft
- def test_load_const(self):
- def load_const():
- def unused(): pass
- # Currently, the empty tuple is immortal, and the otherwise
- # unused nested function's code object is mortal. This test will
- # have to use different values if either of that changes.
- return ()
-
- load_const()
- self.assert_specialized(load_const, "LOAD_CONST_IMMORTAL")
- self.assert_specialized(load_const, "LOAD_CONST_MORTAL")
- self.assert_no_opcode(load_const, "LOAD_CONST")
@cpython_only
@requires_specialization_ft
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 8655a0537a5..e476e472780 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -14,8 +14,9 @@ import unittest
from io import StringIO
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
from test.support.i18n_helper import TestTranslationsBase, update_translation_snapshots
+from test.support.import_helper import ensure_lazy_imports
import optparse
from optparse import make_option, Option, \
@@ -614,9 +615,9 @@ Options:
self.parser.add_option(
"-p", "--prob",
help="blow up with probability PROB [default: %default]")
- self.parser.set_defaults(prob=0.43)
+ self.parser.set_defaults(prob=0.25)
expected_help = self.help_prefix + \
- " -p PROB, --prob=PROB blow up with probability PROB [default: 0.43]\n"
+ " -p PROB, --prob=PROB blow up with probability PROB [default: 0.25]\n"
self.assertHelp(self.parser, expected_help)
def test_alt_expand(self):
@@ -1655,6 +1656,10 @@ class MiscTestCase(unittest.TestCase):
not_exported = {'check_builtin', 'AmbiguousOptionError', 'NO_DEFAULT'}
support.check__all__(self, optparse, not_exported=not_exported)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("optparse", {"textwrap"})
+
class TestTranslations(TestTranslationsBase):
def test_translations(self):
diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py
index 9f131a9110d..4204a6a47d2 100644
--- a/Lib/test/test_ordered_dict.py
+++ b/Lib/test/test_ordered_dict.py
@@ -147,7 +147,7 @@ class OrderedDictTests:
def test_abc(self):
OrderedDict = self.OrderedDict
self.assertIsInstance(OrderedDict(), MutableMapping)
- self.assertTrue(issubclass(OrderedDict, MutableMapping))
+ self.assertIsSubclass(OrderedDict, MutableMapping)
def test_clear(self):
OrderedDict = self.OrderedDict
@@ -314,14 +314,14 @@ class OrderedDictTests:
check(dup)
self.assertIs(dup.x, od.x)
self.assertIs(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
dup = copy.deepcopy(od)
check(dup)
self.assertEqual(dup.x, od.x)
self.assertIsNot(dup.x, od.x)
self.assertEqual(dup.z, od.z)
self.assertIsNot(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
# pickle directly pulls the module, so we have to fake it
with replaced_module('collections', self.module):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -330,7 +330,7 @@ class OrderedDictTests:
check(dup)
self.assertEqual(dup.x, od.x)
self.assertEqual(dup.z, od.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
check(eval(repr(od)))
update_test = OrderedDict()
update_test.update(od)
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 333179a71e3..5217037ae9d 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -13,7 +13,6 @@ import itertools
import locale
import os
import pickle
-import platform
import select
import selectors
import shutil
@@ -818,7 +817,7 @@ class StatAttributeTests(unittest.TestCase):
self.assertEqual(ctx.exception.errno, errno.EBADF)
def check_file_attributes(self, result):
- self.assertTrue(hasattr(result, 'st_file_attributes'))
+ self.assertHasAttr(result, 'st_file_attributes')
self.assertTrue(isinstance(result.st_file_attributes, int))
self.assertTrue(0 <= result.st_file_attributes <= 0xFFFFFFFF)
@@ -1919,6 +1918,10 @@ class MakedirTests(unittest.TestCase):
support.is_wasi,
"WASI's umask is a stub."
)
+ @unittest.skipIf(
+ support.is_emscripten,
+ "TODO: Fails in buildbot; see #135783"
+ )
def test_mode(self):
with os_helper.temp_umask(0o002):
base = os_helper.TESTFN
@@ -2181,7 +2184,7 @@ class GetRandomTests(unittest.TestCase):
self.assertEqual(empty, b'')
def test_getrandom_random(self):
- self.assertTrue(hasattr(os, 'GRND_RANDOM'))
+ self.assertHasAttr(os, 'GRND_RANDOM')
# Don't test os.getrandom(1, os.GRND_RANDOM) to not consume the rare
# resource /dev/random
@@ -4291,13 +4294,8 @@ class EventfdTests(unittest.TestCase):
@unittest.skipIf(sys.platform == "android", "gh-124873: Test is flaky on Android")
@support.requires_linux_version(2, 6, 30)
class TimerfdTests(unittest.TestCase):
- # 1 ms accuracy is reliably achievable on every platform except Android
- # emulators, where we allow 10 ms (gh-108277).
- if sys.platform == "android" and platform.android_ver().is_emulator:
- CLOCK_RES_PLACES = 2
- else:
- CLOCK_RES_PLACES = 3
-
+ # gh-126112: Use 10 ms to tolerate slow buildbots
+ CLOCK_RES_PLACES = 2 # 10 ms
CLOCK_RES = 10 ** -CLOCK_RES_PLACES
CLOCK_RES_NS = 10 ** (9 - CLOCK_RES_PLACES)
@@ -5431,8 +5429,8 @@ class TestPEP519(unittest.TestCase):
def test_pathlike(self):
self.assertEqual('#feelthegil', self.fspath(FakePath('#feelthegil')))
- self.assertTrue(issubclass(FakePath, os.PathLike))
- self.assertTrue(isinstance(FakePath('x'), os.PathLike))
+ self.assertIsSubclass(FakePath, os.PathLike)
+ self.assertIsInstance(FakePath('x'), os.PathLike)
def test_garbage_in_exception_out(self):
vapor = type('blah', (), {})
@@ -5458,8 +5456,8 @@ class TestPEP519(unittest.TestCase):
# true on abstract implementation.
class A(os.PathLike):
pass
- self.assertFalse(issubclass(FakePath, A))
- self.assertTrue(issubclass(FakePath, os.PathLike))
+ self.assertNotIsSubclass(FakePath, A)
+ self.assertIsSubclass(FakePath, os.PathLike)
def test_pathlike_class_getitem(self):
self.assertIsInstance(os.PathLike[bytes], types.GenericAlias)
@@ -5469,7 +5467,7 @@ class TestPEP519(unittest.TestCase):
__slots__ = ()
def __fspath__(self):
return ''
- self.assertFalse(hasattr(A(), '__dict__'))
+ self.assertNotHasAttr(A(), '__dict__')
def test_fspath_set_to_None(self):
class Foo:
diff --git a/Lib/test/test_pathlib/support/lexical_path.py b/Lib/test/test_pathlib/support/lexical_path.py
index f29a521af9b..fd7fbf283a6 100644
--- a/Lib/test/test_pathlib/support/lexical_path.py
+++ b/Lib/test/test_pathlib/support/lexical_path.py
@@ -9,9 +9,10 @@ import posixpath
from . import is_pypi
if is_pypi:
- from pathlib_abc import _JoinablePath
+ from pathlib_abc import vfspath, _JoinablePath
else:
from pathlib.types import _JoinablePath
+ from pathlib._os import vfspath
class LexicalPath(_JoinablePath):
@@ -22,20 +23,20 @@ class LexicalPath(_JoinablePath):
self._segments = pathsegments
def __hash__(self):
- return hash(str(self))
+ return hash(vfspath(self))
def __eq__(self, other):
if not isinstance(other, LexicalPath):
return NotImplemented
- return str(self) == str(other)
+ return vfspath(self) == vfspath(other)
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r})'
+ return f'{type(self).__name__}({vfspath(self)!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments)
diff --git a/Lib/test/test_pathlib/support/local_path.py b/Lib/test/test_pathlib/support/local_path.py
index d481fd45ead..c1423c545bf 100644
--- a/Lib/test/test_pathlib/support/local_path.py
+++ b/Lib/test/test_pathlib/support/local_path.py
@@ -97,7 +97,7 @@ class LocalPathInfo(PathInfo):
__slots__ = ('_path', '_exists', '_is_dir', '_is_file', '_is_symlink')
def __init__(self, path):
- self._path = str(path)
+ self._path = os.fspath(path)
self._exists = None
self._is_dir = None
self._is_file = None
@@ -139,14 +139,12 @@ class ReadableLocalPath(_ReadablePath, LexicalPath):
Simple implementation of a ReadablePath class for local filesystem paths.
"""
__slots__ = ('info',)
+ __fspath__ = LexicalPath.__vfspath__
def __init__(self, *pathsegments):
super().__init__(*pathsegments)
self.info = LocalPathInfo(self)
- def __fspath__(self):
- return str(self)
-
def __open_rb__(self, buffering=-1):
return open(self, 'rb')
@@ -163,9 +161,7 @@ class WritableLocalPath(_WritablePath, LexicalPath):
"""
__slots__ = ()
-
- def __fspath__(self):
- return str(self)
+ __fspath__ = LexicalPath.__vfspath__
def __open_wb__(self, buffering=-1):
return open(self, 'wb')
diff --git a/Lib/test/test_pathlib/support/zip_path.py b/Lib/test/test_pathlib/support/zip_path.py
index 2905260c9df..21e1d07423a 100644
--- a/Lib/test/test_pathlib/support/zip_path.py
+++ b/Lib/test/test_pathlib/support/zip_path.py
@@ -16,9 +16,10 @@ from stat import S_IFMT, S_ISDIR, S_ISREG, S_ISLNK
from . import is_pypi
if is_pypi:
- from pathlib_abc import PathInfo, _ReadablePath, _WritablePath
+ from pathlib_abc import vfspath, PathInfo, _ReadablePath, _WritablePath
else:
from pathlib.types import PathInfo, _ReadablePath, _WritablePath
+ from pathlib._os import vfspath
class ZipPathGround:
@@ -34,16 +35,16 @@ class ZipPathGround:
root.zip_file.close()
def create_file(self, path, data=b''):
- path.zip_file.writestr(str(path), data)
+ path.zip_file.writestr(vfspath(path), data)
def create_dir(self, path):
- zip_info = zipfile.ZipInfo(str(path) + '/')
+ zip_info = zipfile.ZipInfo(vfspath(path) + '/')
zip_info.external_attr |= stat.S_IFDIR << 16
zip_info.external_attr |= stat.FILE_ATTRIBUTE_DIRECTORY
path.zip_file.writestr(zip_info, '')
def create_symlink(self, path, target):
- zip_info = zipfile.ZipInfo(str(path))
+ zip_info = zipfile.ZipInfo(vfspath(path))
zip_info.external_attr = stat.S_IFLNK << 16
path.zip_file.writestr(zip_info, target.encode())
@@ -62,28 +63,28 @@ class ZipPathGround:
self.create_symlink(p.joinpath('brokenLinkLoop'), 'brokenLinkLoop')
def readtext(self, p):
- with p.zip_file.open(str(p), 'r') as f:
+ with p.zip_file.open(vfspath(p), 'r') as f:
f = io.TextIOWrapper(f, encoding='utf-8')
return f.read()
def readbytes(self, p):
- with p.zip_file.open(str(p), 'r') as f:
+ with p.zip_file.open(vfspath(p), 'r') as f:
return f.read()
readlink = readtext
def isdir(self, p):
- path_str = str(p) + "/"
+ path_str = vfspath(p) + "/"
return path_str in p.zip_file.NameToInfo
def isfile(self, p):
- info = p.zip_file.NameToInfo.get(str(p))
+ info = p.zip_file.NameToInfo.get(vfspath(p))
if info is None:
return False
return not stat.S_ISLNK(info.external_attr >> 16)
def islink(self, p):
- info = p.zip_file.NameToInfo.get(str(p))
+ info = p.zip_file.NameToInfo.get(vfspath(p))
if info is None:
return False
return stat.S_ISLNK(info.external_attr >> 16)
@@ -240,20 +241,20 @@ class ReadableZipPath(_ReadablePath):
zip_file.filelist = ZipFileList(zip_file)
def __hash__(self):
- return hash((str(self), self.zip_file))
+ return hash((vfspath(self), self.zip_file))
def __eq__(self, other):
if not isinstance(other, ReadableZipPath):
return NotImplemented
- return str(self) == str(other) and self.zip_file is other.zip_file
+ return vfspath(self) == vfspath(other) and self.zip_file is other.zip_file
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r}, zip_file={self.zip_file!r})'
+ return f'{type(self).__name__}({vfspath(self)!r}, zip_file={self.zip_file!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments, zip_file=self.zip_file)
@@ -261,7 +262,7 @@ class ReadableZipPath(_ReadablePath):
@property
def info(self):
tree = self.zip_file.filelist.tree
- return tree.resolve(str(self), follow_symlinks=False)
+ return tree.resolve(vfspath(self), follow_symlinks=False)
def __open_rb__(self, buffering=-1):
info = self.info.resolve()
@@ -301,36 +302,36 @@ class WritableZipPath(_WritablePath):
self.zip_file = zip_file
def __hash__(self):
- return hash((str(self), self.zip_file))
+ return hash((vfspath(self), self.zip_file))
def __eq__(self, other):
if not isinstance(other, WritableZipPath):
return NotImplemented
- return str(self) == str(other) and self.zip_file is other.zip_file
+ return vfspath(self) == vfspath(other) and self.zip_file is other.zip_file
- def __str__(self):
+ def __vfspath__(self):
if not self._segments:
return ''
return self.parser.join(*self._segments)
def __repr__(self):
- return f'{type(self).__name__}({str(self)!r}, zip_file={self.zip_file!r})'
+ return f'{type(self).__name__}({vfspath(self)!r}, zip_file={self.zip_file!r})'
def with_segments(self, *pathsegments):
return type(self)(*pathsegments, zip_file=self.zip_file)
def __open_wb__(self, buffering=-1):
- return self.zip_file.open(str(self), 'w')
+ return self.zip_file.open(vfspath(self), 'w')
def mkdir(self, mode=0o777):
- zinfo = zipfile.ZipInfo(str(self) + '/')
+ zinfo = zipfile.ZipInfo(vfspath(self) + '/')
zinfo.external_attr |= stat.S_IFDIR << 16
zinfo.external_attr |= stat.FILE_ATTRIBUTE_DIRECTORY
self.zip_file.writestr(zinfo, '')
def symlink_to(self, target, target_is_directory=False):
- zinfo = zipfile.ZipInfo(str(self))
+ zinfo = zipfile.ZipInfo(vfspath(self))
zinfo.external_attr = stat.S_IFLNK << 16
if target_is_directory:
zinfo.external_attr |= 0x10
- self.zip_file.writestr(zinfo, str(target))
+ self.zip_file.writestr(zinfo, vfspath(target))
diff --git a/Lib/test/test_pathlib/test_join_windows.py b/Lib/test/test_pathlib/test_join_windows.py
index 2cc634f25ef..f30c80605f7 100644
--- a/Lib/test/test_pathlib/test_join_windows.py
+++ b/Lib/test/test_pathlib/test_join_windows.py
@@ -8,6 +8,11 @@ import unittest
from .support import is_pypi
from .support.lexical_path import LexicalWindowsPath
+if is_pypi:
+ from pathlib_abc import vfspath
+else:
+ from pathlib._os import vfspath
+
class JoinTestBase:
def test_join(self):
@@ -70,17 +75,17 @@ class JoinTestBase:
self.assertEqual(p / './dd:s', P(r'C:/a/b\./dd:s'))
self.assertEqual(p / 'E:d:s', P('E:d:s'))
- def test_str(self):
+ def test_vfspath(self):
p = self.cls(r'a\b\c')
- self.assertEqual(str(p), 'a\\b\\c')
+ self.assertEqual(vfspath(p), 'a\\b\\c')
p = self.cls(r'c:\a\b\c')
- self.assertEqual(str(p), 'c:\\a\\b\\c')
+ self.assertEqual(vfspath(p), 'c:\\a\\b\\c')
p = self.cls('\\\\a\\b\\')
- self.assertEqual(str(p), '\\\\a\\b\\')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\')
p = self.cls(r'\\a\b\c')
- self.assertEqual(str(p), '\\\\a\\b\\c')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\c')
p = self.cls(r'\\a\b\c\d')
- self.assertEqual(str(p), '\\\\a\\b\\c\\d')
+ self.assertEqual(vfspath(p), '\\\\a\\b\\c\\d')
def test_parts(self):
P = self.cls
diff --git a/Lib/test/test_pathlib/test_pathlib.py b/Lib/test/test_pathlib/test_pathlib.py
index 41a79d0dceb..b2e2cdb3338 100644
--- a/Lib/test/test_pathlib/test_pathlib.py
+++ b/Lib/test/test_pathlib/test_pathlib.py
@@ -16,10 +16,11 @@ from unittest import mock
from urllib.request import pathname2url
from test.support import import_helper
+from test.support import cpython_only
from test.support import is_emscripten, is_wasi
from test.support import infinite_recursion
from test.support import os_helper
-from test.support.os_helper import TESTFN, FakePath
+from test.support.os_helper import TESTFN, FS_NONASCII, FakePath
try:
import fcntl
except ImportError:
@@ -76,8 +77,14 @@ def needs_symlinks(fn):
class UnsupportedOperationTest(unittest.TestCase):
def test_is_notimplemented(self):
- self.assertTrue(issubclass(pathlib.UnsupportedOperation, NotImplementedError))
- self.assertTrue(isinstance(pathlib.UnsupportedOperation(), NotImplementedError))
+ self.assertIsSubclass(pathlib.UnsupportedOperation, NotImplementedError)
+ self.assertIsInstance(pathlib.UnsupportedOperation(), NotImplementedError)
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports("pathlib", {"shutil"})
#
@@ -293,8 +300,8 @@ class PurePathTest(unittest.TestCase):
clsname = p.__class__.__name__
r = repr(p)
# The repr() is in the form ClassName("forward-slashes path").
- self.assertTrue(r.startswith(clsname + '('), r)
- self.assertTrue(r.endswith(')'), r)
+ self.assertStartsWith(r, clsname + '(')
+ self.assertEndsWith(r, ')')
inner = r[len(clsname) + 1 : -1]
self.assertEqual(eval(inner), p.as_posix())
@@ -763,12 +770,16 @@ class PurePathTest(unittest.TestCase):
self.assertEqual(self.make_uri(P('c:/')), 'file:///c:/')
self.assertEqual(self.make_uri(P('c:/a/b.c')), 'file:///c:/a/b.c')
self.assertEqual(self.make_uri(P('c:/a/b%#c')), 'file:///c:/a/b%25%23c')
- self.assertEqual(self.make_uri(P('c:/a/b\xe9')), 'file:///c:/a/b%C3%A9')
self.assertEqual(self.make_uri(P('//some/share/')), 'file://some/share/')
self.assertEqual(self.make_uri(P('//some/share/a/b.c')),
'file://some/share/a/b.c')
- self.assertEqual(self.make_uri(P('//some/share/a/b%#c\xe9')),
- 'file://some/share/a/b%25%23c%C3%A9')
+
+ from urllib.parse import quote_from_bytes
+ QUOTED_FS_NONASCII = quote_from_bytes(os.fsencode(FS_NONASCII))
+ self.assertEqual(self.make_uri(P('c:/a/b' + FS_NONASCII)),
+ 'file:///c:/a/b' + QUOTED_FS_NONASCII)
+ self.assertEqual(self.make_uri(P('//some/share/a/b%#c' + FS_NONASCII)),
+ 'file://some/share/a/b%25%23c' + QUOTED_FS_NONASCII)
@needs_windows
def test_ordering_windows(self):
@@ -2943,7 +2954,13 @@ class PathTest(PurePathTest):
else:
# ".." segments are normalized first on Windows, so this path is stat()able.
self.assertEqual(set(p.glob("xyzzy/..")), { P(self.base, "xyzzy", "..") })
- self.assertEqual(set(p.glob("/".join([".."] * 50))), { P(self.base, *[".."] * 50)})
+ if sys.platform == "emscripten":
+ # Emscripten will return ELOOP if there are 49 or more ..'s.
+ # Can remove when https://github.com/emscripten-core/emscripten/pull/24591 is merged.
+ NDOTDOTS = 48
+ else:
+ NDOTDOTS = 50
+ self.assertEqual(set(p.glob("/".join([".."] * NDOTDOTS))), { P(self.base, *[".."] * NDOTDOTS)})
def test_glob_inaccessible(self):
P = self.cls
@@ -3290,7 +3307,6 @@ class PathTest(PurePathTest):
self.assertEqual(P.from_uri('file:////foo/bar'), P('//foo/bar'))
self.assertEqual(P.from_uri('file://localhost/foo/bar'), P('/foo/bar'))
if not is_wasi:
- self.assertEqual(P.from_uri('file://127.0.0.1/foo/bar'), P('/foo/bar'))
self.assertEqual(P.from_uri(f'file://{socket.gethostname()}/foo/bar'),
P('/foo/bar'))
self.assertRaises(ValueError, P.from_uri, 'foo/bar')
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index be365a5a3dd..6b74e21ad73 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -1,7 +1,9 @@
# A test suite for pdb; not very comprehensive at the moment.
+import _colorize
import doctest
import gc
+import io
import os
import pdb
import sys
@@ -18,7 +20,7 @@ from asyncio.events import _set_event_loop_policy
from contextlib import ExitStack, redirect_stdout
from io import StringIO
from test import support
-from test.support import force_not_colorized, has_socket_support, os_helper
+from test.support import has_socket_support, os_helper
from test.support.import_helper import import_module
from test.support.pty_helper import run_pty, FakeInput
from test.support.script_helper import kill_python
@@ -3446,6 +3448,7 @@ def test_pdb_issue_gh_65052():
"""
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestCase(unittest.TestCase):
def tearDown(self):
@@ -3740,7 +3743,6 @@ def bœr():
self.assertNotIn(b'Error', stdout,
"Got an error running test script under PDB")
- @force_not_colorized
def test_issue16180(self):
# A syntax error in the debuggee.
script = "def f: pass\n"
@@ -3754,7 +3756,6 @@ def bœr():
'Fail to handle a syntax error in the debuggee.'
.format(expected, stderr))
- @force_not_colorized
def test_issue84583(self):
# A syntax error from ast.literal_eval should not make pdb exit.
script = "import ast; ast.literal_eval('')\n"
@@ -4688,6 +4689,40 @@ class PdbTestInline(unittest.TestCase):
self.assertIn("42", stdout)
+@support.force_colorized_test_class
+class PdbTestColorize(unittest.TestCase):
+ def setUp(self):
+ self._original_can_colorize = _colorize.can_colorize
+ # Force colorize to be enabled because we are sending data
+ # to a StringIO
+ _colorize.can_colorize = lambda *args, **kwargs: True
+
+ def tearDown(self):
+ _colorize.can_colorize = self._original_can_colorize
+
+ def test_code_display(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=False)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ def test_stack_entry(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['w', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+
@support.force_not_colorized_test_class
@support.requires_subprocess()
class TestREPLSession(unittest.TestCase):
@@ -4711,9 +4746,12 @@ class TestREPLSession(unittest.TestCase):
self.assertEqual(p.returncode, 0)
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestReadline(unittest.TestCase):
- def setUpClass():
+
+ @classmethod
+ def setUpClass(cls):
# Ensure that the readline module is loaded
# If this fails, the test is skipped because SkipTest will be raised
readline = import_module('readline')
@@ -4812,14 +4850,37 @@ class PdbTestReadline(unittest.TestCase):
self.assertIn(b'I love Python', output)
+ @unittest.skipIf(sys.platform.startswith('freebsd'),
+ '\\x08 is not interpreted as backspace on FreeBSD')
+ def test_multiline_auto_indent(self):
+ script = textwrap.dedent("""
+ import pdb; pdb.Pdb().set_trace()
+ """)
+
+ input = b"def f(x):\n"
+ input += b"if x > 0:\n"
+ input += b"x += 1\n"
+ input += b"return x\n"
+ # We need to do backspaces to remove the auto-indentation
+ input += b"\x08\x08\x08\x08else:\n"
+ input += b"return -x\n"
+ input += b"\n"
+ input += b"f(-21-21)\n"
+ input += b"c\n"
+
+ output = run_pty(script, input)
+
+ self.assertIn(b'42', output)
+
def test_multiline_completion(self):
script = textwrap.dedent("""
import pdb; pdb.Pdb().set_trace()
""")
input = b"def func():\n"
- # Complete: \treturn 40 + 2
- input += b"\tret\t 40 + 2\n"
+ # Auto-indent
+ # Complete: return 40 + 2
+ input += b"ret\t 40 + 2\n"
input += b"\n"
# Complete: func()
input += b"fun\t()\n"
@@ -4829,6 +4890,8 @@ class PdbTestReadline(unittest.TestCase):
self.assertIn(b'42', output)
+ @unittest.skipIf(sys.platform.startswith('freebsd'),
+ '\\x08 is not interpreted as backspace on FreeBSD')
def test_multiline_indent_completion(self):
script = textwrap.dedent("""
import pdb; pdb.Pdb().set_trace()
@@ -4839,12 +4902,13 @@ class PdbTestReadline(unittest.TestCase):
# if the completion is not working as expected
input = textwrap.dedent("""\
def func():
- \ta = 1
- \ta += 1
- \ta += 1
- \tif a > 0:
- a += 1
- \t\treturn a
+ a = 1
+ \x08\ta += 1
+ \x08\x08\ta += 1
+ \x08\x08\x08\ta += 1
+ \x08\x08\x08\x08\tif a > 0:
+ a += 1
+ \x08\x08\x08\x08return a
func()
c
@@ -4852,7 +4916,7 @@ class PdbTestReadline(unittest.TestCase):
output = run_pty(script, input)
- self.assertIn(b'4', output)
+ self.assertIn(b'5', output)
self.assertNotIn(b'Error', output)
def test_interact_completion(self):
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 565e42b04a6..3d7300e1480 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -1,4 +1,5 @@
import dis
+import gc
from itertools import combinations, product
import opcode
import sys
@@ -11,7 +12,7 @@ except ImportError:
from test import support
from test.support.bytecode_helper import (
- BytecodeTestCase, CfgOptimizationTestCase, CompilationStepTestCase)
+ BytecodeTestCase, CfgOptimizationTestCase)
def compile_pattern_with_fast_locals(pattern):
@@ -315,7 +316,7 @@ class TestTranforms(BytecodeTestCase):
return -(1.0-1.0)
for instr in dis.get_instructions(negzero):
- self.assertFalse(instr.opname.startswith('UNARY_'))
+ self.assertNotStartsWith(instr.opname, 'UNARY_')
self.check_lnotab(negzero)
def test_constant_folding_binop(self):
@@ -717,9 +718,9 @@ class TestTranforms(BytecodeTestCase):
self.assertEqual(format('x = %d!', 1234), 'x = 1234!')
self.assertEqual(format('x = %x!', 1234), 'x = 4d2!')
self.assertEqual(format('x = %f!', 1234), 'x = 1234.000000!')
- self.assertEqual(format('x = %s!', 1234.5678901), 'x = 1234.5678901!')
- self.assertEqual(format('x = %f!', 1234.5678901), 'x = 1234.567890!')
- self.assertEqual(format('x = %d!', 1234.5678901), 'x = 1234!')
+ self.assertEqual(format('x = %s!', 1234.0000625), 'x = 1234.0000625!')
+ self.assertEqual(format('x = %f!', 1234.0000625), 'x = 1234.000063!')
+ self.assertEqual(format('x = %d!', 1234.0000625), 'x = 1234!')
self.assertEqual(format('x = %s%% %%%%', 1234), 'x = 1234% %%')
self.assertEqual(format('x = %s!', '%% %s'), 'x = %% %s!')
self.assertEqual(format('x = %s, y = %d', 12, 34), 'x = 12, y = 34')
@@ -2472,6 +2473,13 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.check(insts, insts)
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("DELETE_FAST", 0, 2),
+ ("POP_TOP", None, 3),
+ ]
+ self.check(insts, insts)
+
def test_unoptimized_if_aliased(self):
insts = [
("LOAD_FAST", 0, 1),
@@ -2606,6 +2614,114 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.cfg_optimization_test(insts, expected, consts=[None])
+ def test_format_simple(self):
+ # FORMAT_SIMPLE will leave its operand on the stack if it's a unicode
+ # object. We treat it conservatively and assume that it always leaves
+ # its operand on the stack.
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("STORE_FAST", 1, 3),
+ ]
+ self.check(insts, insts)
+
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("POP_TOP", None, 3),
+ ]
+ expected = [
+ ("LOAD_FAST_BORROW", 0, 1),
+ ("FORMAT_SIMPLE", None, 2),
+ ("POP_TOP", None, 3),
+ ]
+ self.check(insts, expected)
+
+ def test_set_function_attribute(self):
+ # SET_FUNCTION_ATTRIBUTE leaves the function on the stack
+ insts = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("STORE_FAST", 1, 4),
+ ("LOAD_CONST", 0, 5),
+ ("RETURN_VALUE", None, 6)
+ ]
+ self.cfg_optimization_test(insts, insts, consts=[None])
+
+ insts = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("RETURN_VALUE", None, 4)
+ ]
+ expected = [
+ ("LOAD_CONST", 0, 1),
+ ("LOAD_FAST_BORROW", 0, 2),
+ ("SET_FUNCTION_ATTRIBUTE", 2, 3),
+ ("RETURN_VALUE", None, 4)
+ ]
+ self.cfg_optimization_test(insts, expected, consts=[None])
+
+ def test_get_yield_from_iter(self):
+ # GET_YIELD_FROM_ITER may leave its operand on the stack
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("GET_YIELD_FROM_ITER", None, 2),
+ ("LOAD_CONST", 0, 3),
+ send := self.Label(),
+ ("SEND", end := self.Label(), 5),
+ ("YIELD_VALUE", 1, 6),
+ ("RESUME", 2, 7),
+ ("JUMP", send, 8),
+ end,
+ ("END_SEND", None, 9),
+ ("LOAD_CONST", 0, 10),
+ ("RETURN_VALUE", None, 11),
+ ]
+ self.cfg_optimization_test(insts, insts, consts=[None])
+
+ def test_push_exc_info(self):
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("PUSH_EXC_INFO", None, 2),
+ ]
+ self.check(insts, insts)
+
+ def test_load_special(self):
+ # LOAD_SPECIAL may leave self on the stack
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("LOAD_SPECIAL", 0, 2),
+ ("STORE_FAST", 1, 3),
+ ]
+ self.check(insts, insts)
+
+
+ def test_del_in_finally(self):
+ # This loads `obj` onto the stack, executes `del obj`, then returns the
+ # `obj` from the stack. See gh-133371 for more details.
+ def create_obj():
+ obj = [42]
+ try:
+ return obj
+ finally:
+ del obj
+
+ obj = create_obj()
+ # The crash in the linked issue happens while running GC during
+ # interpreter finalization, so run it here manually.
+ gc.collect()
+ self.assertEqual(obj, [42])
+
+ def test_format_simple_unicode(self):
+ # Repro from gh-134889
+ def f():
+ var = f"{1}"
+ var = f"{var}"
+ return var
+ self.assertEqual(f(), "1")
+
if __name__ == "__main__":
diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py
index 1095e7303c1..aa01a9b8f7e 100644
--- a/Lib/test/test_peg_generator/test_c_parser.py
+++ b/Lib/test/test_peg_generator/test_c_parser.py
@@ -387,10 +387,10 @@ class TestCParser(unittest.TestCase):
test_source = """
stmt = "with (\\n a as b,\\n c as d\\n): pass"
the_ast = parse.parse_string(stmt, mode=1)
- self.assertTrue(ast_dump(the_ast).startswith(
+ self.assertStartsWith(ast_dump(the_ast),
"Module(body=[With(items=[withitem(context_expr=Name(id='a', ctx=Load()), optional_vars=Name(id='b', ctx=Store())), "
"withitem(context_expr=Name(id='c', ctx=Load()), optional_vars=Name(id='d', ctx=Store()))]"
- ))
+ )
"""
self.run_test(grammar_source, test_source)
diff --git a/Lib/test/test_peg_generator/test_pegen.py b/Lib/test/test_peg_generator/test_pegen.py
index d8606521345..d912c558123 100644
--- a/Lib/test/test_peg_generator/test_pegen.py
+++ b/Lib/test/test_peg_generator/test_pegen.py
@@ -91,10 +91,8 @@ class TestPegen(unittest.TestCase):
"""
rules = parse_string(grammar, GrammarParser).rules
self.assertEqual(str(rules["start"]), "start: ','.thing+ NEWLINE")
- self.assertTrue(
- repr(rules["start"]).startswith(
- "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
- )
+ self.assertStartsWith(repr(rules["start"]),
+ "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
)
self.assertEqual(str(rules["thing"]), "thing: NUMBER")
parser_class = make_parser(grammar)
diff --git a/Lib/test/test_perf_profiler.py b/Lib/test/test_perf_profiler.py
index c176e505155..7529c853f9c 100644
--- a/Lib/test/test_perf_profiler.py
+++ b/Lib/test/test_perf_profiler.py
@@ -93,9 +93,7 @@ class TestPerfTrampoline(unittest.TestCase):
perf_line, f"Could not find {expected_symbol} in perf file"
)
perf_addr = perf_line.split(" ")[0]
- self.assertFalse(
- perf_addr.startswith("0x"), "Address should not be prefixed with 0x"
- )
+ self.assertNotStartsWith(perf_addr, "0x")
self.assertTrue(
set(perf_addr).issubset(string.hexdigits),
"Address should contain only hex characters",
@@ -508,9 +506,12 @@ def _is_perf_version_at_least(major, minor):
# The output of perf --version looks like "perf version 6.7-3" but
# it can also be perf version "perf version 5.15.143", or even include
# a commit hash in the version string, like "6.12.9.g242e6068fd5c"
+ #
+ # PermissionError is raised if perf does not exist on the Windows Subsystem
+ # for Linux, see #134987
try:
output = subprocess.check_output(["perf", "--version"], text=True)
- except (subprocess.CalledProcessError, FileNotFoundError):
+ except (subprocess.CalledProcessError, FileNotFoundError, PermissionError):
return False
version = output.split()[2]
version = version.split("-")[0]
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index 296d4b882e1..e2384b33345 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -15,7 +15,8 @@ from textwrap import dedent
import doctest
import unittest
from test import support
-from test.support import import_helper, os_helper
+from test.support import cpython_only, import_helper, os_helper
+from test.support.import_helper import ensure_lazy_imports
from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests
@@ -36,6 +37,12 @@ except ImportError:
has_c_implementation = False
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pickle", {"re"})
+
+
class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
dump = staticmethod(pickle._dump)
dumps = staticmethod(pickle._dumps)
@@ -604,10 +611,10 @@ class CompatPickleTests(unittest.TestCase):
with self.subTest(((module3, name3), (module2, name2))):
if (module2, name2) == ('exceptions', 'OSError'):
attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, OSError))
+ self.assertIsSubclass(attr, OSError)
elif (module2, name2) == ('exceptions', 'ImportError'):
attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, ImportError))
+ self.assertIsSubclass(attr, ImportError)
else:
module, name = mapping(module2, name2)
if module3[:1] != '_':
@@ -745,6 +752,7 @@ class CommandLineTest(unittest.TestCase):
expect = self.text_normalize(expect)
self.assertListEqual(res.splitlines(), expect.splitlines())
+ @support.force_not_colorized
def test_unknown_flag(self):
stderr = io.StringIO()
with self.assertRaises(SystemExit):
diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py
index 6ba630ad527..479649053ab 100644
--- a/Lib/test/test_platform.py
+++ b/Lib/test/test_platform.py
@@ -1,5 +1,8 @@
-import os
+import contextlib
import copy
+import io
+import itertools
+import os
import pickle
import platform
import subprocess
@@ -130,6 +133,22 @@ class PlatformTest(unittest.TestCase):
for terse in (False, True):
res = platform.platform(aliased, terse)
+ def test__platform(self):
+ for src, res in [
+ ('foo bar', 'foo_bar'),
+ (
+ '1/2\\3:4;5"6(7)8(7)6"5;4:3\\2/1',
+ '1-2-3-4-5-6-7-8-7-6-5-4-3-2-1'
+ ),
+ ('--', ''),
+ ('-f', '-f'),
+ ('-foo----', '-foo'),
+ ('--foo---', '-foo'),
+ ('---foo--', '-foo'),
+ ]:
+ with self.subTest(src=src):
+ self.assertEqual(platform._platform(src), res)
+
def test_system(self):
res = platform.system()
@@ -380,15 +399,6 @@ class PlatformTest(unittest.TestCase):
finally:
platform._uname_cache = None
- def test_java_ver(self):
- import re
- msg = re.escape(
- "'java_ver' is deprecated and slated for removal in Python 3.15"
- )
- with self.assertWarnsRegex(DeprecationWarning, msg):
- res = platform.java_ver()
- self.assertEqual(len(res), 4)
-
@unittest.skipUnless(support.MS_WINDOWS, 'This test only makes sense on Windows')
def test_win32_ver(self):
release1, version1, csd1, ptype1 = 'a', 'b', 'c', 'd'
@@ -407,7 +417,7 @@ class PlatformTest(unittest.TestCase):
for v in version.split('.'):
int(v) # should not fail
if csd:
- self.assertTrue(csd.startswith('SP'), msg=csd)
+ self.assertStartsWith(csd, 'SP')
if ptype:
if os.cpu_count() > 1:
self.assertIn('Multiprocessor', ptype)
@@ -741,5 +751,66 @@ class PlatformTest(unittest.TestCase):
self.assertEqual(len(info["SPECIALS"]), 5)
+class CommandLineTest(unittest.TestCase):
+ def setUp(self):
+ platform.invalidate_caches()
+ self.addCleanup(platform.invalidate_caches)
+
+ def invoke_platform(self, *flags):
+ output = io.StringIO()
+ with contextlib.redirect_stdout(output):
+ platform._main(args=flags)
+ return output.getvalue()
+
+ def test_unknown_flag(self):
+ with self.assertRaises(SystemExit):
+ output = io.StringIO()
+ # suppress argparse error message
+ with contextlib.redirect_stderr(output):
+ _ = self.invoke_platform('--unknown')
+ self.assertStartsWith(output, "usage: ")
+
+ def test_invocation(self):
+ flags = (
+ "--terse", "--nonaliased", "terse", "nonaliased"
+ )
+
+ for r in range(len(flags) + 1):
+ for combination in itertools.combinations(flags, r):
+ self.invoke_platform(*combination)
+
+ def test_arg_parsing(self):
+ # For backwards compatibility, the `aliased` and `terse` parameters are
+ # computed based on a combination of positional arguments and flags.
+ #
+ # Test that the arguments are correctly passed to the underlying
+ # `platform.platform()` call.
+ options = (
+ (["--nonaliased"], False, False),
+ (["nonaliased"], False, False),
+ (["--terse"], True, True),
+ (["terse"], True, True),
+ (["nonaliased", "terse"], False, True),
+ (["--nonaliased", "terse"], False, True),
+ (["--terse", "nonaliased"], False, True),
+ )
+
+ for flags, aliased, terse in options:
+ with self.subTest(flags=flags, aliased=aliased, terse=terse):
+ with mock.patch.object(platform, 'platform') as obj:
+ self.invoke_platform(*flags)
+ obj.assert_called_once_with(aliased, terse)
+
+ @support.force_not_colorized
+ def test_help(self):
+ output = io.StringIO()
+
+ with self.assertRaises(SystemExit):
+ with contextlib.redirect_stdout(output):
+ platform._main(args=["--help"])
+
+ self.assertStartsWith(output.getvalue(), "usage:")
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_positional_only_arg.py b/Lib/test/test_positional_only_arg.py
index eea0625012d..e412cb1d58d 100644
--- a/Lib/test/test_positional_only_arg.py
+++ b/Lib/test/test_positional_only_arg.py
@@ -37,8 +37,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "def f(/): pass")
check_syntax_error(self, "def f(*, a, /): pass")
check_syntax_error(self, "def f(*, /, a): pass")
- check_syntax_error(self, "def f(a, /, a): pass", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "def f(a, /, *, a): pass", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "def f(a, /, a): pass", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "def f(a, /, *, a): pass", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "def f(a, b/2, c): pass")
check_syntax_error(self, "def f(a, /, c, /): pass")
check_syntax_error(self, "def f(a, /, c, /, d): pass")
@@ -59,8 +59,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "async def f(/): pass")
check_syntax_error(self, "async def f(*, a, /): pass")
check_syntax_error(self, "async def f(*, /, a): pass")
- check_syntax_error(self, "async def f(a, /, a): pass", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "async def f(a, /, *, a): pass", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "async def f(a, /, a): pass", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "async def f(a, /, *, a): pass", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "async def f(a, b/2, c): pass")
check_syntax_error(self, "async def f(a, /, c, /): pass")
check_syntax_error(self, "async def f(a, /, c, /, d): pass")
@@ -247,8 +247,8 @@ class PositionalOnlyTestCase(unittest.TestCase):
check_syntax_error(self, "lambda /: None")
check_syntax_error(self, "lambda *, a, /: None")
check_syntax_error(self, "lambda *, /, a: None")
- check_syntax_error(self, "lambda a, /, a: None", "duplicate argument 'a' in function definition")
- check_syntax_error(self, "lambda a, /, *, a: None", "duplicate argument 'a' in function definition")
+ check_syntax_error(self, "lambda a, /, a: None", "duplicate parameter 'a' in function definition")
+ check_syntax_error(self, "lambda a, /, *, a: None", "duplicate parameter 'a' in function definition")
check_syntax_error(self, "lambda a, /, b, /: None")
check_syntax_error(self, "lambda a, /, b, /, c: None")
check_syntax_error(self, "lambda a, /, b, /, c, *, d: None")
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index c9cbe1541e7..628920e34b5 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -1107,7 +1107,7 @@ class PosixTester(unittest.TestCase):
def _test_chflags_regular_file(self, chflags_func, target_file, **kwargs):
st = os.stat(target_file)
- self.assertTrue(hasattr(st, 'st_flags'))
+ self.assertHasAttr(st, 'st_flags')
# ZFS returns EOPNOTSUPP when attempting to set flag UF_IMMUTABLE.
flags = st.st_flags | stat.UF_IMMUTABLE
@@ -1143,7 +1143,7 @@ class PosixTester(unittest.TestCase):
def test_lchflags_symlink(self):
testfn_st = os.stat(os_helper.TESTFN)
- self.assertTrue(hasattr(testfn_st, 'st_flags'))
+ self.assertHasAttr(testfn_st, 'st_flags')
self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
@@ -1521,6 +1521,51 @@ class PosixTester(unittest.TestCase):
self.assertEqual(cm.exception.errno, errno.EINVAL)
os.close(os.pidfd_open(os.getpid(), 0))
+ @os_helper.skip_unless_hardlink
+ @os_helper.skip_unless_symlink
+ def test_link_follow_symlinks(self):
+ default_follow = sys.platform.startswith(
+ ('darwin', 'freebsd', 'netbsd', 'openbsd', 'dragonfly', 'sunos5'))
+ default_no_follow = sys.platform.startswith(('win32', 'linux'))
+ orig = os_helper.TESTFN
+ symlink = orig + 'symlink'
+ posix.symlink(orig, symlink)
+ self.addCleanup(os_helper.unlink, symlink)
+
+ with self.subTest('no follow_symlinks'):
+ # no follow_symlinks -> platform depending
+ link = orig + 'link'
+ posix.link(symlink, link)
+ self.addCleanup(os_helper.unlink, link)
+ if os.link in os.supports_follow_symlinks or default_follow:
+ self.assertEqual(posix.lstat(link), posix.lstat(orig))
+ elif default_no_follow:
+ self.assertEqual(posix.lstat(link), posix.lstat(symlink))
+
+ with self.subTest('follow_symlinks=False'):
+ # follow_symlinks=False -> duplicate the symlink itself
+ link = orig + 'link_nofollow'
+ try:
+ posix.link(symlink, link, follow_symlinks=False)
+ except NotImplementedError:
+ if os.link in os.supports_follow_symlinks or default_no_follow:
+ raise
+ else:
+ self.addCleanup(os_helper.unlink, link)
+ self.assertEqual(posix.lstat(link), posix.lstat(symlink))
+
+ with self.subTest('follow_symlinks=True'):
+ # follow_symlinks=True -> duplicate the target file
+ link = orig + 'link_following'
+ try:
+ posix.link(symlink, link, follow_symlinks=True)
+ except NotImplementedError:
+ if os.link in os.supports_follow_symlinks or default_follow:
+ raise
+ else:
+ self.addCleanup(os_helper.unlink, link)
+ self.assertEqual(posix.lstat(link), posix.lstat(orig))
+
# tests for the posix *at functions follow
class TestPosixDirFd(unittest.TestCase):
@@ -2173,12 +2218,12 @@ class TestPosixWeaklinking(unittest.TestCase):
def test_pwritev(self):
self._verify_available("HAVE_PWRITEV")
if self.mac_ver >= (10, 16):
- self.assertTrue(hasattr(os, "pwritev"), "os.pwritev is not available")
- self.assertTrue(hasattr(os, "preadv"), "os.readv is not available")
+ self.assertHasAttr(os, "pwritev")
+ self.assertHasAttr(os, "preadv")
else:
- self.assertFalse(hasattr(os, "pwritev"), "os.pwritev is available")
- self.assertFalse(hasattr(os, "preadv"), "os.readv is available")
+ self.assertNotHasAttr(os, "pwritev")
+ self.assertNotHasAttr(os, "preadv")
def test_stat(self):
self._verify_available("HAVE_FSTATAT")
diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py
index fa19d549c26..21f06712548 100644
--- a/Lib/test/test_posixpath.py
+++ b/Lib/test/test_posixpath.py
@@ -4,12 +4,13 @@ import posixpath
import random
import sys
import unittest
-from posixpath import realpath, abspath, dirname, basename
+from functools import partial
+from posixpath import realpath, abspath, dirname, basename, ALLOW_MISSING
from test import support
from test import test_genericpath
from test.support import import_helper
from test.support import os_helper
-from test.support.os_helper import FakePath
+from test.support.os_helper import FakePath, TESTFN
from unittest import mock
try:
@@ -21,7 +22,7 @@ except ImportError:
# An absolute path to a temporary filename for testing. We can't rely on TESTFN
# being an absolute path, so we need this.
-ABSTFN = abspath(os_helper.TESTFN)
+ABSTFN = abspath(TESTFN)
def skip_if_ABSTFN_contains_backslash(test):
"""
@@ -33,21 +34,16 @@ def skip_if_ABSTFN_contains_backslash(test):
msg = "ABSTFN is not a posix path - tests fail"
return [test, unittest.skip(msg)(test)][found_backslash]
-def safe_rmdir(dirname):
- try:
- os.rmdir(dirname)
- except OSError:
- pass
+
+def _parameterize(*parameters):
+ return support.subTests('kwargs', parameters)
+
class PosixPathTest(unittest.TestCase):
def setUp(self):
- self.tearDown()
-
- def tearDown(self):
for suffix in ["", "1", "2"]:
- os_helper.unlink(os_helper.TESTFN + suffix)
- safe_rmdir(os_helper.TESTFN + suffix)
+ self.assertFalse(posixpath.lexists(ABSTFN + suffix))
def test_join(self):
fn = posixpath.join
@@ -194,25 +190,28 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(posixpath.dirname(b"//foo//bar"), b"//foo")
def test_islink(self):
- self.assertIs(posixpath.islink(os_helper.TESTFN + "1"), False)
- self.assertIs(posixpath.lexists(os_helper.TESTFN + "2"), False)
+ self.assertIs(posixpath.islink(TESTFN + "1"), False)
+ self.assertIs(posixpath.lexists(TESTFN + "2"), False)
- with open(os_helper.TESTFN + "1", "wb") as f:
+ self.addCleanup(os_helper.unlink, TESTFN + "1")
+ with open(TESTFN + "1", "wb") as f:
f.write(b"foo")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "1"), False)
+ self.assertIs(posixpath.islink(TESTFN + "1"), False)
if os_helper.can_symlink():
- os.symlink(os_helper.TESTFN + "1", os_helper.TESTFN + "2")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "2"), True)
- os.remove(os_helper.TESTFN + "1")
- self.assertIs(posixpath.islink(os_helper.TESTFN + "2"), True)
- self.assertIs(posixpath.exists(os_helper.TESTFN + "2"), False)
- self.assertIs(posixpath.lexists(os_helper.TESTFN + "2"), True)
-
- self.assertIs(posixpath.islink(os_helper.TESTFN + "\udfff"), False)
- self.assertIs(posixpath.islink(os.fsencode(os_helper.TESTFN) + b"\xff"), False)
- self.assertIs(posixpath.islink(os_helper.TESTFN + "\x00"), False)
- self.assertIs(posixpath.islink(os.fsencode(os_helper.TESTFN) + b"\x00"), False)
+ self.addCleanup(os_helper.unlink, TESTFN + "2")
+ os.symlink(TESTFN + "1", TESTFN + "2")
+ self.assertIs(posixpath.islink(TESTFN + "2"), True)
+ os.remove(TESTFN + "1")
+ self.assertIs(posixpath.islink(TESTFN + "2"), True)
+ self.assertIs(posixpath.exists(TESTFN + "2"), False)
+ self.assertIs(posixpath.lexists(TESTFN + "2"), True)
+
+ def test_islink_invalid_paths(self):
+ self.assertIs(posixpath.islink(TESTFN + "\udfff"), False)
+ self.assertIs(posixpath.islink(os.fsencode(TESTFN) + b"\xff"), False)
+ self.assertIs(posixpath.islink(TESTFN + "\x00"), False)
+ self.assertIs(posixpath.islink(os.fsencode(TESTFN) + b"\x00"), False)
def test_ismount(self):
self.assertIs(posixpath.ismount("/"), True)
@@ -227,8 +226,9 @@ class PosixPathTest(unittest.TestCase):
os.mkdir(ABSTFN)
self.assertIs(posixpath.ismount(ABSTFN), False)
finally:
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
+ def test_ismount_invalid_paths(self):
self.assertIs(posixpath.ismount('/\udfff'), False)
self.assertIs(posixpath.ismount(b'/\xff'), False)
self.assertIs(posixpath.ismount('/\x00'), False)
@@ -241,7 +241,7 @@ class PosixPathTest(unittest.TestCase):
os.symlink("/", ABSTFN)
self.assertIs(posixpath.ismount(ABSTFN), False)
finally:
- os.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN)
@unittest.skipIf(posix is None, "Test requires posix module")
def test_ismount_different_device(self):
@@ -448,32 +448,35 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(result, expected)
@skip_if_ABSTFN_contains_backslash
- def test_realpath_curdir(self):
- self.assertEqual(realpath('.'), os.getcwd())
- self.assertEqual(realpath('./.'), os.getcwd())
- self.assertEqual(realpath('/'.join(['.'] * 100)), os.getcwd())
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_curdir(self, kwargs):
+ self.assertEqual(realpath('.', **kwargs), os.getcwd())
+ self.assertEqual(realpath('./.', **kwargs), os.getcwd())
+ self.assertEqual(realpath('/'.join(['.'] * 100), **kwargs), os.getcwd())
- self.assertEqual(realpath(b'.'), os.getcwdb())
- self.assertEqual(realpath(b'./.'), os.getcwdb())
- self.assertEqual(realpath(b'/'.join([b'.'] * 100)), os.getcwdb())
+ self.assertEqual(realpath(b'.', **kwargs), os.getcwdb())
+ self.assertEqual(realpath(b'./.', **kwargs), os.getcwdb())
+ self.assertEqual(realpath(b'/'.join([b'.'] * 100), **kwargs), os.getcwdb())
@skip_if_ABSTFN_contains_backslash
- def test_realpath_pardir(self):
- self.assertEqual(realpath('..'), dirname(os.getcwd()))
- self.assertEqual(realpath('../..'), dirname(dirname(os.getcwd())))
- self.assertEqual(realpath('/'.join(['..'] * 100)), '/')
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_pardir(self, kwargs):
+ self.assertEqual(realpath('..', **kwargs), dirname(os.getcwd()))
+ self.assertEqual(realpath('../..', **kwargs), dirname(dirname(os.getcwd())))
+ self.assertEqual(realpath('/'.join(['..'] * 100), **kwargs), '/')
- self.assertEqual(realpath(b'..'), dirname(os.getcwdb()))
- self.assertEqual(realpath(b'../..'), dirname(dirname(os.getcwdb())))
- self.assertEqual(realpath(b'/'.join([b'..'] * 100)), b'/')
+ self.assertEqual(realpath(b'..', **kwargs), dirname(os.getcwdb()))
+ self.assertEqual(realpath(b'../..', **kwargs), dirname(dirname(os.getcwdb())))
+ self.assertEqual(realpath(b'/'.join([b'..'] * 100), **kwargs), b'/')
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_basic(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_basic(self, kwargs):
# Basic operation.
try:
os.symlink(ABSTFN+"1", ABSTFN)
- self.assertEqual(realpath(ABSTFN), ABSTFN+"1")
+ self.assertEqual(realpath(ABSTFN, **kwargs), ABSTFN+"1")
finally:
os_helper.unlink(ABSTFN)
@@ -489,23 +492,121 @@ class PosixPathTest(unittest.TestCase):
finally:
os_helper.unlink(ABSTFN)
+ def test_realpath_invalid_paths(self):
+ path = '/\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/x\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/nonexistent/x\x00'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = '/\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(ValueError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+
+ path = '/nonexistent/x\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+ path = b'/nonexistent/x\x00/..'
+ self.assertRaises(ValueError, realpath, path, strict=False)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertRaises(ValueError, realpath, path, strict=ALLOW_MISSING)
+
+ path = '/\udfff'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/\udfff'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), path)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ path = '/\udfff/..'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), '/')
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), '/')
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ path = '/nonexistent/\udfff/..'
+ if sys.platform == 'win32':
+ self.assertEqual(realpath(path, strict=False), '/nonexistent')
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), '/nonexistent')
+ else:
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeEncodeError, realpath, path, strict=ALLOW_MISSING)
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+
+ path = b'/\xff'
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=True)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertEqual(realpath(path, strict=False), path)
+ if support.is_wasi:
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+ self.assertEqual(realpath(path, strict=ALLOW_MISSING), path)
+ path = b'/nonexistent/\xff'
+ if sys.platform == 'win32':
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=False)
+ self.assertRaises(UnicodeDecodeError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertEqual(realpath(path, strict=False), path)
+ if support.is_wasi:
+ self.assertRaises(OSError, realpath, path, strict=True)
+ self.assertRaises(OSError, realpath, path, strict=ALLOW_MISSING)
+ else:
+ self.assertRaises(FileNotFoundError, realpath, path, strict=True)
+
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_relative(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_relative(self, kwargs):
try:
os.symlink(posixpath.relpath(ABSTFN+"1"), ABSTFN)
- self.assertEqual(realpath(ABSTFN), ABSTFN+"1")
+ self.assertEqual(realpath(ABSTFN, **kwargs), ABSTFN+"1")
finally:
os_helper.unlink(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_missing_pardir(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_missing_pardir(self, kwargs):
try:
- os.symlink(os_helper.TESTFN + "1", os_helper.TESTFN)
- self.assertEqual(realpath("nonexistent/../" + os_helper.TESTFN), ABSTFN + "1")
+ os.symlink(TESTFN + "1", TESTFN)
+ self.assertEqual(
+ realpath("nonexistent/../" + TESTFN, **kwargs), ABSTFN + "1")
finally:
- os_helper.unlink(os_helper.TESTFN)
+ os_helper.unlink(TESTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -550,37 +651,38 @@ class PosixPathTest(unittest.TestCase):
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_symlink_loops_strict(self):
+ @_parameterize({'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_symlink_loops_strict(self, kwargs):
# Bug #43757, raise OSError if we get into an infinite symlink loop in
- # strict mode.
+ # the strict modes.
try:
os.symlink(ABSTFN, ABSTFN)
- self.assertRaises(OSError, realpath, ABSTFN, strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN, **kwargs)
os.symlink(ABSTFN+"1", ABSTFN+"2")
os.symlink(ABSTFN+"2", ABSTFN+"1")
- self.assertRaises(OSError, realpath, ABSTFN+"1", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"2", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"1", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"2", **kwargs)
- self.assertRaises(OSError, realpath, ABSTFN+"1/x", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"1/..", strict=True)
- self.assertRaises(OSError, realpath, ABSTFN+"1/../x", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/x", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/..", **kwargs)
+ self.assertRaises(OSError, realpath, ABSTFN+"1/../x", **kwargs)
os.symlink(ABSTFN+"x", ABSTFN+"y")
self.assertRaises(OSError, realpath,
- ABSTFN+"1/../" + basename(ABSTFN) + "y", strict=True)
+ ABSTFN+"1/../" + basename(ABSTFN) + "y", **kwargs)
self.assertRaises(OSError, realpath,
- ABSTFN+"1/../" + basename(ABSTFN) + "1", strict=True)
+ ABSTFN+"1/../" + basename(ABSTFN) + "1", **kwargs)
os.symlink(basename(ABSTFN) + "a/b", ABSTFN+"a")
- self.assertRaises(OSError, realpath, ABSTFN+"a", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"a", **kwargs)
os.symlink("../" + basename(dirname(ABSTFN)) + "/" +
basename(ABSTFN) + "c", ABSTFN+"c")
- self.assertRaises(OSError, realpath, ABSTFN+"c", strict=True)
+ self.assertRaises(OSError, realpath, ABSTFN+"c", **kwargs)
# Test using relative path as well.
with os_helper.change_cwd(dirname(ABSTFN)):
- self.assertRaises(OSError, realpath, basename(ABSTFN), strict=True)
+ self.assertRaises(OSError, realpath, basename(ABSTFN), **kwargs)
finally:
os_helper.unlink(ABSTFN)
os_helper.unlink(ABSTFN+"1")
@@ -591,28 +693,30 @@ class PosixPathTest(unittest.TestCase):
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_repeated_indirect_symlinks(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_repeated_indirect_symlinks(self, kwargs):
# Issue #6975.
try:
os.mkdir(ABSTFN)
os.symlink('../' + basename(ABSTFN), ABSTFN + '/self')
os.symlink('self/self/self', ABSTFN + '/link')
- self.assertEqual(realpath(ABSTFN + '/link'), ABSTFN)
+ self.assertEqual(realpath(ABSTFN + '/link', **kwargs), ABSTFN)
finally:
os_helper.unlink(ABSTFN + '/self')
os_helper.unlink(ABSTFN + '/link')
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_deep_recursion(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_deep_recursion(self, kwargs):
depth = 10
try:
os.mkdir(ABSTFN)
for i in range(depth):
os.symlink('/'.join(['%d' % i] * 10), ABSTFN + '/%d' % (i + 1))
os.symlink('.', ABSTFN + '/0')
- self.assertEqual(realpath(ABSTFN + '/%d' % depth), ABSTFN)
+ self.assertEqual(realpath(ABSTFN + '/%d' % depth, **kwargs), ABSTFN)
# Test using relative path as well.
with os_helper.change_cwd(ABSTFN):
@@ -620,11 +724,12 @@ class PosixPathTest(unittest.TestCase):
finally:
for i in range(depth + 1):
os_helper.unlink(ABSTFN + '/%d' % i)
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_parents(self):
+ @_parameterize({}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_parents(self, kwargs):
# We also need to resolve any symlinks in the parents of a relative
# path passed to realpath. E.g.: current working directory is
# /usr/doc with 'doc' being a symlink to /usr/share/doc. We call
@@ -635,15 +740,17 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "/y", ABSTFN + "/k")
with os_helper.change_cwd(ABSTFN + "/k"):
- self.assertEqual(realpath("a"), ABSTFN + "/y/a")
+ self.assertEqual(realpath("a", **kwargs),
+ ABSTFN + "/y/a")
finally:
os_helper.unlink(ABSTFN + "/k")
- safe_rmdir(ABSTFN + "/y")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/y")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_before_normalizing(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_before_normalizing(self, kwargs):
# Bug #990669: Symbolic links should be resolved before we
# normalize the path. E.g.: if we have directories 'a', 'k' and 'y'
# in the following hierarchy:
@@ -658,20 +765,21 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "/k/y", ABSTFN + "/link-y")
# Absolute path.
- self.assertEqual(realpath(ABSTFN + "/link-y/.."), ABSTFN + "/k")
+ self.assertEqual(realpath(ABSTFN + "/link-y/..", **kwargs), ABSTFN + "/k")
# Relative path.
with os_helper.change_cwd(dirname(ABSTFN)):
- self.assertEqual(realpath(basename(ABSTFN) + "/link-y/.."),
+ self.assertEqual(realpath(basename(ABSTFN) + "/link-y/..", **kwargs),
ABSTFN + "/k")
finally:
os_helper.unlink(ABSTFN + "/link-y")
- safe_rmdir(ABSTFN + "/k/y")
- safe_rmdir(ABSTFN + "/k")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/k/y")
+ os_helper.rmdir(ABSTFN + "/k")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
- def test_realpath_resolve_first(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_resolve_first(self, kwargs):
# Bug #1213894: The first component of the path, if not absolute,
# must be resolved too.
@@ -681,12 +789,12 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN, ABSTFN + "link")
with os_helper.change_cwd(dirname(ABSTFN)):
base = basename(ABSTFN)
- self.assertEqual(realpath(base + "link"), ABSTFN)
- self.assertEqual(realpath(base + "link/k"), ABSTFN + "/k")
+ self.assertEqual(realpath(base + "link", **kwargs), ABSTFN)
+ self.assertEqual(realpath(base + "link/k", **kwargs), ABSTFN + "/k")
finally:
os_helper.unlink(ABSTFN + "link")
- safe_rmdir(ABSTFN + "/k")
- safe_rmdir(ABSTFN)
+ os_helper.rmdir(ABSTFN + "/k")
+ os_helper.rmdir(ABSTFN)
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -700,27 +808,95 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath(ABSTFN + '/foo'), ABSTFN + '/foo')
self.assertEqual(realpath(ABSTFN + '/../foo'), dirname(ABSTFN) + '/foo')
self.assertEqual(realpath(ABSTFN + '/foo/..'), ABSTFN)
+ finally:
+ os.chmod(ABSTFN, 0o755, follow_symlinks=False)
+ os_helper.unlink(ABSTFN)
+
+ @os_helper.skip_unless_symlink
+ @skip_if_ABSTFN_contains_backslash
+ @unittest.skipIf(os.chmod not in os.supports_follow_symlinks, "Can't set symlink permissions")
+ @unittest.skipIf(sys.platform != "darwin", "only macOS requires read permission to readlink()")
+ @_parameterize({'strict': True}, {'strict': ALLOW_MISSING})
+ def test_realpath_unreadable_symlink_strict(self, kwargs):
+ try:
+ os.symlink(ABSTFN+"1", ABSTFN)
+ os.chmod(ABSTFN, 0o000, follow_symlinks=False)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN, **kwargs)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN + '/foo', **kwargs),
with self.assertRaises(PermissionError):
- realpath(ABSTFN, strict=True)
+ realpath(ABSTFN + '/../foo', **kwargs)
+ with self.assertRaises(PermissionError):
+ realpath(ABSTFN + '/foo/..', **kwargs)
finally:
os.chmod(ABSTFN, 0o755, follow_symlinks=False)
os.unlink(ABSTFN)
@skip_if_ABSTFN_contains_backslash
+ @os_helper.skip_unless_symlink
+ def test_realpath_unreadable_directory(self):
+ try:
+ os.mkdir(ABSTFN)
+ os.mkdir(ABSTFN + '/k')
+ os.chmod(ABSTFN, 0o000)
+ self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN)
+
+ try:
+ os.stat(ABSTFN)
+ except PermissionError:
+ pass
+ else:
+ self.skipTest('Cannot block permissions')
+
+ self.assertEqual(realpath(ABSTFN + '/k', strict=False),
+ ABSTFN + '/k')
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/k',
+ strict=True)
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/k',
+ strict=ALLOW_MISSING)
+
+ self.assertEqual(realpath(ABSTFN + '/missing', strict=False),
+ ABSTFN + '/missing')
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/missing',
+ strict=True)
+ self.assertRaises(PermissionError, realpath, ABSTFN + '/missing',
+ strict=ALLOW_MISSING)
+ finally:
+ os.chmod(ABSTFN, 0o755)
+ os_helper.rmdir(ABSTFN + '/k')
+ os_helper.rmdir(ABSTFN)
+
+ @skip_if_ABSTFN_contains_backslash
def test_realpath_nonterminal_file(self):
try:
with open(ABSTFN, 'w') as f:
f.write('test_posixpath wuz ere')
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN)
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN)
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN)
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN)
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
@@ -733,16 +909,30 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "1", ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN + "1")
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "1")
+ self.assertEqual(realpath(ABSTFN, strict=ALLOW_MISSING), ABSTFN + "1")
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN + "1")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN + "1")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "1/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN + "1")
@os_helper.skip_unless_symlink
@skip_if_ABSTFN_contains_backslash
@@ -754,16 +944,31 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN + "1", ABSTFN)
self.assertEqual(realpath(ABSTFN, strict=False), ABSTFN + "2")
self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "2")
+ self.assertEqual(realpath(ABSTFN, strict=True), ABSTFN + "2")
+
self.assertEqual(realpath(ABSTFN + "/", strict=False), ABSTFN + "2")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/.", strict=False), ABSTFN + "2")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/.",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/..", strict=False), dirname(ABSTFN))
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/..",
+ strict=ALLOW_MISSING)
+
self.assertEqual(realpath(ABSTFN + "/subdir", strict=False), ABSTFN + "2/subdir")
self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir", strict=True)
+ self.assertRaises(NotADirectoryError, realpath, ABSTFN + "/subdir",
+ strict=ALLOW_MISSING)
finally:
os_helper.unlink(ABSTFN)
+ os_helper.unlink(ABSTFN + "1")
+ os_helper.unlink(ABSTFN + "2")
def test_relpath(self):
(real_getcwd, os.getcwd) = (os.getcwd, lambda: r"/home/user/bar")
@@ -889,8 +1094,8 @@ class PathLikeTests(unittest.TestCase):
path = posixpath
def setUp(self):
- self.file_name = os_helper.TESTFN
- self.file_path = FakePath(os_helper.TESTFN)
+ self.file_name = TESTFN
+ self.file_path = FakePath(TESTFN)
self.addCleanup(os_helper.unlink, self.file_name)
with open(self.file_name, 'xb', 0) as file:
file.write(b"test_posixpath.PathLikeTests")
@@ -947,9 +1152,12 @@ class PathLikeTests(unittest.TestCase):
def test_path_abspath(self):
self.assertPathEqual(self.path.abspath)
- def test_path_realpath(self):
+ @_parameterize({}, {'strict': True}, {'strict': ALLOW_MISSING})
+ def test_path_realpath(self, kwargs):
self.assertPathEqual(self.path.realpath)
+ self.assertPathEqual(partial(self.path.realpath, **kwargs))
+
def test_path_relpath(self):
self.assertPathEqual(self.path.relpath)
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
index dfbc2a06e73..41c337ade7e 100644
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -10,6 +10,10 @@ import random
import re
import types
import unittest
+from collections.abc import ItemsView, KeysView, Mapping, MappingView, ValuesView
+
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
# list, tuple and dict subclasses that do or don't overwrite __repr__
class list2(list):
@@ -67,6 +71,14 @@ class dict_custom_repr(dict):
def __repr__(self):
return '*'*len(dict.__repr__(self))
+class mappingview_custom_repr(MappingView):
+ def __repr__(self):
+ return '*'*len(MappingView.__repr__(self))
+
+class keysview_custom_repr(KeysView):
+ def __repr__(self):
+ return '*'*len(KeysView.__repr__(self))
+
@dataclasses.dataclass
class dataclass1:
field1: str
@@ -129,6 +141,10 @@ class QueryTestCase(unittest.TestCase):
self.b = list(range(200))
self.a[-12] = self.b
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pprint", {"dataclasses", "re"})
+
def test_init(self):
pp = pprint.PrettyPrinter()
pp = pprint.PrettyPrinter(indent=4, width=40, depth=5,
@@ -173,10 +189,17 @@ class QueryTestCase(unittest.TestCase):
# Messy dict.
self.d = {}
self.d[0] = self.d[1] = self.d[2] = self.d
+ self.e = {}
+ self.v = ValuesView(self.e)
+ self.m = MappingView(self.e)
+ self.dv = self.e.values()
+ self.e["v"] = self.v
+ self.e["m"] = self.m
+ self.e["dv"] = self.dv
pp = pprint.PrettyPrinter()
- for icky in self.a, self.b, self.d, (self.d, self.d):
+ for icky in self.a, self.b, self.d, (self.d, self.d), self.e, self.v, self.m, self.dv:
self.assertTrue(pprint.isrecursive(icky), "expected isrecursive")
self.assertFalse(pprint.isreadable(icky), "expected not isreadable")
self.assertTrue(pp.isrecursive(icky), "expected isrecursive")
@@ -184,10 +207,11 @@ class QueryTestCase(unittest.TestCase):
# Break the cycles.
self.d.clear()
+ self.e.clear()
del self.a[:]
del self.b[:]
- for safe in self.a, self.b, self.d, (self.d, self.d):
+ for safe in self.a, self.b, self.d, (self.d, self.d), self.e, self.v, self.m, self.dv:
# module-level convenience functions
self.assertFalse(pprint.isrecursive(safe),
"expected not isrecursive for %r" % (safe,))
@@ -230,6 +254,8 @@ class QueryTestCase(unittest.TestCase):
set(), set2(), set3(),
frozenset(), frozenset2(), frozenset3(),
{}, dict2(), dict3(),
+ {}.keys(), {}.values(), {}.items(),
+ MappingView({}), KeysView({}), ItemsView({}), ValuesView({}),
self.assertTrue, pprint,
-6, -6, -6-6j, -1.5, "x", b"x", bytearray(b"x"),
(3,), [3], {3: 6},
@@ -239,6 +265,9 @@ class QueryTestCase(unittest.TestCase):
set({7}), set2({7}), set3({7}),
frozenset({8}), frozenset2({8}), frozenset3({8}),
dict2({5: 6}), dict3({5: 6}),
+ {5: 6}.keys(), {5: 6}.values(), {5: 6}.items(),
+ MappingView({5: 6}), KeysView({5: 6}),
+ ItemsView({5: 6}), ValuesView({5: 6}),
range(10, -11, -1),
True, False, None, ...,
):
@@ -268,6 +297,12 @@ class QueryTestCase(unittest.TestCase):
dict_custom_repr(),
dict_custom_repr({5: 6}),
dict_custom_repr(zip(range(N),range(N))),
+ mappingview_custom_repr({}),
+ mappingview_custom_repr({5: 6}),
+ mappingview_custom_repr(dict(zip(range(N),range(N)))),
+ keysview_custom_repr({}),
+ keysview_custom_repr({5: 6}),
+ keysview_custom_repr(dict(zip(range(N),range(N)))),
):
native = repr(cont)
expected = '*' * len(native)
@@ -296,6 +331,56 @@ class QueryTestCase(unittest.TestCase):
self.assertEqual(pprint.pformat(type(o)), exp)
o = range(100)
+ exp = 'dict_keys([%s])' % ',\n '.join(map(str, o))
+ keys = dict.fromkeys(o).keys()
+ self.assertEqual(pprint.pformat(keys), exp)
+
+ o = range(100)
+ exp = 'dict_values([%s])' % ',\n '.join(map(str, o))
+ values = {v: v for v in o}.values()
+ self.assertEqual(pprint.pformat(values), exp)
+
+ o = range(100)
+ exp = 'dict_items([%s])' % ',\n '.join("(%s, %s)" % (i, i) for i in o)
+ items = {v: v for v in o}.items()
+ self.assertEqual(pprint.pformat(items), exp)
+
+ o = range(100)
+ exp = 'odict_keys([%s])' % ',\n '.join(map(str, o))
+ keys = collections.OrderedDict.fromkeys(o).keys()
+ self.assertEqual(pprint.pformat(keys), exp)
+
+ o = range(100)
+ exp = 'odict_values([%s])' % ',\n '.join(map(str, o))
+ values = collections.OrderedDict({v: v for v in o}).values()
+ self.assertEqual(pprint.pformat(values), exp)
+
+ o = range(100)
+ exp = 'odict_items([%s])' % ',\n '.join("(%s, %s)" % (i, i) for i in o)
+ items = collections.OrderedDict({v: v for v in o}).items()
+ self.assertEqual(pprint.pformat(items), exp)
+
+ o = range(100)
+ exp = 'KeysView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ keys_view = KeysView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(keys_view), exp)
+
+ o = range(100)
+ exp = 'ItemsView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ items_view = ItemsView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(items_view), exp)
+
+ o = range(100)
+ exp = 'MappingView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ mapping_view = MappingView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(mapping_view), exp)
+
+ o = range(100)
+ exp = 'ValuesView({%s})' % (': None,\n '.join(map(str, o)) + ': None')
+ values_view = ValuesView(dict.fromkeys(o))
+ self.assertEqual(pprint.pformat(values_view), exp)
+
+ o = range(100)
exp = '[%s]' % ',\n '.join(map(str, o))
for type in [list, list2]:
self.assertEqual(pprint.pformat(type(o)), exp)
@@ -373,7 +458,7 @@ class QueryTestCase(unittest.TestCase):
return super().__new__(Temperature, celsius_degrees)
def __repr__(self):
kelvin_degrees = self + 273.15
- return f"{kelvin_degrees}°K"
+ return f"{kelvin_degrees:.2f}°K"
self.assertEqual(pprint.pformat(Temperature(1000)), '1273.15°K')
def test_sorted_dict(self):
@@ -418,6 +503,30 @@ OrderedDict([('the', 0),
('a', 6),
('lazy', 7),
('dog', 8)])""")
+ self.assertEqual(pprint.pformat(d.keys(), sort_dicts=False),
+"""\
+odict_keys(['the',
+ 'quick',
+ 'brown',
+ 'fox',
+ 'jumped',
+ 'over',
+ 'a',
+ 'lazy',
+ 'dog'])""")
+ self.assertEqual(pprint.pformat(d.items(), sort_dicts=False),
+"""\
+odict_items([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])""")
+ self.assertEqual(pprint.pformat(d.values(), sort_dicts=False),
+ "odict_values([0, 1, 2, 3, 4, 5, 6, 7, 8])")
def test_mapping_proxy(self):
words = 'the quick brown fox jumped over a lazy dog'.split()
@@ -446,6 +555,152 @@ mappingproxy(OrderedDict([('the', 0),
('lazy', 7),
('dog', 8)]))""")
+ def test_dict_views(self):
+ for dict_class in (dict, collections.OrderedDict, collections.Counter):
+ empty = dict_class({})
+ short = dict_class(dict(zip('edcba', 'edcba')))
+ long = dict_class(dict((chr(x), chr(x)) for x in range(90, 64, -1)))
+ lengths = {"empty": empty, "short": short, "long": long}
+ prefix = "odict" if dict_class is collections.OrderedDict else "dict"
+ for name, d in lengths.items():
+ with self.subTest(length=name, prefix=prefix):
+ is_short = len(d) < 6
+ joiner = ", " if is_short else ",\n "
+ k = d.keys()
+ v = d.values()
+ i = d.items()
+ self.assertEqual(pprint.pformat(k, sort_dicts=True),
+ prefix + "_keys([%s])" %
+ joiner.join(repr(key) for key in sorted(k)))
+ self.assertEqual(pprint.pformat(v, sort_dicts=True),
+ prefix + "_values([%s])" %
+ joiner.join(repr(val) for val in sorted(v)))
+ self.assertEqual(pprint.pformat(i, sort_dicts=True),
+ prefix + "_items([%s])" %
+ joiner.join(repr(item) for item in sorted(i)))
+ self.assertEqual(pprint.pformat(k, sort_dicts=False),
+ prefix + "_keys([%s])" %
+ joiner.join(repr(key) for key in k))
+ self.assertEqual(pprint.pformat(v, sort_dicts=False),
+ prefix + "_values([%s])" %
+ joiner.join(repr(val) for val in v))
+ self.assertEqual(pprint.pformat(i, sort_dicts=False),
+ prefix + "_items([%s])" %
+ joiner.join(repr(item) for item in i))
+
+ def test_abc_views(self):
+ empty = {}
+ short = dict(zip('edcba', 'edcba'))
+ long = dict((chr(x), chr(x)) for x in range(90, 64, -1))
+ lengths = {"empty": empty, "short": short, "long": long}
+ # Test that a subclass that doesn't replace __repr__ works with different lengths
+ class MV(MappingView): pass
+
+ for name, d in lengths.items():
+ with self.subTest(length=name, name="Views"):
+ is_short = len(d) < 6
+ joiner = ", " if is_short else ",\n "
+ i = d.items()
+ s = sorted(i)
+ joined_items = "({%s})" % joiner.join(["%r: %r" % (k, v) for (k, v) in i])
+ sorted_items = "({%s})" % joiner.join(["%r: %r" % (k, v) for (k, v) in s])
+ self.assertEqual(pprint.pformat(KeysView(d), sort_dicts=True),
+ KeysView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(ItemsView(d), sort_dicts=True),
+ ItemsView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(MappingView(d), sort_dicts=True),
+ MappingView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(MV(d), sort_dicts=True),
+ MV.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(ValuesView(d), sort_dicts=True),
+ ValuesView.__name__ + sorted_items)
+ self.assertEqual(pprint.pformat(KeysView(d), sort_dicts=False),
+ KeysView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(ItemsView(d), sort_dicts=False),
+ ItemsView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(MappingView(d), sort_dicts=False),
+ MappingView.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(MV(d), sort_dicts=False),
+ MV.__name__ + joined_items)
+ self.assertEqual(pprint.pformat(ValuesView(d), sort_dicts=False),
+ ValuesView.__name__ + joined_items)
+
+ def test_nested_views(self):
+ d = {1: MappingView({1: MappingView({1: MappingView({1: 2})})})}
+ self.assertEqual(repr(d),
+ "{1: MappingView({1: MappingView({1: MappingView({1: 2})})})}")
+ self.assertEqual(pprint.pformat(d),
+ "{1: MappingView({1: MappingView({1: MappingView({1: 2})})})}")
+ self.assertEqual(pprint.pformat(d, depth=2),
+ "{1: MappingView({1: {...}})}")
+ d = {}
+ d1 = {1: d.values()}
+ d2 = {1: d1.values()}
+ d3 = {1: d2.values()}
+ self.assertEqual(pprint.pformat(d3),
+ "{1: dict_values([dict_values([dict_values([])])])}")
+ self.assertEqual(pprint.pformat(d3, depth=2),
+ "{1: dict_values([{...}])}")
+
+ def test_unorderable_items_views(self):
+ """Check that views with unorderable items have stable sorting."""
+ d = dict((((3+1j), 3), ((1+1j), (1+0j)), (1j, 0j), (500, None), (499, None)))
+ iv = ItemsView(d)
+ self.assertEqual(pprint.pformat(iv),
+ pprint.pformat(iv))
+ self.assertTrue(pprint.pformat(iv).endswith(", 499: None, 500: None})"),
+ pprint.pformat(iv))
+ self.assertEqual(pprint.pformat(d.items()), # Won't be equal unless _safe_tuple
+ pprint.pformat(d.items())) # is used in _safe_repr
+ self.assertTrue(pprint.pformat(d.items()).endswith(", (499, None), (500, None)])"))
+
+ def test_mapping_view_subclass_no_mapping(self):
+ class BMV(MappingView):
+ def __init__(self, d):
+ super().__init__(d)
+ self.mapping = self._mapping
+ del self._mapping
+
+ self.assertRaises(AttributeError, pprint.pformat, BMV({}))
+
+ def test_mapping_subclass_repr(self):
+ """Test that mapping ABC views use their ._mapping's __repr__."""
+ class MyMapping(Mapping):
+ def __init__(self, keys=None):
+ self._keys = {} if keys is None else dict.fromkeys(keys)
+
+ def __getitem__(self, item):
+ return self._keys[item]
+
+ def __len__(self):
+ return len(self._keys)
+
+ def __iter__(self):
+ return iter(self._keys)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}([{', '.join(map(repr, self._keys.keys()))}])"
+
+ m = MyMapping(["test", 1])
+ self.assertEqual(repr(m), "MyMapping(['test', 1])")
+ short_view_repr = "%s(MyMapping(['test', 1]))"
+ self.assertEqual(repr(m.keys()), short_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(m.items()), short_view_repr % "ItemsView")
+ self.assertEqual(pprint.pformat(m.keys()), short_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(MappingView(m)), short_view_repr % "MappingView")
+ self.assertEqual(pprint.pformat(m.values()), short_view_repr % "ValuesView")
+
+ alpha = "abcdefghijklmnopqrstuvwxyz"
+ m = MyMapping(alpha)
+ alpha_repr = ", ".join(map(repr, list(alpha)))
+ long_view_repr = "%%s(MyMapping([%s]))" % alpha_repr
+ self.assertEqual(repr(m), "MyMapping([%s])" % alpha_repr)
+ self.assertEqual(repr(m.keys()), long_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(m.items()), long_view_repr % "ItemsView")
+ self.assertEqual(pprint.pformat(m.keys()), long_view_repr % "KeysView")
+ self.assertEqual(pprint.pformat(MappingView(m)), long_view_repr % "MappingView")
+ self.assertEqual(pprint.pformat(m.values()), long_view_repr % "ValuesView")
+
def test_empty_simple_namespace(self):
ns = types.SimpleNamespace()
formatted = pprint.pformat(ns)
@@ -761,6 +1016,10 @@ frozenset2({0,
'frozenset({' + ','.join(map(repr, skeys)) + '})')
self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys))),
'{' + ','.join('%r:None' % k for k in skeys) + '}')
+ self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys).keys())),
+ 'dict_keys([' + ','.join('%r' % k for k in skeys) + '])')
+ self.assertEqual(clean(pprint.pformat(dict.fromkeys(keys).items())),
+ 'dict_items([' + ','.join('(%r,None)' % k for k in skeys) + '])')
# Issue 10017: TypeError on user-defined types as dict keys.
self.assertEqual(pprint.pformat({Unorderable: 0, 1: 0}),
@@ -1042,6 +1301,66 @@ ChainMap({'a': 6,
('a', 6),
('lazy', 7),
('dog', 8)]))""")
+ self.assertEqual(pprint.pformat(d.keys()),
+"""\
+KeysView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
+ self.assertEqual(pprint.pformat(d.items()),
+ """\
+ItemsView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
+ self.assertEqual(pprint.pformat(d.values()),
+ """\
+ValuesView(ChainMap({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0},
+ OrderedDict([('the', 0),
+ ('quick', 1),
+ ('brown', 2),
+ ('fox', 3),
+ ('jumped', 4),
+ ('over', 5),
+ ('a', 6),
+ ('lazy', 7),
+ ('dog', 8)])))""")
def test_deque(self):
d = collections.deque()
@@ -1089,6 +1408,36 @@ deque([('brown', 2),
'over': 5,
'quick': 1,
'the': 0}""")
+ self.assertEqual(pprint.pformat(d.keys()), """\
+KeysView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
+ self.assertEqual(pprint.pformat(d.items()), """\
+ItemsView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
+ self.assertEqual(pprint.pformat(d.values()), """\
+ValuesView({'a': 6,
+ 'brown': 2,
+ 'dog': 8,
+ 'fox': 3,
+ 'jumped': 4,
+ 'lazy': 7,
+ 'over': 5,
+ 'quick': 1,
+ 'the': 0})""")
def test_user_list(self):
d = collections.UserList()
diff --git a/Lib/test/test_property.py b/Lib/test/test_property.py
index cea241b0f20..26aefdbf042 100644
--- a/Lib/test/test_property.py
+++ b/Lib/test/test_property.py
@@ -87,8 +87,8 @@ class PropertyTests(unittest.TestCase):
self.assertEqual(base.spam, 10)
self.assertEqual(base._spam, 10)
delattr(base, "spam")
- self.assertTrue(not hasattr(base, "spam"))
- self.assertTrue(not hasattr(base, "_spam"))
+ self.assertNotHasAttr(base, "spam")
+ self.assertNotHasAttr(base, "_spam")
base.spam = 20
self.assertEqual(base.spam, 20)
self.assertEqual(base._spam, 20)
diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py
index d5a5a9738c2..a26a8c1d522 100644
--- a/Lib/test/test_pstats.py
+++ b/Lib/test/test_pstats.py
@@ -1,6 +1,7 @@
import unittest
from test import support
+from test.support.import_helper import ensure_lazy_imports
from io import StringIO
from pstats import SortKey
from enum import StrEnum, _test_simple_enum
@@ -10,6 +11,12 @@ import pstats
import tempfile
import cProfile
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pstats", {"typing"})
+
+
class AddCallersTestCase(unittest.TestCase):
"""Tests for pstats.add_callers helper."""
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index c1728f5019d..4836f38c388 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -20,7 +20,6 @@ import select
import signal
import socket
import io # readline
-import warnings
TEST_STRING_1 = b"I wish to buy a fish license.\n"
TEST_STRING_2 = b"For my pet fish, Eric.\n"
diff --git a/Lib/test/test_pulldom.py b/Lib/test/test_pulldom.py
index 6dc51e4371d..3c8ed251aca 100644
--- a/Lib/test/test_pulldom.py
+++ b/Lib/test/test_pulldom.py
@@ -46,7 +46,7 @@ class PullDOMTestCase(unittest.TestCase):
items = pulldom.parseString(SMALL_SAMPLE)
evt, node = next(items)
# Just check the node is a Document:
- self.assertTrue(hasattr(node, "createElement"))
+ self.assertHasAttr(node, "createElement")
self.assertEqual(pulldom.START_DOCUMENT, evt)
evt, node = next(items)
self.assertEqual(pulldom.START_ELEMENT, evt)
@@ -192,7 +192,7 @@ class ThoroughTestCase(unittest.TestCase):
evt, node = next(pd)
self.assertEqual(pulldom.START_DOCUMENT, evt)
# Just check the node is a Document:
- self.assertTrue(hasattr(node, "createElement"))
+ self.assertHasAttr(node, "createElement")
if before_root:
evt, node = next(pd)
diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py
index df05cd07d7e..3e7b2cd0dc9 100644
--- a/Lib/test/test_pyclbr.py
+++ b/Lib/test/test_pyclbr.py
@@ -103,7 +103,7 @@ class PyclbrTest(TestCase):
for name, value in dict.items():
if name in ignore:
continue
- self.assertHasAttr(module, name, ignore)
+ self.assertHasAttr(module, name)
py_item = getattr(module, name)
if isinstance(value, pyclbr.Function):
self.assertIsInstance(py_item, (FunctionType, BuiltinFunctionType))
diff --git a/Lib/test/test_pydoc/test_pydoc.py b/Lib/test/test_pydoc/test_pydoc.py
index ac88b3c6f13..d1d6f4987de 100644
--- a/Lib/test/test_pydoc/test_pydoc.py
+++ b/Lib/test/test_pydoc/test_pydoc.py
@@ -553,7 +553,7 @@ class PydocDocTest(unittest.TestCase):
# of the known subclasses of object. (doc.docclass() used to
# fail if HeapType was imported before running this test, like
# when running tests sequentially.)
- from _testcapi import HeapType
+ from _testcapi import HeapType # noqa: F401
except ImportError:
pass
text = doc.docclass(object)
@@ -1380,7 +1380,7 @@ class PydocImportTest(PydocBaseTest):
helper('modules garbage')
result = help_io.getvalue()
- self.assertTrue(result.startswith(expected))
+ self.assertStartsWith(result, expected)
def test_importfile(self):
try:
diff --git a/Lib/test/test_pyrepl/support.py b/Lib/test/test_pyrepl/support.py
index 3692e164cb9..4f7f9d77933 100644
--- a/Lib/test/test_pyrepl/support.py
+++ b/Lib/test/test_pyrepl/support.py
@@ -113,9 +113,6 @@ handle_events_narrow_console = partial(
prepare_console=partial(prepare_console, width=10),
)
-reader_no_colors = partial(prepare_reader, can_colorize=False)
-reader_force_colors = partial(prepare_reader, can_colorize=True)
-
class FakeConsole(Console):
def __init__(self, events, encoding="utf-8") -> None:
diff --git a/Lib/test/test_pyrepl/test_eventqueue.py b/Lib/test/test_pyrepl/test_eventqueue.py
index afb55710342..edfe6ac4748 100644
--- a/Lib/test/test_pyrepl/test_eventqueue.py
+++ b/Lib/test/test_pyrepl/test_eventqueue.py
@@ -53,7 +53,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": "b"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "b")
@@ -63,7 +63,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"c": "d"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "a")
@@ -73,13 +73,13 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "c")
- eq.push("d")
+ eq.push(b"d")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "d")
@@ -88,32 +88,32 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
eq.flush_buf()
- eq.push("\033")
+ eq.push(b"\033")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\033")
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "b")
def test_push_special_key(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("A")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"A")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
def test_push_unrecognized_escape_sequence(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("Z")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"Z")
self.assertEqual(len(eq.events), 3)
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
@@ -122,12 +122,54 @@ class EventQueueTestBase:
self.assertEqual(eq.events[2].evt, "key")
self.assertEqual(eq.events[2].data, "Z")
- def test_push_unicode_character(self):
+ def test_push_unicode_character_as_str(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("ч")
- self.assertEqual(eq.events[0].evt, "key")
- self.assertEqual(eq.events[0].data, "ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ñ")
+
+ def test_push_unicode_character_two_bytes(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ encoded = "ч".encode(eq.encoding, "replace")
+ self.assertEqual(len(encoded), 2)
+
+ eq.push(encoded[0])
+ e = eq.get()
+ self.assertIsNone(e)
+
+ eq.push(encoded[1])
+ e = eq.get()
+ self.assertEqual(e.evt, "key")
+ self.assertEqual(e.data, "ч")
+
+ def test_push_single_chars_and_unicode_character_as_str(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ def _event(evt, data, raw=None):
+ r = raw if raw is not None else data.encode(eq.encoding)
+ e = Event(evt, data, r)
+ return e
+
+ def _push(keys):
+ for k in keys:
+ eq.push(k)
+
+ self.assertIsInstance("ñ", str)
+
+ # If an exception happens during push, the existing events must be
+ # preserved and we can continue to push.
+ _push(b"b")
+ with self.assertRaises(AssertionError):
+ _push("ñ")
+ _push(b"a")
+
+ self.assertEqual(eq.get(), _event("key", "b"))
+ self.assertEqual(eq.get(), _event("key", "a"))
@unittest.skipIf(support.MS_WINDOWS, "No Unix event queue on Windows")
diff --git a/Lib/test/test_pyrepl/test_interact.py b/Lib/test/test_pyrepl/test_interact.py
index a20719033fc..8c0eeab6dca 100644
--- a/Lib/test/test_pyrepl/test_interact.py
+++ b/Lib/test/test_pyrepl/test_interact.py
@@ -113,7 +113,7 @@ class TestSimpleInteract(unittest.TestCase):
r = """
def f(x, x): ...
^
-SyntaxError: duplicate argument 'x' in function definition"""
+SyntaxError: duplicate parameter 'x' in function definition"""
self.assertIn(r, f.getvalue())
def test_runsource_shows_syntax_error_for_failed_compilation(self):
diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py
index 93029ab6e08..98bae7dd703 100644
--- a/Lib/test/test_pyrepl/test_pyrepl.py
+++ b/Lib/test/test_pyrepl/test_pyrepl.py
@@ -8,9 +8,10 @@ import select
import subprocess
import sys
import tempfile
+from pkgutil import ModuleInfo
from unittest import TestCase, skipUnless, skipIf
from unittest.mock import patch
-from test.support import force_not_colorized, make_clean_env
+from test.support import force_not_colorized, make_clean_env, Py_DEBUG
from test.support import SHORT_TIMEOUT, STDLIB_DIR
from test.support.import_helper import import_module
from test.support.os_helper import EnvironmentVarGuard, unlink
@@ -452,6 +453,11 @@ class TestPyReplAutoindent(TestCase):
)
# fmt: on
+ events = code_to_events(input_code)
+ reader = self.prepare_reader(events)
+ output = multiline_input(reader)
+ self.assertEqual(output, output_code)
+
def test_auto_indent_continuation(self):
# auto indenting according to previous user indentation
# fmt: off
@@ -912,7 +918,14 @@ class TestPyReplCompleter(TestCase):
class TestPyReplModuleCompleter(TestCase):
def setUp(self):
+ import importlib
+ # Make iter_modules() search only the standard library.
+ # This makes the test more reliable in case there are
+ # other user packages/scripts on PYTHONPATH which can
+ # interfere with the completions.
+ lib_path = os.path.dirname(importlib.__path__[0])
self._saved_sys_path = sys.path
+ sys.path = [lib_path]
def tearDown(self):
sys.path = self._saved_sys_path
@@ -920,19 +933,12 @@ class TestPyReplModuleCompleter(TestCase):
def prepare_reader(self, events, namespace):
console = FakeConsole(events)
config = ReadlineConfig()
+ config.module_completer = ModuleCompleter(namespace)
config.readline_completer = rlcompleter.Completer(namespace).complete
reader = ReadlineAlikeReader(console=console, config=config)
return reader
def test_import_completions(self):
- import importlib
- # Make iter_modules() search only the standard library.
- # This makes the test more reliable in case there are
- # other user packages/scripts on PYTHONPATH which can
- # intefere with the completions.
- lib_path = os.path.dirname(importlib.__path__[0])
- sys.path = [lib_path]
-
cases = (
("import path\t\n", "import pathlib"),
("import importlib.\t\tres\t\n", "import importlib.resources"),
@@ -954,10 +960,38 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
- def test_relative_import_completions(self):
+ @patch("pkgutil.iter_modules", lambda: [ModuleInfo(None, "public", True),
+ ModuleInfo(None, "_private", True)])
+ @patch("sys.builtin_module_names", ())
+ def test_private_completions(self):
+ cases = (
+ # Return public methods by default
+ ("import \t\n", "import public"),
+ ("from \t\n", "from public"),
+ # Return private methods if explicitly specified
+ ("import _\t\n", "import _private"),
+ ("from _\t\n", "from _private"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ @patch(
+ "_pyrepl._module_completer.ModuleCompleter.iter_submodules",
+ lambda *_: [
+ ModuleInfo(None, "public", True),
+ ModuleInfo(None, "_private", True),
+ ],
+ )
+ def test_sub_module_private_completions(self):
cases = (
- ("from .readl\t\n", "from .readline"),
- ("from . import readl\t\n", "from . import readline"),
+ # Return public methods by default
+ ("from foo import \t\n", "from foo import public"),
+ # Return private methods if explicitly specified
+ ("from foo import _\t\n", "from foo import _private"),
)
for code, expected in cases:
with self.subTest(code=code):
@@ -966,8 +1000,42 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
- @patch("pkgutil.iter_modules", lambda: [(None, 'valid_name', None),
- (None, 'invalid-name', None)])
+ def test_builtin_completion_top_level(self):
+ import importlib
+ # Make iter_modules() search only the standard library.
+ # This makes the test more reliable in case there are
+ # other user packages/scripts on PYTHONPATH which can
+ # intefere with the completions.
+ lib_path = os.path.dirname(importlib.__path__[0])
+ sys.path = [lib_path]
+
+ cases = (
+ ("import bui\t\n", "import builtins"),
+ ("from bui\t\n", "from builtins"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ def test_relative_import_completions(self):
+ cases = (
+ (None, "from .readl\t\n", "from .readl"),
+ (None, "from . import readl\t\n", "from . import readl"),
+ ("_pyrepl", "from .readl\t\n", "from .readline"),
+ ("_pyrepl", "from . import readl\t\n", "from . import readline"),
+ )
+ for package, code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={"__package__": package})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ @patch("pkgutil.iter_modules", lambda: [ModuleInfo(None, "valid_name", True),
+ ModuleInfo(None, "invalid-name", True)])
def test_invalid_identifiers(self):
# Make sure modules which are not valid identifiers
# are not suggested as those cannot be imported via 'import'.
@@ -983,6 +1051,19 @@ class TestPyReplModuleCompleter(TestCase):
output = reader.readline()
self.assertEqual(output, expected)
+ def test_no_fallback_on_regular_completion(self):
+ cases = (
+ ("import pri\t\n", "import pri"),
+ ("from pri\t\n", "from pri"),
+ ("from typing import Na\t\n", "from typing import Na"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
def test_get_path_and_prefix(self):
cases = (
('', ('', '')),
@@ -1050,11 +1131,15 @@ class TestPyReplModuleCompleter(TestCase):
self.assertEqual(actual, parsed)
# The parser should not get tripped up by any
# other preceding statements
- code = f'import xyz\n{code}'
- with self.subTest(code=code):
+ _code = f'import xyz\n{code}'
+ parser = ImportParser(_code)
+ actual = parser.parse()
+ with self.subTest(code=_code):
self.assertEqual(actual, parsed)
- code = f'import xyz;{code}'
- with self.subTest(code=code):
+ _code = f'import xyz;{code}'
+ parser = ImportParser(_code)
+ actual = parser.parse()
+ with self.subTest(code=_code):
self.assertEqual(actual, parsed)
def test_parse_error(self):
@@ -1327,7 +1412,7 @@ class TestMain(ReplTestCase):
)
@force_not_colorized
- def _run_repl_globals_test(self, expectations, *, as_file=False, as_module=False):
+ def _run_repl_globals_test(self, expectations, *, as_file=False, as_module=False, pythonstartup=False):
clean_env = make_clean_env()
clean_env["NO_COLOR"] = "1" # force_not_colorized doesn't touch subprocesses
@@ -1336,9 +1421,13 @@ class TestMain(ReplTestCase):
blue.mkdir()
mod = blue / "calx.py"
mod.write_text("FOO = 42", encoding="utf-8")
+ startup = blue / "startup.py"
+ startup.write_text("BAR = 64", encoding="utf-8")
commands = [
"print(f'^{" + var + "=}')" for var in expectations
] + ["exit()"]
+ if pythonstartup:
+ clean_env["PYTHONSTARTUP"] = str(startup)
if as_file and as_module:
self.fail("as_file and as_module are mutually exclusive")
elif as_file:
@@ -1357,7 +1446,13 @@ class TestMain(ReplTestCase):
skip=True,
)
else:
- self.fail("Choose one of as_file or as_module")
+ output, exit_code = self.run_repl(
+ commands,
+ cmdline_args=[],
+ env=clean_env,
+ cwd=td,
+ skip=True,
+ )
self.assertEqual(exit_code, 0)
for var, expected in expectations.items():
@@ -1370,6 +1465,23 @@ class TestMain(ReplTestCase):
self.assertNotIn("Exception", output)
self.assertNotIn("Traceback", output)
+ def test_globals_initialized_as_default(self):
+ expectations = {
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations)
+
+ def test_globals_initialized_from_pythonstartup(self):
+ expectations = {
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations, pythonstartup=True)
+
def test_inspect_keeps_globals_from_inspected_file(self):
expectations = {
"FOO": "42",
@@ -1379,6 +1491,16 @@ class TestMain(ReplTestCase):
}
self._run_repl_globals_test(expectations, as_file=True)
+ def test_inspect_keeps_globals_from_inspected_file_with_pythonstartup(self):
+ expectations = {
+ "FOO": "42",
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "None",
+ # "__file__" is missing in -i, like in the basic REPL
+ }
+ self._run_repl_globals_test(expectations, as_file=True, pythonstartup=True)
+
def test_inspect_keeps_globals_from_inspected_module(self):
expectations = {
"FOO": "42",
@@ -1388,26 +1510,32 @@ class TestMain(ReplTestCase):
}
self._run_repl_globals_test(expectations, as_module=True)
+ def test_inspect_keeps_globals_from_inspected_module_with_pythonstartup(self):
+ expectations = {
+ "FOO": "42",
+ "BAR": "64",
+ "__name__": "'__main__'",
+ "__package__": "'blue'",
+ "__file__": re.compile(r"^'.*calx.py'$"),
+ }
+ self._run_repl_globals_test(expectations, as_module=True, pythonstartup=True)
+
@force_not_colorized
def test_python_basic_repl(self):
env = os.environ.copy()
- commands = ("from test.support import initialized_with_pyrepl\n"
- "initialized_with_pyrepl()\n"
- "exit()\n")
-
+ pyrepl_commands = "clear\nexit()\n"
env.pop("PYTHON_BASIC_REPL", None)
- output, exit_code = self.run_repl(commands, env=env, skip=True)
+ output, exit_code = self.run_repl(pyrepl_commands, env=env, skip=True)
self.assertEqual(exit_code, 0)
- self.assertIn("True", output)
- self.assertNotIn("False", output)
self.assertNotIn("Exception", output)
+ self.assertNotIn("NameError", output)
self.assertNotIn("Traceback", output)
+ basic_commands = "help\nexit()\n"
env["PYTHON_BASIC_REPL"] = "1"
- output, exit_code = self.run_repl(commands, env=env)
+ output, exit_code = self.run_repl(basic_commands, env=env)
self.assertEqual(exit_code, 0)
- self.assertIn("False", output)
- self.assertNotIn("True", output)
+ self.assertIn("Type help() for interactive help", output)
self.assertNotIn("Exception", output)
self.assertNotIn("Traceback", output)
@@ -1544,6 +1672,17 @@ class TestMain(ReplTestCase):
self.assertEqual(exit_code, 0)
self.assertNotIn("TypeError", output)
+ @force_not_colorized
+ def test_non_string_suggestion_candidates(self):
+ commands = ("import runpy\n"
+ "runpy._run_module_code('blech', {0: '', 'bluch': ''}, '')\n"
+ "exit()\n")
+
+ output, exit_code = self.run_repl(commands)
+ self.assertEqual(exit_code, 0)
+ self.assertNotIn("all elements in 'candidates' must be strings", output)
+ self.assertIn("bluch", output)
+
def test_readline_history_file(self):
# skip, if readline module is not available
readline = import_module('readline')
@@ -1605,3 +1744,16 @@ class TestMain(ReplTestCase):
# Extra stuff (newline and `exit` rewrites) are necessary
# because of how run_repl works.
self.assertNotIn(">>> \n>>> >>>", cleaned_output)
+
+ @skipUnless(Py_DEBUG, '-X showrefcount requires a Python debug build')
+ def test_showrefcount(self):
+ env = os.environ.copy()
+ env.pop("PYTHON_BASIC_REPL", "")
+ output, _ = self.run_repl("1\n1+2\nexit()\n", cmdline_args=['-Xshowrefcount'], env=env)
+ matches = re.findall(r'\[-?\d+ refs, \d+ blocks\]', output)
+ self.assertEqual(len(matches), 3)
+
+ env["PYTHON_BASIC_REPL"] = "1"
+ output, _ = self.run_repl("1\n1+2\nexit()\n", cmdline_args=['-Xshowrefcount'], env=env)
+ matches = re.findall(r'\[-?\d+ refs, \d+ blocks\]', output)
+ self.assertEqual(len(matches), 3)
diff --git a/Lib/test/test_pyrepl/test_reader.py b/Lib/test/test_pyrepl/test_reader.py
index 8d7fcf538d2..1f655264f1c 100644
--- a/Lib/test/test_pyrepl/test_reader.py
+++ b/Lib/test/test_pyrepl/test_reader.py
@@ -4,20 +4,21 @@ import rlcompleter
from textwrap import dedent
from unittest import TestCase
from unittest.mock import MagicMock
+from test.support import force_colorized_test_class, force_not_colorized_test_class
from .support import handle_all_events, handle_events_narrow_console
from .support import ScreenEqualMixin, code_to_events
-from .support import prepare_console, reader_force_colors
-from .support import reader_no_colors as prepare_reader
+from .support import prepare_reader, prepare_console
from _pyrepl.console import Event
from _pyrepl.reader import Reader
-from _colorize import theme
+from _colorize import default_theme
-overrides = {"RESET": "z", "SOFT_KEYWORD": "K"}
-colors = {overrides.get(k, k[0].lower()): v for k, v in theme.items()}
+overrides = {"reset": "z", "soft_keyword": "K"}
+colors = {overrides.get(k, k[0].lower()): v for k, v in default_theme.syntax.items()}
+@force_not_colorized_test_class
class TestReader(ScreenEqualMixin, TestCase):
def test_calc_screen_wrap_simple(self):
events = code_to_events(10 * "a")
@@ -127,13 +128,6 @@ class TestReader(ScreenEqualMixin, TestCase):
reader.setpos_from_xy(0, 0)
self.assertEqual(reader.pos, 0)
- def test_control_characters(self):
- code = 'flag = "🏳️‍🌈"'
- events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
- self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
- self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors))
-
def test_setpos_from_xy_multiple_lines(self):
# fmt: off
code = (
@@ -364,6 +358,8 @@ class TestReader(ScreenEqualMixin, TestCase):
reader.setpos_from_xy(8, 0)
self.assertEqual(reader.pos, 7)
+@force_colorized_test_class
+class TestReaderInColor(ScreenEqualMixin, TestCase):
def test_syntax_highlighting_basic(self):
code = dedent(
"""\
@@ -403,7 +399,7 @@ class TestReader(ScreenEqualMixin, TestCase):
)
expected_sync = expected.format(a="", **colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected_sync)
self.assertEqual(reader.pos, 2**7 + 2**8)
@@ -416,7 +412,7 @@ class TestReader(ScreenEqualMixin, TestCase):
[Event(evt="key", data="up", raw=bytearray(b"\x1bOA"))] * 13,
code_to_events("async "),
)
- reader, _ = handle_all_events(more_events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(more_events)
self.assert_screen_equal(reader, expected_async)
self.assertEqual(reader.pos, 21)
self.assertEqual(reader.cxy, (6, 1))
@@ -433,7 +429,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -451,7 +447,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -471,7 +467,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -497,6 +493,64 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(OB="{", CB="}", **colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_indentation_error(self):
+ code = dedent(
+ """\
+ def unfinished_function():
+ var = 1
+ oops
+ """
+ )
+ expected = dedent(
+ """\
+ {k}def{z} {d}unfinished_function{z}{o}({z}{o}){z}{o}:{z}
+ var {o}={z} {n}1{z}
+ oops
+ """
+ ).format(**colors)
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.assert_screen_equal(reader, expected)
+
+ def test_syntax_highlighting_literal_brace_in_fstring_or_tstring(self):
+ code = dedent(
+ """\
+ f"{{"
+ f"}}"
+ f"a{{b"
+ f"a}}b"
+ f"a{{b}}c"
+ t"a{{b}}c"
+ f"{{{0}}}"
+ f"{ {0} }"
+ """
+ )
+ expected = dedent(
+ """\
+ {s}f"{z}{s}<<{z}{s}"{z}
+ {s}f"{z}{s}>>{z}{s}"{z}
+ {s}f"{z}{s}a<<{z}{s}b{z}{s}"{z}
+ {s}f"{z}{s}a>>{z}{s}b{z}{s}"{z}
+ {s}f"{z}{s}a<<{z}{s}b>>{z}{s}c{z}{s}"{z}
+ {s}t"{z}{s}a<<{z}{s}b>>{z}{s}c{z}{s}"{z}
+ {s}f"{z}{s}<<{z}{o}<{z}{n}0{z}{o}>{z}{s}>>{z}{s}"{z}
+ {s}f"{z}{o}<{z} {o}<{z}{n}0{z}{o}>{z} {o}>{z}{s}"{z}
+ """
+ ).format(**colors).replace("<", "{").replace(">", "}")
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, code, clean=True)
+ self.maxDiff=None
+ self.assert_screen_equal(reader, expected)
+
+ def test_control_characters(self):
+ code = 'flag = "🏳️‍🌈"'
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
+ self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors))
diff --git a/Lib/test/test_pyrepl/test_unix_console.py b/Lib/test/test_pyrepl/test_unix_console.py
index 7acb84a94f7..b3f7dc028fe 100644
--- a/Lib/test/test_pyrepl/test_unix_console.py
+++ b/Lib/test/test_pyrepl/test_unix_console.py
@@ -3,11 +3,12 @@ import os
import sys
import unittest
from functools import partial
-from test.support import os_helper
+from test.support import os_helper, force_not_colorized_test_class
+
from unittest import TestCase
from unittest.mock import MagicMock, call, patch, ANY
-from .support import handle_all_events, code_to_events, reader_no_colors
+from .support import handle_all_events, code_to_events
try:
from _pyrepl.console import Event
@@ -19,6 +20,7 @@ except ImportError:
def unix_console(events, **kwargs):
console = UnixConsole()
console.get_event = MagicMock(side_effect=events)
+ console.getpending = MagicMock(return_value=Event("key", ""))
height = kwargs.get("height", 25)
width = kwargs.get("width", 80)
@@ -33,12 +35,10 @@ def unix_console(events, **kwargs):
handle_events_unix_console = partial(
handle_all_events,
- prepare_reader=reader_no_colors,
prepare_console=unix_console,
)
handle_events_narrow_unix_console = partial(
handle_all_events,
- prepare_reader=reader_no_colors,
prepare_console=partial(unix_console, width=5),
)
handle_events_short_unix_console = partial(
@@ -120,6 +120,7 @@ TERM_CAPABILITIES = {
)
@patch("termios.tcsetattr", lambda a, b, c: None)
@patch("os.write")
+@force_not_colorized_test_class
class TestConsole(TestCase):
def test_simple_addition(self, _os_write):
code = "12+34"
@@ -255,9 +256,7 @@ class TestConsole(TestCase):
# fmt: on
events = itertools.chain(code_to_events(code))
- reader, console = handle_events_short_unix_console(
- events, prepare_reader=reader_no_colors
- )
+ reader, console = handle_events_short_unix_console(events)
console.height = 2
console.getheightwidth = MagicMock(lambda _: (2, 80))
diff --git a/Lib/test/test_pyrepl/test_windows_console.py b/Lib/test/test_pyrepl/test_windows_console.py
index e95fec46a85..f9607e02c60 100644
--- a/Lib/test/test_pyrepl/test_windows_console.py
+++ b/Lib/test/test_pyrepl/test_windows_console.py
@@ -7,12 +7,13 @@ if sys.platform != "win32":
import itertools
from functools import partial
+from test.support import force_not_colorized_test_class
from typing import Iterable
from unittest import TestCase
from unittest.mock import MagicMock, call
from .support import handle_all_events, code_to_events
-from .support import reader_no_colors as default_prepare_reader
+from .support import prepare_reader as default_prepare_reader
try:
from _pyrepl.console import Event, Console
@@ -24,14 +25,17 @@ try:
MOVE_DOWN,
ERASE_IN_LINE,
)
+ import _pyrepl.windows_console as wc
except ImportError:
pass
+@force_not_colorized_test_class
class WindowsConsoleTests(TestCase):
def console(self, events, **kwargs) -> Console:
console = WindowsConsole()
console.get_event = MagicMock(side_effect=events)
+ console.getpending = MagicMock(return_value=Event("key", ""))
console.wait = MagicMock()
console._scroll = MagicMock()
console._hide_cursor = MagicMock()
@@ -350,8 +354,227 @@ class WindowsConsoleTests(TestCase):
Event(evt="key", data='\x1a', raw=bytearray(b'\x1a')),
],
)
- reader, _ = self.handle_events_narrow(events)
+ reader, con = self.handle_events_narrow(events)
self.assertEqual(reader.cxy, (2, 3))
+ con.restore()
+
+
+class WindowsConsoleGetEventTests(TestCase):
+ # Virtual-Key Codes: https://learn.microsoft.com/en-us/windows/win32/inputdev/virtual-key-codes
+ VK_BACK = 0x08
+ VK_RETURN = 0x0D
+ VK_LEFT = 0x25
+ VK_7 = 0x37
+ VK_M = 0x4D
+ # Used for miscellaneous characters; it can vary by keyboard.
+ # For the US standard keyboard, the '" key.
+ # For the German keyboard, the Ä key.
+ VK_OEM_7 = 0xDE
+
+ # State of control keys: https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ RIGHT_ALT_PRESSED = 0x0001
+ RIGHT_CTRL_PRESSED = 0x0004
+ LEFT_ALT_PRESSED = 0x0002
+ LEFT_CTRL_PRESSED = 0x0008
+ ENHANCED_KEY = 0x0100
+ SHIFT_PRESSED = 0x0010
+
+
+ def get_event(self, input_records, **kwargs) -> Console:
+ self.console = WindowsConsole(encoding='utf-8')
+ self.mock = MagicMock(side_effect=input_records)
+ self.console._read_input = self.mock
+ self.console._WindowsConsole__vt_support = kwargs.get("vt_support",
+ False)
+ self.console.wait = MagicMock(return_value=True)
+ event = self.console.get_event(block=False)
+ return event
+
+ def get_input_record(self, unicode_char, vcode=0, control=0):
+ return wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(
+ bKeyDown=True,
+ wRepeatCount=1,
+ wVirtualKeyCode=vcode,
+ wVirtualScanCode=0, # not used
+ uChar=wc.Char(unicode_char),
+ dwControlKeyState=control
+ )))
+
+ def test_EmptyBuffer(self):
+ self.assertEqual(self.get_event([None]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_WINDOW_BUFFER_SIZE_EVENT(self):
+ ir = wc.INPUT_RECORD(
+ wc.WINDOW_BUFFER_SIZE_EVENT,
+ wc.ConsoleEvent(WindowsBufferSizeEvent=
+ wc.WindowsBufferSizeEvent(
+ wc._COORD(0, 0))))
+ self.assertEqual(self.get_event([ir]), Event("resize", ""))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_KEY_EVENT_up_ignored(self):
+ ir = wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_unhandled_events(self):
+ for event in (wc.FOCUS_EVENT, wc.MENU_EVENT, wc.MOUSE_EVENT):
+ ir = wc.INPUT_RECORD(
+ event,
+ # fake data, nothing is read except bKeyDown
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_enter(self):
+ ir = self.get_input_record("\r", self.VK_RETURN)
+ self.assertEqual(self.get_event([ir]), Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace(self):
+ ir = self.get_input_record("\x08", self.VK_BACK)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "backspace"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m(self):
+ ir = self.get_input_record("m", self.VK_M)
+ self.assertEqual(self.get_event([ir]), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_M(self):
+ ir = self.get_input_record("M", self.VK_M, self.SHIFT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "M"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left(self):
+ # VK_LEFT is sent as ENHANCED_KEY
+ ir = self.get_input_record("\x00", self.VK_LEFT, self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ # self.mock is not called again, since the second time we read from the
+ # command queue
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED_and_LEFT_CTRL_PRESSED(self):
+ # For the shift keys, Windows does not send anything when
+ # ALT and CTRL are both pressed, so let's test with VK_M.
+ # get_event() receives this input, but does not
+ # generate an event.
+ # This is for e.g. an English keyboard layout, for a
+ # German layout this returns `µ`, see test_AltGr_m.
+ ir = self.get_input_record(
+ "\x00", self.VK_M, self.LEFT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.RIGHT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_7(self):
+ # E.g. on a German keyboard layout, '{' is entered via
+ # AltGr + 7, where AltGr is the right Alt key on the keyboard.
+ # In this case, Windows automatically sets
+ # RIGHT_ALT_PRESSED = 0x0001 + LEFT_CTRL_PRESSED = 0x0008
+ # This can also be entered like
+ # LeftAlt + LeftCtrl + 7 or
+ # LeftAlt + RightCtrl + 7
+ # See https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw
+ ir = self.get_input_record(
+ "{", vcode=self.VK_7,
+ control=self.RIGHT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "{"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_m(self):
+ # E.g. on a German keyboard layout, this yields 'µ'
+ # Let's use LEFT_ALT_PRESSED and RIGHT_CTRL_PRESSED this
+ # time, to cover that, too. See above in test_AltGr_7.
+ ir = self.get_input_record(
+ "µ", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED | self.RIGHT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "µ"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_umlaut_a_german(self):
+ ir = self.get_input_record("ä", self.VK_OEM_7)
+ self.assertEqual(self.get_event([ir]), Event("key", "ä"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ # virtual terminal tests
+ # Note: wVirtualKeyCode, wVirtualScanCode and dwControlKeyState
+ # are always zero in this case.
+ # "\r" and backspace are handled specially, everything else
+ # is handled in "elif self.__vt_support:" in WindowsConsole.get_event().
+ # Hence, only one regular key ("m") and a terminal sequence
+ # are sufficient to test here, the real tests happen in test_eventqueue
+ # and test_keymap.
+
+ def test_enter_vt(self):
+ ir = self.get_input_record("\r")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace_vt(self):
+ ir = self.get_input_record("\x7f")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "backspace", b"\x7f"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_up_vt(self):
+ irs = [self.get_input_record(x) for x in "\x1b[A"]
+ self.assertEqual(self.get_event(irs, vt_support=True),
+ Event(evt='key', data='up', raw=bytearray(b'\x1b[A')))
+ self.assertEqual(self.mock.call_count, 3)
if __name__ == "__main__":
diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py
index 7f4fe357034..c855fb8fe2b 100644
--- a/Lib/test/test_queue.py
+++ b/Lib/test/test_queue.py
@@ -6,7 +6,7 @@ import threading
import time
import unittest
import weakref
-from test.support import gc_collect
+from test.support import gc_collect, bigmemtest
from test.support import import_helper
from test.support import threading_helper
@@ -963,33 +963,33 @@ class BaseSimpleQueueTest:
# One producer, one consumer => results appended in well-defined order
self.assertEqual(results, inputs)
- def test_many_threads(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads(self, size):
# Test multiple concurrent put() and get()
- N = 50
q = self.q
inputs = list(range(10000))
- results = self.run_threads(N, q, inputs, self.feed, self.consume)
+ results = self.run_threads(size, q, inputs, self.feed, self.consume)
# Multiple consumers without synchronization append the
# results in random order
self.assertEqual(sorted(results), inputs)
- def test_many_threads_nonblock(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads_nonblock(self, size):
# Test multiple concurrent put() and get(block=False)
- N = 50
q = self.q
inputs = list(range(10000))
- results = self.run_threads(N, q, inputs,
+ results = self.run_threads(size, q, inputs,
self.feed, self.consume_nonblock)
self.assertEqual(sorted(results), inputs)
- def test_many_threads_timeout(self):
+ @bigmemtest(size=50, memuse=100*2**20, dry_run=False)
+ def test_many_threads_timeout(self, size):
# Test multiple concurrent put() and get(timeout=...)
- N = 50
q = self.q
inputs = list(range(1000))
- results = self.run_threads(N, q, inputs,
+ results = self.run_threads(size, q, inputs,
self.feed, self.consume_timeout)
self.assertEqual(sorted(results), inputs)
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 96f6cc86219..0217ebd132b 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -14,6 +14,15 @@ from test import support
from fractions import Fraction
from collections import abc, Counter
+
+class MyIndex:
+ def __init__(self, value):
+ self.value = value
+
+ def __index__(self):
+ return self.value
+
+
class TestBasicOps:
# Superclass with tests common to all generators.
# Subclasses must arrange for self.gen to retrieve the Random instance
@@ -142,6 +151,7 @@ class TestBasicOps:
# Exception raised if size of sample exceeds that of population
self.assertRaises(ValueError, self.gen.sample, population, N+1)
self.assertRaises(ValueError, self.gen.sample, [], -1)
+ self.assertRaises(TypeError, self.gen.sample, population, 1.0)
def test_sample_distribution(self):
# For the entire allowable range of 0 <= k <= N, validate that
@@ -259,6 +269,7 @@ class TestBasicOps:
choices(data, range(4), k=5),
choices(k=5, population=data, weights=range(4)),
choices(k=5, population=data, cum_weights=range(4)),
+ choices(data, k=MyIndex(5)),
]:
self.assertEqual(len(sample), 5)
self.assertEqual(type(sample), list)
@@ -369,118 +380,40 @@ class TestBasicOps:
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
+ @support.requires_IEEE_754
+ def test_53_bits_per_float(self):
+ span = 2 ** 53
+ cum = 0
+ for i in range(100):
+ cum |= int(self.gen.random() * span)
+ self.assertEqual(cum, span-1)
+
def test_getrandbits(self):
+ getrandbits = self.gen.getrandbits
# Verify ranges
for k in range(1, 1000):
- self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
- self.assertEqual(self.gen.getrandbits(0), 0)
+ self.assertTrue(0 <= getrandbits(k) < 2**k)
+ self.assertEqual(getrandbits(0), 0)
# Verify all bits active
- getbits = self.gen.getrandbits
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
all_bits = 2**span-1
cum = 0
cpl_cum = 0
for i in range(100):
- v = getbits(span)
+ v = getrandbits(span)
cum |= v
cpl_cum |= all_bits ^ v
self.assertEqual(cum, all_bits)
self.assertEqual(cpl_cum, all_bits)
# Verify argument checking
- self.assertRaises(TypeError, self.gen.getrandbits)
- self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
- self.assertRaises(ValueError, self.gen.getrandbits, -1)
- self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
-
- def test_pickling(self):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- state = pickle.dumps(self.gen, proto)
- origseq = [self.gen.random() for i in range(10)]
- newgen = pickle.loads(state)
- restoredseq = [newgen.random() for i in range(10)]
- self.assertEqual(origseq, restoredseq)
-
- def test_bug_1727780(self):
- # verify that version-2-pickles can be loaded
- # fine, whether they are created on 32-bit or 64-bit
- # platforms, and that version-3-pickles load fine.
- files = [("randv2_32.pck", 780),
- ("randv2_64.pck", 866),
- ("randv3.pck", 343)]
- for file, value in files:
- with open(support.findfile(file),"rb") as f:
- r = pickle.load(f)
- self.assertEqual(int(r.random()*1000), value)
-
- def test_bug_9025(self):
- # Had problem with an uneven distribution in int(n*random())
- # Verify the fix by checking that distributions fall within expectations.
- n = 100000
- randrange = self.gen.randrange
- k = sum(randrange(6755399441055744) % 3 == 2 for i in range(n))
- self.assertTrue(0.30 < k/n < .37, (k/n))
-
- def test_randbytes(self):
- # Verify ranges
- for n in range(1, 10):
- data = self.gen.randbytes(n)
- self.assertEqual(type(data), bytes)
- self.assertEqual(len(data), n)
-
- self.assertEqual(self.gen.randbytes(0), b'')
-
- # Verify argument checking
- self.assertRaises(TypeError, self.gen.randbytes)
- self.assertRaises(TypeError, self.gen.randbytes, 1, 2)
- self.assertRaises(ValueError, self.gen.randbytes, -1)
- self.assertRaises(TypeError, self.gen.randbytes, 1.0)
-
- def test_mu_sigma_default_args(self):
- self.assertIsInstance(self.gen.normalvariate(), float)
- self.assertIsInstance(self.gen.gauss(), float)
-
-
-try:
- random.SystemRandom().random()
-except NotImplementedError:
- SystemRandom_available = False
-else:
- SystemRandom_available = True
-
-@unittest.skipUnless(SystemRandom_available, "random.SystemRandom not available")
-class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
- gen = random.SystemRandom()
-
- def test_autoseed(self):
- # Doesn't need to do anything except not fail
- self.gen.seed()
-
- def test_saverestore(self):
- self.assertRaises(NotImplementedError, self.gen.getstate)
- self.assertRaises(NotImplementedError, self.gen.setstate, None)
-
- def test_seedargs(self):
- # Doesn't need to do anything except not fail
- self.gen.seed(100)
-
- def test_gauss(self):
- self.gen.gauss_next = None
- self.gen.seed(100)
- self.assertEqual(self.gen.gauss_next, None)
-
- def test_pickling(self):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- self.assertRaises(NotImplementedError, pickle.dumps, self.gen, proto)
-
- def test_53_bits_per_float(self):
- # This should pass whenever a C double has 53 bit precision.
- span = 2 ** 53
- cum = 0
- for i in range(100):
- cum |= int(self.gen.random() * span)
- self.assertEqual(cum, span-1)
+ self.assertRaises(TypeError, getrandbits)
+ self.assertRaises(TypeError, getrandbits, 1, 2)
+ self.assertRaises(ValueError, getrandbits, -1)
+ self.assertRaises(OverflowError, getrandbits, 1<<1000)
+ self.assertRaises(ValueError, getrandbits, -1<<1000)
+ self.assertRaises(TypeError, getrandbits, 10.1)
def test_bigrand(self):
# The randrange routine should build-up the required number of bits
@@ -559,6 +492,10 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
randrange(1000, step=100)
with self.assertRaises(TypeError):
randrange(1000, None, step=100)
+ with self.assertRaises(TypeError):
+ randrange(1000, step=MyIndex(1))
+ with self.assertRaises(TypeError):
+ randrange(1000, None, step=MyIndex(1))
def test_randbelow_logic(self, _log=log, int=int):
# check bitcount transition points: 2**i and 2**(i+1)-1
@@ -581,6 +518,116 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(k, numbits) # note the stronger assertion
self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
+ def test_randrange_index(self):
+ randrange = self.gen.randrange
+ self.assertIn(randrange(MyIndex(5)), range(5))
+ self.assertIn(randrange(MyIndex(2), MyIndex(7)), range(2, 7))
+ self.assertIn(randrange(MyIndex(5), MyIndex(15), MyIndex(2)), range(5, 15, 2))
+
+ def test_randint(self):
+ randint = self.gen.randint
+ self.assertIn(randint(2, 5), (2, 3, 4, 5))
+ self.assertEqual(randint(2, 2), 2)
+ self.assertIn(randint(MyIndex(2), MyIndex(5)), (2, 3, 4, 5))
+ self.assertEqual(randint(MyIndex(2), MyIndex(2)), 2)
+
+ self.assertRaises(ValueError, randint, 5, 2)
+ self.assertRaises(TypeError, randint)
+ self.assertRaises(TypeError, randint, 2)
+ self.assertRaises(TypeError, randint, 2, 5, 1)
+ self.assertRaises(TypeError, randint, 2.0, 5)
+ self.assertRaises(TypeError, randint, 2, 5.0)
+
+ def test_pickling(self):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ state = pickle.dumps(self.gen, proto)
+ origseq = [self.gen.random() for i in range(10)]
+ newgen = pickle.loads(state)
+ restoredseq = [newgen.random() for i in range(10)]
+ self.assertEqual(origseq, restoredseq)
+
+ def test_bug_1727780(self):
+ # verify that version-2-pickles can be loaded
+ # fine, whether they are created on 32-bit or 64-bit
+ # platforms, and that version-3-pickles load fine.
+ files = [("randv2_32.pck", 780),
+ ("randv2_64.pck", 866),
+ ("randv3.pck", 343)]
+ for file, value in files:
+ with open(support.findfile(file),"rb") as f:
+ r = pickle.load(f)
+ self.assertEqual(int(r.random()*1000), value)
+
+ def test_bug_9025(self):
+ # Had problem with an uneven distribution in int(n*random())
+ # Verify the fix by checking that distributions fall within expectations.
+ n = 100000
+ randrange = self.gen.randrange
+ k = sum(randrange(6755399441055744) % 3 == 2 for i in range(n))
+ self.assertTrue(0.30 < k/n < .37, (k/n))
+
+ def test_randrange_bug_1590891(self):
+ start = 1000000000000
+ stop = -100000000000000000000
+ step = -200
+ x = self.gen.randrange(start, stop, step)
+ self.assertTrue(stop < x <= start)
+ self.assertEqual((x+stop)%step, 0)
+
+ def test_randbytes(self):
+ # Verify ranges
+ for n in range(1, 10):
+ data = self.gen.randbytes(n)
+ self.assertEqual(type(data), bytes)
+ self.assertEqual(len(data), n)
+
+ self.assertEqual(self.gen.randbytes(0), b'')
+
+ # Verify argument checking
+ self.assertRaises(TypeError, self.gen.randbytes)
+ self.assertRaises(TypeError, self.gen.randbytes, 1, 2)
+ self.assertRaises(ValueError, self.gen.randbytes, -1)
+ self.assertRaises(OverflowError, self.gen.randbytes, 1<<1000)
+ self.assertRaises((ValueError, OverflowError), self.gen.randbytes, -1<<1000)
+ self.assertRaises(TypeError, self.gen.randbytes, 1.0)
+
+ def test_mu_sigma_default_args(self):
+ self.assertIsInstance(self.gen.normalvariate(), float)
+ self.assertIsInstance(self.gen.gauss(), float)
+
+
+try:
+ random.SystemRandom().random()
+except NotImplementedError:
+ SystemRandom_available = False
+else:
+ SystemRandom_available = True
+
+@unittest.skipUnless(SystemRandom_available, "random.SystemRandom not available")
+class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
+ gen = random.SystemRandom()
+
+ def test_autoseed(self):
+ # Doesn't need to do anything except not fail
+ self.gen.seed()
+
+ def test_saverestore(self):
+ self.assertRaises(NotImplementedError, self.gen.getstate)
+ self.assertRaises(NotImplementedError, self.gen.setstate, None)
+
+ def test_seedargs(self):
+ # Doesn't need to do anything except not fail
+ self.gen.seed(100)
+
+ def test_gauss(self):
+ self.gen.gauss_next = None
+ self.gen.seed(100)
+ self.assertEqual(self.gen.gauss_next, None)
+
+ def test_pickling(self):
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.assertRaises(NotImplementedError, pickle.dumps, self.gen, proto)
+
class TestRawMersenneTwister(unittest.TestCase):
@test.support.cpython_only
@@ -766,38 +813,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
seed = (1 << (10000 * 8)) - 1 # about 10K bytes
self.gen.seed(seed)
- def test_53_bits_per_float(self):
- # This should pass whenever a C double has 53 bit precision.
- span = 2 ** 53
- cum = 0
- for i in range(100):
- cum |= int(self.gen.random() * span)
- self.assertEqual(cum, span-1)
-
- def test_bigrand(self):
- # The randrange routine should build-up the required number of bits
- # in stages so that all bit positions are active.
- span = 2 ** 500
- cum = 0
- for i in range(100):
- r = self.gen.randrange(span)
- self.assertTrue(0 <= r < span)
- cum |= r
- self.assertEqual(cum, span-1)
-
- def test_bigrand_ranges(self):
- for i in [40,80, 160, 200, 211, 250, 375, 512, 550]:
- start = self.gen.randrange(2 ** (i-2))
- stop = self.gen.randrange(2 ** i)
- if stop <= start:
- continue
- self.assertTrue(start <= self.gen.randrange(start, stop) < stop)
-
- def test_rangelimits(self):
- for start, stop in [(-2,0), (-(2**60)-2,-(2**60)), (2**60,2**60+2)]:
- self.assertEqual(set(range(start,stop)),
- set([self.gen.randrange(start,stop) for i in range(100)]))
-
def test_getrandbits(self):
super().test_getrandbits()
@@ -805,6 +820,25 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.gen.seed(1234567)
self.assertEqual(self.gen.getrandbits(100),
97904845777343510404718956115)
+ self.gen.seed(1234567)
+ self.assertEqual(self.gen.getrandbits(MyIndex(100)),
+ 97904845777343510404718956115)
+
+ def test_getrandbits_2G_bits(self):
+ size = 2**31
+ self.gen.seed(1234567)
+ x = self.gen.getrandbits(size)
+ self.assertEqual(x.bit_length(), size)
+ self.assertEqual(x & (2**100-1), 890186470919986886340158459475)
+ self.assertEqual(x >> (size-100), 1226514312032729439655761284440)
+
+ @support.bigmemtest(size=2**32, memuse=1/8+2/15, dry_run=False)
+ def test_getrandbits_4G_bits(self, size):
+ self.gen.seed(1234568)
+ x = self.gen.getrandbits(size)
+ self.assertEqual(x.bit_length(), size)
+ self.assertEqual(x & (2**100-1), 287241425661104632871036099814)
+ self.assertEqual(x >> (size-100), 739728759900339699429794460738)
def test_randrange_uses_getrandbits(self):
# Verify use of getrandbits by randrange
@@ -816,27 +850,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(self.gen.randrange(2**99),
97904845777343510404718956115)
- def test_randbelow_logic(self, _log=log, int=int):
- # check bitcount transition points: 2**i and 2**(i+1)-1
- # show that: k = int(1.001 + _log(n, 2))
- # is equal to or one greater than the number of bits in n
- for i in range(1, 1000):
- n = 1 << i # check an exact power of two
- numbits = i+1
- k = int(1.00001 + _log(n, 2))
- self.assertEqual(k, numbits)
- self.assertEqual(n, 2**(k-1))
-
- n += n - 1 # check 1 below the next power of two
- k = int(1.00001 + _log(n, 2))
- self.assertIn(k, [numbits, numbits+1])
- self.assertTrue(2**k > n > 2**(k-2))
-
- n -= n >> 15 # check a little farther below the next power of two
- k = int(1.00001 + _log(n, 2))
- self.assertEqual(k, numbits) # note the stronger assertion
- self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
-
def test_randbelow_without_getrandbits(self):
# Random._randbelow() can only use random() when the built-in one
# has been overridden but no new getrandbits() method was supplied.
@@ -871,14 +884,6 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.gen._randbelow_without_getrandbits(n, maxsize=maxsize)
self.assertEqual(random_mock.call_count, 2)
- def test_randrange_bug_1590891(self):
- start = 1000000000000
- stop = -100000000000000000000
- step = -200
- x = self.gen.randrange(start, stop, step)
- self.assertTrue(stop < x <= start)
- self.assertEqual((x+stop)%step, 0)
-
def test_choices_algorithms(self):
# The various ways of specifying weights should produce the same results
choices = self.gen.choices
@@ -962,6 +967,14 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(self.gen.randbytes(n),
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
+ @support.bigmemtest(size=2**29, memuse=1+16/15, dry_run=False)
+ def test_randbytes_256M(self, size):
+ self.gen.seed(2849427419)
+ x = self.gen.randbytes(size)
+ self.assertEqual(len(x), size)
+ self.assertEqual(x[:12].hex(), 'f6fd9ae63855ab91ea238b4f')
+ self.assertEqual(x[-12:].hex(), '0e7af69a84ee99bf4a11becc')
+
def test_sample_counts_equivalence(self):
# Test the documented strong equivalence to a sample with repeated elements.
# We run this test on random.Random() which makes deterministic selections
@@ -1411,30 +1424,31 @@ class TestModule(unittest.TestCase):
class CommandLineTest(unittest.TestCase):
+ @support.force_not_colorized
def test_parse_args(self):
args, help_text = random._parse_args(shlex.split("--choice a b c"))
self.assertEqual(args.choice, ["a", "b", "c"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("--integer 5"))
self.assertEqual(args.integer, 5)
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("--float 2.5"))
self.assertEqual(args.float, 2.5)
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("a b c"))
self.assertEqual(args.input, ["a", "b", "c"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("5"))
self.assertEqual(args.input, ["5"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
args, help_text = random._parse_args(shlex.split("2.5"))
self.assertEqual(args.input, ["2.5"])
- self.assertTrue(help_text.startswith("usage: "))
+ self.assertStartsWith(help_text, "usage: ")
def test_main(self):
for command, expected in [
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index cf8525ed901..e9128ac1d97 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -619,6 +619,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.fullmatch(r"a.*?b", "axxb").span(), (0, 4))
self.assertIsNone(re.fullmatch(r"a+", "ab"))
self.assertIsNone(re.fullmatch(r"abc$", "abc\n"))
+ self.assertIsNone(re.fullmatch(r"abc\z", "abc\n"))
self.assertIsNone(re.fullmatch(r"abc\Z", "abc\n"))
self.assertIsNone(re.fullmatch(r"(?m)abc$", "abc\n"))
self.assertEqual(re.fullmatch(r"ab(?=c)cd", "abcd").span(), (0, 4))
@@ -802,6 +803,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.search(r"\B(b.)\B",
"abc bcd bc abxd", re.ASCII).group(1), "bx")
self.assertEqual(re.search(r"^abc$", "\nabc\n", re.M).group(0), "abc")
+ self.assertEqual(re.search(r"^\Aabc\z$", "abc", re.M).group(0), "abc")
+ self.assertIsNone(re.search(r"^\Aabc\z$", "\nabc\n", re.M))
self.assertEqual(re.search(r"^\Aabc\Z$", "abc", re.M).group(0), "abc")
self.assertIsNone(re.search(r"^\Aabc\Z$", "\nabc\n", re.M))
self.assertEqual(re.search(br"\b(b.)\b",
@@ -813,6 +816,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.search(br"\B(b.)\B",
b"abc bcd bc abxd", re.LOCALE).group(1), b"bx")
self.assertEqual(re.search(br"^abc$", b"\nabc\n", re.M).group(0), b"abc")
+ self.assertEqual(re.search(br"^\Aabc\z$", b"abc", re.M).group(0), b"abc")
+ self.assertIsNone(re.search(br"^\Aabc\z$", b"\nabc\n", re.M))
self.assertEqual(re.search(br"^\Aabc\Z$", b"abc", re.M).group(0), b"abc")
self.assertIsNone(re.search(br"^\Aabc\Z$", b"\nabc\n", re.M))
self.assertEqual(re.search(r"\d\D\w\W\s\S",
@@ -836,7 +841,7 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.match(r"[\^a]+", 'a^').group(), 'a^')
self.assertIsNone(re.match(r"[\^a]+", 'b'))
re.purge() # for warnings
- for c in 'ceghijklmopqyzCEFGHIJKLMNOPQRTVXY':
+ for c in 'ceghijklmopqyCEFGHIJKLMNOPQRTVXY':
with self.subTest(c):
self.assertRaises(re.PatternError, re.compile, '\\%c' % c)
for c in 'ceghijklmopqyzABCEFGHIJKLMNOPQRTVXYZ':
@@ -2608,8 +2613,8 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.findall(r'(?>(?:ab){1,3})', 'ababc'), ['abab'])
def test_bug_gh91616(self):
- self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\Z', "a.txt")) # reproducer
- self.assertTrue(re.fullmatch(r'(?s:(?=(?P<g0>.*?\.))(?P=g0).*)\Z', "a.txt"))
+ self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\z', "a.txt")) # reproducer
+ self.assertTrue(re.fullmatch(r'(?s:(?=(?P<g0>.*?\.))(?P=g0).*)\z', "a.txt"))
def test_bug_gh100061(self):
# gh-100061
@@ -2863,11 +2868,11 @@ class PatternReprTests(unittest.TestCase):
pattern = 'Very %spattern' % ('long ' * 1000)
r = repr(re.compile(pattern))
self.assertLess(len(r), 300)
- self.assertEqual(r[:30], "re.compile('Very long long lon")
+ self.assertStartsWith(r, "re.compile('Very long long lon")
r = repr(re.compile(pattern, re.I))
self.assertLess(len(r), 300)
- self.assertEqual(r[:30], "re.compile('Very long long lon")
- self.assertEqual(r[-16:], ", re.IGNORECASE)")
+ self.assertStartsWith(r, "re.compile('Very long long lon")
+ self.assertEndsWith(r, ", re.IGNORECASE)")
def test_flags_repr(self):
self.assertEqual(repr(re.I), "re.IGNORECASE")
@@ -2946,7 +2951,7 @@ class ImplementationTest(unittest.TestCase):
self.assertEqual(mod.__name__, name)
self.assertEqual(mod.__package__, '')
for attr in deprecated[name]:
- self.assertTrue(hasattr(mod, attr))
+ self.assertHasAttr(mod, attr)
del sys.modules[name]
@cpython_only
diff --git a/Lib/test/test_readline.py b/Lib/test/test_readline.py
index b9d082b3597..45192fe5082 100644
--- a/Lib/test/test_readline.py
+++ b/Lib/test/test_readline.py
@@ -1,6 +1,7 @@
"""
Very minimal unittests for parts of the readline module.
"""
+import codecs
import locale
import os
import sys
@@ -231,6 +232,13 @@ print("History length:", readline.get_current_history_length())
# writing and reading non-ASCII bytes into/from a TTY works, but
# readline or ncurses ignores non-ASCII bytes on read.
self.skipTest(f"the LC_CTYPE locale is {loc!r}")
+ if sys.flags.utf8_mode:
+ encoding = locale.getencoding()
+ encoding = codecs.lookup(encoding).name # normalize the name
+ if encoding != "utf-8":
+ # gh-133711: The Python UTF-8 Mode ignores the LC_CTYPE locale
+ # and always use the UTF-8 encoding.
+ self.skipTest(f"the LC_CTYPE encoding is {encoding!r}")
try:
readline.add_history("\xEB\xEF")
diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py
index 7e317d5ab94..5bc3c5924b0 100644
--- a/Lib/test/test_regrtest.py
+++ b/Lib/test/test_regrtest.py
@@ -768,13 +768,16 @@ class BaseTestCase(unittest.TestCase):
self.fail(msg)
return proc
- def run_python(self, args, **kw):
+ def run_python(self, args, isolated=True, **kw):
extraargs = []
if 'uops' in sys._xoptions:
# Pass -X uops along
extraargs.extend(['-X', 'uops'])
- args = [sys.executable, *extraargs, '-X', 'faulthandler', '-I', *args]
- proc = self.run_command(args, **kw)
+ cmd = [sys.executable, *extraargs, '-X', 'faulthandler']
+ if isolated:
+ cmd.append('-I')
+ cmd.extend(args)
+ proc = self.run_command(cmd, **kw)
return proc.stdout
@@ -831,8 +834,8 @@ class ProgramsTestCase(BaseTestCase):
self.check_executed_tests(output, self.tests,
randomize=True, stats=len(self.tests))
- def run_tests(self, args, env=None):
- output = self.run_python(args, env=env)
+ def run_tests(self, args, env=None, isolated=True):
+ output = self.run_python(args, env=env, isolated=isolated)
self.check_output(output)
def test_script_regrtest(self):
@@ -874,7 +877,10 @@ class ProgramsTestCase(BaseTestCase):
self.run_tests(args)
def run_batch(self, *args):
- proc = self.run_command(args)
+ proc = self.run_command(args,
+ # gh-133711: cmd.exe uses the OEM code page
+ # to display the non-ASCII current directory
+ errors="backslashreplace")
self.check_output(proc.stdout)
@unittest.skipUnless(sysconfig.is_python_build(),
@@ -2064,7 +2070,7 @@ class ArgsTestCase(BaseTestCase):
self.check_executed_tests(output, [testname],
failed=[testname],
parallel=True,
- stats=TestStats(1, 1, 0))
+ stats=TestStats(1, 2, 1))
def _check_random_seed(self, run_workers: bool):
# gh-109276: When -r/--randomize is used, random.seed() is called
@@ -2273,7 +2279,6 @@ class ArgsTestCase(BaseTestCase):
def test_xml(self):
code = textwrap.dedent(r"""
import unittest
- from test import support
class VerboseTests(unittest.TestCase):
def test_failed(self):
@@ -2308,6 +2313,50 @@ class ArgsTestCase(BaseTestCase):
for out in testcase.iter('system-out'):
self.assertEqual(out.text, r"abc \x1b def")
+ def test_nonascii(self):
+ code = textwrap.dedent(r"""
+ import unittest
+
+ class NonASCIITests(unittest.TestCase):
+ def test_docstring(self):
+ '''docstring:\u20ac'''
+
+ def test_subtest(self):
+ with self.subTest(param='subtest:\u20ac'):
+ pass
+
+ def test_skip(self):
+ self.skipTest('skipped:\u20ac')
+ """)
+ testname = self.create_test(code=code)
+
+ env = dict(os.environ)
+ env['PYTHONIOENCODING'] = 'ascii'
+
+ def check(output):
+ self.check_executed_tests(output, testname, stats=TestStats(3, 0, 1))
+ self.assertIn(r'docstring:\u20ac', output)
+ self.assertIn(r'skipped:\u20ac', output)
+
+ # Run sequentially
+ output = self.run_tests('-v', testname, env=env, isolated=False)
+ check(output)
+
+ # Run in parallel
+ output = self.run_tests('-j1', '-v', testname, env=env, isolated=False)
+ check(output)
+
+ def test_pgo_exclude(self):
+ # Get PGO tests
+ output = self.run_tests('--pgo', '--list-tests')
+ pgo_tests = output.strip().split()
+
+ # Exclude test_re
+ output = self.run_tests('--pgo', '--list-tests', '-x', 'test_re')
+ tests = output.strip().split()
+ self.assertNotIn('test_re', tests)
+ self.assertEqual(len(tests), len(pgo_tests) - 1)
+
class TestUtils(unittest.TestCase):
def test_format_duration(self):
diff --git a/Lib/test/test_remote_pdb.py b/Lib/test/test_remote_pdb.py
index e4c44c78d4a..a1c50af15f3 100644
--- a/Lib/test/test_remote_pdb.py
+++ b/Lib/test/test_remote_pdb.py
@@ -1,22 +1,19 @@
import io
-import time
import itertools
import json
import os
+import re
import signal
import socket
import subprocess
import sys
-import tempfile
import textwrap
-import threading
import unittest
import unittest.mock
-from contextlib import contextmanager, redirect_stdout, ExitStack
-from pathlib import Path
-from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT
-from test.support.os_helper import temp_dir, TESTFN, unlink
-from typing import Dict, List, Optional, Tuple, Union, Any
+from contextlib import closing, contextmanager, redirect_stdout, redirect_stderr, ExitStack
+from test.support import is_wasi, cpython_only, force_color, requires_subprocess, SHORT_TIMEOUT
+from test.support.os_helper import TESTFN, unlink
+from typing import List
import pdb
from pdb import _PdbServer, _PdbClient
@@ -79,44 +76,6 @@ class MockSocketFile:
return results
-class MockDebuggerSocket:
- """Mock file-like simulating a connection to a _RemotePdb instance"""
-
- def __init__(self, incoming):
- self.incoming = iter(incoming)
- self.outgoing = []
- self.buffered = bytearray()
-
- def write(self, data: bytes) -> None:
- """Simulate write to socket."""
- self.buffered += data
-
- def flush(self) -> None:
- """Ensure each line is valid JSON."""
- lines = self.buffered.splitlines(keepends=True)
- self.buffered.clear()
- for line in lines:
- assert line.endswith(b"\n")
- self.outgoing.append(json.loads(line))
-
- def readline(self) -> bytes:
- """Read a line from the prepared input queue."""
- # Anything written must be flushed before trying to read,
- # since the read will be dependent upon the last write.
- assert not self.buffered
- try:
- item = next(self.incoming)
- if not isinstance(item, bytes):
- item = json.dumps(item).encode()
- return item + b"\n"
- except StopIteration:
- return b""
-
- def close(self) -> None:
- """No-op close implementation."""
- pass
-
-
class PdbClientTestCase(unittest.TestCase):
"""Tests for the _PdbClient class."""
@@ -124,8 +83,11 @@ class PdbClientTestCase(unittest.TestCase):
self,
*,
incoming,
- simulate_failure=None,
+ simulate_send_failure=False,
+ simulate_sigint_during_stdout_write=False,
+ use_interrupt_socket=False,
expected_outgoing=None,
+ expected_outgoing_signals=None,
expected_completions=None,
expected_exception=None,
expected_stdout="",
@@ -134,6 +96,8 @@ class PdbClientTestCase(unittest.TestCase):
):
if expected_outgoing is None:
expected_outgoing = []
+ if expected_outgoing_signals is None:
+ expected_outgoing_signals = []
if expected_completions is None:
expected_completions = []
if expected_state is None:
@@ -142,16 +106,6 @@ class PdbClientTestCase(unittest.TestCase):
expected_state.setdefault("write_failed", False)
messages = [m for source, m in incoming if source == "server"]
prompts = [m["prompt"] for source, m in incoming if source == "user"]
- sockfile = MockDebuggerSocket(messages)
- stdout = io.StringIO()
-
- if simulate_failure:
- sockfile.write = unittest.mock.Mock()
- sockfile.flush = unittest.mock.Mock()
- if simulate_failure == "write":
- sockfile.write.side_effect = OSError("write failed")
- elif simulate_failure == "flush":
- sockfile.flush.side_effect = OSError("flush failed")
input_iter = (m for source, m in incoming if source == "user")
completions = []
@@ -178,18 +132,60 @@ class PdbClientTestCase(unittest.TestCase):
reply = message["input"]
if isinstance(reply, BaseException):
raise reply
- return reply
+ if isinstance(reply, str):
+ return reply
+ return reply()
with ExitStack() as stack:
+ client_sock, server_sock = socket.socketpair()
+ stack.enter_context(closing(client_sock))
+ stack.enter_context(closing(server_sock))
+
+ server_sock = unittest.mock.Mock(wraps=server_sock)
+
+ client_sock.sendall(
+ b"".join(
+ (m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n"
+ for m in messages
+ )
+ )
+ client_sock.shutdown(socket.SHUT_WR)
+
+ if simulate_send_failure:
+ server_sock.sendall = unittest.mock.Mock(
+ side_effect=OSError("sendall failed")
+ )
+ client_sock.shutdown(socket.SHUT_RD)
+
+ stdout = io.StringIO()
+
+ if simulate_sigint_during_stdout_write:
+ orig_stdout_write = stdout.write
+
+ def sigint_stdout_write(s):
+ signal.raise_signal(signal.SIGINT)
+ return orig_stdout_write(s)
+
+ stdout.write = sigint_stdout_write
+
input_mock = stack.enter_context(
unittest.mock.patch("pdb.input", side_effect=mock_input)
)
stack.enter_context(redirect_stdout(stdout))
+ if use_interrupt_socket:
+ interrupt_sock = unittest.mock.Mock(spec=socket.socket)
+ mock_kill = None
+ else:
+ interrupt_sock = None
+ mock_kill = stack.enter_context(
+ unittest.mock.patch("os.kill", spec=os.kill)
+ )
+
client = _PdbClient(
- pid=0,
- sockfile=sockfile,
- interrupt_script="/a/b.py",
+ pid=12345,
+ server_socket=server_sock,
+ interrupt_sock=interrupt_sock,
)
if expected_exception is not None:
@@ -199,13 +195,12 @@ class PdbClientTestCase(unittest.TestCase):
client.cmdloop()
- actual_outgoing = sockfile.outgoing
- if simulate_failure:
- actual_outgoing += [
- json.loads(msg.args[0]) for msg in sockfile.write.mock_calls
- ]
+ sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls]
+ for msg in sent_msgs:
+ assert msg.endswith(b"\n")
+ actual_outgoing = [json.loads(msg) for msg in sent_msgs]
- self.assertEqual(sockfile.outgoing, expected_outgoing)
+ self.assertEqual(actual_outgoing, expected_outgoing)
self.assertEqual(completions, expected_completions)
if expected_stdout_substring and not expected_stdout:
self.assertIn(expected_stdout_substring, stdout.getvalue())
@@ -215,6 +210,20 @@ class PdbClientTestCase(unittest.TestCase):
actual_state = {k: getattr(client, k) for k in expected_state}
self.assertEqual(actual_state, expected_state)
+ if use_interrupt_socket:
+ outgoing_signals = [
+ signal.Signals(int.from_bytes(call.args[0]))
+ for call in interrupt_sock.sendall.call_args_list
+ ]
+ else:
+ assert mock_kill is not None
+ outgoing_signals = []
+ for call in mock_kill.call_args_list:
+ pid, signum = call.args
+ self.assertEqual(pid, 12345)
+ outgoing_signals.append(signal.Signals(signum))
+ self.assertEqual(outgoing_signals, expected_outgoing_signals)
+
def test_remote_immediately_closing_the_connection(self):
"""Test the behavior when the remote closes the connection immediately."""
incoming = []
@@ -409,11 +418,38 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "dumb"},
)
- def test_keyboard_interrupt_at_prompt(self):
- """Test signaling when a prompt gets a KeyboardInterrupt."""
+ def test_sigint_at_prompt(self):
+ """Test signaling when a prompt gets interrupted."""
incoming = [
("server", {"prompt": "(Pdb) ", "state": "pdb"}),
- ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
+ (
+ "user",
+ {
+ "prompt": "(Pdb) ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"signal": "INT"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_sigint_at_continuation_prompt(self):
+ """Test signaling when a continuation prompt gets interrupted."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "if True:"}),
+ (
+ "user",
+ {
+ "prompt": "... ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
]
self.do_test(
incoming=incoming,
@@ -423,6 +459,22 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "pdb"},
)
+ def test_sigint_when_writing(self):
+ """Test siginaling when sys.stdout.write() gets interrupted."""
+ incoming = [
+ ("server", {"message": "Some message or other\n", "type": "info"}),
+ ]
+ for use_interrupt_socket in [False, True]:
+ with self.subTest(use_interrupt_socket=use_interrupt_socket):
+ self.do_test(
+ incoming=incoming,
+ simulate_sigint_during_stdout_write=True,
+ use_interrupt_socket=use_interrupt_socket,
+ expected_outgoing=[],
+ expected_outgoing_signals=[signal.SIGINT],
+ expected_stdout="Some message or other\n",
+ )
+
def test_eof_at_prompt(self):
"""Test signaling when a prompt gets an EOFError."""
incoming = [
@@ -478,20 +530,7 @@ class PdbClientTestCase(unittest.TestCase):
self.do_test(
incoming=incoming,
expected_outgoing=[{"signal": "INT"}],
- simulate_failure="write",
- expected_state={"write_failed": True},
- )
-
- def test_flush_failing(self):
- """Test terminating if flush fails due to a half closed socket."""
- incoming = [
- ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
- ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
- ]
- self.do_test(
- incoming=incoming,
- expected_outgoing=[{"signal": "INT"}],
- simulate_failure="flush",
+ simulate_send_failure=True,
expected_state={"write_failed": True},
)
@@ -531,6 +570,44 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "pdb"},
)
+ def test_multiline_completion_in_pdb_state(self):
+ """Test requesting tab completions at a (Pdb) continuation prompt."""
+ # GIVEN
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "if True:"}),
+ (
+ "user",
+ {
+ "prompt": "... ",
+ "completion_request": {
+ "line": " b",
+ "begidx": 4,
+ "endidx": 5,
+ },
+ "input": " bool()",
+ },
+ ),
+ ("server", {"completions": ["bin", "bool", "bytes"]}),
+ ("user", {"prompt": "... ", "input": ""}),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {
+ "complete": {
+ "text": "b",
+ "line": "! b",
+ "begidx": 2,
+ "endidx": 3,
+ }
+ },
+ {"reply": "if True:\n bool()\n"},
+ ],
+ expected_completions=["bin", "bool", "bytes"],
+ expected_state={"state": "pdb"},
+ )
+
def test_completion_in_interact_state(self):
"""Test requesting tab completions at a >>> prompt."""
incoming = [
@@ -622,42 +699,7 @@ class PdbClientTestCase(unittest.TestCase):
},
{"reply": "xyz"},
],
- simulate_failure="write",
- expected_completions=[],
- expected_state={"state": "interact", "write_failed": True},
- )
-
- def test_flush_failure_during_completion(self):
- """Test failing to flush to the socket to request tab completions."""
- incoming = [
- ("server", {"prompt": ">>> ", "state": "interact"}),
- (
- "user",
- {
- "prompt": ">>> ",
- "completion_request": {
- "line": "xy",
- "begidx": 0,
- "endidx": 2,
- },
- "input": "xyz",
- },
- ),
- ]
- self.do_test(
- incoming=incoming,
- expected_outgoing=[
- {
- "complete": {
- "text": "xy",
- "line": "xy",
- "begidx": 0,
- "endidx": 2,
- }
- },
- {"reply": "xyz"},
- ],
- simulate_failure="flush",
+ simulate_send_failure=True,
expected_completions=[],
expected_state={"state": "interact", "write_failed": True},
)
@@ -994,6 +1036,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=False,
+ colorize=False,
)
return x # This line won't be reached in debugging
@@ -1051,23 +1095,6 @@ class PdbConnectTestCase(unittest.TestCase):
client_file.write(json.dumps({"reply": command}).encode() + b"\n")
client_file.flush()
- def _send_interrupt(self, pid):
- """Helper to send an interrupt signal to the debugger."""
- # with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
- interrupt_script = TESTFN + "_interrupt_script.py"
- with open(interrupt_script, 'w') as f:
- f.write(
- 'import pdb, sys\n'
- 'print("Hello, world!")\n'
- 'if inst := pdb.Pdb._last_pdb_instance:\n'
- ' inst.set_trace(sys._getframe(1))\n'
- )
- self.addCleanup(unlink, interrupt_script)
- try:
- sys.remote_exec(pid, interrupt_script)
- except PermissionError:
- self.skipTest("Insufficient permissions to execute code in remote process")
-
def test_connect_and_basic_commands(self):
"""Test connecting to a remote debugger and sending basic commands."""
self._create_script()
@@ -1180,6 +1207,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=True,
+ colorize=False,
)
print("Connected to debugger")
iterations = 50
@@ -1195,6 +1224,10 @@ class PdbConnectTestCase(unittest.TestCase):
self._create_script(script=script)
process, client_file = self._connect_and_get_client_file()
+ # Accept a 2nd connection from the subprocess to tell it about signals
+ signal_sock, _ = self.server_sock.accept()
+ self.addCleanup(signal_sock.close)
+
with kill_on_error(process):
# Skip initial messages until we get to the prompt
self._read_until_prompt(client_file)
@@ -1210,7 +1243,7 @@ class PdbConnectTestCase(unittest.TestCase):
break
# Inject a script to interrupt the running process
- self._send_interrupt(process.pid)
+ signal_sock.sendall(signal.SIGINT.to_bytes())
messages = self._read_until_prompt(client_file)
# Verify we got the keyboard interrupt message.
@@ -1266,6 +1299,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=fake_version,
+ signal_raising_thread=False,
+ colorize=False,
)
# This should print if the debugger detaches correctly
@@ -1393,5 +1428,151 @@ class PdbConnectTestCase(unittest.TestCase):
self.assertIn("Function returned: 42", stdout)
self.assertEqual(process.returncode, 0)
+
+def _supports_remote_attaching():
+ PROCESS_VM_READV_SUPPORTED = False
+
+ try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ except ImportError:
+ pass
+
+ return PROCESS_VM_READV_SUPPORTED
+
+
+@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled")
+@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32",
+ "Test only runs on Linux, Windows and MacOS")
+@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(),
+ "Testing on Linux requires process_vm_readv support")
+@cpython_only
+@requires_subprocess()
+class PdbAttachTestCase(unittest.TestCase):
+ def setUp(self):
+ # Create a server socket that will wait for the debugger to connect
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock.bind(('127.0.0.1', 0)) # Let OS assign port
+ self.sock.listen(1)
+ self.port = self.sock.getsockname()[1]
+ self._create_script()
+
+ def _create_script(self, script=None):
+ # Create a file for subprocess script
+ script = textwrap.dedent(
+ f"""
+ import socket
+ import time
+
+ def foo():
+ return bar()
+
+ def bar():
+ return baz()
+
+ def baz():
+ x = 1
+ # Trigger attach
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(('127.0.0.1', {self.port}))
+ sock.close()
+ count = 0
+ while x == 1 and count < 100:
+ count += 1
+ time.sleep(0.1)
+ return x
+
+ result = foo()
+ print(f"Function returned: {{result}}")
+ """
+ )
+
+ self.script_path = TESTFN + "_connect_test.py"
+ with open(self.script_path, 'w') as f:
+ f.write(script)
+
+ def tearDown(self):
+ self.sock.close()
+ try:
+ unlink(self.script_path)
+ except OSError:
+ pass
+
+ def do_integration_test(self, client_stdin):
+ process = subprocess.Popen(
+ [sys.executable, self.script_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+ self.addCleanup(process.stdout.close)
+ self.addCleanup(process.stderr.close)
+
+ # Wait for the process to reach our attachment point
+ self.sock.settimeout(10)
+ conn, _ = self.sock.accept()
+ conn.close()
+
+ client_stdin = io.StringIO(client_stdin)
+ client_stdout = io.StringIO()
+ client_stderr = io.StringIO()
+
+ self.addCleanup(client_stdin.close)
+ self.addCleanup(client_stdout.close)
+ self.addCleanup(client_stderr.close)
+ self.addCleanup(process.wait)
+
+ with (
+ unittest.mock.patch("sys.stdin", client_stdin),
+ redirect_stdout(client_stdout),
+ redirect_stderr(client_stderr),
+ unittest.mock.patch("sys.argv", ["pdb", "-p", str(process.pid)]),
+ ):
+ try:
+ pdb.main()
+ except PermissionError:
+ self.skipTest("Insufficient permissions for remote execution")
+
+ process.wait()
+ server_stdout = process.stdout.read()
+ server_stderr = process.stderr.read()
+
+ if process.returncode != 0:
+ print("server failed")
+ print(f"server stdout:\n{server_stdout}")
+ print(f"server stderr:\n{server_stderr}")
+
+ self.assertEqual(process.returncode, 0)
+ return {
+ "client": {
+ "stdout": client_stdout.getvalue(),
+ "stderr": client_stderr.getvalue(),
+ },
+ "server": {
+ "stdout": server_stdout,
+ "stderr": server_stderr,
+ },
+ }
+
+ def test_attach_to_process_without_colors(self):
+ with force_color(False):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("while x == 1", output["client"]["stdout"])
+ self.assertNotIn("\x1b", output["client"]["stdout"])
+
+ def test_attach_to_process_with_colors(self):
+ with force_color(True):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("\x1b", output["client"]["stdout"])
+ self.assertNotIn("while x == 1", output["client"]["stdout"])
+ self.assertIn("while x == 1", re.sub("\x1b[^m]*m", "", output["client"]["stdout"]))
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_repl.py b/Lib/test/test_repl.py
index 228b326699e..f4a4634fc62 100644
--- a/Lib/test/test_repl.py
+++ b/Lib/test/test_repl.py
@@ -38,8 +38,8 @@ def spawn_repl(*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kw):
# line option '-i' and the process name set to '<stdin>'.
# The directory of argv[0] must match the directory of the Python
# executable for the Popen() call to python to succeed as the directory
- # path may be used by Py_GetPath() to build the default module search
- # path.
+ # path may be used by PyConfig_Get("module_search_paths") to build the
+ # default module search path.
stdin_fname = os.path.join(os.path.dirname(sys.executable), "<stdin>")
cmd_line = [stdin_fname, '-I', '-i']
cmd_line.extend(args)
@@ -197,7 +197,7 @@ class TestInteractiveInterpreter(unittest.TestCase):
expected_lines = [
' def f(x, x): ...',
' ^',
- "SyntaxError: duplicate argument 'x' in function definition"
+ "SyntaxError: duplicate parameter 'x' in function definition"
]
self.assertEqual(output.splitlines()[4:-1], expected_lines)
diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py
index ffeb1fba7b8..22a55b57c07 100644
--- a/Lib/test/test_reprlib.py
+++ b/Lib/test/test_reprlib.py
@@ -3,6 +3,7 @@
Nick Mathewson
"""
+import annotationlib
import sys
import os
import shutil
@@ -11,7 +12,7 @@ import importlib.util
import unittest
import textwrap
-from test.support import verbose
+from test.support import verbose, EqualToForwardRef
from test.support.os_helper import create_empty_file
from reprlib import repr as r # Don't shadow builtin repr
from reprlib import Repr
@@ -150,14 +151,38 @@ class ReprTests(unittest.TestCase):
eq(r(frozenset({1, 2, 3, 4, 5, 6, 7})), "frozenset({1, 2, 3, 4, 5, 6, ...})")
def test_numbers(self):
- eq = self.assertEqual
- eq(r(123), repr(123))
- eq(r(123), repr(123))
- eq(r(1.0/3), repr(1.0/3))
-
- n = 10**100
- expected = repr(n)[:18] + "..." + repr(n)[-19:]
- eq(r(n), expected)
+ for x in [123, 1.0 / 3]:
+ self.assertEqual(r(x), repr(x))
+
+ max_digits = sys.get_int_max_str_digits()
+ for k in [100, max_digits - 1]:
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ expected = repr(n)[:18] + "..." + repr(n)[-19:]
+ self.assertEqual(r(n), expected)
+
+ def re_msg(n, d):
+ return (rf'<{n.__class__.__name__} instance with roughly {d} '
+ rf'digits \(limit at {max_digits}\) at 0x[a-f0-9]+>')
+
+ k = max_digits
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
+
+ for k in [max_digits + 1, 2 * max_digits]:
+ self.assertGreater(k, 100)
+ with self.subTest(f'10 ** {k}', k=k):
+ n = 10 ** k
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
+ with self.subTest(f'10 ** {k} - 1', k=k):
+ n = 10 ** k - 1
+ # Here, since math.log10(n) == math.log10(n-1),
+ # the number of digits of n - 1 is overestimated.
+ self.assertRaises(ValueError, repr, n)
+ self.assertRegex(r(n), re_msg(n, k + 1))
def test_instance(self):
eq = self.assertEqual
@@ -172,13 +197,13 @@ class ReprTests(unittest.TestCase):
eq(r(i3), ("<ClassWithFailingRepr instance at %#x>"%id(i3)))
s = r(ClassWithFailingRepr)
- self.assertTrue(s.startswith("<class "))
- self.assertTrue(s.endswith(">"))
+ self.assertStartsWith(s, "<class ")
+ self.assertEndsWith(s, ">")
self.assertIn(s.find("..."), [12, 13])
def test_lambda(self):
r = repr(lambda x: x)
- self.assertTrue(r.startswith("<function ReprTests.test_lambda.<locals>.<lambda"), r)
+ self.assertStartsWith(r, "<function ReprTests.test_lambda.<locals>.<lambda")
# XXX anonymous functions? see func_repr
def test_builtin_function(self):
@@ -186,8 +211,8 @@ class ReprTests(unittest.TestCase):
# Functions
eq(repr(hash), '<built-in function hash>')
# Methods
- self.assertTrue(repr(''.split).startswith(
- '<built-in method split of str object at 0x'))
+ self.assertStartsWith(repr(''.split),
+ '<built-in method split of str object at 0x')
def test_range(self):
eq = self.assertEqual
@@ -372,20 +397,20 @@ class ReprTests(unittest.TestCase):
'object': {
1: 'two',
b'three': [
- (4.5, 6.7),
+ (4.5, 6.25),
[set((8, 9)), frozenset((10, 11))],
],
},
'tests': (
(dict(indent=None), '''\
- {1: 'two', b'three': [(4.5, 6.7), [{8, 9}, frozenset({10, 11})]]}'''),
+ {1: 'two', b'three': [(4.5, 6.25), [{8, 9}, frozenset({10, 11})]]}'''),
(dict(indent=False), '''\
{
1: 'two',
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -405,7 +430,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -425,7 +450,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -445,7 +470,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -465,7 +490,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -493,7 +518,7 @@ class ReprTests(unittest.TestCase):
b'three': [
(
4.5,
- 6.7,
+ 6.25,
),
[
{
@@ -513,7 +538,7 @@ class ReprTests(unittest.TestCase):
-->b'three': [
-->-->(
-->-->-->4.5,
- -->-->-->6.7,
+ -->-->-->6.25,
-->-->),
-->-->[
-->-->-->{
@@ -533,7 +558,7 @@ class ReprTests(unittest.TestCase):
....b'three': [
........(
............4.5,
- ............6.7,
+ ............6.25,
........),
........[
............{
@@ -729,8 +754,8 @@ class baz:
importlib.invalidate_caches()
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import baz
ibaz = baz.baz()
- self.assertTrue(repr(ibaz).startswith(
- "<%s.baz object at 0x" % baz.__name__))
+ self.assertStartsWith(repr(ibaz),
+ "<%s.baz object at 0x" % baz.__name__)
def test_method(self):
self._check_path_limitations('qux')
@@ -743,13 +768,13 @@ class aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
from areallylongpackageandmodulenametotestreprtruncation.areallylongpackageandmodulenametotestreprtruncation import qux
# Unbound methods first
r = repr(qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod)
- self.assertTrue(r.startswith('<function aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod'), r)
+ self.assertStartsWith(r, '<function aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod')
# Bound method next
iqux = qux.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa()
r = repr(iqux.amethod)
- self.assertTrue(r.startswith(
+ self.assertStartsWith(r,
'<bound method aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.amethod of <%s.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa object at 0x' \
- % (qux.__name__,) ), r)
+ % (qux.__name__,) )
@unittest.skip('needs a built-in function with a really long name')
def test_builtin_function(self):
@@ -829,5 +854,19 @@ class TestRecursiveRepr(unittest.TestCase):
self.assertEqual(type_params[0].__name__, 'T')
self.assertEqual(type_params[0].__bound__, str)
+ def test_annotations(self):
+ class My:
+ @recursive_repr()
+ def __repr__(self, default: undefined = ...):
+ return default
+
+ annotations = annotationlib.get_annotations(
+ My.__repr__, format=annotationlib.Format.FORWARDREF
+ )
+ self.assertEqual(
+ annotations,
+ {'default': EqualToForwardRef("undefined", owner=My.__repr__)}
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_rlcompleter.py b/Lib/test/test_rlcompleter.py
index 1cff6a218f8..a8914953ce9 100644
--- a/Lib/test/test_rlcompleter.py
+++ b/Lib/test/test_rlcompleter.py
@@ -54,11 +54,26 @@ class TestRlcompleter(unittest.TestCase):
['str.{}('.format(x) for x in dir(str)
if x.startswith('s')])
self.assertEqual(self.stdcompleter.attr_matches('tuple.foospamegg'), [])
- expected = sorted({'None.%s%s' % (x,
- '()' if x in ('__init_subclass__', '__class__')
- else '' if x == '__doc__'
- else '(')
- for x in dir(None)})
+
+ def create_expected_for_none():
+ if not MISSING_C_DOCSTRINGS:
+ parentheses = ('__init_subclass__', '__class__')
+ else:
+ # When `--without-doc-strings` is used, `__class__`
+ # won't have a known signature.
+ parentheses = ('__init_subclass__',)
+
+ items = set()
+ for x in dir(None):
+ if x in parentheses:
+ items.add(f'None.{x}()')
+ elif x == '__doc__':
+ items.add(f'None.{x}')
+ else:
+ items.add(f'None.{x}(')
+ return sorted(items)
+
+ expected = create_expected_for_none()
self.assertEqual(self.stdcompleter.attr_matches('None.'), expected)
self.assertEqual(self.stdcompleter.attr_matches('None._'), expected)
self.assertEqual(self.stdcompleter.attr_matches('None.__'), expected)
@@ -73,7 +88,7 @@ class TestRlcompleter(unittest.TestCase):
['CompleteMe._ham'])
matches = self.completer.attr_matches('CompleteMe.__')
for x in matches:
- self.assertTrue(x.startswith('CompleteMe.__'), x)
+ self.assertStartsWith(x, 'CompleteMe.__')
self.assertIn('CompleteMe.__name__', matches)
self.assertIn('CompleteMe.__new__(', matches)
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
index ada78ec8e6b..a2a07c04f58 100644
--- a/Lib/test/test_runpy.py
+++ b/Lib/test/test_runpy.py
@@ -796,7 +796,7 @@ class TestExit(unittest.TestCase):
# Use -E to ignore PYTHONSAFEPATH
cmd = [sys.executable, '-E', *cmd]
proc = subprocess.run(cmd, *args, **kwargs, text=True, stderr=subprocess.PIPE)
- self.assertTrue(proc.stderr.endswith("\nKeyboardInterrupt\n"), proc.stderr)
+ self.assertEndsWith(proc.stderr, "\nKeyboardInterrupt\n")
self.assertEqual(proc.returncode, self.EXPECTED_CODE)
def test_pymain_run_file(self):
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index 24a366efc6c..520fbc1b662 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -778,7 +778,7 @@ class ScopeTests(unittest.TestCase):
class X:
locals()["x"] = 43
del x
- self.assertFalse(hasattr(X, "x"))
+ self.assertNotHasAttr(X, "x")
self.assertEqual(x, 42)
@cpython_only
diff --git a/Lib/test/test_script_helper.py b/Lib/test/test_script_helper.py
index f7871fd3b77..eeea6c4842b 100644
--- a/Lib/test/test_script_helper.py
+++ b/Lib/test/test_script_helper.py
@@ -74,8 +74,7 @@ class TestScriptHelperEnvironment(unittest.TestCase):
"""Code coverage for interpreter_requires_environment()."""
def setUp(self):
- self.assertTrue(
- hasattr(script_helper, '__cached_interp_requires_environment'))
+ self.assertHasAttr(script_helper, '__cached_interp_requires_environment')
# Reset the private cached state.
script_helper.__dict__['__cached_interp_requires_environment'] = None
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index c01e323553d..c0df9507bd7 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -237,7 +237,7 @@ class TestJointOps:
if type(self.s) not in (set, frozenset):
self.assertEqual(self.s.x, dup.x)
self.assertEqual(self.s.z, dup.z)
- self.assertFalse(hasattr(self.s, 'y'))
+ self.assertNotHasAttr(self.s, 'y')
del self.s.x, self.s.z
def test_iterator_pickling(self):
@@ -876,8 +876,8 @@ class TestBasicOps:
def check_repr_against_values(self):
text = repr(self.set)
- self.assertTrue(text.startswith('{'))
- self.assertTrue(text.endswith('}'))
+ self.assertStartsWith(text, '{')
+ self.assertEndsWith(text, '}')
result = text[1:-1].split(', ')
result.sort()
diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py
index f35571ea886..a13ddcb76b7 100644
--- a/Lib/test/test_shlex.py
+++ b/Lib/test/test_shlex.py
@@ -3,6 +3,7 @@ import itertools
import shlex
import string
import unittest
+from test.support import cpython_only
from test.support import import_helper
@@ -364,6 +365,7 @@ class ShlexTest(unittest.TestCase):
with self.assertRaises(AttributeError):
shlex_instance.punctuation_chars = False
+ @cpython_only
def test_lazy_imports(self):
import_helper.ensure_lazy_imports('shlex', {'collections', 're', 'os'})
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index ed01163074a..ebb6cf88336 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -427,12 +427,12 @@ class TestRmTree(BaseTest, unittest.TestCase):
else:
self.assertIs(func, os.listdir)
self.assertIn(arg, [TESTFN, self.child_dir_path])
- self.assertTrue(issubclass(exc[0], OSError))
+ self.assertIsSubclass(exc[0], OSError)
self.errorState += 1
else:
self.assertEqual(func, os.rmdir)
self.assertEqual(arg, TESTFN)
- self.assertTrue(issubclass(exc[0], OSError))
+ self.assertIsSubclass(exc[0], OSError)
self.errorState = 3
@unittest.skipIf(sys.platform[:6] == 'cygwin',
@@ -2153,6 +2153,10 @@ class TestArchives(BaseTest, unittest.TestCase):
def test_unpack_archive_bztar(self):
self.check_unpack_tarball('bztar')
+ @support.requires_zstd()
+ def test_unpack_archive_zstdtar(self):
+ self.check_unpack_tarball('zstdtar')
+
@support.requires_lzma()
@unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger")
def test_unpack_archive_xztar(self):
@@ -3475,7 +3479,7 @@ class PublicAPITests(unittest.TestCase):
"""Ensures that the correct values are exposed in the public API."""
def test_module_all_attribute(self):
- self.assertTrue(hasattr(shutil, '__all__'))
+ self.assertHasAttr(shutil, '__all__')
target_api = ['copyfileobj', 'copyfile', 'copymode', 'copystat',
'copy', 'copy2', 'copytree', 'move', 'rmtree', 'Error',
'SpecialFileError', 'make_archive',
@@ -3488,7 +3492,7 @@ class PublicAPITests(unittest.TestCase):
target_api.append('disk_usage')
self.assertEqual(set(shutil.__all__), set(target_api))
with self.assertWarns(DeprecationWarning):
- from shutil import ExecError
+ from shutil import ExecError # noqa: F401
if __name__ == '__main__':
diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py
index a7e9241f44d..d0e32942635 100644
--- a/Lib/test/test_site.py
+++ b/Lib/test/test_site.py
@@ -307,8 +307,7 @@ class HelperFunctionsTests(unittest.TestCase):
with EnvironmentVarGuard() as environ:
environ['PYTHONUSERBASE'] = 'xoxo'
- self.assertTrue(site.getuserbase().startswith('xoxo'),
- site.getuserbase())
+ self.assertStartsWith(site.getuserbase(), 'xoxo')
@unittest.skipUnless(HAS_USER_SITE, 'need user site')
def test_getusersitepackages(self):
@@ -318,7 +317,7 @@ class HelperFunctionsTests(unittest.TestCase):
# the call sets USER_BASE *and* USER_SITE
self.assertEqual(site.USER_SITE, user_site)
- self.assertTrue(user_site.startswith(site.USER_BASE), user_site)
+ self.assertStartsWith(user_site, site.USER_BASE)
self.assertEqual(site.USER_BASE, site.getuserbase())
def test_getsitepackages(self):
@@ -359,11 +358,10 @@ class HelperFunctionsTests(unittest.TestCase):
environ.unset('PYTHONUSERBASE', 'APPDATA')
user_base = site.getuserbase()
- self.assertTrue(user_base.startswith('~' + os.sep),
- user_base)
+ self.assertStartsWith(user_base, '~' + os.sep)
user_site = site.getusersitepackages()
- self.assertTrue(user_site.startswith(user_base), user_site)
+ self.assertStartsWith(user_site, user_base)
with mock.patch('os.path.isdir', return_value=False) as mock_isdir, \
mock.patch.object(site, 'addsitedir') as mock_addsitedir, \
@@ -495,18 +493,18 @@ class ImportSideEffectTests(unittest.TestCase):
def test_setting_quit(self):
# 'quit' and 'exit' should be injected into builtins
- self.assertTrue(hasattr(builtins, "quit"))
- self.assertTrue(hasattr(builtins, "exit"))
+ self.assertHasAttr(builtins, "quit")
+ self.assertHasAttr(builtins, "exit")
def test_setting_copyright(self):
# 'copyright', 'credits', and 'license' should be in builtins
- self.assertTrue(hasattr(builtins, "copyright"))
- self.assertTrue(hasattr(builtins, "credits"))
- self.assertTrue(hasattr(builtins, "license"))
+ self.assertHasAttr(builtins, "copyright")
+ self.assertHasAttr(builtins, "credits")
+ self.assertHasAttr(builtins, "license")
def test_setting_help(self):
# 'help' should be set in builtins
- self.assertTrue(hasattr(builtins, "help"))
+ self.assertHasAttr(builtins, "help")
def test_sitecustomize_executed(self):
# If sitecustomize is available, it should have been imported.
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index ace97ce0cbe..3dd67b2a2ab 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,8 +2,9 @@ import unittest
from unittest import mock
from test import support
from test.support import (
- is_apple, os_helper, refleak_helper, socket_helper, threading_helper
+ cpython_only, is_apple, os_helper, refleak_helper, socket_helper, threading_helper
)
+from test.support.import_helper import ensure_lazy_imports
import _thread as thread
import array
import contextlib
@@ -257,6 +258,12 @@ HAVE_SOCKET_HYPERV = _have_socket_hyperv()
# Size in bytes of the int type
SIZEOF_INT = array.array("i").itemsize
+class TestLazyImport(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("socket", {"array", "selectors"})
+
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
@@ -1078,9 +1085,7 @@ class GeneralModuleTests(unittest.TestCase):
'IPV6_USE_MIN_MTU',
}
for opt in opts:
- self.assertTrue(
- hasattr(socket, opt), f"Missing RFC3542 socket option '{opt}'"
- )
+ self.assertHasAttr(socket, opt)
def testHostnameRes(self):
# Testing hostname resolution mechanisms
@@ -1586,11 +1591,11 @@ class GeneralModuleTests(unittest.TestCase):
@unittest.skipUnless(os.name == "nt", "Windows specific")
def test_sock_ioctl(self):
- self.assertTrue(hasattr(socket.socket, 'ioctl'))
- self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
- self.assertTrue(hasattr(socket, 'RCVALL_ON'))
- self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
- self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
+ self.assertHasAttr(socket.socket, 'ioctl')
+ self.assertHasAttr(socket, 'SIO_RCVALL')
+ self.assertHasAttr(socket, 'RCVALL_ON')
+ self.assertHasAttr(socket, 'RCVALL_OFF')
+ self.assertHasAttr(socket, 'SIO_KEEPALIVE_VALS')
s = socket.socket()
self.addCleanup(s.close)
self.assertRaises(ValueError, s.ioctl, -1, None)
@@ -6075,10 +6080,10 @@ class UDPLITETimeoutTest(SocketUDPLITETest):
class TestExceptions(unittest.TestCase):
def testExceptionTree(self):
- self.assertTrue(issubclass(OSError, Exception))
- self.assertTrue(issubclass(socket.herror, OSError))
- self.assertTrue(issubclass(socket.gaierror, OSError))
- self.assertTrue(issubclass(socket.timeout, OSError))
+ self.assertIsSubclass(OSError, Exception)
+ self.assertIsSubclass(socket.herror, OSError)
+ self.assertIsSubclass(socket.gaierror, OSError)
+ self.assertIsSubclass(socket.timeout, OSError)
self.assertIs(socket.error, OSError)
self.assertIs(socket.timeout, TimeoutError)
diff --git a/Lib/test/test_source_encoding.py b/Lib/test/test_source_encoding.py
index 61b00778f83..1399f3fcd2d 100644
--- a/Lib/test/test_source_encoding.py
+++ b/Lib/test/test_source_encoding.py
@@ -145,8 +145,7 @@ class MiscSourceEncodingTest(unittest.TestCase):
compile(input, "<string>", "exec")
expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \
"ordinal not in range(128)"
- self.assertTrue(c.exception.args[0].startswith(expected),
- msg=c.exception.args[0])
+ self.assertStartsWith(c.exception.args[0], expected)
def test_file_parse_error_multiline(self):
# gh96611:
diff --git a/Lib/test/test_sqlite3/test_cli.py b/Lib/test/test_sqlite3/test_cli.py
index dcd90d11d46..720fa3c4c1e 100644
--- a/Lib/test/test_sqlite3/test_cli.py
+++ b/Lib/test/test_sqlite3/test_cli.py
@@ -1,12 +1,26 @@
"""sqlite3 CLI tests."""
import sqlite3
+import sys
+import textwrap
import unittest
+import unittest.mock
+import os
from sqlite3.__main__ import main as cli
+from test.support.import_helper import import_module
from test.support.os_helper import TESTFN, unlink
-from test.support import captured_stdout, captured_stderr, captured_stdin
+from test.support.pty_helper import run_pty
+from test.support import (
+ captured_stdout,
+ captured_stderr,
+ captured_stdin,
+ force_not_colorized_test_class,
+ requires_subprocess,
+ verbose,
+)
+@force_not_colorized_test_class
class CommandLineInterface(unittest.TestCase):
def _do_test(self, *args, expect_success=True):
@@ -63,6 +77,7 @@ class CommandLineInterface(unittest.TestCase):
self.assertIn("(0,)", out)
+@force_not_colorized_test_class
class InteractiveSession(unittest.TestCase):
MEMORY_DB_MSG = "Connected to a transient in-memory database"
PS1 = "sqlite> "
@@ -110,6 +125,38 @@ class InteractiveSession(unittest.TestCase):
self.assertEqual(out.count(self.PS2), 0)
self.assertIn(sqlite3.sqlite_version, out)
+ def test_interact_empty_source(self):
+ out, err = self.run_cli(commands=("", " "))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 3)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_dot_commands_unknown(self):
+ out, err = self.run_cli(commands=(".unknown_command", ))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
+ self.assertIn('Error: unknown command: "', err)
+ # test "unknown_command" is pointed out in the error message
+ self.assertIn("unknown_command", err)
+
+ def test_interact_dot_commands_empty(self):
+ out, err = self.run_cli(commands=("."))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 2)
+ self.assertEqual(out.count(self.PS2), 0)
+
+ def test_interact_dot_commands_with_whitespaces(self):
+ out, err = self.run_cli(commands=(".version ", ". version"))
+ self.assertIn(self.MEMORY_DB_MSG, err)
+ self.assertEqual(out.count(sqlite3.sqlite_version + "\n"), 2)
+ self.assertEndsWith(out, self.PS1)
+ self.assertEqual(out.count(self.PS1), 3)
+ self.assertEqual(out.count(self.PS2), 0)
+
def test_interact_valid_sql(self):
out, err = self.run_cli(commands=("SELECT 1;",))
self.assertIn(self.MEMORY_DB_MSG, err)
@@ -152,6 +199,117 @@ class InteractiveSession(unittest.TestCase):
out, _ = self.run_cli(TESTFN, commands=("SELECT count(t) FROM t;",))
self.assertIn("(0,)\n", out)
+ def test_color(self):
+ with unittest.mock.patch("_colorize.can_colorize", return_value=True):
+ out, err = self.run_cli(commands="TEXT\n")
+ self.assertIn("\x1b[1;35msqlite> \x1b[0m", out)
+ self.assertIn("\x1b[1;35m ... \x1b[0m\x1b", out)
+ out, err = self.run_cli(commands=("sel;",))
+ self.assertIn('\x1b[1;35mOperationalError (SQLITE_ERROR)\x1b[0m: '
+ '\x1b[35mnear "sel": syntax error\x1b[0m', err)
+
+
+@requires_subprocess()
+@force_not_colorized_test_class
+class Completion(unittest.TestCase):
+ PS1 = "sqlite> "
+
+ @classmethod
+ def setUpClass(cls):
+ _sqlite3 = import_module("_sqlite3")
+ if not hasattr(_sqlite3, "SQLITE_KEYWORDS"):
+ raise unittest.SkipTest("unable to determine SQLite keywords")
+
+ readline = import_module("readline")
+ if readline.backend == "editline":
+ raise unittest.SkipTest("libedit readline is not supported")
+
+ def write_input(self, input_, env=None):
+ script = textwrap.dedent("""
+ import readline
+ from sqlite3.__main__ import main
+
+ readline.parse_and_bind("set colored-completion-prefix off")
+ main()
+ """)
+ return run_pty(script, input_, env)
+
+ def test_complete_sql_keywords(self):
+ # List candidates starting with 'S', there should be multiple matches.
+ input_ = b"S\t\tEL\t 1;\n.quit\n"
+ output = self.write_input(input_)
+ self.assertIn(b"SELECT", output)
+ self.assertIn(b"SET", output)
+ self.assertIn(b"SAVEPOINT", output)
+ self.assertIn(b"(1,)", output)
+
+ # Keywords are completed in upper case for even lower case user input.
+ input_ = b"sel\t\t 1;\n.quit\n"
+ output = self.write_input(input_)
+ self.assertIn(b"SELECT", output)
+ self.assertIn(b"(1,)", output)
+
+ @unittest.skipIf(sys.platform.startswith("freebsd"),
+ "Two actual tabs are inserted when there are no matching"
+ " completions in the pseudo-terminal opened by run_pty()"
+ " on FreeBSD")
+ def test_complete_no_match(self):
+ input_ = b"xyzzy\t\t\b\b\b\b\b\b\b.quit\n"
+ # Set NO_COLOR to disable coloring for self.PS1.
+ output = self.write_input(input_, env={**os.environ, "NO_COLOR": "1"})
+ lines = output.decode().splitlines()
+ indices = (
+ i for i, line in enumerate(lines, 1)
+ if line.startswith(f"{self.PS1}xyzzy")
+ )
+ line_num = next(indices, -1)
+ self.assertNotEqual(line_num, -1)
+ # Completions occupy lines, assert no extra lines when there is nothing
+ # to complete.
+ self.assertEqual(line_num, len(lines))
+
+ def test_complete_no_input(self):
+ from _sqlite3 import SQLITE_KEYWORDS
+
+ script = textwrap.dedent("""
+ import readline
+ from sqlite3.__main__ import main
+
+ # Configure readline to ...:
+ # - hide control sequences surrounding each candidate
+ # - hide "Display all xxx possibilities? (y or n)"
+ # - hide "--More--"
+ # - show candidates one per line
+ readline.parse_and_bind("set colored-completion-prefix off")
+ readline.parse_and_bind("set colored-stats off")
+ readline.parse_and_bind("set completion-query-items 0")
+ readline.parse_and_bind("set page-completions off")
+ readline.parse_and_bind("set completion-display-width 0")
+ readline.parse_and_bind("set show-all-if-ambiguous off")
+ readline.parse_and_bind("set show-all-if-unmodified off")
+
+ main()
+ """)
+ input_ = b"\t\t.quit\n"
+ output = run_pty(script, input_, env={**os.environ, "NO_COLOR": "1"})
+ try:
+ lines = output.decode().splitlines()
+ indices = [
+ i for i, line in enumerate(lines)
+ if line.startswith(self.PS1)
+ ]
+ self.assertEqual(len(indices), 2)
+ start, end = indices
+ candidates = [l.strip() for l in lines[start+1:end]]
+ self.assertEqual(candidates, sorted(SQLITE_KEYWORDS))
+ except:
+ if verbose:
+ print(' PTY output: '.center(30, '-'))
+ print(output.decode(errors='replace'))
+ print(' end PTY output '.center(30, '-'))
+ raise
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index c3aa3bf2d7b..291e0356253 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -550,17 +550,9 @@ class ConnectionTests(unittest.TestCase):
cx.execute("insert into u values(0)")
def test_connect_positional_arguments(self):
- regex = (
- r"Passing more than 1 positional argument to sqlite3.connect\(\)"
- " is deprecated. Parameters 'timeout', 'detect_types', "
- "'isolation_level', 'check_same_thread', 'factory', "
- "'cached_statements' and 'uri' will become keyword-only "
- "parameters in Python 3.15."
- )
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- cx = sqlite.connect(":memory:", 1.0)
- cx.close()
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ r'connect\(\) takes at most 1 positional arguments'):
+ sqlite.connect(":memory:", 1.0)
def test_connection_resource_warning(self):
with self.assertWarns(ResourceWarning):
diff --git a/Lib/test/test_sqlite3/test_factory.py b/Lib/test/test_sqlite3/test_factory.py
index cc9f1ec5c4b..776659e3b16 100644
--- a/Lib/test/test_sqlite3/test_factory.py
+++ b/Lib/test/test_sqlite3/test_factory.py
@@ -71,18 +71,9 @@ class ConnectionFactoryTests(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(Factory, self).__init__(*args, **kwargs)
- regex = (
- r"Passing more than 1 positional argument to _sqlite3.Connection\(\) "
- r"is deprecated. Parameters 'timeout', 'detect_types', "
- r"'isolation_level', 'check_same_thread', 'factory', "
- r"'cached_statements' and 'uri' will become keyword-only "
- r"parameters in Python 3.15."
- )
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- with memory_database(5.0, 0, None, True, Factory) as con:
- self.assertIsNone(con.isolation_level)
- self.assertIsInstance(con, Factory)
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ r'connect\(\) takes at most 1 positional arguments'):
+ memory_database(5.0, 0, None, True, Factory)
class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py
index 53b8a39bf29..2b907e35131 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -220,16 +220,9 @@ class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
""")
def test_progress_handler_keyword_args(self):
- regex = (
- r"Passing keyword argument 'progress_handler' to "
- r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. "
- r"Parameter 'progress_handler' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes at least 1 positional argument'):
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
- self.assertEqual(cm.filename, __file__)
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@@ -353,16 +346,9 @@ class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
cx.execute("select 1")
def test_trace_keyword_args(self):
- regex = (
- r"Passing keyword argument 'trace_callback' to "
- r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. "
- r"Parameter 'trace_callback' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 1 positional argument'):
self.con.set_trace_callback(trace_callback=lambda: None)
- self.assertEqual(cm.filename, __file__)
if __name__ == "__main__":
diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py
index 3abc43a3b1a..11cf877a011 100644
--- a/Lib/test/test_sqlite3/test_userfunctions.py
+++ b/Lib/test/test_sqlite3/test_userfunctions.py
@@ -422,27 +422,9 @@ class FunctionTests(unittest.TestCase):
self.con.execute, "select badreturn()")
def test_func_keyword_args(self):
- regex = (
- r"Passing keyword arguments 'name', 'narg' and 'func' to "
- r"_sqlite3.Connection.create_function\(\) is deprecated. "
- r"Parameters 'name', 'narg' and 'func' will become "
- r"positional-only in Python 3.15."
- )
-
- def noop():
- return None
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function("noop", 0, func=noop)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function("noop", narg=0, func=noop)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_function(name="noop", narg=0, func=noop)
- self.assertEqual(cm.filename, __file__)
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 3 positional arguments'):
+ self.con.create_function("noop", 0, func=lambda: None)
class WindowSumInt:
@@ -737,25 +719,9 @@ class AggregateTests(unittest.TestCase):
self.assertEqual(val, txt)
def test_agg_keyword_args(self):
- regex = (
- r"Passing keyword arguments 'name', 'n_arg' and 'aggregate_class' to "
- r"_sqlite3.Connection.create_aggregate\(\) is deprecated. "
- r"Parameters 'name', 'n_arg' and 'aggregate_class' will become "
- r"positional-only in Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 3 positional arguments'):
self.con.create_aggregate("test", 1, aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_aggregate("test", n_arg=1, aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
- self.con.create_aggregate(name="test", n_arg=0,
- aggregate_class=AggrText)
- self.assertEqual(cm.filename, __file__)
class AuthorizerTests(unittest.TestCase):
@@ -800,16 +766,9 @@ class AuthorizerTests(unittest.TestCase):
self.con.execute("select c2 from t1")
def test_authorizer_keyword_args(self):
- regex = (
- r"Passing keyword argument 'authorizer_callback' to "
- r"_sqlite3.Connection.set_authorizer\(\) is deprecated. "
- r"Parameter 'authorizer_callback' will become positional-only in "
- r"Python 3.15."
- )
-
- with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
+ with self.assertRaisesRegex(TypeError,
+ 'takes exactly 1 positional argument'):
self.con.set_authorizer(authorizer_callback=lambda: None)
- self.assertEqual(cm.filename, __file__)
class AuthorizerRaiseExceptionTests(AuthorizerTests):
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 395b2ef88ab..f123f6ece40 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -31,6 +31,7 @@ import weakref
import platform
import sysconfig
import functools
+from contextlib import nullcontext
try:
import ctypes
except ImportError:
@@ -539,9 +540,9 @@ class BasicSocketTests(unittest.TestCase):
openssl_ver = f"OpenSSL {major:d}.{minor:d}.{patch:d}"
else:
openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
- self.assertTrue(
- s.startswith((openssl_ver, libressl_ver, "AWS-LC")),
- (s, t, hex(n))
+ self.assertStartsWith(
+ s, (openssl_ver, libressl_ver, "AWS-LC"),
+ (t, hex(n))
)
@support.cpython_only
@@ -1668,7 +1669,7 @@ class SSLErrorTests(unittest.TestCase):
regex = "(NO_START_LINE|UNSUPPORTED_PUBLIC_KEY_TYPE)"
self.assertRegex(cm.exception.reason, regex)
s = str(cm.exception)
- self.assertTrue("NO_START_LINE" in s, s)
+ self.assertIn("NO_START_LINE", s)
def test_subclass(self):
# Check that the appropriate SSLError subclass is raised
@@ -1683,7 +1684,7 @@ class SSLErrorTests(unittest.TestCase):
with self.assertRaises(ssl.SSLWantReadError) as cm:
c.do_handshake()
s = str(cm.exception)
- self.assertTrue(s.startswith("The operation did not complete (read)"), s)
+ self.assertStartsWith(s, "The operation did not complete (read)")
# For compatibility
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
@@ -2843,6 +2844,7 @@ class ThreadedTests(unittest.TestCase):
# See GH-124984: OpenSSL is not thread safe.
threads = []
+ warnings_filters = sys.flags.context_aware_warnings
global USE_SAME_TEST_CONTEXT
USE_SAME_TEST_CONTEXT = True
try:
@@ -2851,7 +2853,10 @@ class ThreadedTests(unittest.TestCase):
self.test_alpn_protocols,
self.test_getpeercert,
self.test_crl_check,
- self.test_check_hostname_idn,
+ functools.partial(
+ self.test_check_hostname_idn,
+ warnings_filters=warnings_filters,
+ ),
self.test_wrong_cert_tls12,
self.test_wrong_cert_tls13,
):
@@ -3097,7 +3102,7 @@ class ThreadedTests(unittest.TestCase):
cipher = s.cipher()[0].split('-')
self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
- def test_check_hostname_idn(self):
+ def test_check_hostname_idn(self, warnings_filters=True):
if support.verbose:
sys.stdout.write("\n")
@@ -3152,16 +3157,30 @@ class ThreadedTests(unittest.TestCase):
server_hostname="python.example.org") as s:
with self.assertRaises(ssl.CertificateError):
s.connect((HOST, server.port))
- with ThreadedEchoServer(context=server_context, chatty=True) as server:
- with warnings_helper.check_no_resource_warning(self):
- with self.assertRaises(UnicodeError):
- context.wrap_socket(socket.socket(),
- server_hostname='.pythontest.net')
- with ThreadedEchoServer(context=server_context, chatty=True) as server:
- with warnings_helper.check_no_resource_warning(self):
- with self.assertRaises(UnicodeDecodeError):
- context.wrap_socket(socket.socket(),
- server_hostname=b'k\xf6nig.idn.pythontest.net')
+ with (
+ ThreadedEchoServer(context=server_context, chatty=True) as server,
+ (
+ warnings_helper.check_no_resource_warning(self)
+ if warnings_filters
+ else nullcontext()
+ ),
+ self.assertRaises(UnicodeError),
+ ):
+ context.wrap_socket(socket.socket(), server_hostname='.pythontest.net')
+
+ with (
+ ThreadedEchoServer(context=server_context, chatty=True) as server,
+ (
+ warnings_helper.check_no_resource_warning(self)
+ if warnings_filters
+ else nullcontext()
+ ),
+ self.assertRaises(UnicodeDecodeError),
+ ):
+ context.wrap_socket(
+ socket.socket(),
+ server_hostname=b'k\xf6nig.idn.pythontest.net',
+ )
def test_wrong_cert_tls12(self):
"""Connecting when the server rejects the client's certificate
@@ -4488,6 +4507,7 @@ class ThreadedTests(unittest.TestCase):
@requires_tls_version('TLSv1_3')
@unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build')
+ @unittest.skipUnless(ssl.HAS_PSK_TLS13, 'TLS 1.3 PSK disabled on this OpenSSL build')
def test_psk_tls1_3(self):
psk = bytes.fromhex('deadbeef')
identity_hint = 'identity-hint'
diff --git a/Lib/test/test_stable_abi_ctypes.py b/Lib/test/test_stable_abi_ctypes.py
index 1e6f69d49e9..5a6ba9de337 100644
--- a/Lib/test/test_stable_abi_ctypes.py
+++ b/Lib/test/test_stable_abi_ctypes.py
@@ -658,7 +658,11 @@ SYMBOL_NAMES = (
"PySys_AuditTuple",
"PySys_FormatStderr",
"PySys_FormatStdout",
+ "PySys_GetAttr",
+ "PySys_GetAttrString",
"PySys_GetObject",
+ "PySys_GetOptionalAttr",
+ "PySys_GetOptionalAttrString",
"PySys_GetXOptions",
"PySys_HasWarnOptions",
"PySys_ResetWarnOptions",
diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py
index 49013a4bcd8..5fd25d5012c 100644
--- a/Lib/test/test_stat.py
+++ b/Lib/test/test_stat.py
@@ -157,7 +157,7 @@ class TestFilemode:
os.chmod(TESTFN, 0o700)
st_mode, modestr = self.get_mode()
- self.assertEqual(modestr[:3], '-rw')
+ self.assertStartsWith(modestr, '-rw')
self.assertS_IS("REG", st_mode)
self.assertEqual(self.statmod.S_IFMT(st_mode),
self.statmod.S_IFREG)
@@ -256,7 +256,7 @@ class TestFilemode:
"FILE_ATTRIBUTE_* constants are Win32 specific")
def test_file_attribute_constants(self):
for key, value in sorted(self.file_attributes.items()):
- self.assertTrue(hasattr(self.statmod, key), key)
+ self.assertHasAttr(self.statmod, key)
modvalue = getattr(self.statmod, key)
self.assertEqual(value, modvalue, key)
@@ -314,7 +314,7 @@ class TestFilemode:
self.assertEqual(self.statmod.S_ISGID, 0o002000)
self.assertEqual(self.statmod.S_ISVTX, 0o001000)
- self.assertFalse(hasattr(self.statmod, "S_ISTXT"))
+ self.assertNotHasAttr(self.statmod, "S_ISTXT")
self.assertEqual(self.statmod.S_IREAD, self.statmod.S_IRUSR)
self.assertEqual(self.statmod.S_IWRITE, self.statmod.S_IWUSR)
self.assertEqual(self.statmod.S_IEXEC, self.statmod.S_IXUSR)
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index c69baa4bf4d..8250b0aef09 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -645,7 +645,7 @@ class TestNumericTestCase(unittest.TestCase):
def test_numerictestcase_is_testcase(self):
# Ensure that NumericTestCase actually is a TestCase.
- self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
+ self.assertIsSubclass(NumericTestCase, unittest.TestCase)
def test_error_msg_numeric(self):
# Test the error message generated for numeric comparisons.
@@ -683,32 +683,23 @@ class GlobalsTest(unittest.TestCase):
def test_meta(self):
# Test for the existence of metadata.
for meta in self.expected_metadata:
- self.assertTrue(hasattr(self.module, meta),
- "%s not present" % meta)
+ self.assertHasAttr(self.module, meta)
def test_check_all(self):
# Check everything in __all__ exists and is public.
module = self.module
for name in module.__all__:
# No private names in __all__:
- self.assertFalse(name.startswith("_"),
+ self.assertNotStartsWith(name, "_",
'private name "%s" in __all__' % name)
# And anything in __all__ must exist:
- self.assertTrue(hasattr(module, name),
- 'missing name "%s" in __all__' % name)
+ self.assertHasAttr(module, name)
class StatisticsErrorTest(unittest.TestCase):
def test_has_exception(self):
- errmsg = (
- "Expected StatisticsError to be a ValueError, but got a"
- " subclass of %r instead."
- )
- self.assertTrue(hasattr(statistics, 'StatisticsError'))
- self.assertTrue(
- issubclass(statistics.StatisticsError, ValueError),
- errmsg % statistics.StatisticsError.__base__
- )
+ self.assertHasAttr(statistics, 'StatisticsError')
+ self.assertIsSubclass(statistics.StatisticsError, ValueError)
# === Tests for private utility functions ===
@@ -2355,6 +2346,7 @@ class TestGeometricMean(unittest.TestCase):
class TestKDE(unittest.TestCase):
+ @support.requires_resource('cpu')
def test_kde(self):
kde = statistics.kde
StatisticsError = statistics.StatisticsError
@@ -3327,7 +3319,8 @@ class TestNormalDistC(unittest.TestCase, TestNormalDist):
def load_tests(loader, tests, ignore):
"""Used for doctest/unittest integration."""
tests.addTests(doctest.DocTestSuite())
- tests.addTests(doctest.DocTestSuite(statistics))
+ if sys.float_repr_style == 'short':
+ tests.addTests(doctest.DocTestSuite(statistics))
return tests
diff --git a/Lib/test/test_str.py b/Lib/test/test_str.py
index d6a7bd0da59..2584fbf72d3 100644
--- a/Lib/test/test_str.py
+++ b/Lib/test/test_str.py
@@ -1231,10 +1231,10 @@ class StrTest(string_tests.StringLikeTest,
self.assertEqual('{0:\x00^6}'.format(3), '\x00\x003\x00\x00\x00')
self.assertEqual('{0:<6}'.format(3), '3 ')
- self.assertEqual('{0:\x00<6}'.format(3.14), '3.14\x00\x00')
- self.assertEqual('{0:\x01<6}'.format(3.14), '3.14\x01\x01')
- self.assertEqual('{0:\x00^6}'.format(3.14), '\x003.14\x00')
- self.assertEqual('{0:^6}'.format(3.14), ' 3.14 ')
+ self.assertEqual('{0:\x00<6}'.format(3.25), '3.25\x00\x00')
+ self.assertEqual('{0:\x01<6}'.format(3.25), '3.25\x01\x01')
+ self.assertEqual('{0:\x00^6}'.format(3.25), '\x003.25\x00')
+ self.assertEqual('{0:^6}'.format(3.25), ' 3.25 ')
self.assertEqual('{0:\x00<12}'.format(3+2.0j), '(3+2j)\x00\x00\x00\x00\x00\x00')
self.assertEqual('{0:\x01<12}'.format(3+2.0j), '(3+2j)\x01\x01\x01\x01\x01\x01')
diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py
index 752e31359cf..375f6aaedd8 100644
--- a/Lib/test/test_strftime.py
+++ b/Lib/test/test_strftime.py
@@ -39,7 +39,21 @@ class StrftimeTest(unittest.TestCase):
if now[3] < 12: self.ampm='(AM|am)'
else: self.ampm='(PM|pm)'
- self.jan1 = time.localtime(time.mktime((now[0], 1, 1, 0, 0, 0, 0, 1, 0)))
+ jan1 = time.struct_time(
+ (
+ now.tm_year, # Year
+ 1, # Month (January)
+ 1, # Day (1st)
+ 0, # Hour (0)
+ 0, # Minute (0)
+ 0, # Second (0)
+ -1, # tm_wday (will be determined)
+ 1, # tm_yday (day 1 of the year)
+ -1, # tm_isdst (let the system determine)
+ )
+ )
+ # use mktime to get the correct tm_wday and tm_isdst values
+ self.jan1 = time.localtime(time.mktime(jan1))
try:
if now[8]: self.tz = time.tzname[1]
diff --git a/Lib/test/test_string/_support.py b/Lib/test/test_string/_support.py
index eaa3354a559..abdddaf187b 100644
--- a/Lib/test/test_string/_support.py
+++ b/Lib/test/test_string/_support.py
@@ -1,4 +1,3 @@
-import unittest
from string.templatelib import Interpolation
diff --git a/Lib/test/test_string/test_string.py b/Lib/test/test_string/test_string.py
index f6d112d8a93..5394fe4e12c 100644
--- a/Lib/test/test_string/test_string.py
+++ b/Lib/test/test_string/test_string.py
@@ -2,6 +2,14 @@ import unittest
import string
from string import Template
import types
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "collections"})
class ModuleTest(unittest.TestCase):
diff --git a/Lib/test/test_string/test_templatelib.py b/Lib/test/test_string/test_templatelib.py
index 5b9490c2be6..85fcff486d6 100644
--- a/Lib/test/test_string/test_templatelib.py
+++ b/Lib/test/test_string/test_templatelib.py
@@ -148,6 +148,13 @@ class TemplateIterTests(unittest.TestCase):
self.assertEqual(res[1].format_spec, '')
self.assertEqual(res[2], ' yz')
+ def test_exhausted(self):
+ # See https://github.com/python/cpython/issues/134119.
+ template_iter = iter(t"{1}")
+ self.assertIsInstance(next(template_iter), Interpolation)
+ self.assertRaises(StopIteration, next, template_iter)
+ self.assertRaises(StopIteration, next, template_iter)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_strptime.py b/Lib/test/test_strptime.py
index 268230f6da7..0241e543cd7 100644
--- a/Lib/test/test_strptime.py
+++ b/Lib/test/test_strptime.py
@@ -221,14 +221,16 @@ class StrptimeTests(unittest.TestCase):
self.assertRaises(ValueError, _strptime._strptime_time, data_string="%d",
format="%A")
for bad_format in ("%", "% ", "%\n"):
- with self.assertRaisesRegex(ValueError, "stray % in format "):
+ with (self.subTest(format=bad_format),
+ self.assertRaisesRegex(ValueError, "stray % in format ")):
_strptime._strptime_time("2005", bad_format)
- for bad_format in ("%e", "%Oe", "%O", "%O ", "%Ee", "%E", "%E ",
- "%.", "%+", "%_", "%~", "%\\",
+ for bad_format in ("%i", "%Oi", "%O", "%O ", "%Ee", "%E", "%E ",
+ "%.", "%+", "%~", "%\\",
"%O.", "%O+", "%O_", "%O~", "%O\\"):
directive = bad_format[1:].rstrip()
- with self.assertRaisesRegex(ValueError,
- f"'{re.escape(directive)}' is a bad directive in format "):
+ with (self.subTest(format=bad_format),
+ self.assertRaisesRegex(ValueError,
+ f"'{re.escape(directive)}' is a bad directive in format ")):
_strptime._strptime_time("2005", bad_format)
msg_week_no_year_or_weekday = r"ISO week directive '%V' must be used with " \
@@ -335,6 +337,15 @@ class StrptimeTests(unittest.TestCase):
self.roundtrip('%B', 1, (1900, m, 1, 0, 0, 0, 0, 1, 0))
self.roundtrip('%b', 1, (1900, m, 1, 0, 0, 0, 0, 1, 0))
+ @run_with_locales('LC_TIME', 'az_AZ', 'ber_DZ', 'ber_MA', 'crh_UA')
+ def test_month_locale2(self):
+ # Test for month directives
+ # Month name contains 'İ' ('\u0130')
+ self.roundtrip('%B', 1, (2025, 6, 1, 0, 0, 0, 6, 152, 0))
+ self.roundtrip('%b', 1, (2025, 6, 1, 0, 0, 0, 6, 152, 0))
+ self.roundtrip('%B', 1, (2025, 7, 1, 0, 0, 0, 1, 182, 0))
+ self.roundtrip('%b', 1, (2025, 7, 1, 0, 0, 0, 1, 182, 0))
+
def test_day(self):
# Test for day directives
self.roundtrip('%d %Y', 2)
@@ -480,13 +491,11 @@ class StrptimeTests(unittest.TestCase):
# * Year is not included: ha_NG.
# * Use non-Gregorian calendar: lo_LA, thai, th_TH.
# On Windows: ar_IN, ar_SA, fa_IR, ps_AF.
- #
- # BUG: Generates regexp that does not match the current date and time
- # for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'he_IL', 'eu_ES', 'ar_AE', 'mfe_MU', 'yo_NG',
'csb_PL', 'br_FR', 'gez_ET', 'brx_IN',
- 'my_MM', 'or_IN', 'shn_MM', 'az_IR')
+ 'my_MM', 'or_IN', 'shn_MM', 'az_IR',
+ 'byn_ER', 'wal_ET', 'lzh_TW')
def test_date_time_locale(self):
# Test %c directive
loc = locale.getlocale(locale.LC_TIME)[0]
@@ -525,11 +534,9 @@ class StrptimeTests(unittest.TestCase):
# NB: Does not roundtrip because use non-Gregorian calendar:
# lo_LA, thai, th_TH. On Windows: ar_IN, ar_SA, fa_IR, ps_AF.
- # BUG: Generates regexp that does not match the current date
- # for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'he_IL', 'eu_ES', 'ar_AE',
- 'az_IR', 'my_MM', 'or_IN', 'shn_MM')
+ 'az_IR', 'my_MM', 'or_IN', 'shn_MM', 'lzh_TW')
def test_date_locale(self):
# Test %x directive
now = time.time()
@@ -546,7 +553,7 @@ class StrptimeTests(unittest.TestCase):
# NB: Dates before 1969 do not roundtrip on many locales, including C.
@unittest.skipIf(support.linked_to_musl(), "musl libc issue, bpo-46390")
@run_with_locales('LC_TIME', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
- 'eu_ES', 'ar_AE', 'my_MM', 'shn_MM')
+ 'eu_ES', 'ar_AE', 'my_MM', 'shn_MM', 'lzh_TW')
def test_date_locale2(self):
# Test %x directive
loc = locale.getlocale(locale.LC_TIME)[0]
@@ -562,11 +569,11 @@ class StrptimeTests(unittest.TestCase):
# norwegian, nynorsk.
# * Hours are in 12-hour notation without AM/PM indication: hy_AM,
# ms_MY, sm_WS.
- # BUG: Generates regexp that does not match the current time for lzh_TW.
@run_with_locales('LC_TIME', 'C', 'en_US', 'fr_FR', 'de_DE', 'ja_JP',
'aa_ET', 'am_ET', 'az_IR', 'byn_ER', 'fa_IR', 'gez_ET',
'my_MM', 'om_ET', 'or_IN', 'shn_MM', 'sid_ET', 'so_SO',
- 'ti_ET', 'tig_ER', 'wal_ET')
+ 'ti_ET', 'tig_ER', 'wal_ET', 'lzh_TW',
+ 'ar_SA', 'bg_BG')
def test_time_locale(self):
# Test %X directive
loc = locale.getlocale(locale.LC_TIME)[0]
diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py
index 2727514fad4..570de390a95 100644
--- a/Lib/test/test_strtod.py
+++ b/Lib/test/test_strtod.py
@@ -19,7 +19,7 @@ strtod_parser = re.compile(r""" # A numeric string consists of:
(?P<int>\d*) # having a (possibly empty) integer part
(?:\.(?P<frac>\d*))? # followed by an optional fractional part
(?:E(?P<exp>[-+]?\d+))? # and an optional exponent
- \Z
+ \z
""", re.VERBOSE | re.IGNORECASE).match
# Pure Python version of correctly rounded string->float conversion.
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index d0bc0bd7b61..9622151143c 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -42,7 +42,7 @@ class StructSeqTest(unittest.TestCase):
# os.stat() gives a complicated struct sequence.
st = os.stat(__file__)
rep = repr(st)
- self.assertTrue(rep.startswith("os.stat_result"))
+ self.assertStartsWith(rep, "os.stat_result")
self.assertIn("st_mode=", rep)
self.assertIn("st_ino=", rep)
self.assertIn("st_dev=", rep)
@@ -307,7 +307,7 @@ class StructSeqTest(unittest.TestCase):
self.assertEqual(t5.tm_mon, 2)
# named invisible fields
- self.assertTrue(hasattr(t, 'tm_zone'), f"{t} has no attribute 'tm_zone'")
+ self.assertHasAttr(t, 'tm_zone')
with self.assertRaisesRegex(AttributeError, 'readonly attribute'):
t.tm_zone = 'some other zone'
self.assertEqual(t2.tm_zone, t.tm_zone)
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index 3cb755cd56c..f0e350c71f6 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -162,6 +162,20 @@ class ProcessTestCase(BaseTestCase):
[sys.executable, "-c", "while True: pass"],
timeout=0.1)
+ def test_timeout_exception(self):
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = -1)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("-1 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = 0)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("0 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+
def test_check_call_zero(self):
# check_call() function with zero return code
rc = subprocess.check_call(ZERO_RETURN_CMD)
@@ -1164,7 +1178,7 @@ class ProcessTestCase(BaseTestCase):
self.assertEqual("line1\nline2\nline3\nline4\nline5\n", stdout)
# Python debug build push something like "[42442 refs]\n"
# to stderr at exit of subprocess.
- self.assertTrue(stderr.startswith("eline2\neline6\neline7\n"))
+ self.assertStartsWith(stderr, "eline2\neline6\neline7\n")
def test_universal_newlines_communicate_encodings(self):
# Check that universal newlines mode works for various encodings,
@@ -1496,7 +1510,7 @@ class ProcessTestCase(BaseTestCase):
"[sys.executable, '-c', 'print(\"Hello World!\")'])",
'assert retcode == 0'))
output = subprocess.check_output([sys.executable, '-c', code])
- self.assertTrue(output.startswith(b'Hello World!'), ascii(output))
+ self.assertStartsWith(output, b'Hello World!')
def test_handles_closed_on_exception(self):
# If CreateProcess exits with an error, ensure the
@@ -1821,8 +1835,8 @@ class RunFuncTestCase(BaseTestCase):
capture_output=True)
lines = cp.stderr.splitlines()
self.assertEqual(len(lines), 2, lines)
- self.assertTrue(lines[0].startswith(b"<string>:2: EncodingWarning: "))
- self.assertTrue(lines[1].startswith(b"<string>:3: EncodingWarning: "))
+ self.assertStartsWith(lines[0], b"<string>:2: EncodingWarning: ")
+ self.assertStartsWith(lines[1], b"<string>:3: EncodingWarning: ")
def _get_test_grp_name():
diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py
index 5cef612a340..193c8b7d7f3 100644
--- a/Lib/test/test_super.py
+++ b/Lib/test/test_super.py
@@ -547,11 +547,11 @@ class TestSuper(unittest.TestCase):
self.assertEqual(s.__reduce__, e.__reduce__)
self.assertEqual(s.__reduce_ex__, e.__reduce_ex__)
self.assertEqual(s.__getstate__, e.__getstate__)
- self.assertFalse(hasattr(s, '__getnewargs__'))
- self.assertFalse(hasattr(s, '__getnewargs_ex__'))
- self.assertFalse(hasattr(s, '__setstate__'))
- self.assertFalse(hasattr(s, '__copy__'))
- self.assertFalse(hasattr(s, '__deepcopy__'))
+ self.assertNotHasAttr(s, '__getnewargs__')
+ self.assertNotHasAttr(s, '__getnewargs_ex__')
+ self.assertNotHasAttr(s, '__setstate__')
+ self.assertNotHasAttr(s, '__copy__')
+ self.assertNotHasAttr(s, '__deepcopy__')
def test_pickling(self):
e = E()
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 468bac82924..e48a2464ee5 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -407,10 +407,10 @@ class TestSupport(unittest.TestCase):
with support.swap_attr(obj, "y", 5) as y:
self.assertEqual(obj.y, 5)
self.assertIsNone(y)
- self.assertFalse(hasattr(obj, 'y'))
+ self.assertNotHasAttr(obj, 'y')
with support.swap_attr(obj, "y", 5):
del obj.y
- self.assertFalse(hasattr(obj, 'y'))
+ self.assertNotHasAttr(obj, 'y')
def test_swap_item(self):
D = {"x":1}
@@ -561,6 +561,7 @@ class TestSupport(unittest.TestCase):
['-Wignore', '-X', 'dev'],
['-X', 'faulthandler'],
['-X', 'importtime'],
+ ['-X', 'importtime=2'],
['-X', 'showrefcount'],
['-X', 'tracemalloc'],
['-X', 'tracemalloc=3'],
diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py
index 0ee17849e28..c52d2421941 100644
--- a/Lib/test/test_syntax.py
+++ b/Lib/test/test_syntax.py
@@ -382,6 +382,13 @@ SyntaxError: invalid syntax
Traceback (most recent call last):
SyntaxError: invalid syntax
+# But prefixes of soft keywords should
+# still raise specialized errors
+
+>>> (mat x)
+Traceback (most recent call last):
+SyntaxError: invalid syntax. Perhaps you forgot a comma?
+
From compiler_complex_args():
>>> def f(None=1):
@@ -419,7 +426,7 @@ SyntaxError: invalid syntax
>>> def foo(/,a,b=,c):
... pass
Traceback (most recent call last):
-SyntaxError: at least one argument must precede /
+SyntaxError: at least one parameter must precede /
>>> def foo(a,/,/,b,c):
... pass
@@ -454,67 +461,67 @@ SyntaxError: / must be ahead of *
>>> def foo(a,*b=3,c):
... pass
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> def foo(a,*b: int=,c):
... pass
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> def foo(a,**b=3):
... pass
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> def foo(a,**b: int=3):
... pass
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> def foo(a,*a, b, **c, d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, d=4):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, *d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a,*a, b, **c, **d):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(a=1,/,**b,/,c):
... pass
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> def foo(*b,*d):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,*b,c,*d,*e,c):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,b,/,c,*b,c,*d,*e,c):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a,b,/,c,*b,c,*d,**e):
... pass
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> def foo(a=1,/*,b,c):
... pass
@@ -538,7 +545,7 @@ SyntaxError: expected default value expression
>>> lambda /,a,b,c: None
Traceback (most recent call last):
-SyntaxError: at least one argument must precede /
+SyntaxError: at least one parameter must precede /
>>> lambda a,/,/,b,c: None
Traceback (most recent call last):
@@ -570,47 +577,47 @@ SyntaxError: expected comma between / and *
>>> lambda a,*b=3,c: None
Traceback (most recent call last):
-SyntaxError: var-positional argument cannot have default value
+SyntaxError: var-positional parameter cannot have default value
>>> lambda a,**b=3: None
Traceback (most recent call last):
-SyntaxError: var-keyword argument cannot have default value
+SyntaxError: var-keyword parameter cannot have default value
>>> lambda a, *a, b, **c, d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, d=4: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, *d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a,*a, b, **c, **d: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda a=1,/,**b,/,c: None
Traceback (most recent call last):
-SyntaxError: arguments cannot follow var-keyword argument
+SyntaxError: parameters cannot follow var-keyword parameter
>>> lambda *b,*d: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,*b,c,*d,*e,c: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,b,/,c,*b,c,*d,*e,c: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a,b,/,c,*b,c,*d,**e: None
Traceback (most recent call last):
-SyntaxError: * argument may appear only once
+SyntaxError: * may appear only once
>>> lambda a=1,d=,c: None
Traceback (most recent call last):
@@ -1304,7 +1311,7 @@ Missing parens after function definition
Traceback (most recent call last):
SyntaxError: expected '('
-Parenthesized arguments in function definitions
+Parenthesized parameters in function definitions
>>> def f(x, (y, z), w):
... pass
@@ -1431,6 +1438,23 @@ Better error message for using `except as` with not a name:
Traceback (most recent call last):
SyntaxError: cannot use except* statement with literal
+Regression tests for gh-133999:
+
+ >>> try: pass
+ ... except TypeError as name: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> try: pass
+ ... except* TypeError as name: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> match 1:
+ ... case 1 | 2 as abc: raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
Ensure that early = are not matched by the parser as invalid comparisons
>>> f(2, 4, x=34); 1 $ 2
Traceback (most recent call last):
@@ -1678,6 +1702,28 @@ Make sure that the old "raise X, Y[, Z]" form is gone:
...
SyntaxError: invalid syntax
+Better errors for `raise` statement:
+
+ >>> raise ValueError from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression after 'from'?
+
+ >>> raise mod.ValueError() from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression after 'from'?
+
+ >>> raise from exc
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> raise from None
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
+ >>> raise from
+ Traceback (most recent call last):
+ SyntaxError: did you forget an expression between 'raise' and 'from'?
+
Check that an multiple exception types with missing parentheses
raise a custom exception only when using 'as'
@@ -2178,7 +2224,7 @@ Corner-cases that used to fail to raise the correct error:
>>> with (lambda *:0): pass
Traceback (most recent call last):
- SyntaxError: named arguments must follow bare *
+ SyntaxError: named parameters must follow bare *
Corner-cases that used to crash:
@@ -2826,6 +2872,13 @@ class SyntaxErrorTestCase(unittest.TestCase):
"""
self._check_error(source, "parameter and nonlocal", lineno=3)
+ def test_raise_from_error_message(self):
+ source = """if 1:
+ raise AssertionError() from None
+ print(1,,2)
+ """
+ self._check_error(source, "invalid syntax", lineno=3)
+
def test_yield_outside_function(self):
self._check_error("if 0: yield", "outside function")
self._check_error("if 0: yield\nelse: x=1", "outside function")
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 56413d00823..486bf10a0b5 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -24,7 +24,7 @@ from test.support import import_helper
from test.support import force_not_colorized
from test.support import SHORT_TIMEOUT
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ImportError:
interpreters = None
import textwrap
@@ -57,7 +57,7 @@ class DisplayHookTest(unittest.TestCase):
dh(None)
self.assertEqual(out.getvalue(), "")
- self.assertTrue(not hasattr(builtins, "_"))
+ self.assertNotHasAttr(builtins, "_")
# sys.displayhook() requires arguments
self.assertRaises(TypeError, dh)
@@ -172,7 +172,7 @@ class ExceptHookTest(unittest.TestCase):
with support.captured_stderr() as err:
sys.__excepthook__(*sys.exc_info())
- self.assertTrue(err.getvalue().endswith("ValueError: 42\n"))
+ self.assertEndsWith(err.getvalue(), "ValueError: 42\n")
self.assertRaises(TypeError, sys.__excepthook__)
@@ -192,7 +192,7 @@ class ExceptHookTest(unittest.TestCase):
err = err.getvalue()
self.assertIn(""" File "b'bytes_filename'", line 123\n""", err)
self.assertIn(""" text\n""", err)
- self.assertTrue(err.endswith("SyntaxError: msg\n"))
+ self.assertEndsWith(err, "SyntaxError: msg\n")
def test_excepthook(self):
with test.support.captured_output("stderr") as stderr:
@@ -269,8 +269,7 @@ class SysModuleTest(unittest.TestCase):
rc, out, err = assert_python_failure('-c', code, **env_vars)
self.assertEqual(rc, 1)
self.assertEqual(out, b'')
- self.assertTrue(err.startswith(expected),
- "%s doesn't start with %s" % (ascii(err), ascii(expected)))
+ self.assertStartsWith(err, expected)
# test that stderr buffer is flushed before the exit message is written
# into stderr
@@ -437,7 +436,7 @@ class SysModuleTest(unittest.TestCase):
@unittest.skipUnless(hasattr(sys, "setdlopenflags"),
'test needs sys.setdlopenflags()')
def test_dlopenflags(self):
- self.assertTrue(hasattr(sys, "getdlopenflags"))
+ self.assertHasAttr(sys, "getdlopenflags")
self.assertRaises(TypeError, sys.getdlopenflags, 42)
oldflags = sys.getdlopenflags()
self.assertRaises(TypeError, sys.setdlopenflags)
@@ -623,8 +622,7 @@ class SysModuleTest(unittest.TestCase):
# And the next record must be for g456().
filename, lineno, funcname, sourceline = stack[i+1]
self.assertEqual(funcname, "g456")
- self.assertTrue((sourceline.startswith("if leave_g.wait(") or
- sourceline.startswith("g_raised.set()")))
+ self.assertStartsWith(sourceline, ("if leave_g.wait(", "g_raised.set()"))
finally:
# Reap the spawned thread.
leave_g.set()
@@ -731,7 +729,7 @@ class SysModuleTest(unittest.TestCase):
info = sys.thread_info
self.assertEqual(len(info), 3)
self.assertIn(info.name, ('nt', 'pthread', 'pthread-stubs', 'solaris', None))
- self.assertIn(info.lock, ('semaphore', 'mutex+cond', None))
+ self.assertIn(info.lock, ('pymutex', None))
if sys.platform.startswith(("linux", "android", "freebsd")):
self.assertEqual(info.name, "pthread")
elif sys.platform == "win32":
@@ -860,7 +858,7 @@ class SysModuleTest(unittest.TestCase):
"hash_randomization", "isolated", "dev_mode", "utf8_mode",
"warn_default_encoding", "safe_path", "int_max_str_digits")
for attr in attrs:
- self.assertTrue(hasattr(sys.flags, attr), attr)
+ self.assertHasAttr(sys.flags, attr)
attr_type = bool if attr in ("dev_mode", "safe_path") else int
self.assertEqual(type(getattr(sys.flags, attr)), attr_type, attr)
self.assertTrue(repr(sys.flags))
@@ -871,12 +869,7 @@ class SysModuleTest(unittest.TestCase):
def assert_raise_on_new_sys_type(self, sys_attr):
# Users are intentionally prevented from creating new instances of
# sys.flags, sys.version_info, and sys.getwindowsversion.
- arg = sys_attr
- attr_type = type(sys_attr)
- with self.assertRaises(TypeError):
- attr_type(arg)
- with self.assertRaises(TypeError):
- attr_type.__new__(attr_type, arg)
+ support.check_disallow_instantiation(self, type(sys_attr), sys_attr)
def test_sys_flags_no_instantiation(self):
self.assert_raise_on_new_sys_type(sys.flags)
@@ -1072,10 +1065,11 @@ class SysModuleTest(unittest.TestCase):
levels = {'alpha': 0xA, 'beta': 0xB, 'candidate': 0xC, 'final': 0xF}
- self.assertTrue(hasattr(sys.implementation, 'name'))
- self.assertTrue(hasattr(sys.implementation, 'version'))
- self.assertTrue(hasattr(sys.implementation, 'hexversion'))
- self.assertTrue(hasattr(sys.implementation, 'cache_tag'))
+ self.assertHasAttr(sys.implementation, 'name')
+ self.assertHasAttr(sys.implementation, 'version')
+ self.assertHasAttr(sys.implementation, 'hexversion')
+ self.assertHasAttr(sys.implementation, 'cache_tag')
+ self.assertHasAttr(sys.implementation, 'supports_isolated_interpreters')
version = sys.implementation.version
self.assertEqual(version[:2], (version.major, version.minor))
@@ -1089,6 +1083,15 @@ class SysModuleTest(unittest.TestCase):
self.assertEqual(sys.implementation.name,
sys.implementation.name.lower())
+ # https://peps.python.org/pep-0734
+ sii = sys.implementation.supports_isolated_interpreters
+ self.assertIsInstance(sii, bool)
+ if test.support.check_impl_detail(cpython=True):
+ if test.support.is_emscripten or test.support.is_wasi:
+ self.assertFalse(sii)
+ else:
+ self.assertTrue(sii)
+
@test.support.cpython_only
def test_debugmallocstats(self):
# Test sys._debugmallocstats()
@@ -1137,23 +1140,12 @@ class SysModuleTest(unittest.TestCase):
b = sys.getallocatedblocks()
self.assertLessEqual(b, a)
try:
- # While we could imagine a Python session where the number of
- # multiple buffer objects would exceed the sharing of references,
- # it is unlikely to happen in a normal test run.
- #
- # In free-threaded builds each code object owns an array of
- # pointers to copies of the bytecode. When the number of
- # code objects is a large fraction of the total number of
- # references, this can cause the total number of allocated
- # blocks to exceed the total number of references.
- #
- # For some reason, iOS seems to trigger the "unlikely to happen"
- # case reliably under CI conditions. It's not clear why; but as
- # this test is checking the behavior of getallocatedblock()
- # under garbage collection, we can skip this pre-condition check
- # for now. See GH-130384.
- if not support.Py_GIL_DISABLED and not support.is_apple_mobile:
- self.assertLess(a, sys.gettotalrefcount())
+ # The reported blocks will include immortalized strings, but the
+ # total ref count will not. This will sanity check that among all
+ # other objects (those eligible for garbage collection) there
+ # are more references being tracked than allocated blocks.
+ interned_immortal = sys.getunicodeinternedsize(_only_immortal=True)
+ self.assertLess(a - interned_immortal, sys.gettotalrefcount())
except AttributeError:
# gettotalrefcount() not available
pass
@@ -1301,6 +1293,7 @@ class SysModuleTest(unittest.TestCase):
for name in sys.stdlib_module_names:
self.assertIsInstance(name, str)
+ @unittest.skipUnless(hasattr(sys, '_stdlib_dir'), 'need sys._stdlib_dir')
def test_stdlib_dir(self):
os = import_helper.import_fresh_module('os')
marker = getattr(os, '__file__', None)
@@ -1419,7 +1412,7 @@ class UnraisableHookTest(unittest.TestCase):
else:
self.assertIn("ValueError", report)
self.assertIn("del is broken", report)
- self.assertTrue(report.endswith("\n"))
+ self.assertEndsWith(report, "\n")
def test_original_unraisablehook_exception_qualname(self):
# See bpo-41031, bpo-45083.
@@ -1955,33 +1948,19 @@ class SizeofTest(unittest.TestCase):
self.assertEqual(out, b"")
self.assertEqual(err, b"")
-
-def _supports_remote_attaching():
- PROCESS_VM_READV_SUPPORTED = False
-
- try:
- from _testexternalinspection import PROCESS_VM_READV_SUPPORTED
- except ImportError:
- pass
-
- return PROCESS_VM_READV_SUPPORTED
-
-@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled")
-@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32",
- "Test only runs on Linux, Windows and MacOS")
-@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(),
- "Test only runs on Linux with process_vm_readv support")
+@test.support.support_remote_exec_only
@test.support.cpython_only
class TestRemoteExec(unittest.TestCase):
def tearDown(self):
test.support.reap_children()
- def _run_remote_exec_test(self, script_code, python_args=None, env=None, prologue=''):
+ def _run_remote_exec_test(self, script_code, python_args=None, env=None,
+ prologue='',
+ script_path=os_helper.TESTFN + '_remote.py'):
# Create the script that will be remotely executed
- script = os_helper.TESTFN + '_remote.py'
- self.addCleanup(os_helper.unlink, script)
+ self.addCleanup(os_helper.unlink, script_path)
- with open(script, 'w') as f:
+ with open(script_path, 'w') as f:
f.write(script_code)
# Create and run the target process
@@ -2050,7 +2029,7 @@ sock.close()
self.assertEqual(response, b"ready")
# Try remote exec on the target process
- sys.remote_exec(proc.pid, script)
+ sys.remote_exec(proc.pid, script_path)
# Signal script to continue
client_socket.sendall(b"continue")
@@ -2073,14 +2052,32 @@ sock.close()
def test_remote_exec(self):
"""Test basic remote exec functionality"""
- script = '''
-print("Remote script executed successfully!")
-'''
+ script = 'print("Remote script executed successfully!")'
returncode, stdout, stderr = self._run_remote_exec_test(script)
# self.assertEqual(returncode, 0)
self.assertIn(b"Remote script executed successfully!", stdout)
self.assertEqual(stderr, b"")
+ def test_remote_exec_bytes(self):
+ script = 'print("Remote script executed successfully!")'
+ script_path = os.fsencode(os_helper.TESTFN) + b'_bytes_remote.py'
+ returncode, stdout, stderr = self._run_remote_exec_test(script,
+ script_path=script_path)
+ self.assertIn(b"Remote script executed successfully!", stdout)
+ self.assertEqual(stderr, b"")
+
+ @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE, 'requires undecodable path')
+ @unittest.skipIf(sys.platform == 'darwin',
+ 'undecodable paths are not supported on macOS')
+ def test_remote_exec_undecodable(self):
+ script = 'print("Remote script executed successfully!")'
+ script_path = os_helper.TESTFN_UNDECODABLE + b'_undecodable_remote.py'
+ for script_path in [script_path, os.fsdecode(script_path)]:
+ returncode, stdout, stderr = self._run_remote_exec_test(script,
+ script_path=script_path)
+ self.assertIn(b"Remote script executed successfully!", stdout)
+ self.assertEqual(stderr, b"")
+
def test_remote_exec_with_self_process(self):
"""Test remote exec with the target process being the same as the test process"""
@@ -2101,7 +2098,7 @@ print("Remote script executed successfully!")
prologue = '''\
import sys
def audit_hook(event, arg):
- print(f"Audit event: {event}, arg: {arg}")
+ print(f"Audit event: {event}, arg: {arg}".encode("ascii", errors="replace"))
sys.addaudithook(audit_hook)
'''
script = '''
@@ -2110,7 +2107,7 @@ print("Remote script executed successfully!")
returncode, stdout, stderr = self._run_remote_exec_test(script, prologue=prologue)
self.assertEqual(returncode, 0)
self.assertIn(b"Remote script executed successfully!", stdout)
- self.assertIn(b"Audit event: remote_debugger_script, arg: ", stdout)
+ self.assertIn(b"Audit event: cpython.remote_debugger_script, arg: ", stdout)
self.assertEqual(stderr, b"")
def test_remote_exec_with_exception(self):
@@ -2157,6 +2154,13 @@ raise Exception("Remote script exception")
with self.assertRaises(OSError):
sys.remote_exec(99999, "print('should not run')")
+ def test_remote_exec_invalid_script(self):
+ """Test remote exec with invalid script type"""
+ with self.assertRaises(TypeError):
+ sys.remote_exec(0, None)
+ with self.assertRaises(TypeError):
+ sys.remote_exec(0, 123)
+
def test_remote_exec_syntax_error(self):
"""Test remote exec with syntax error in script"""
script = '''
@@ -2196,6 +2200,64 @@ this is invalid python code
self.assertIn(b"Remote debugging is not enabled", err)
self.assertEqual(out, b"")
+class TestSysJIT(unittest.TestCase):
+
+ def test_jit_is_available(self):
+ available = sys._jit.is_available()
+ script = f"import sys; assert sys._jit.is_available() is {available}"
+ assert_python_ok("-c", script, PYTHON_JIT="0")
+ assert_python_ok("-c", script, PYTHON_JIT="1")
+
+ def test_jit_is_enabled(self):
+ available = sys._jit.is_available()
+ script = "import sys; assert sys._jit.is_enabled() is {enabled}"
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
+
+ def test_jit_is_active(self):
+ available = sys._jit.is_available()
+ script = textwrap.dedent(
+ """
+ import _testcapi
+ import _testinternalcapi
+ import sys
+
+ def frame_0_interpreter() -> None:
+ assert sys._jit.is_active() is False
+
+ def frame_1_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_0_interpreter()
+ assert sys._jit.is_active() is False
+
+ def frame_2_jit(expected: bool) -> None:
+ # Inlined into the last loop of frame_3_jit:
+ assert sys._jit.is_active() is expected
+ # Insert C frame:
+ _testcapi.pyobject_vectorcall(frame_1_interpreter, None, None)
+ assert sys._jit.is_active() is expected
+
+ def frame_3_jit() -> None:
+ # JITs just before the last loop:
+ for i in range(_testinternalcapi.TIER2_THRESHOLD + 1):
+ # Careful, doing this in the reverse order breaks tracing:
+ expected = {enabled} and i == _testinternalcapi.TIER2_THRESHOLD
+ assert sys._jit.is_active() is expected
+ frame_2_jit(expected)
+ assert sys._jit.is_active() is expected
+
+ def frame_4_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_3_jit()
+ assert sys._jit.is_active() is False
+
+ assert sys._jit.is_active() is False
+ frame_4_interpreter()
+ assert sys._jit.is_active() is False
+ """
+ )
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
if __name__ == "__main__":
diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py
index 53e55383bf9..2eb8de4b29f 100644
--- a/Lib/test/test_sysconfig.py
+++ b/Lib/test/test_sysconfig.py
@@ -32,7 +32,6 @@ from sysconfig import (get_paths, get_platform, get_config_vars,
from sysconfig.__main__ import _main, _parse_makefile, _get_pybuilddir, _get_json_data_name
import _imp
import _osx_support
-import _sysconfig
HAS_USER_BASE = sysconfig._HAS_USER_BASE
@@ -186,7 +185,7 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
# The include directory on POSIX isn't exactly the same as before,
# but it is "within"
sysconfig_includedir = sysconfig.get_path('include', scheme='posix_venv', vars=vars)
- self.assertTrue(sysconfig_includedir.startswith(incpath + os.sep))
+ self.assertStartsWith(sysconfig_includedir, incpath + os.sep)
def test_nt_venv_scheme(self):
# The following directories were hardcoded in the venv module
@@ -531,13 +530,10 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
Python_h = os.path.join(srcdir, 'Include', 'Python.h')
self.assertTrue(os.path.exists(Python_h), Python_h)
# <srcdir>/PC/pyconfig.h.in always exists even if unused
- pyconfig_h = os.path.join(srcdir, 'PC', 'pyconfig.h.in')
- self.assertTrue(os.path.exists(pyconfig_h), pyconfig_h)
pyconfig_h_in = os.path.join(srcdir, 'pyconfig.h.in')
self.assertTrue(os.path.exists(pyconfig_h_in), pyconfig_h_in)
if os.name == 'nt':
- # <executable dir>/pyconfig.h exists on Windows in a build tree
- pyconfig_h = os.path.join(sys.executable, '..', 'pyconfig.h')
+ pyconfig_h = os.path.join(srcdir, 'PC', 'pyconfig.h')
self.assertTrue(os.path.exists(pyconfig_h), pyconfig_h)
elif os.name == 'posix':
makefile_dir = os.path.dirname(sysconfig.get_makefile_filename())
@@ -572,8 +568,7 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
expected_suffixes = 'i386-linux-gnu.so', 'x86_64-linux-gnux32.so', 'i386-linux-musl.so'
else: # 8 byte pointer size
expected_suffixes = 'x86_64-linux-gnu.so', 'x86_64-linux-musl.so'
- self.assertTrue(suffix.endswith(expected_suffixes),
- f'unexpected suffix {suffix!r}')
+ self.assertEndsWith(suffix, expected_suffixes)
@unittest.skipUnless(sys.platform == 'android', 'Android-specific test')
def test_android_ext_suffix(self):
@@ -585,13 +580,12 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
"aarch64": "aarch64-linux-android",
"armv7l": "arm-linux-androideabi",
}[machine]
- self.assertTrue(suffix.endswith(f"-{expected_triplet}.so"),
- f"{machine=}, {suffix=}")
+ self.assertEndsWith(suffix, f"-{expected_triplet}.so")
@unittest.skipUnless(sys.platform == 'darwin', 'OS X-specific test')
def test_osx_ext_suffix(self):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
- self.assertTrue(suffix.endswith('-darwin.so'), suffix)
+ self.assertEndsWith(suffix, '-darwin.so')
def test_always_set_py_debug(self):
self.assertIn('Py_DEBUG', sysconfig.get_config_vars())
@@ -717,8 +711,8 @@ class TestSysConfig(unittest.TestCase, VirtualEnvironmentMixin):
ignore_keys |= {'prefix', 'exec_prefix', 'base', 'platbase', 'installed_base', 'installed_platbase'}
for key in ignore_keys:
- json_config_vars.pop(key)
- system_config_vars.pop(key)
+ json_config_vars.pop(key, None)
+ system_config_vars.pop(key, None)
self.assertEqual(system_config_vars, json_config_vars)
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index fcbaf854cc2..7055e1ed147 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -38,6 +38,10 @@ try:
import lzma
except ImportError:
lzma = None
+try:
+ from compression import zstd
+except ImportError:
+ zstd = None
def sha256sum(data):
return sha256(data).hexdigest()
@@ -48,6 +52,7 @@ tarname = support.findfile("testtar.tar", subdir="archivetestdata")
gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
xzname = os.path.join(TEMPDIR, "testtar.tar.xz")
+zstname = os.path.join(TEMPDIR, "testtar.tar.zst")
tmpname = os.path.join(TEMPDIR, "tmp.tar")
dotlessname = os.path.join(TEMPDIR, "testtar")
@@ -90,6 +95,12 @@ class LzmaTest:
open = lzma.LZMAFile if lzma else None
taropen = tarfile.TarFile.xzopen
+@support.requires_zstd()
+class ZstdTest:
+ tarname = zstname
+ suffix = 'zst'
+ open = zstd.ZstdFile if zstd else None
+ taropen = tarfile.TarFile.zstopen
class ReadTest(TarTest):
@@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest):
class LzmaUstarReadTest(LzmaTest, UstarReadTest):
pass
+class ZstdUstarReadTest(ZstdTest, UstarReadTest):
+ pass
class ListTest(ReadTest, unittest.TestCase):
@@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest):
class LzmaListTest(LzmaTest, ListTest):
pass
+class ZstdListTest(ZstdTest, ListTest):
+ pass
class CommonReadTest(ReadTest):
@@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase):
class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase):
pass
+class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase):
+ pass
class StreamReadTest(CommonReadTest, unittest.TestCase):
@@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest):
class LzmaStreamReadTest(LzmaTest, StreamReadTest):
pass
+class ZstdStreamReadTest(ZstdTest, StreamReadTest):
+ pass
+
class TarStreamModeReadTest(StreamModeTest, unittest.TestCase):
def test_stream_mode_no_cache(self):
@@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest):
class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest):
pass
+class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest):
+ pass
+
class DetectReadTest(TarTest, unittest.TestCase):
def _testfunc_file(self, name, mode):
try:
@@ -986,6 +1009,8 @@ class Bz2DetectReadTest(Bz2Test, DetectReadTest):
class LzmaDetectReadTest(LzmaTest, DetectReadTest):
pass
+class ZstdDetectReadTest(ZstdTest, DetectReadTest):
+ pass
class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase):
"""
@@ -1625,7 +1650,7 @@ class WriteTest(WriteTestBase, unittest.TestCase):
try:
for t in tar:
if t.name != ".":
- self.assertTrue(t.name.startswith("./"), t.name)
+ self.assertStartsWith(t.name, "./")
finally:
tar.close()
@@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest):
class LzmaWriteTest(LzmaTest, WriteTest):
pass
+class ZstdWriteTest(ZstdTest, WriteTest):
+ pass
class StreamWriteTest(WriteTestBase, unittest.TestCase):
@@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest):
class LzmaStreamWriteTest(LzmaTest, StreamWriteTest):
decompressor = lzma.LZMADecompressor if lzma else None
+class ZstdStreamWriteTest(ZstdTest, StreamWriteTest):
+ decompressor = zstd.ZstdDecompressor if zstd else None
+
class _CompressedWriteTest(TarTest):
# This is not actually a standalone test.
# It does not inherit WriteTest because it only makes sense with gz,bz2
@@ -2042,6 +2072,14 @@ class LzmaCreateTest(LzmaTest, CreateTest):
tobj.add(self.file_path)
+class ZstdCreateTest(ZstdTest, CreateTest):
+
+ # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel.
+ # It does not allow for level to be specified when reading.
+ def test_create_with_level(self):
+ with tarfile.open(tmpname, self.mode, level=1) as tobj:
+ tobj.add(self.file_path)
+
class CreateWithXModeTest(CreateTest):
prefix = "x"
@@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase):
class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase):
pass
+class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase):
+ pass
class LimitsTest(unittest.TestCase):
@@ -2675,6 +2715,31 @@ class MiscTest(unittest.TestCase):
str(excinfo.exception),
)
+ @unittest.skipUnless(os_helper.can_symlink(), 'requires symlink support')
+ @unittest.skipUnless(hasattr(os, 'chmod'), "missing os.chmod")
+ @unittest.mock.patch('os.chmod')
+ def test_deferred_directory_attributes_update(self, mock_chmod):
+ # Regression test for gh-127987: setting attributes on arbitrary files
+ tempdir = os.path.join(TEMPDIR, 'test127987')
+ def mock_chmod_side_effect(path, mode, **kwargs):
+ target_path = os.path.realpath(path)
+ if os.path.commonpath([target_path, tempdir]) != tempdir:
+ raise Exception("should not try to chmod anything outside the destination", target_path)
+ mock_chmod.side_effect = mock_chmod_side_effect
+
+ outside_tree_dir = os.path.join(TEMPDIR, 'outside_tree_dir')
+ with ArchiveMaker() as arc:
+ arc.add('x', symlink_to='.')
+ arc.add('x', type=tarfile.DIRTYPE, mode='?rwsrwsrwt')
+ arc.add('x', symlink_to=outside_tree_dir)
+
+ os.makedirs(outside_tree_dir)
+ try:
+ arc.open().extractall(path=tempdir, filter='tar')
+ finally:
+ os_helper.rmtree(outside_tree_dir)
+ os_helper.rmtree(tempdir)
+
class CommandLineTest(unittest.TestCase):
@@ -2835,7 +2900,7 @@ class CommandLineTest(unittest.TestCase):
support.findfile('tokenize_tests-no-coding-cookie-'
'and-utf8-bom-sig-only.txt',
subdir='tokenizedata')]
- for filetype in (GzipTest, Bz2Test, LzmaTest):
+ for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest):
if not filetype.open:
continue
try:
@@ -3235,6 +3300,10 @@ class NoneInfoExtractTests(ReadTest):
got_paths = set(
p.relative_to(directory)
for p in pathlib.Path(directory).glob('**/*'))
+ if self.extraction_filter in (None, 'data'):
+ # The 'data' filter is expected to reject special files
+ for path in 'ustar/fifotype', 'ustar/blktype', 'ustar/chrtype':
+ got_paths.discard(pathlib.Path(path))
self.assertEqual(self.control_paths, got_paths)
@contextmanager
@@ -3450,11 +3519,12 @@ class ArchiveMaker:
with t.open() as tar:
... # `tar` is now a TarFile with 'filename' in it!
"""
- def __init__(self):
+ def __init__(self, **kwargs):
self.bio = io.BytesIO()
+ self.tar_kwargs = dict(kwargs)
def __enter__(self):
- self.tar_w = tarfile.TarFile(mode='w', fileobj=self.bio)
+ self.tar_w = tarfile.TarFile(mode='w', fileobj=self.bio, **self.tar_kwargs)
return self
def __exit__(self, *exc):
@@ -3463,12 +3533,28 @@ class ArchiveMaker:
self.bio = None
def add(self, name, *, type=None, symlink_to=None, hardlink_to=None,
- mode=None, size=None, **kwargs):
- """Add a member to the test archive. Call within `with`."""
+ mode=None, size=None, content=None, **kwargs):
+ """Add a member to the test archive. Call within `with`.
+
+ Provides many shortcuts:
+ - default `type` is based on symlink_to, hardlink_to, and trailing `/`
+ in name (which is stripped)
+ - size & content defaults are based on each other
+ - content can be str or bytes
+ - mode should be textual ('-rwxrwxrwx')
+
+ (add more! this is unstable internal test-only API)
+ """
name = str(name)
tarinfo = tarfile.TarInfo(name).replace(**kwargs)
+ if content is not None:
+ if isinstance(content, str):
+ content = content.encode()
+ size = len(content)
if size is not None:
tarinfo.size = size
+ if content is None:
+ content = bytes(tarinfo.size)
if mode:
tarinfo.mode = _filemode_to_int(mode)
if symlink_to is not None:
@@ -3482,7 +3568,7 @@ class ArchiveMaker:
if type is not None:
tarinfo.type = type
if tarinfo.isreg():
- fileobj = io.BytesIO(bytes(tarinfo.size))
+ fileobj = io.BytesIO(content)
else:
fileobj = None
self.tar_w.addfile(tarinfo, fileobj)
@@ -3516,7 +3602,7 @@ class TestExtractionFilters(unittest.TestCase):
destdir = outerdir / 'dest'
@contextmanager
- def check_context(self, tar, filter):
+ def check_context(self, tar, filter, *, check_flag=True):
"""Extracts `tar` to `self.destdir` and allows checking the result
If an error occurs, it must be checked using `expect_exception`
@@ -3525,27 +3611,40 @@ class TestExtractionFilters(unittest.TestCase):
except the destination directory itself and parent directories of
other files.
When checking directories, do so before their contents.
+
+ A file called 'flag' is made in outerdir (i.e. outside destdir)
+ before extraction; it should not be altered nor should its contents
+ be read/copied.
"""
with os_helper.temp_dir(self.outerdir):
+ flag_path = self.outerdir / 'flag'
+ flag_path.write_text('capture me')
try:
tar.extractall(self.destdir, filter=filter)
except Exception as exc:
self.raised_exception = exc
+ self.reraise_exception = True
self.expected_paths = set()
else:
self.raised_exception = None
+ self.reraise_exception = False
self.expected_paths = set(self.outerdir.glob('**/*'))
self.expected_paths.discard(self.destdir)
+ self.expected_paths.discard(flag_path)
try:
- yield
+ yield self
finally:
tar.close()
- if self.raised_exception:
+ if self.reraise_exception:
raise self.raised_exception
self.assertEqual(self.expected_paths, set())
+ if check_flag:
+ self.assertEqual(flag_path.read_text(), 'capture me')
+ else:
+ assert filter == 'fully_trusted'
def expect_file(self, name, type=None, symlink_to=None, mode=None,
- size=None):
+ size=None, content=None):
"""Check a single file. See check_context."""
if self.raised_exception:
raise self.raised_exception
@@ -3564,26 +3663,45 @@ class TestExtractionFilters(unittest.TestCase):
# The symlink might be the same (textually) as what we expect,
# but some systems change the link to an equivalent path, so
# we fall back to samefile().
- if expected != got:
- self.assertTrue(got.samefile(expected))
+ try:
+ if expected != got:
+ self.assertTrue(got.samefile(expected))
+ except Exception as e:
+ # attach a note, so it's shown even if `samefile` fails
+ e.add_note(f'{expected=}, {got=}')
+ raise
elif type == tarfile.REGTYPE or type is None:
self.assertTrue(path.is_file())
elif type == tarfile.DIRTYPE:
self.assertTrue(path.is_dir())
elif type == tarfile.FIFOTYPE:
self.assertTrue(path.is_fifo())
+ elif type == tarfile.SYMTYPE:
+ self.assertTrue(path.is_symlink())
else:
raise NotImplementedError(type)
if size is not None:
self.assertEqual(path.stat().st_size, size)
+ if content is not None:
+ self.assertEqual(path.read_text(), content)
for parent in path.parents:
self.expected_paths.discard(parent)
+ def expect_any_tree(self, name):
+ """Check a directory; forget about its contents."""
+ tree_path = (self.destdir / name).resolve()
+ self.expect_file(tree_path, type=tarfile.DIRTYPE)
+ self.expected_paths = {
+ p for p in self.expected_paths
+ if tree_path not in p.parents
+ }
+
def expect_exception(self, exc_type, message_re='.'):
with self.assertRaisesRegex(exc_type, message_re):
if self.raised_exception is not None:
raise self.raised_exception
- self.raised_exception = None
+ self.reraise_exception = False
+ return self.raised_exception
def test_benign_file(self):
with ArchiveMaker() as arc:
@@ -3669,6 +3787,80 @@ class TestExtractionFilters(unittest.TestCase):
self.expect_file('parent/evil')
@symlink_test
+ @os_helper.skip_unless_symlink
+ def test_realpath_limit_attack(self):
+ # (CVE-2025-4517)
+
+ with ArchiveMaker() as arc:
+ # populate the symlinks and dirs that expand in os.path.realpath()
+ # The component length is chosen so that in common cases, the unexpanded
+ # path fits in PATH_MAX, but it overflows when the final symlink
+ # is expanded
+ steps = "abcdefghijklmnop"
+ if sys.platform == 'win32':
+ component = 'd' * 25
+ elif 'PC_PATH_MAX' in os.pathconf_names:
+ max_path_len = os.pathconf(self.outerdir.parent, "PC_PATH_MAX")
+ path_sep_len = 1
+ dest_len = len(str(self.destdir)) + path_sep_len
+ component_len = (max_path_len - dest_len) // (len(steps) + path_sep_len)
+ component = 'd' * component_len
+ else:
+ raise NotImplementedError("Need to guess component length for {sys.platform}")
+ path = ""
+ step_path = ""
+ for i in steps:
+ arc.add(os.path.join(path, component), type=tarfile.DIRTYPE,
+ mode='drwxrwxrwx')
+ arc.add(os.path.join(path, i), symlink_to=component)
+ path = os.path.join(path, component)
+ step_path = os.path.join(step_path, i)
+ # create the final symlink that exceeds PATH_MAX and simply points
+ # to the top dir.
+ # this link will never be expanded by
+ # os.path.realpath(strict=False), nor anything after it.
+ linkpath = os.path.join(*steps, "l"*254)
+ parent_segments = [".."] * len(steps)
+ arc.add(linkpath, symlink_to=os.path.join(*parent_segments))
+ # make a symlink outside to keep the tar command happy
+ arc.add("escape", symlink_to=os.path.join(linkpath, ".."))
+ # use the symlinks above, that are not checked, to create a hardlink
+ # to a file outside of the destination path
+ arc.add("flaglink", hardlink_to=os.path.join("escape", "flag"))
+ # now that we have the hardlink we can overwrite the file
+ arc.add("flaglink", content='overwrite')
+ # we can also create new files as well!
+ arc.add("escape/newfile", content='new')
+
+ with (self.subTest('fully_trusted'),
+ self.check_context(arc.open(), filter='fully_trusted',
+ check_flag=False)):
+ if sys.platform == 'win32':
+ self.expect_exception((FileNotFoundError, FileExistsError))
+ elif self.raised_exception:
+ # Cannot symlink/hardlink: tarfile falls back to getmember()
+ self.expect_exception(KeyError)
+ # Otherwise, this block should never enter.
+ else:
+ self.expect_any_tree(component)
+ self.expect_file('flaglink', content='overwrite')
+ self.expect_file('../newfile', content='new')
+ self.expect_file('escape', type=tarfile.SYMTYPE)
+ self.expect_file('a', symlink_to=component)
+
+ for filter in 'tar', 'data':
+ with self.subTest(filter), self.check_context(arc.open(), filter=filter):
+ exc = self.expect_exception((OSError, KeyError))
+ if isinstance(exc, OSError):
+ if sys.platform == 'win32':
+ # 3: ERROR_PATH_NOT_FOUND
+ # 5: ERROR_ACCESS_DENIED
+ # 206: ERROR_FILENAME_EXCED_RANGE
+ self.assertIn(exc.winerror, (3, 5, 206))
+ else:
+ self.assertEqual(exc.errno, errno.ENAMETOOLONG)
+
+ @symlink_test
def test_parent_symlink2(self):
# Test interplaying symlinks
# Inspired by 'dirsymlink2b' in jwilk/traversal-archives
@@ -3890,8 +4082,8 @@ class TestExtractionFilters(unittest.TestCase):
arc.add('symlink2', symlink_to=os.path.join(
'linkdir', 'hardlink2'))
arc.add('targetdir/target', size=3)
- arc.add('linkdir/hardlink', hardlink_to='targetdir/target')
- arc.add('linkdir/hardlink2', hardlink_to='linkdir/symlink')
+ arc.add('linkdir/hardlink', hardlink_to=os.path.join('targetdir', 'target'))
+ arc.add('linkdir/hardlink2', hardlink_to=os.path.join('linkdir', 'symlink'))
for filter in 'tar', 'data', 'fully_trusted':
with self.check_context(arc.open(), filter):
@@ -3907,6 +4099,129 @@ class TestExtractionFilters(unittest.TestCase):
self.expect_file('linkdir/symlink', size=3)
self.expect_file('symlink2', size=3)
+ @symlink_test
+ def test_sneaky_hardlink_fallback(self):
+ # (CVE-2025-4330)
+ # Test that when hardlink extraction falls back to extracting members
+ # from the archive, the extracted member is (re-)filtered.
+ with ArchiveMaker() as arc:
+ # Create a directory structure so the c/escape symlink stays
+ # inside the path
+ arc.add("a/t/dummy")
+ # Create b/ directory
+ arc.add("b/")
+ # Point "c" to the bottom of the tree in "a"
+ arc.add("c", symlink_to=os.path.join("a", "t"))
+ # link to non-existant location under "a"
+ arc.add("c/escape", symlink_to=os.path.join("..", "..",
+ "link_here"))
+ # Move "c" to point to "b" ("c/escape" no longer exists)
+ arc.add("c", symlink_to="b")
+ # Attempt to create a hard link to "c/escape". Since it doesn't
+ # exist it will attempt to extract "cescape" but at "boom".
+ arc.add("boom", hardlink_to=os.path.join("c", "escape"))
+
+ with self.check_context(arc.open(), 'data'):
+ if not os_helper.can_symlink():
+ # When 'c/escape' is extracted, 'c' is a regular
+ # directory, and 'c/escape' *would* point outside
+ # the destination if symlinks were allowed.
+ self.expect_exception(
+ tarfile.LinkOutsideDestinationError)
+ elif sys.platform == "win32":
+ # On Windows, 'c/escape' points outside the destination
+ self.expect_exception(tarfile.LinkOutsideDestinationError)
+ else:
+ e = self.expect_exception(
+ tarfile.LinkFallbackError,
+ "link 'boom' would be extracted as a copy of "
+ + "'c/escape', which was rejected")
+ self.assertIsInstance(e.__cause__,
+ tarfile.LinkOutsideDestinationError)
+ for filter in 'tar', 'fully_trusted':
+ with self.subTest(filter), self.check_context(arc.open(), filter):
+ if not os_helper.can_symlink():
+ self.expect_file("a/t/dummy")
+ self.expect_file("b/")
+ self.expect_file("c/")
+ else:
+ self.expect_file("a/t/dummy")
+ self.expect_file("b/")
+ self.expect_file("a/t/escape", symlink_to='../../link_here')
+ self.expect_file("boom", symlink_to='../../link_here')
+ self.expect_file("c", symlink_to='b')
+
+ @symlink_test
+ def test_exfiltration_via_symlink(self):
+ # (CVE-2025-4138)
+ # Test changing symlinks that result in a symlink pointing outside
+ # the extraction directory, unless prevented by 'data' filter's
+ # normalization.
+ with ArchiveMaker() as arc:
+ arc.add("escape", symlink_to=os.path.join('link', 'link', '..', '..', 'link-here'))
+ arc.add("link", symlink_to='./')
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter):
+ if os_helper.can_symlink():
+ self.expect_file("link", symlink_to='./')
+ if filter == 'data':
+ self.expect_file("escape", symlink_to='link-here')
+ else:
+ self.expect_file("escape",
+ symlink_to='link/link/../../link-here')
+ else:
+ # Nothing is extracted.
+ pass
+
+ @symlink_test
+ def test_chmod_outside_dir(self):
+ # (CVE-2024-12718)
+ # Test that members used for delayed updates of directory metadata
+ # are (re-)filtered.
+ with ArchiveMaker() as arc:
+ # "pwn" is a veeeery innocent symlink:
+ arc.add("a/pwn", symlink_to='.')
+ # But now "pwn" is also a directory, so it's scheduled to have its
+ # metadata updated later:
+ arc.add("a/pwn/", mode='drwxrwxrwx')
+ # Oops, "pwn" is not so innocent any more:
+ arc.add("a/pwn", symlink_to='x/../')
+ # Newly created symlink points to the dest dir,
+ # so it's OK for the "data" filter.
+ arc.add('a/x', symlink_to=('../'))
+ # But now "pwn" points outside the dest dir
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter) as cc:
+ if not os_helper.can_symlink():
+ self.expect_file("a/pwn/")
+ elif filter == 'data':
+ self.expect_file("a/x", symlink_to='../')
+ self.expect_file("a/pwn", symlink_to='.')
+ else:
+ self.expect_file("a/x", symlink_to='../')
+ self.expect_file("a/pwn", symlink_to='x/../')
+ if sys.platform != "win32":
+ st_mode = cc.outerdir.stat().st_mode
+ self.assertNotEqual(st_mode & 0o777, 0o777)
+
+ def test_link_fallback_normalizes(self):
+ # Make sure hardlink fallbacks work for non-normalized paths for all
+ # filters
+ with ArchiveMaker() as arc:
+ arc.add("dir/")
+ arc.add("dir/../afile")
+ arc.add("link1", hardlink_to='dir/../afile')
+ arc.add("link2", hardlink_to='dir/../dir/../afile')
+
+ for filter in 'tar', 'data', 'fully_trusted':
+ with self.check_context(arc.open(), filter) as cc:
+ self.expect_file("dir/")
+ self.expect_file("afile")
+ self.expect_file("link1")
+ self.expect_file("link2")
+
def test_modes(self):
# Test how file modes are extracted
# (Note that the modes are ignored on platforms without working chmod)
@@ -4031,24 +4346,64 @@ class TestExtractionFilters(unittest.TestCase):
# The 'tar' filter returns TarInfo objects with the same name/type.
# (It can also fail for particularly "evil" input, but we don't have
# that in the test archive.)
- with tarfile.TarFile.open(tarname) as tar:
+ with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar:
for tarinfo in tar.getmembers():
- filtered = tarfile.tar_filter(tarinfo, '')
+ try:
+ filtered = tarfile.tar_filter(tarinfo, '')
+ except UnicodeEncodeError:
+ continue
self.assertIs(filtered.name, tarinfo.name)
self.assertIs(filtered.type, tarinfo.type)
def test_data_filter(self):
# The 'data' filter either raises, or returns TarInfo with the same
# name/type.
- with tarfile.TarFile.open(tarname) as tar:
+ with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar:
for tarinfo in tar.getmembers():
try:
filtered = tarfile.data_filter(tarinfo, '')
- except tarfile.FilterError:
+ except (tarfile.FilterError, UnicodeEncodeError):
continue
self.assertIs(filtered.name, tarinfo.name)
self.assertIs(filtered.type, tarinfo.type)
+ @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths')
+ def test_filter_unencodable(self):
+ # Sanity check using a valid path.
+ tarinfo = tarfile.TarInfo(os_helper.TESTFN)
+ filtered = tarfile.tar_filter(tarinfo, '')
+ self.assertIs(filtered.name, tarinfo.name)
+ filtered = tarfile.data_filter(tarinfo, '')
+ self.assertIs(filtered.name, tarinfo.name)
+
+ tarinfo = tarfile.TarInfo('test\x00')
+ self.assertRaises(ValueError, tarfile.tar_filter, tarinfo, '')
+ self.assertRaises(ValueError, tarfile.data_filter, tarinfo, '')
+ tarinfo = tarfile.TarInfo('\ud800')
+ self.assertRaises(UnicodeEncodeError, tarfile.tar_filter, tarinfo, '')
+ self.assertRaises(UnicodeEncodeError, tarfile.data_filter, tarinfo, '')
+
+ @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths')
+ def test_extract_unencodable(self):
+ # Create a member with name \xed\xa0\x80 which is UTF-8 encoded
+ # lone surrogate \ud800.
+ with ArchiveMaker(encoding='ascii', errors='surrogateescape') as arc:
+ arc.add('\udced\udca0\udc80')
+ with os_helper.temp_cwd() as tmp:
+ tar = arc.open(encoding='utf-8', errors='surrogatepass',
+ errorlevel=1)
+ self.assertEqual(tar.getnames(), ['\ud800'])
+ with self.assertRaises(UnicodeEncodeError):
+ tar.extractall()
+ self.assertEqual(os.listdir(), [])
+
+ tar = arc.open(encoding='utf-8', errors='surrogatepass',
+ errorlevel=0, debug=1)
+ with support.captured_stderr() as stderr:
+ tar.extractall()
+ self.assertEqual(os.listdir(), [])
+ self.assertIn('tarfile: UnicodeEncodeError ', stderr.getvalue())
+
def test_change_default_filter_on_instance(self):
tar = tarfile.TarFile(tarname, 'r')
def strict_filter(tarinfo, path):
@@ -4161,13 +4516,13 @@ class TestExtractionFilters(unittest.TestCase):
# If errorlevel is 0, errors affected by errorlevel are ignored
with self.check_context(arc.open(errorlevel=0), extracterror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), filtererror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), oserror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=0), tarerror_filter):
self.expect_exception(tarfile.TarError)
@@ -4178,7 +4533,7 @@ class TestExtractionFilters(unittest.TestCase):
# If 1, all fatal errors are raised
with self.check_context(arc.open(errorlevel=1), extracterror_filter):
- self.expect_file('file')
+ pass
with self.check_context(arc.open(errorlevel=1), filtererror_filter):
self.expect_exception(tarfile.FilterError)
@@ -4257,7 +4612,7 @@ def setUpModule():
data = fobj.read()
# Create compressed tarfiles.
- for c in GzipTest, Bz2Test, LzmaTest:
+ for c in GzipTest, Bz2Test, LzmaTest, ZstdTest:
if c.open:
os_helper.unlink(c.tarname)
testtarnames.append(c.tarname)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index d46d3c0f040..52b13b98cbc 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -516,11 +516,11 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase):
_mock_candidate_names('aaa', 'aaa', 'bbb'):
(fd1, name1) = self.make_temp()
os.close(fd1)
- self.assertTrue(name1.endswith('aaa'))
+ self.assertEndsWith(name1, 'aaa')
(fd2, name2) = self.make_temp()
os.close(fd2)
- self.assertTrue(name2.endswith('bbb'))
+ self.assertEndsWith(name2, 'bbb')
def test_collision_with_existing_directory(self):
# _mkstemp_inner tries another name when a directory with
@@ -528,11 +528,11 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase):
with _inside_empty_temp_dir(), \
_mock_candidate_names('aaa', 'aaa', 'bbb'):
dir = tempfile.mkdtemp()
- self.assertTrue(dir.endswith('aaa'))
+ self.assertEndsWith(dir, 'aaa')
(fd, name) = self.make_temp()
os.close(fd)
- self.assertTrue(name.endswith('bbb'))
+ self.assertEndsWith(name, 'bbb')
class TestGetTempPrefix(BaseTestCase):
@@ -828,9 +828,9 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase):
_mock_candidate_names('aaa', 'aaa', 'bbb'):
file = tempfile.NamedTemporaryFile(delete=False)
file.close()
- self.assertTrue(file.name.endswith('aaa'))
+ self.assertEndsWith(file.name, 'aaa')
dir = tempfile.mkdtemp()
- self.assertTrue(dir.endswith('bbb'))
+ self.assertEndsWith(dir, 'bbb')
def test_collision_with_existing_directory(self):
# mkdtemp tries another name when a directory with
@@ -838,9 +838,9 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase):
with _inside_empty_temp_dir(), \
_mock_candidate_names('aaa', 'aaa', 'bbb'):
dir1 = tempfile.mkdtemp()
- self.assertTrue(dir1.endswith('aaa'))
+ self.assertEndsWith(dir1, 'aaa')
dir2 = tempfile.mkdtemp()
- self.assertTrue(dir2.endswith('bbb'))
+ self.assertEndsWith(dir2, 'bbb')
def test_for_tempdir_is_bytes_issue40701_api_warts(self):
orig_tempdir = tempfile.tempdir
diff --git a/Lib/test/test_termios.py b/Lib/test/test_termios.py
index e5d11cf84d2..ce8392a6ccd 100644
--- a/Lib/test/test_termios.py
+++ b/Lib/test/test_termios.py
@@ -290,8 +290,8 @@ class TestModule(unittest.TestCase):
self.assertGreaterEqual(value, 0)
def test_exception(self):
- self.assertTrue(issubclass(termios.error, Exception))
- self.assertFalse(issubclass(termios.error, OSError))
+ self.assertIsSubclass(termios.error, Exception)
+ self.assertNotIsSubclass(termios.error, OSError)
if __name__ == '__main__':
diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py
index 420fc6ec8be..acb427b0c78 100644
--- a/Lib/test/test_threadedtempfile.py
+++ b/Lib/test/test_threadedtempfile.py
@@ -15,6 +15,7 @@ provoking a 2.0 failure under Linux.
import tempfile
+from test import support
from test.support import threading_helper
import unittest
import io
@@ -49,7 +50,8 @@ class TempFileGreedy(threading.Thread):
class ThreadedTempFileTest(unittest.TestCase):
- def test_main(self):
+ @support.bigmemtest(size=NUM_THREADS, memuse=60*2**20, dry_run=False)
+ def test_main(self, size):
threads = [TempFileGreedy() for i in range(NUM_THREADS)]
with threading_helper.start_threads(threads, startEvent.set):
pass
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 814c00ca0fd..13b55d0f0a2 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -5,7 +5,7 @@ Tests for the threading module.
import test.support
from test.support import threading_helper, requires_subprocess, requires_gil_enabled
from test.support import verbose, cpython_only, os_helper
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.script_helper import assert_python_ok, assert_python_failure
from test.support import force_not_colorized
@@ -28,7 +28,7 @@ from test import lock_tests
from test import support
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ImportError:
interpreters = None
@@ -121,6 +121,10 @@ class ThreadTests(BaseTestCase):
maxDiff = 9999
@cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("threading", {"functools", "warnings"})
+
+ @cpython_only
def test_name(self):
def func(): pass
@@ -526,7 +530,8 @@ class ThreadTests(BaseTestCase):
finally:
sys.setswitchinterval(old_interval)
- def test_join_from_multiple_threads(self):
+ @support.bigmemtest(size=20, memuse=72*2**20, dry_run=False)
+ def test_join_from_multiple_threads(self, size):
# Thread.join() should be thread-safe
errors = []
@@ -1248,7 +1253,7 @@ class ThreadTests(BaseTestCase):
# its state should be removed from interpreter' thread states list
# to avoid its double cleanup
try:
- from resource import setrlimit, RLIMIT_NPROC
+ from resource import setrlimit, RLIMIT_NPROC # noqa: F401
except ImportError as err:
self.skipTest(err) # RLIMIT_NPROC is specific to Linux and BSD
code = """if 1:
@@ -1279,12 +1284,6 @@ class ThreadTests(BaseTestCase):
@cpython_only
def test_finalize_daemon_thread_hang(self):
- if support.check_sanitizer(thread=True, memory=True):
- # the thread running `time.sleep(100)` below will still be alive
- # at process exit
- self.skipTest(
- "https://github.com/python/cpython/issues/124878 - Known"
- " race condition that TSAN identifies.")
# gh-87135: tests that daemon threads hang during finalization
script = textwrap.dedent('''
import os
@@ -1347,6 +1346,35 @@ class ThreadTests(BaseTestCase):
''')
assert_python_ok("-c", script)
+ @skip_unless_reliable_fork
+ @unittest.skipUnless(hasattr(threading, 'get_native_id'), "test needs threading.get_native_id()")
+ def test_native_id_after_fork(self):
+ script = """if True:
+ import threading
+ import os
+ from test import support
+
+ parent_thread_native_id = threading.current_thread().native_id
+ print(parent_thread_native_id, flush=True)
+ assert parent_thread_native_id == threading.get_native_id()
+ childpid = os.fork()
+ if childpid == 0:
+ print(threading.current_thread().native_id, flush=True)
+ assert threading.current_thread().native_id == threading.get_native_id()
+ else:
+ try:
+ assert parent_thread_native_id == threading.current_thread().native_id
+ assert parent_thread_native_id == threading.get_native_id()
+ finally:
+ support.wait_process(childpid, exitcode=0)
+ """
+ rc, out, err = assert_python_ok('-c', script)
+ self.assertEqual(rc, 0)
+ self.assertEqual(err, b"")
+ native_ids = out.strip().splitlines()
+ self.assertEqual(len(native_ids), 2)
+ self.assertNotEqual(native_ids[0], native_ids[1])
+
class ThreadJoinOnShutdown(BaseTestCase):
def _run_and_join(self, script):
@@ -1427,7 +1455,8 @@ class ThreadJoinOnShutdown(BaseTestCase):
self._run_and_join(script)
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
- def test_4_daemon_threads(self):
+ @support.bigmemtest(size=40, memuse=70*2**20, dry_run=False)
+ def test_4_daemon_threads(self, size):
# Check that a daemon thread cannot crash the interpreter on shutdown
# by manipulating internal structures that are being disposed of in
# the main thread.
@@ -2131,8 +2160,7 @@ class CRLockTests(lock_tests.RLockTests):
]
for args, kwargs in arg_types:
with self.subTest(args=args, kwargs=kwargs):
- with self.assertWarns(DeprecationWarning):
- threading.RLock(*args, **kwargs)
+ self.assertRaises(TypeError, threading.RLock, *args, **kwargs)
# Subtypes with custom `__init__` are allowed (but, not recommended):
class CustomRLock(self.locktype):
@@ -2150,6 +2178,9 @@ class ConditionAsRLockTests(lock_tests.RLockTests):
# Condition uses an RLock by default and exports its API.
locktype = staticmethod(threading.Condition)
+ def test_constructor_noargs(self):
+ self.skipTest("Condition allows positional arguments")
+
def test_recursion_count(self):
self.skipTest("Condition does not expose _recursion_count()")
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index d06f65270ef..5312faa5077 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -761,17 +761,17 @@ class TestPytime(unittest.TestCase):
# Get the localtime and examine it for the offset and zone.
lt = time.localtime()
- self.assertTrue(hasattr(lt, "tm_gmtoff"))
- self.assertTrue(hasattr(lt, "tm_zone"))
+ self.assertHasAttr(lt, "tm_gmtoff")
+ self.assertHasAttr(lt, "tm_zone")
# See if the offset and zone are similar to the module
# attributes.
if lt.tm_gmtoff is None:
- self.assertTrue(not hasattr(time, "timezone"))
+ self.assertNotHasAttr(time, "timezone")
else:
self.assertEqual(lt.tm_gmtoff, -[time.timezone, time.altzone][lt.tm_isdst])
if lt.tm_zone is None:
- self.assertTrue(not hasattr(time, "tzname"))
+ self.assertNotHasAttr(time, "tzname")
else:
self.assertEqual(lt.tm_zone, time.tzname[lt.tm_isdst])
@@ -1184,11 +1184,11 @@ class TestTimeWeaklinking(unittest.TestCase):
if mac_ver >= (10, 12):
for name in clock_names:
- self.assertTrue(hasattr(time, name), f"time.{name} is not available")
+ self.assertHasAttr(time, name)
else:
for name in clock_names:
- self.assertFalse(hasattr(time, name), f"time.{name} is available")
+ self.assertNotHasAttr(time, name)
if __name__ == "__main__":
diff --git a/Lib/test/test_timeit.py b/Lib/test/test_timeit.py
index f5ae0a84eb3..2aeebea9f93 100644
--- a/Lib/test/test_timeit.py
+++ b/Lib/test/test_timeit.py
@@ -222,8 +222,8 @@ class TestTimeit(unittest.TestCase):
def assert_exc_string(self, exc_string, expected_exc_name):
exc_lines = exc_string.splitlines()
self.assertGreater(len(exc_lines), 2)
- self.assertTrue(exc_lines[0].startswith('Traceback'))
- self.assertTrue(exc_lines[-1].startswith(expected_exc_name))
+ self.assertStartsWith(exc_lines[0], 'Traceback')
+ self.assertStartsWith(exc_lines[-1], expected_exc_name)
def test_print_exc(self):
s = io.StringIO()
diff --git a/Lib/test/test_tkinter/support.py b/Lib/test/test_tkinter/support.py
index ebb9e00ff91..46b01e6f131 100644
--- a/Lib/test/test_tkinter/support.py
+++ b/Lib/test/test_tkinter/support.py
@@ -58,7 +58,7 @@ class AbstractDefaultRootTest:
destroy_default_root()
tkinter.NoDefaultRoot()
self.assertRaises(RuntimeError, constructor)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
def destroy_default_root():
diff --git a/Lib/test/test_tkinter/test_misc.py b/Lib/test/test_tkinter/test_misc.py
index 96ea3f0117c..0c76e07066f 100644
--- a/Lib/test/test_tkinter/test_misc.py
+++ b/Lib/test/test_tkinter/test_misc.py
@@ -497,7 +497,7 @@ class MiscTest(AbstractTkTest, unittest.TestCase):
self.assertEqual(vi.serial, 0)
else:
self.assertEqual(vi.micro, 0)
- self.assertTrue(str(vi).startswith(f'{vi.major}.{vi.minor}'))
+ self.assertStartsWith(str(vi), f'{vi.major}.{vi.minor}')
def test_embedded_null(self):
widget = tkinter.Entry(self.root)
@@ -609,7 +609,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, '??')
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, '??')
self.assertEqual(e.char, '??')
@@ -642,7 +642,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, '??')
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, '??')
self.assertEqual(e.char, '??')
@@ -676,7 +676,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertIsInstance(e.state, int)
self.assertNotEqual(e.state, 0)
@@ -747,7 +747,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, 1)
self.assertEqual(e.state, 0)
self.assertEqual(e.char, '??')
@@ -781,7 +781,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0x100)
self.assertEqual(e.char, '??')
@@ -814,7 +814,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIs(e.widget, f)
self.assertIsInstance(e.serial, int)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.time, 0)
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0)
@@ -849,7 +849,7 @@ class EventTest(AbstractTkTest, unittest.TestCase):
self.assertIsInstance(e.serial, int)
self.assertEqual(e.time, 0)
self.assertIs(e.send_event, False)
- self.assertFalse(hasattr(e, 'focus'))
+ self.assertNotHasAttr(e, 'focus')
self.assertEqual(e.num, '??')
self.assertEqual(e.state, 0)
self.assertEqual(e.char, '??')
@@ -1308,17 +1308,17 @@ class DefaultRootTest(AbstractDefaultRootTest, unittest.TestCase):
self.assertIs(tkinter._default_root, root)
tkinter.NoDefaultRoot()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
# repeated call is no-op
tkinter.NoDefaultRoot()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root.destroy()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root = tkinter.Tk()
self.assertIs(tkinter._support_default_root, False)
- self.assertFalse(hasattr(tkinter, '_default_root'))
+ self.assertNotHasAttr(tkinter, '_default_root')
root.destroy()
def test_getboolean(self):
diff --git a/Lib/test/test_tkinter/widget_tests.py b/Lib/test/test_tkinter/widget_tests.py
index ac7fb5977e0..f518925e994 100644
--- a/Lib/test/test_tkinter/widget_tests.py
+++ b/Lib/test/test_tkinter/widget_tests.py
@@ -65,7 +65,7 @@ class AbstractWidgetTest(AbstractTkTest):
orig = widget[name]
if errmsg is not None:
errmsg = errmsg.format(re.escape(str(value)))
- errmsg = fr'\A{errmsg}\Z'
+ errmsg = fr'\A{errmsg}\z'
with self.assertRaisesRegex(tkinter.TclError, errmsg or ''):
widget[name] = value
self.assertEqual(widget[name], orig)
diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py
index 2d41a5e5ac0..865e0c5b40d 100644
--- a/Lib/test/test_tokenize.py
+++ b/Lib/test/test_tokenize.py
@@ -1,6 +1,8 @@
import contextlib
+import itertools
import os
import re
+import string
import tempfile
import token
import tokenize
@@ -1975,6 +1977,10 @@ if 1:
for case in cases:
self.check_roundtrip(case)
+ self.check_roundtrip(r"t'{ {}}'")
+ self.check_roundtrip(r"t'{f'{ {}}'}{ {}}'")
+ self.check_roundtrip(r"f'{t'{ {}}'}{ {}}'")
+
def test_continuation(self):
# Balancing continuation
@@ -3234,5 +3240,77 @@ class CommandLineTest(unittest.TestCase):
self.check_output(source, expect, flag)
+class StringPrefixTest(unittest.TestCase):
+ @staticmethod
+ def determine_valid_prefixes():
+ # Try all lengths until we find a length that has zero valid
+ # prefixes. This will miss the case where for example there
+ # are no valid 3 character prefixes, but there are valid 4
+ # character prefixes. That seems unlikely.
+
+ single_char_valid_prefixes = set()
+
+ # Find all of the single character string prefixes. Just get
+ # the lowercase version, we'll deal with combinations of upper
+ # and lower case later. I'm using this logic just in case
+ # some uppercase-only prefix is added.
+ for letter in itertools.chain(string.ascii_lowercase, string.ascii_uppercase):
+ try:
+ eval(f'{letter}""')
+ single_char_valid_prefixes.add(letter.lower())
+ except SyntaxError:
+ pass
+
+ # This logic assumes that all combinations of valid prefixes only use
+ # the characters that are valid single character prefixes. That seems
+ # like a valid assumption, but if it ever changes this will need
+ # adjusting.
+ valid_prefixes = set()
+ for length in itertools.count():
+ num_at_this_length = 0
+ for prefix in (
+ "".join(l)
+ for l in itertools.combinations(single_char_valid_prefixes, length)
+ ):
+ for t in itertools.permutations(prefix):
+ for u in itertools.product(*[(c, c.upper()) for c in t]):
+ p = "".join(u)
+ if p == "not":
+ # 'not' can never be a string prefix,
+ # because it's a valid expression: not ""
+ continue
+ try:
+ eval(f'{p}""')
+
+ # No syntax error, so p is a valid string
+ # prefix.
+
+ valid_prefixes.add(p)
+ num_at_this_length += 1
+ except SyntaxError:
+ pass
+ if num_at_this_length == 0:
+ return valid_prefixes
+
+
+ def test_prefixes(self):
+ # Get the list of defined string prefixes. I don't see an
+ # obvious documented way of doing this, but probably the best
+ # thing is to split apart tokenize.StringPrefix.
+
+ # Make sure StringPrefix begins and ends in parens. We're
+ # assuming it's of the form "(a|b|ab)", if a, b, and cd are
+ # valid string prefixes.
+ self.assertEqual(tokenize.StringPrefix[0], '(')
+ self.assertEqual(tokenize.StringPrefix[-1], ')')
+
+ # Then split apart everything else by '|'.
+ defined_prefixes = set(tokenize.StringPrefix[1:-1].split('|'))
+
+ # Now compute the actual allowed string prefixes and compare
+ # to what is defined in the tokenize module.
+ self.assertEqual(defined_prefixes, self.determine_valid_prefixes())
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_tools/i18n_data/docstrings.py b/Lib/test/test_tools/i18n_data/docstrings.py
index 151a55a4b56..14559a632da 100644
--- a/Lib/test/test_tools/i18n_data/docstrings.py
+++ b/Lib/test/test_tools/i18n_data/docstrings.py
@@ -1,7 +1,7 @@
"""Module docstring"""
# Test docstring extraction
-from gettext import gettext as _
+from gettext import gettext as _ # noqa: F401
# Empty docstring
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 683486e9aca..74b979d0096 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -37,6 +37,12 @@ test_code.co_positions = lambda _: iter([(6, 6, 0, 0)])
test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals'])
test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next', 'tb_lasti'])
+color_overrides = {"reset": "z", "filename": "fn", "error_highlight": "E"}
+colors = {
+ color_overrides.get(k, k[0].lower()): v
+ for k, v in _colorize.default_theme.traceback.items()
+}
+
LEVENSHTEIN_DATA_FILE = Path(__file__).parent / 'levenshtein_examples.json'
@@ -4182,6 +4188,15 @@ class SuggestionFormattingTestBase:
self.assertNotIn("blech", actual)
self.assertNotIn("oh no!", actual)
+ def test_attribute_error_with_non_string_candidates(self):
+ class T:
+ bluch = 1
+
+ instance = T()
+ instance.__dict__[0] = 1
+ actual = self.get_suggestion(instance, 'blich')
+ self.assertIn("bluch", actual)
+
def test_attribute_error_with_bad_name(self):
def raise_attribute_error_with_bad_name():
raise AttributeError(name=12, obj=23)
@@ -4217,8 +4232,8 @@ class SuggestionFormattingTestBase:
return mod_name
- def get_import_from_suggestion(self, mod_dict, name):
- modname = self.make_module(mod_dict)
+ def get_import_from_suggestion(self, code, name):
+ modname = self.make_module(code)
def callable():
try:
@@ -4295,6 +4310,13 @@ class SuggestionFormattingTestBase:
self.assertIn("'_bluch'", self.get_import_from_suggestion(code, '_luch'))
self.assertNotIn("'_bluch'", self.get_import_from_suggestion(code, 'bluch'))
+ def test_import_from_suggestions_non_string(self):
+ modWithNonStringAttr = textwrap.dedent("""\
+ globals()[0] = 1
+ bluch = 1
+ """)
+ self.assertIn("'bluch'", self.get_import_from_suggestion(modWithNonStringAttr, 'blech'))
+
def test_import_from_suggestions_do_not_trigger_for_long_attributes(self):
code = "blech = None"
@@ -4391,6 +4413,15 @@ class SuggestionFormattingTestBase:
actual = self.get_suggestion(func)
self.assertIn("'ZeroDivisionError'?", actual)
+ def test_name_error_suggestions_with_non_string_candidates(self):
+ def func():
+ abc = 1
+ custom_globals = globals().copy()
+ custom_globals[0] = 1
+ print(eval("abv", custom_globals, locals()))
+ actual = self.get_suggestion(func)
+ self.assertIn("abc", actual)
+
def test_name_error_suggestions_do_not_trigger_for_long_names(self):
def func():
somethingverywronghehehehehehe = None
@@ -4721,6 +4752,8 @@ class MiscTest(unittest.TestCase):
class TestColorizedTraceback(unittest.TestCase):
+ maxDiff = None
+
def test_colorized_traceback(self):
def foo(*args):
x = {'a':{'b': None}}
@@ -4743,9 +4776,9 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
lines = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
+ red = colors["e"]
+ boldr = colors["E"]
+ reset = colors["z"]
self.assertIn("y = " + red + "x['a']['b']" + reset + boldr + "['c']" + reset, lines)
self.assertIn("return " + red + "(lambda *args: foo(*args))" + reset + boldr + "(1,2,3,4)" + reset, lines)
self.assertIn("return (lambda *args: " + red + "foo" + reset + boldr + "(*args)" + reset + ")(1,2,3,4)", lines)
@@ -4761,18 +4794,16 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
actual = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
- expected = "".join([
- f' File {magenta}"<string>"{reset}, line {magenta}1{reset}\n',
- f' a {boldr}${reset} b\n',
- f' {boldr}^{reset}\n',
- f'{boldm}SyntaxError{reset}: {magenta}invalid syntax{reset}\n']
- )
- self.assertIn(expected, actual)
+ def expected(t, m, fn, l, f, E, e, z):
+ return "".join(
+ [
+ f' File {fn}"<string>"{z}, line {l}1{z}\n',
+ f' a {E}${z} b\n',
+ f' {E}^{z}\n',
+ f'{t}SyntaxError{z}: {m}invalid syntax{z}\n'
+ ]
+ )
+ self.assertIn(expected(**colors), actual)
def test_colorized_traceback_is_the_default(self):
def foo():
@@ -4788,23 +4819,21 @@ class TestColorizedTraceback(unittest.TestCase):
exception_print(e)
actual = tbstderr.getvalue().splitlines()
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
- expected = ['Traceback (most recent call last):',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+5}{reset}, in {magenta}test_colorized_traceback_is_the_default{reset}',
- f' {red}foo{reset+boldr}(){reset}',
- f' {red}~~~{reset+boldr}^^{reset}',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+1}{reset}, in {magenta}foo{reset}',
- f' {red}1{reset+boldr}/{reset+red}0{reset}',
- f' {red}~{reset+boldr}^{reset+red}~{reset}',
- f'{boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ 'Traceback (most recent call last):',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+5}{z}, in {f}test_colorized_traceback_is_the_default{z}',
+ f' {e}foo{z}{E}(){z}',
+ f' {e}~~~{z}{E}^^{z}',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+1}{z}, in {f}foo{z}',
+ f' {e}1{z}{E}/{z}{e}0{z}',
+ f' {e}~{z}{E}^{z}{e}~{z}',
+ f'{t}ZeroDivisionError{z}: {m}division by zero{z}',
+ ]
+ self.assertEqual(actual, expected(**colors))
def test_colorized_traceback_from_exception_group(self):
def foo():
@@ -4822,33 +4851,31 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
actual = "".join(exc.format(colorize=True)).splitlines()
- expected = [f" + Exception Group Traceback (most recent call last):",
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+9}{reset}, in {magenta}test_colorized_traceback_from_exception_group{reset}',
- f' | {red}foo{reset}{boldr}(){reset}',
- f' | {red}~~~{reset}{boldr}^^{reset}',
- f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
- f" | foo = {foo}",
- f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+6}{reset}, in {magenta}foo{reset}',
- f' | raise ExceptionGroup("test", exceptions)',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ExceptionGroup{reset}: {magenta}test (1 sub-exception){reset}',
- f' +-+---------------- 1 ----------------',
- f' | Traceback (most recent call last):',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+3}{reset}, in {magenta}foo{reset}',
- f' | {red}1 {reset}{boldr}/{reset}{red} 0{reset}',
- f' | {red}~~{reset}{boldr}^{reset}{red}~~{reset}',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}',
- f' +------------------------------------']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ f" + Exception Group Traceback (most recent call last):",
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+9}{z}, in {f}test_colorized_traceback_from_exception_group{z}',
+ f' | {e}foo{z}{E}(){z}',
+ f' | {e}~~~{z}{E}^^{z}',
+ f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
+ f" | foo = {foo}",
+ f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+6}{z}, in {f}foo{z}',
+ f' | raise ExceptionGroup("test", exceptions)',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ExceptionGroup{z}: {m}test (1 sub-exception){z}',
+ f' +-+---------------- 1 ----------------',
+ f' | Traceback (most recent call last):',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+3}{z}, in {f}foo{z}',
+ f' | {e}1 {z}{E}/{z}{e} 0{z}',
+ f' | {e}~~{z}{E}^{z}{e}~~{z}',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ZeroDivisionError{z}: {m}division by zero{z}',
+ f' +------------------------------------',
+ ]
+ self.assertEqual(actual, expected(**colors))
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_tstring.py b/Lib/test/test_tstring.py
index e72a1ea5417..aabae385567 100644
--- a/Lib/test/test_tstring.py
+++ b/Lib/test/test_tstring.py
@@ -219,6 +219,7 @@ class TestTString(unittest.TestCase, TStringBaseCase):
("t'{lambda:1}'", "t-string: lambda expressions are not allowed "
"without parentheses"),
("t'{x:{;}}'", "t-string: expecting a valid expression after '{'"),
+ ("t'{1:d\n}'", "t-string: newlines are not allowed in format specifiers")
):
with self.subTest(case), self.assertRaisesRegex(SyntaxError, err):
eval(case)
diff --git a/Lib/test/test_ttk/test_widgets.py b/Lib/test/test_ttk/test_widgets.py
index d5620becfa7..f33da2a8848 100644
--- a/Lib/test/test_ttk/test_widgets.py
+++ b/Lib/test/test_ttk/test_widgets.py
@@ -490,7 +490,7 @@ class ComboboxTest(EntryTest, unittest.TestCase):
width = self.combo.winfo_width()
x, y = width - 5, 5
if sys.platform != 'darwin': # there's no down arrow on macOS
- self.assertRegex(self.combo.identify(x, y), r'.*downarrow\Z')
+ self.assertRegex(self.combo.identify(x, y), r'.*downarrow\z')
self.combo.event_generate('<Button-1>', x=x, y=y)
self.combo.event_generate('<ButtonRelease-1>', x=x, y=y)
@@ -1250,7 +1250,7 @@ class SpinboxTest(EntryTest, unittest.TestCase):
height = self.spin.winfo_height()
x = width - 5
y = height//2 - 5
- self.assertRegex(self.spin.identify(x, y), r'.*uparrow\Z')
+ self.assertRegex(self.spin.identify(x, y), r'.*uparrow\z')
self.spin.event_generate('<ButtonPress-1>', x=x, y=y)
self.spin.event_generate('<ButtonRelease-1>', x=x, y=y)
self.spin.update_idletasks()
@@ -1260,7 +1260,7 @@ class SpinboxTest(EntryTest, unittest.TestCase):
height = self.spin.winfo_height()
x = width - 5
y = height//2 + 4
- self.assertRegex(self.spin.identify(x, y), r'.*downarrow\Z')
+ self.assertRegex(self.spin.identify(x, y), r'.*downarrow\z')
self.spin.event_generate('<ButtonPress-1>', x=x, y=y)
self.spin.event_generate('<ButtonRelease-1>', x=x, y=y)
self.spin.update_idletasks()
diff --git a/Lib/test/test_type_annotations.py b/Lib/test/test_type_annotations.py
index b72d3dbe516..c66cb058552 100644
--- a/Lib/test/test_type_annotations.py
+++ b/Lib/test/test_type_annotations.py
@@ -327,6 +327,25 @@ class AnnotateTests(unittest.TestCase):
f.__annotations__ = {"z": 43}
self.assertIs(f.__annotate__, None)
+ def test_user_defined_annotate(self):
+ class X:
+ a: int
+
+ def __annotate__(format):
+ return {"a": str}
+ self.assertEqual(X.__annotate__(annotationlib.Format.VALUE), {"a": str})
+ self.assertEqual(annotationlib.get_annotations(X), {"a": str})
+
+ mod = build_module(
+ """
+ a: int
+ def __annotate__(format):
+ return {"a": str}
+ """
+ )
+ self.assertEqual(mod.__annotate__(annotationlib.Format.VALUE), {"a": str})
+ self.assertEqual(annotationlib.get_annotations(mod), {"a": str})
+
class DeferredEvaluationTests(unittest.TestCase):
def test_function(self):
@@ -479,6 +498,28 @@ class DeferredEvaluationTests(unittest.TestCase):
self.assertEqual(f.__annotate__(annotationlib.Format.VALUE), annos)
self.assertEqual(f.__annotations__, annos)
+ def test_set_annotations(self):
+ function_code = textwrap.dedent("""
+ def f(x: int):
+ pass
+ """)
+ class_code = textwrap.dedent("""
+ class f:
+ x: int
+ """)
+ for future in (False, True):
+ for label, code in (("function", function_code), ("class", class_code)):
+ with self.subTest(future=future, label=label):
+ if future:
+ code = "from __future__ import annotations\n" + code
+ ns = run_code(code)
+ f = ns["f"]
+ anno = "int" if future else int
+ self.assertEqual(f.__annotations__, {"x": anno})
+
+ f.__annotations__ = {"x": str}
+ self.assertEqual(f.__annotations__, {"x": str})
+
def test_name_clash_with_format(self):
# this test would fail if __annotate__'s parameter was called "format"
# during symbol table construction
diff --git a/Lib/test/test_type_comments.py b/Lib/test/test_type_comments.py
index ee8939f62d0..c40c45594f4 100644
--- a/Lib/test/test_type_comments.py
+++ b/Lib/test/test_type_comments.py
@@ -344,7 +344,7 @@ class TypeCommentTests(unittest.TestCase):
todo = set(t.name[1:])
self.assertEqual(len(t.args.args) + len(t.args.posonlyargs),
len(todo) - bool(t.args.vararg) - bool(t.args.kwarg))
- self.assertTrue(t.name.startswith('f'), t.name)
+ self.assertStartsWith(t.name, 'f')
for index, c in enumerate(t.name[1:]):
todo.remove(c)
if c == 'v':
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 3552b6b4ef8..fc26e71ffcb 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -2,7 +2,7 @@
from test.support import (
run_with_locale, cpython_only, no_rerun,
- MISSING_C_DOCSTRINGS, EqualToForwardRef,
+ MISSING_C_DOCSTRINGS, EqualToForwardRef, check_disallow_instantiation,
)
from test.support.script_helper import assert_python_ok
from test.support.import_helper import import_fresh_module
@@ -517,8 +517,8 @@ class TypesTests(unittest.TestCase):
# and a number after the decimal. This is tricky, because
# a totally empty format specifier means something else.
# So, just use a sign flag
- test(1e200, '+g', '+1e+200')
- test(1e200, '+', '+1e+200')
+ test(1.25e200, '+g', '+1.25e+200')
+ test(1.25e200, '+', '+1.25e+200')
test(1.1e200, '+g', '+1.1e+200')
test(1.1e200, '+', '+1.1e+200')
@@ -827,15 +827,15 @@ class UnionTests(unittest.TestCase):
self.assertIsInstance(True, x)
self.assertIsInstance('a', x)
self.assertNotIsInstance(None, x)
- self.assertTrue(issubclass(int, x))
- self.assertTrue(issubclass(bool, x))
- self.assertTrue(issubclass(str, x))
- self.assertFalse(issubclass(type(None), x))
+ self.assertIsSubclass(int, x)
+ self.assertIsSubclass(bool, x)
+ self.assertIsSubclass(str, x)
+ self.assertNotIsSubclass(type(None), x)
for x in (int | None, typing.Union[int, None]):
with self.subTest(x=x):
self.assertIsInstance(None, x)
- self.assertTrue(issubclass(type(None), x))
+ self.assertIsSubclass(type(None), x)
for x in (
int | collections.abc.Mapping,
@@ -844,8 +844,8 @@ class UnionTests(unittest.TestCase):
with self.subTest(x=x):
self.assertIsInstance({}, x)
self.assertNotIsInstance((), x)
- self.assertTrue(issubclass(dict, x))
- self.assertFalse(issubclass(list, x))
+ self.assertIsSubclass(dict, x)
+ self.assertNotIsSubclass(list, x)
def test_instancecheck_and_subclasscheck_order(self):
T = typing.TypeVar('T')
@@ -857,7 +857,7 @@ class UnionTests(unittest.TestCase):
for x in will_resolve:
with self.subTest(x=x):
self.assertIsInstance(1, x)
- self.assertTrue(issubclass(int, x))
+ self.assertIsSubclass(int, x)
wont_resolve = (
T | int,
@@ -890,7 +890,7 @@ class UnionTests(unittest.TestCase):
def __subclasscheck__(cls, sub):
1/0
x = int | BadMeta('A', (), {})
- self.assertTrue(issubclass(int, x))
+ self.assertIsSubclass(int, x)
self.assertRaises(ZeroDivisionError, issubclass, list, x)
def test_or_type_operator_with_TypeVar(self):
@@ -1148,8 +1148,7 @@ class UnionTests(unittest.TestCase):
msg='Check for union reference leak.')
def test_instantiation(self):
- with self.assertRaises(TypeError):
- types.UnionType()
+ check_disallow_instantiation(self, types.UnionType)
self.assertIs(int, types.UnionType[int])
self.assertIs(int, types.UnionType[int, int])
self.assertEqual(int | str, types.UnionType[int, str])
@@ -1399,7 +1398,7 @@ class ClassCreationTests(unittest.TestCase):
def test_new_class_subclass(self):
C = types.new_class("C", (int,))
- self.assertTrue(issubclass(C, int))
+ self.assertIsSubclass(C, int)
def test_new_class_meta(self):
Meta = self.Meta
@@ -1444,7 +1443,7 @@ class ClassCreationTests(unittest.TestCase):
bases=(int,),
kwds=dict(metaclass=Meta, z=2),
exec_body=func)
- self.assertTrue(issubclass(C, int))
+ self.assertIsSubclass(C, int)
self.assertIsInstance(C, Meta)
self.assertEqual(C.x, 0)
self.assertEqual(C.y, 1)
@@ -2513,15 +2512,16 @@ class SubinterpreterTests(unittest.TestCase):
def setUpClass(cls):
global interpreters
try:
- from test.support import interpreters
+ from concurrent import interpreters
except ModuleNotFoundError:
raise unittest.SkipTest('subinterpreters required')
- import test.support.interpreters.channels
+ from test.support import channels # noqa: F401
+ cls.create_channel = staticmethod(channels.create)
@cpython_only
@no_rerun('channels (and queues) might have a refleak; see gh-122199')
def test_static_types_inherited_slots(self):
- rch, sch = interpreters.channels.create()
+ rch, sch = self.create_channel()
script = textwrap.dedent("""
import test.support
@@ -2547,7 +2547,7 @@ class SubinterpreterTests(unittest.TestCase):
main_results = collate_results(raw)
interp = interpreters.create()
- interp.exec('from test.support import interpreters')
+ interp.exec('from concurrent import interpreters')
interp.prepare_main(sch=sch)
interp.exec(script)
raw = rch.recv_nowait()
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 8c55ba4623e..ef02e8202fc 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -46,11 +46,10 @@ import abc
import textwrap
import typing
import weakref
-import warnings
import types
from test.support import (
- captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper, run_code,
+ captured_stderr, cpython_only, requires_docstrings, import_helper, run_code,
EqualToForwardRef,
)
from test.typinganndata import (
@@ -6859,12 +6858,10 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(hints, {'value': Final})
def test_top_level_class_var(self):
- # https://bugs.python.org/issue45166
- with self.assertRaisesRegex(
- TypeError,
- r'typing.ClassVar\[int\] is not valid as type argument',
- ):
- get_type_hints(ann_module6)
+ # This is not meaningful but we don't raise for it.
+ # https://github.com/python/cpython/issues/133959
+ hints = get_type_hints(ann_module6)
+ self.assertEqual(hints, {'wrong': ClassVar[int]})
def test_get_type_hints_typeddict(self):
self.assertEqual(get_type_hints(TotalMovie), {'title': str, 'year': int})
@@ -6967,6 +6964,11 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': Callable[..., T]})
+ def test_special_forms_no_forward(self):
+ def f(x: ClassVar[int]):
+ pass
+ self.assertEqual(get_type_hints(f), {'x': ClassVar[int]})
+
def test_special_forms_forward(self):
class C:
@@ -6982,8 +6984,9 @@ class GetTypeHintsTests(BaseTestCase):
self.assertEqual(get_type_hints(C, globals())['b'], Final[int])
self.assertEqual(get_type_hints(C, globals())['x'], ClassVar)
self.assertEqual(get_type_hints(C, globals())['y'], Final)
- with self.assertRaises(TypeError):
- get_type_hints(CF, globals()),
+ lfi = get_type_hints(CF, globals())['b']
+ self.assertIs(get_origin(lfi), list)
+ self.assertEqual(get_args(lfi), (Final[int],))
def test_union_forward_recursion(self):
ValueList = List['Value']
@@ -7216,33 +7219,113 @@ class GetUtilitiesTestCase(TestCase):
class EvaluateForwardRefTests(BaseTestCase):
def test_evaluate_forward_ref(self):
int_ref = ForwardRef('int')
- missing = ForwardRef('missing')
+ self.assertIs(typing.evaluate_forward_ref(int_ref), int)
self.assertIs(
typing.evaluate_forward_ref(int_ref, type_params=()),
int,
)
self.assertIs(
+ typing.evaluate_forward_ref(int_ref, format=annotationlib.Format.VALUE),
+ int,
+ )
+ self.assertIs(
typing.evaluate_forward_ref(
- int_ref, type_params=(), format=annotationlib.Format.FORWARDREF,
+ int_ref, format=annotationlib.Format.FORWARDREF,
),
int,
)
+ self.assertEqual(
+ typing.evaluate_forward_ref(
+ int_ref, format=annotationlib.Format.STRING,
+ ),
+ 'int',
+ )
+
+ def test_evaluate_forward_ref_undefined(self):
+ missing = ForwardRef('missing')
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(missing)
self.assertIs(
typing.evaluate_forward_ref(
- missing, type_params=(), format=annotationlib.Format.FORWARDREF,
+ missing, format=annotationlib.Format.FORWARDREF,
),
missing,
)
self.assertEqual(
typing.evaluate_forward_ref(
- int_ref, type_params=(), format=annotationlib.Format.STRING,
+ missing, format=annotationlib.Format.STRING,
),
- 'int',
+ "missing",
)
- def test_evaluate_forward_ref_no_type_params(self):
- ref = ForwardRef('int')
- self.assertIs(typing.evaluate_forward_ref(ref), int)
+ def test_evaluate_forward_ref_nested(self):
+ ref = ForwardRef("int | list['str']")
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref),
+ int | list[str],
+ )
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF),
+ int | list[str],
+ )
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.STRING),
+ "int | list['str']",
+ )
+
+ why = ForwardRef('"\'str\'"')
+ self.assertIs(typing.evaluate_forward_ref(why), str)
+
+ def test_evaluate_forward_ref_none(self):
+ none_ref = ForwardRef('None')
+ self.assertIs(typing.evaluate_forward_ref(none_ref), None)
+
+ def test_globals(self):
+ A = "str"
+ ref = ForwardRef('list[A]')
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, globals={'A': A}),
+ list[str],
+ )
+
+ def test_owner(self):
+ ref = ForwardRef("A")
+
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+
+ # We default to the globals of `owner`,
+ # so it no longer raises `NameError`
+ self.assertIs(
+ typing.evaluate_forward_ref(ref, owner=Loop), A
+ )
+
+ def test_inherited_owner(self):
+ # owner passed to evaluate_forward_ref
+ ref = ForwardRef("list['A']")
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, owner=Loop),
+ list[A],
+ )
+
+ # owner set on the ForwardRef
+ ref = ForwardRef("list['A']", owner=Loop)
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref),
+ list[A],
+ )
+
+ def test_partial_evaluation(self):
+ ref = ForwardRef("list[A]")
+ with self.assertRaises(NameError):
+ typing.evaluate_forward_ref(ref)
+
+ self.assertEqual(
+ typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF),
+ list[EqualToForwardRef('A')],
+ )
class CollectionsAbcTests(BaseTestCase):
@@ -8080,78 +8163,13 @@ class NamedTupleTests(BaseTestCase):
self.assertIs(type(a), Group)
self.assertEqual(a, (1, [2]))
- def test_namedtuple_keyword_usage(self):
- with self.assertWarnsRegex(
- DeprecationWarning,
- "Creating NamedTuple classes using keyword arguments is deprecated"
- ):
- LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int)
-
- nick = LocalEmployee('Nick', 25)
- self.assertIsInstance(nick, tuple)
- self.assertEqual(nick.name, 'Nick')
- self.assertEqual(LocalEmployee.__name__, 'LocalEmployee')
- self.assertEqual(LocalEmployee._fields, ('name', 'age'))
- self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int))
-
- with self.assertRaisesRegex(
- TypeError,
- "Either list of fields or keywords can be provided to NamedTuple, not both"
- ):
- NamedTuple('Name', [('x', int)], y=str)
-
- with self.assertRaisesRegex(
- TypeError,
- "Either list of fields or keywords can be provided to NamedTuple, not both"
- ):
- NamedTuple('Name', [], y=str)
-
- with self.assertRaisesRegex(
- TypeError,
- (
- r"Cannot pass `None` as the 'fields' parameter "
- r"and also specify fields using keyword arguments"
- )
- ):
- NamedTuple('Name', None, x=int)
-
- def test_namedtuple_special_keyword_names(self):
- with self.assertWarnsRegex(
- DeprecationWarning,
- "Creating NamedTuple classes using keyword arguments is deprecated"
- ):
- NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list)
-
- self.assertEqual(NT.__name__, 'NT')
- self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields'))
- a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)])
- self.assertEqual(a.cls, str)
- self.assertEqual(a.self, 42)
- self.assertEqual(a.typename, 'foo')
- self.assertEqual(a.fields, [('bar', tuple)])
-
def test_empty_namedtuple(self):
- expected_warning = re.escape(
- "Failing to pass a value for the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. `NT1 = NamedTuple('NT1', [])`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- NT1 = NamedTuple('NT1')
+ with self.assertRaisesRegex(TypeError, "missing.*required.*argument"):
+ BAD = NamedTuple('BAD')
- expected_warning = re.escape(
- "Passing `None` as the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. `NT2 = NamedTuple('NT2', [])`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- NT2 = NamedTuple('NT2', None)
-
- NT3 = NamedTuple('NT2', [])
+ NT1 = NamedTuple('NT1', {})
+ NT2 = NamedTuple('NT2', ())
+ NT3 = NamedTuple('NT3', [])
class CNT(NamedTuple):
pass # empty body
@@ -8166,16 +8184,18 @@ class NamedTupleTests(BaseTestCase):
def test_namedtuple_errors(self):
with self.assertRaises(TypeError):
NamedTuple.__new__()
+ with self.assertRaisesRegex(TypeError, "object is not iterable"):
+ NamedTuple('Name', None)
with self.assertRaisesRegex(
TypeError,
- "missing 1 required positional argument"
+ "missing 2 required positional arguments"
):
NamedTuple()
with self.assertRaisesRegex(
TypeError,
- "takes from 1 to 2 positional arguments but 3 were given"
+ "takes 2 positional arguments but 3 were given"
):
NamedTuple('Emp', [('name', str)], None)
@@ -8187,10 +8207,22 @@ class NamedTupleTests(BaseTestCase):
with self.assertRaisesRegex(
TypeError,
- "missing 1 required positional argument: 'typename'"
+ "got some positional-only arguments passed as keyword arguments"
):
NamedTuple(typename='Emp', name=str, id=int)
+ with self.assertRaisesRegex(
+ TypeError,
+ "got an unexpected keyword argument"
+ ):
+ NamedTuple('Name', [('x', int)], y=str)
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "got an unexpected keyword argument"
+ ):
+ NamedTuple('Name', [], y=str)
+
def test_copy_and_pickle(self):
global Emp # pickle wants to reference the class by name
Emp = NamedTuple('Emp', [('name', str), ('cool', int)])
@@ -8538,6 +8570,36 @@ class TypedDictTests(BaseTestCase):
self.assertEqual(Child.__required_keys__, frozenset(['a']))
self.assertEqual(Child.__optional_keys__, frozenset())
+ def test_inheritance_pep563(self):
+ def _make_td(future, class_name, annos, base, extra_names=None):
+ lines = []
+ if future:
+ lines.append('from __future__ import annotations')
+ lines.append('from typing import TypedDict')
+ lines.append(f'class {class_name}({base}):')
+ for name, anno in annos.items():
+ lines.append(f' {name}: {anno}')
+ code = '\n'.join(lines)
+ ns = run_code(code, extra_names)
+ return ns[class_name]
+
+ for base_future in (True, False):
+ for child_future in (True, False):
+ with self.subTest(base_future=base_future, child_future=child_future):
+ base = _make_td(
+ base_future, "Base", {"base": "int"}, "TypedDict"
+ )
+ self.assertIsNotNone(base.__annotate__)
+ child = _make_td(
+ child_future, "Child", {"child": "int"}, "Base", {"Base": base}
+ )
+ base_anno = ForwardRef("int", module="builtins") if base_future else int
+ child_anno = ForwardRef("int", module="builtins") if child_future else int
+ self.assertEqual(base.__annotations__, {'base': base_anno})
+ self.assertEqual(
+ child.__annotations__, {'child': child_anno, 'base': base_anno}
+ )
+
def test_required_notrequired_keys(self):
self.assertEqual(NontotalMovie.__required_keys__,
frozenset({"title"}))
@@ -8904,39 +8966,27 @@ class TypedDictTests(BaseTestCase):
self.assertEqual(CallTypedDict.__orig_bases__, (TypedDict,))
def test_zero_fields_typeddicts(self):
- T1 = TypedDict("T1", {})
+ T1a = TypedDict("T1a", {})
+ T1b = TypedDict("T1b", [])
+ T1c = TypedDict("T1c", ())
class T2(TypedDict): pass
class T3[tvar](TypedDict): pass
S = TypeVar("S")
class T4(TypedDict, Generic[S]): pass
- expected_warning = re.escape(
- "Failing to pass a value for the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a TypedDict class with 0 fields "
- "using the functional syntax, "
- "pass an empty dictionary, e.g. `T5 = TypedDict('T5', {})`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- T5 = TypedDict('T5')
-
- expected_warning = re.escape(
- "Passing `None` as the 'fields' parameter is deprecated "
- "and will be disallowed in Python 3.15. "
- "To create a TypedDict class with 0 fields "
- "using the functional syntax, "
- "pass an empty dictionary, e.g. `T6 = TypedDict('T6', {})`."
- )
- with self.assertWarnsRegex(DeprecationWarning, fr"^{expected_warning}$"):
- T6 = TypedDict('T6', None)
-
- for klass in T1, T2, T3, T4, T5, T6:
+ for klass in T1a, T1b, T1c, T2, T3, T4:
with self.subTest(klass=klass.__name__):
self.assertEqual(klass.__annotations__, {})
self.assertEqual(klass.__required_keys__, set())
self.assertEqual(klass.__optional_keys__, set())
self.assertIsInstance(klass(), dict)
+ def test_errors(self):
+ with self.assertRaisesRegex(TypeError, "missing 1 required.*argument"):
+ TypedDict('TD')
+ with self.assertRaisesRegex(TypeError, "object is not iterable"):
+ TypedDict('TD', None)
+
def test_readonly_inheritance(self):
class Base1(TypedDict):
a: ReadOnly[int]
@@ -10731,6 +10781,9 @@ class UnionGenericAliasTests(BaseTestCase):
with self.assertWarns(DeprecationWarning):
self.assertNotEqual(int, typing._UnionGenericAlias)
+ def test_hashable(self):
+ self.assertEqual(hash(typing._UnionGenericAlias), hash(Union))
+
def load_tests(loader, tests, pattern):
import doctest
diff --git a/Lib/test/test_unittest/test_case.py b/Lib/test/test_unittest/test_case.py
index a04af55f3fc..d66cab146af 100644
--- a/Lib/test/test_unittest/test_case.py
+++ b/Lib/test/test_unittest/test_case.py
@@ -1989,7 +1989,7 @@ test case
pass
self.assertIsNone(value)
- def testAssertStartswith(self):
+ def testAssertStartsWith(self):
self.assertStartsWith('ababahalamaha', 'ababa')
self.assertStartsWith('ababahalamaha', ('x', 'ababa', 'y'))
self.assertStartsWith(UserString('ababahalamaha'), 'ababa')
@@ -2034,7 +2034,7 @@ test case
self.assertStartsWith('ababahalamaha', 'amaha', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertNotStartswith(self):
+ def testAssertNotStartsWith(self):
self.assertNotStartsWith('ababahalamaha', 'amaha')
self.assertNotStartsWith('ababahalamaha', ('x', 'amaha', 'y'))
self.assertNotStartsWith(UserString('ababahalamaha'), 'amaha')
@@ -2079,7 +2079,7 @@ test case
self.assertNotStartsWith('ababahalamaha', 'ababa', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertEndswith(self):
+ def testAssertEndsWith(self):
self.assertEndsWith('ababahalamaha', 'amaha')
self.assertEndsWith('ababahalamaha', ('x', 'amaha', 'y'))
self.assertEndsWith(UserString('ababahalamaha'), 'amaha')
@@ -2124,7 +2124,7 @@ test case
self.assertEndsWith('ababahalamaha', 'ababa', msg='abracadabra')
self.assertIn('ababahalamaha', str(cm.exception))
- def testAssertNotEndswith(self):
+ def testAssertNotEndsWith(self):
self.assertNotEndsWith('ababahalamaha', 'ababa')
self.assertNotEndsWith('ababahalamaha', ('x', 'ababa', 'y'))
self.assertNotEndsWith(UserString('ababahalamaha'), 'ababa')
diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py
index 9ac4c52449c..3f44e617303 100644
--- a/Lib/test/test_unittest/test_result.py
+++ b/Lib/test/test_unittest/test_result.py
@@ -1282,14 +1282,22 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 1)
+ self.assertEqual(len(result.errors), 2)
description = 'tearDownModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[1]
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
def testBufferSetUpModule_DoModuleCleanups(self):
with captured_stdout() as stdout:
result = unittest.TestResult()
@@ -1313,22 +1321,34 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\nset up module\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 2)
+ self.assertEqual(len(result.errors), 3)
description = 'setUpModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ZeroDivisionError: division by zero', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ValueError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn('\nStdout:\nset up module\n', formatted_exc)
+
test_case, formatted_exc = result.errors[1]
self.assertIn(expected_out, formatted_exc)
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ZeroDivisionError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[2]
+ self.assertIn(expected_out, formatted_exc)
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ZeroDivisionError', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
def testBufferTearDownModule_DoModuleCleanups(self):
with captured_stdout() as stdout:
result = unittest.TestResult()
@@ -1355,21 +1375,32 @@ class TestOutputBuffering(unittest.TestCase):
suite(result)
expected_out = '\nStdout:\ntear down module\ndo cleanup2\ndo cleanup1\n'
self.assertEqual(stdout.getvalue(), expected_out)
- self.assertEqual(len(result.errors), 2)
+ self.assertEqual(len(result.errors), 3)
description = 'tearDownModule (Module)'
test_case, formatted_exc = result.errors[0]
self.assertEqual(test_case.description, description)
self.assertIn('ZeroDivisionError: division by zero', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ValueError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn('\nStdout:\ntear down module\n', formatted_exc)
+
test_case, formatted_exc = result.errors[1]
self.assertEqual(test_case.description, description)
self.assertIn('ValueError: bad cleanup2', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
self.assertNotIn('ZeroDivisionError', formatted_exc)
self.assertNotIn('TypeError', formatted_exc)
self.assertIn(expected_out, formatted_exc)
+ test_case, formatted_exc = result.errors[2]
+ self.assertEqual(test_case.description, description)
+ self.assertIn('TypeError: bad cleanup1', formatted_exc)
+ self.assertNotIn('ExceptionGroup', formatted_exc)
+ self.assertNotIn('ZeroDivisionError', formatted_exc)
+ self.assertNotIn('ValueError', formatted_exc)
+ self.assertIn(expected_out, formatted_exc)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_unittest/test_runner.py b/Lib/test/test_unittest/test_runner.py
index 4d3cfd60b8d..a47e2ebb59d 100644
--- a/Lib/test/test_unittest/test_runner.py
+++ b/Lib/test/test_unittest/test_runner.py
@@ -13,6 +13,7 @@ from test.test_unittest.support import (
LoggingResult,
ResultWithNoStartTestRunStopTestRun,
)
+from test.support.testcase import ExceptionIsLikeMixin
def resultFactory(*_):
@@ -604,7 +605,7 @@ class TestClassCleanup(unittest.TestCase):
@support.force_not_colorized_test_class
-class TestModuleCleanUp(unittest.TestCase):
+class TestModuleCleanUp(ExceptionIsLikeMixin, unittest.TestCase):
def test_add_and_do_ModuleCleanup(self):
module_cleanups = []
@@ -646,11 +647,50 @@ class TestModuleCleanUp(unittest.TestCase):
[(module_cleanup_good, (1, 2, 3),
dict(four='hello', five='goodbye')),
(module_cleanup_bad, (), {})])
- with self.assertRaises(CustomError) as e:
+ with self.assertRaises(Exception) as e:
unittest.case.doModuleCleanups()
- self.assertEqual(str(e.exception), 'CleanUpExc')
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed',
+ [CustomError('CleanUpExc')]))
self.assertEqual(unittest.case._module_cleanups, [])
+ def test_doModuleCleanup_with_multiple_errors_in_addModuleCleanup(self):
+ def module_cleanup_bad1():
+ raise TypeError('CleanUpExc1')
+
+ def module_cleanup_bad2():
+ raise ValueError('CleanUpExc2')
+
+ class Module:
+ unittest.addModuleCleanup(module_cleanup_bad1)
+ unittest.addModuleCleanup(module_cleanup_bad2)
+ with self.assertRaises(ExceptionGroup) as e:
+ unittest.case.doModuleCleanups()
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ]))
+
+ def test_doModuleCleanup_with_exception_group_in_addModuleCleanup(self):
+ def module_cleanup_bad():
+ raise ExceptionGroup('CleanUpExc', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ])
+
+ class Module:
+ unittest.addModuleCleanup(module_cleanup_bad)
+ with self.assertRaises(ExceptionGroup) as e:
+ unittest.case.doModuleCleanups()
+ self.assertExceptionIsLike(e.exception,
+ ExceptionGroup('module cleanup failed', [
+ ExceptionGroup('CleanUpExc', [
+ ValueError('CleanUpExc2'),
+ TypeError('CleanUpExc1'),
+ ]),
+ ]))
+
def test_addModuleCleanup_arg_errors(self):
cleanups = []
def cleanup(*args, **kwargs):
@@ -871,9 +911,11 @@ class TestModuleCleanUp(unittest.TestCase):
ordering = []
blowUp = True
suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest)
- with self.assertRaises(CustomError) as cm:
+ with self.assertRaises(Exception) as cm:
suite.debug()
- self.assertEqual(str(cm.exception), 'CleanUpExc')
+ self.assertExceptionIsLike(cm.exception,
+ ExceptionGroup('module cleanup failed',
+ [CustomError('CleanUpExc')]))
self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test',
'tearDownClass', 'tearDownModule', 'cleanup_exc'])
self.assertEqual(unittest.case._module_cleanups, [])
diff --git a/Lib/test/test_unittest/testmock/testhelpers.py b/Lib/test/test_unittest/testmock/testhelpers.py
index d1e48bde982..0e82c723ec3 100644
--- a/Lib/test/test_unittest/testmock/testhelpers.py
+++ b/Lib/test/test_unittest/testmock/testhelpers.py
@@ -1050,6 +1050,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithPostInit()),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithPostInit)
self.assertIsInstance(mock.a, int)
self.assertIsInstance(mock.b, int)
@@ -1072,6 +1073,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithDefault(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithDefault)
self.assertIsInstance(mock.a, int)
self.assertIsInstance(mock.b, int)
@@ -1087,6 +1089,7 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithMethod(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithMethod)
self.assertIsInstance(mock.a, int)
mock.b.assert_not_called()
@@ -1102,11 +1105,29 @@ class SpecSignatureTest(unittest.TestCase):
create_autospec(WithNonFields(1)),
]:
with self.subTest(mock=mock):
+ self.assertIsInstance(mock, WithNonFields)
with self.assertRaisesRegex(AttributeError, msg):
mock.a
with self.assertRaisesRegex(AttributeError, msg):
mock.b
+ def test_dataclass_special_attrs(self):
+ @dataclass
+ class Description:
+ name: str
+
+ for mock in [
+ create_autospec(Description, instance=True),
+ create_autospec(Description(1)),
+ ]:
+ with self.subTest(mock=mock):
+ self.assertIsInstance(mock, Description)
+ self.assertIs(mock.__class__, Description)
+ self.assertIsInstance(mock.__dataclass_fields__, MagicMock)
+ self.assertIsInstance(mock.__dataclass_params__, MagicMock)
+ self.assertIsInstance(mock.__match_args__, MagicMock)
+ self.assertIsInstance(mock.__hash__, MagicMock)
+
class TestCallList(unittest.TestCase):
def test_args_list_contains_call_list(self):
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index d3af7a8489e..d4db5e60af7 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -817,6 +817,15 @@ class CosmeticTestCase(ASTTestCase):
self.check_ast_roundtrip("def f[T: int = int, **P = int, *Ts = *int]():\n pass")
self.check_ast_roundtrip("class C[T: int = int, **P = int, *Ts = *int]():\n pass")
+ def test_tstr(self):
+ self.check_ast_roundtrip("t'{a + b}'")
+ self.check_ast_roundtrip("t'{a + b:x}'")
+ self.check_ast_roundtrip("t'{a + b!s}'")
+ self.check_ast_roundtrip("t'{ {a}}'")
+ self.check_ast_roundtrip("t'{ {a}=}'")
+ self.check_ast_roundtrip("t'{{a}}'")
+ self.check_ast_roundtrip("t''")
+
class ManualASTCreationTestCase(unittest.TestCase):
"""Test that AST nodes created without a type_params field unparse correctly."""
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 90de828cc71..1d889ae7cf4 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -109,7 +109,7 @@ class urlopen_FileTests(unittest.TestCase):
finally:
f.close()
self.pathname = os_helper.TESTFN
- self.quoted_pathname = urllib.parse.quote(self.pathname)
+ self.quoted_pathname = urllib.parse.quote(os.fsencode(self.pathname))
self.returned_obj = urllib.request.urlopen("file:%s" % self.quoted_pathname)
def tearDown(self):
@@ -1551,7 +1551,8 @@ class Pathname_Tests(unittest.TestCase):
urllib.request.url2pathname(url, require_scheme=True),
expected_path)
- error_subtests = [
+ def test_url2pathname_require_scheme_errors(self):
+ subtests = [
'',
':',
'foo',
@@ -1561,13 +1562,21 @@ class Pathname_Tests(unittest.TestCase):
'data:file:foo',
'data:file://foo',
]
- for url in error_subtests:
+ for url in subtests:
with self.subTest(url=url):
self.assertRaises(
urllib.error.URLError,
urllib.request.url2pathname,
url, require_scheme=True)
+ @unittest.skipIf(support.is_emscripten, "Fixed by https://github.com/emscripten-core/emscripten/pull/24593")
+ def test_url2pathname_resolve_host(self):
+ fn = urllib.request.url2pathname
+ sep = os.path.sep
+ self.assertEqual(fn('//127.0.0.1/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+
@unittest.skipUnless(sys.platform == 'win32',
'test specific to Windows pathnames.')
def test_url2pathname_win(self):
@@ -1598,6 +1607,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertEqual(fn('//server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('////server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('/////server/path/to/file'), '\\\\server\\path\\to\\file')
+ self.assertEqual(fn('//127.0.0.1/path/to/file'), '\\\\127.0.0.1\\path\\to\\file')
# Localhost paths
self.assertEqual(fn('//localhost/C:/path/to/file'), 'C:\\path\\to\\file')
self.assertEqual(fn('//localhost/C|/path/to/file'), 'C:\\path\\to\\file')
@@ -1622,8 +1632,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertRaises(urllib.error.URLError, fn, '//:80/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//:/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//c:80/foo/bar')
- self.assertEqual(fn('//127.0.0.1/foo/bar'), '/foo/bar')
- self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), '/foo/bar')
+ self.assertRaises(urllib.error.URLError, fn, '//127.0.0.1/foo/bar')
@unittest.skipUnless(os_helper.FS_NONASCII, 'need os_helper.FS_NONASCII')
def test_url2pathname_nonascii(self):
diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py
index aabc360289a..b2bde5a9b1d 100644
--- a/Lib/test/test_urlparse.py
+++ b/Lib/test/test_urlparse.py
@@ -2,6 +2,7 @@ import sys
import unicodedata
import unittest
import urllib.parse
+from test import support
RFC1808_BASE = "http://a/b/c/d;p?q#f"
RFC2396_BASE = "http://a/b/c/d;p?q"
@@ -156,27 +157,25 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(result3.hostname, result.hostname)
self.assertEqual(result3.port, result.port)
- def test_qsl(self):
- for orig, expect in parse_qsl_test_cases:
- result = urllib.parse.parse_qsl(orig, keep_blank_values=True)
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- expect_without_blanks = [v for v in expect if len(v[1])]
- result = urllib.parse.parse_qsl(orig, keep_blank_values=False)
- self.assertEqual(result, expect_without_blanks,
- "Error parsing %r" % orig)
-
- def test_qs(self):
- for orig, expect in parse_qs_test_cases:
- result = urllib.parse.parse_qs(orig, keep_blank_values=True)
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- expect_without_blanks = {v: expect[v]
- for v in expect if len(expect[v][0])}
- result = urllib.parse.parse_qs(orig, keep_blank_values=False)
- self.assertEqual(result, expect_without_blanks,
- "Error parsing %r" % orig)
-
- def test_roundtrips(self):
- str_cases = [
+ @support.subTests('orig,expect', parse_qsl_test_cases)
+ def test_qsl(self, orig, expect):
+ result = urllib.parse.parse_qsl(orig, keep_blank_values=True)
+ self.assertEqual(result, expect)
+ expect_without_blanks = [v for v in expect if len(v[1])]
+ result = urllib.parse.parse_qsl(orig, keep_blank_values=False)
+ self.assertEqual(result, expect_without_blanks)
+
+ @support.subTests('orig,expect', parse_qs_test_cases)
+ def test_qs(self, orig, expect):
+ result = urllib.parse.parse_qs(orig, keep_blank_values=True)
+ self.assertEqual(result, expect)
+ expect_without_blanks = {v: expect[v]
+ for v in expect if len(expect[v][0])}
+ result = urllib.parse.parse_qs(orig, keep_blank_values=False)
+ self.assertEqual(result, expect_without_blanks)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,parsed,split', [
('path/to/file',
('', '', 'path/to/file', '', '', ''),
('', '', 'path/to/file', '', '')),
@@ -263,23 +262,21 @@ class UrlParseTestCase(unittest.TestCase):
('sch_me:path/to/file',
('', '', 'sch_me:path/to/file', '', '', ''),
('', '', 'sch_me:path/to/file', '', '')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- tuple(x.encode('ascii') for x in t[1]),
- tuple(x.encode('ascii') for x in t[2]))
- bytes_cases = [_encode(x) for x in str_cases]
- str_cases += [
('schème:path/to/file',
('', '', 'schème:path/to/file', '', '', ''),
('', '', 'schème:path/to/file', '', '')),
- ]
- for url, parsed, split in str_cases + bytes_cases:
- with self.subTest(url):
- self.checkRoundtrips(url, parsed, split)
-
- def test_roundtrips_normalization(self):
- str_cases = [
+ ])
+ def test_roundtrips(self, bytes, url, parsed, split):
+ if bytes:
+ if not url.isascii():
+ self.skipTest('non-ASCII bytes')
+ url = str_encode(url)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ self.checkRoundtrips(url, parsed, split)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,url2,parsed,split', [
('///path/to/file',
'/path/to/file',
('', '', '/path/to/file', '', '', ''),
@@ -300,22 +297,18 @@ class UrlParseTestCase(unittest.TestCase):
'https:///tmp/junk.txt',
('https', '', '/tmp/junk.txt', '', '', ''),
('https', '', '/tmp/junk.txt', '', '')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- t[1].encode('ascii'),
- tuple(x.encode('ascii') for x in t[2]),
- tuple(x.encode('ascii') for x in t[3]))
- bytes_cases = [_encode(x) for x in str_cases]
- for url, url2, parsed, split in str_cases + bytes_cases:
- with self.subTest(url):
- self.checkRoundtrips(url, parsed, split, url2)
-
- def test_http_roundtrips(self):
- # urllib.parse.urlsplit treats 'http:' as an optimized special case,
- # so we test both 'http:' and 'https:' in all the following.
- # Three cheers for white box knowledge!
- str_cases = [
+ ])
+ def test_roundtrips_normalization(self, bytes, url, url2, parsed, split):
+ if bytes:
+ url = str_encode(url)
+ url2 = str_encode(url2)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ self.checkRoundtrips(url, parsed, split, url2)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('scheme', ('http', 'https'))
+ @support.subTests('url,parsed,split', [
('://www.python.org',
('www.python.org', '', '', '', ''),
('www.python.org', '', '', '')),
@@ -331,23 +324,20 @@ class UrlParseTestCase(unittest.TestCase):
('://a/b/c/d;p?q#f',
('a', '/b/c/d', 'p', 'q', 'f'),
('a', '/b/c/d;p', 'q', 'f')),
- ]
- def _encode(t):
- return (t[0].encode('ascii'),
- tuple(x.encode('ascii') for x in t[1]),
- tuple(x.encode('ascii') for x in t[2]))
- bytes_cases = [_encode(x) for x in str_cases]
- str_schemes = ('http', 'https')
- bytes_schemes = (b'http', b'https')
- str_tests = str_schemes, str_cases
- bytes_tests = bytes_schemes, bytes_cases
- for schemes, test_cases in (str_tests, bytes_tests):
- for scheme in schemes:
- for url, parsed, split in test_cases:
- url = scheme + url
- parsed = (scheme,) + parsed
- split = (scheme,) + split
- self.checkRoundtrips(url, parsed, split)
+ ])
+ def test_http_roundtrips(self, bytes, scheme, url, parsed, split):
+ # urllib.parse.urlsplit treats 'http:' as an optimized special case,
+ # so we test both 'http:' and 'https:' in all the following.
+ # Three cheers for white box knowledge!
+ if bytes:
+ scheme = str_encode(scheme)
+ url = str_encode(url)
+ parsed = tuple_encode(parsed)
+ split = tuple_encode(split)
+ url = scheme + url
+ parsed = (scheme,) + parsed
+ split = (scheme,) + split
+ self.checkRoundtrips(url, parsed, split)
def checkJoin(self, base, relurl, expected, *, relroundtrip=True):
with self.subTest(base=base, relurl=relurl):
@@ -363,12 +353,13 @@ class UrlParseTestCase(unittest.TestCase):
relurlb = urllib.parse.urlunsplit(urllib.parse.urlsplit(relurlb))
self.assertEqual(urllib.parse.urljoin(baseb, relurlb), expectedb)
- def test_unparse_parse(self):
- str_cases = ['Python', './Python','x-newscheme://foo.com/stuff','x://y','x:/y','x:/','/',]
- bytes_cases = [x.encode('ascii') for x in str_cases]
- for u in str_cases + bytes_cases:
- self.assertEqual(urllib.parse.urlunsplit(urllib.parse.urlsplit(u)), u)
- self.assertEqual(urllib.parse.urlunparse(urllib.parse.urlparse(u)), u)
+ @support.subTests('bytes', (False, True))
+ @support.subTests('u', ['Python', './Python','x-newscheme://foo.com/stuff','x://y','x:/y','x:/','/',])
+ def test_unparse_parse(self, bytes, u):
+ if bytes:
+ u = str_encode(u)
+ self.assertEqual(urllib.parse.urlunsplit(urllib.parse.urlsplit(u)), u)
+ self.assertEqual(urllib.parse.urlunparse(urllib.parse.urlparse(u)), u)
def test_RFC1808(self):
# "normal" cases from RFC 1808:
@@ -695,8 +686,8 @@ class UrlParseTestCase(unittest.TestCase):
self.checkJoin('///b/c', '///w', '///w')
self.checkJoin('///b/c', 'w', '///b/w')
- def test_RFC2732(self):
- str_cases = [
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,hostname,port', [
('http://Test.python.org:5432/foo/', 'test.python.org', 5432),
('http://12.34.56.78:5432/foo/', '12.34.56.78', 5432),
('http://[::1]:5432/foo/', '::1', 5432),
@@ -727,26 +718,28 @@ class UrlParseTestCase(unittest.TestCase):
('http://[::12.34.56.78]:/foo/', '::12.34.56.78', None),
('http://[::ffff:12.34.56.78]:/foo/',
'::ffff:12.34.56.78', None),
- ]
- def _encode(t):
- return t[0].encode('ascii'), t[1].encode('ascii'), t[2]
- bytes_cases = [_encode(x) for x in str_cases]
- for url, hostname, port in str_cases + bytes_cases:
- urlparsed = urllib.parse.urlparse(url)
- self.assertEqual((urlparsed.hostname, urlparsed.port) , (hostname, port))
-
- str_cases = [
+ ])
+ def test_RFC2732(self, bytes, url, hostname, port):
+ if bytes:
+ url = str_encode(url)
+ hostname = str_encode(hostname)
+ urlparsed = urllib.parse.urlparse(url)
+ self.assertEqual((urlparsed.hostname, urlparsed.port), (hostname, port))
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('invalid_url', [
'http://::12.34.56.78]/',
'http://[::1/foo/',
'ftp://[::1/foo/bad]/bad',
'http://[::1/foo/bad]/bad',
- 'http://[::ffff:12.34.56.78']
- bytes_cases = [x.encode('ascii') for x in str_cases]
- for invalid_url in str_cases + bytes_cases:
- self.assertRaises(ValueError, urllib.parse.urlparse, invalid_url)
-
- def test_urldefrag(self):
- str_cases = [
+ 'http://[::ffff:12.34.56.78'])
+ def test_RFC2732_invalid(self, bytes, invalid_url):
+ if bytes:
+ invalid_url = str_encode(invalid_url)
+ self.assertRaises(ValueError, urllib.parse.urlparse, invalid_url)
+
+ @support.subTests('bytes', (False, True))
+ @support.subTests('url,defrag,frag', [
('http://python.org#frag', 'http://python.org', 'frag'),
('http://python.org', 'http://python.org', ''),
('http://python.org/#frag', 'http://python.org/', 'frag'),
@@ -770,18 +763,18 @@ class UrlParseTestCase(unittest.TestCase):
('http:?q#f', 'http:?q', 'f'),
('//a/b/c;p?q#f', '//a/b/c;p?q', 'f'),
('://a/b/c;p?q#f', '://a/b/c;p?q', 'f'),
- ]
- def _encode(t):
- return type(t)(x.encode('ascii') for x in t)
- bytes_cases = [_encode(x) for x in str_cases]
- for url, defrag, frag in str_cases + bytes_cases:
- with self.subTest(url):
- result = urllib.parse.urldefrag(url)
- hash = '#' if isinstance(url, str) else b'#'
- self.assertEqual(result.geturl(), url.rstrip(hash))
- self.assertEqual(result, (defrag, frag))
- self.assertEqual(result.url, defrag)
- self.assertEqual(result.fragment, frag)
+ ])
+ def test_urldefrag(self, bytes, url, defrag, frag):
+ if bytes:
+ url = str_encode(url)
+ defrag = str_encode(defrag)
+ frag = str_encode(frag)
+ result = urllib.parse.urldefrag(url)
+ hash = '#' if isinstance(url, str) else b'#'
+ self.assertEqual(result.geturl(), url.rstrip(hash))
+ self.assertEqual(result, (defrag, frag))
+ self.assertEqual(result.url, defrag)
+ self.assertEqual(result.fragment, frag)
def test_urlsplit_scoped_IPv6(self):
p = urllib.parse.urlsplit('http://[FE80::822a:a8ff:fe49:470c%tESt]:1234')
@@ -981,42 +974,35 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(p.scheme, "https")
self.assertEqual(p.geturl(), "https://www.python.org/")
- def test_attributes_bad_port(self):
+ @support.subTests('bytes', (False, True))
+ @support.subTests('parse', (urllib.parse.urlsplit, urllib.parse.urlparse))
+ @support.subTests('port', ("foo", "1.5", "-1", "0x10", "-0", "1_1", " 1", "1 ", "६"))
+ def test_attributes_bad_port(self, bytes, parse, port):
"""Check handling of invalid ports."""
- for bytes in (False, True):
- for parse in (urllib.parse.urlsplit, urllib.parse.urlparse):
- for port in ("foo", "1.5", "-1", "0x10", "-0", "1_1", " 1", "1 ", "६"):
- with self.subTest(bytes=bytes, parse=parse, port=port):
- netloc = "www.example.net:" + port
- url = "http://" + netloc + "/"
- if bytes:
- if netloc.isascii() and port.isascii():
- netloc = netloc.encode("ascii")
- url = url.encode("ascii")
- else:
- continue
- p = parse(url)
- self.assertEqual(p.netloc, netloc)
- with self.assertRaises(ValueError):
- p.port
+ netloc = "www.example.net:" + port
+ url = "http://" + netloc + "/"
+ if bytes:
+ if not (netloc.isascii() and port.isascii()):
+ self.skipTest('non-ASCII bytes')
+ netloc = str_encode(netloc)
+ url = str_encode(url)
+ p = parse(url)
+ self.assertEqual(p.netloc, netloc)
+ with self.assertRaises(ValueError):
+ p.port
- def test_attributes_bad_scheme(self):
+ @support.subTests('bytes', (False, True))
+ @support.subTests('parse', (urllib.parse.urlsplit, urllib.parse.urlparse))
+ @support.subTests('scheme', (".", "+", "-", "0", "http&", "६http"))
+ def test_attributes_bad_scheme(self, bytes, parse, scheme):
"""Check handling of invalid schemes."""
- for bytes in (False, True):
- for parse in (urllib.parse.urlsplit, urllib.parse.urlparse):
- for scheme in (".", "+", "-", "0", "http&", "६http"):
- with self.subTest(bytes=bytes, parse=parse, scheme=scheme):
- url = scheme + "://www.example.net"
- if bytes:
- if url.isascii():
- url = url.encode("ascii")
- else:
- continue
- p = parse(url)
- if bytes:
- self.assertEqual(p.scheme, b"")
- else:
- self.assertEqual(p.scheme, "")
+ url = scheme + "://www.example.net"
+ if bytes:
+ if not url.isascii():
+ self.skipTest('non-ASCII bytes')
+ url = url.encode("ascii")
+ p = parse(url)
+ self.assertEqual(p.scheme, b"" if bytes else "")
def test_attributes_without_netloc(self):
# This example is straight from RFC 3261. It looks like it
@@ -1128,24 +1114,21 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(urllib.parse.urlparse(b"x-newscheme://foo.com/stuff?query"),
(b'x-newscheme', b'foo.com', b'/stuff', b'', b'query', b''))
- def test_default_scheme(self):
+ @support.subTests('func', (urllib.parse.urlparse, urllib.parse.urlsplit))
+ def test_default_scheme(self, func):
# Exercise the scheme parameter of urlparse() and urlsplit()
- for func in (urllib.parse.urlparse, urllib.parse.urlsplit):
- with self.subTest(function=func):
- result = func("http://example.net/", "ftp")
- self.assertEqual(result.scheme, "http")
- result = func(b"http://example.net/", b"ftp")
- self.assertEqual(result.scheme, b"http")
- self.assertEqual(func("path", "ftp").scheme, "ftp")
- self.assertEqual(func("path", scheme="ftp").scheme, "ftp")
- self.assertEqual(func(b"path", scheme=b"ftp").scheme, b"ftp")
- self.assertEqual(func("path").scheme, "")
- self.assertEqual(func(b"path").scheme, b"")
- self.assertEqual(func(b"path", "").scheme, b"")
-
- def test_parse_fragments(self):
- # Exercise the allow_fragments parameter of urlparse() and urlsplit()
- tests = (
+ result = func("http://example.net/", "ftp")
+ self.assertEqual(result.scheme, "http")
+ result = func(b"http://example.net/", b"ftp")
+ self.assertEqual(result.scheme, b"http")
+ self.assertEqual(func("path", "ftp").scheme, "ftp")
+ self.assertEqual(func("path", scheme="ftp").scheme, "ftp")
+ self.assertEqual(func(b"path", scheme=b"ftp").scheme, b"ftp")
+ self.assertEqual(func("path").scheme, "")
+ self.assertEqual(func(b"path").scheme, b"")
+ self.assertEqual(func(b"path", "").scheme, b"")
+
+ @support.subTests('url,attr,expected_frag', (
("http:#frag", "path", "frag"),
("//example.net#frag", "path", "frag"),
("index.html#frag", "path", "frag"),
@@ -1156,24 +1139,24 @@ class UrlParseTestCase(unittest.TestCase):
("//abc#@frag", "path", "@frag"),
("//abc:80#@frag", "path", "@frag"),
("//abc#@frag:80", "path", "@frag:80"),
- )
- for url, attr, expected_frag in tests:
- for func in (urllib.parse.urlparse, urllib.parse.urlsplit):
- if attr == "params" and func is urllib.parse.urlsplit:
- attr = "path"
- with self.subTest(url=url, function=func):
- result = func(url, allow_fragments=False)
- self.assertEqual(result.fragment, "")
- self.assertEndsWith(getattr(result, attr),
- "#" + expected_frag)
- self.assertEqual(func(url, "", False).fragment, "")
-
- result = func(url, allow_fragments=True)
- self.assertEqual(result.fragment, expected_frag)
- self.assertNotEndsWith(getattr(result, attr), expected_frag)
- self.assertEqual(func(url, "", True).fragment,
- expected_frag)
- self.assertEqual(func(url).fragment, expected_frag)
+ ))
+ @support.subTests('func', (urllib.parse.urlparse, urllib.parse.urlsplit))
+ def test_parse_fragments(self, url, attr, expected_frag, func):
+ # Exercise the allow_fragments parameter of urlparse() and urlsplit()
+ if attr == "params" and func is urllib.parse.urlsplit:
+ attr = "path"
+ result = func(url, allow_fragments=False)
+ self.assertEqual(result.fragment, "")
+ self.assertEndsWith(getattr(result, attr),
+ "#" + expected_frag)
+ self.assertEqual(func(url, "", False).fragment, "")
+
+ result = func(url, allow_fragments=True)
+ self.assertEqual(result.fragment, expected_frag)
+ self.assertNotEndsWith(getattr(result, attr), expected_frag)
+ self.assertEqual(func(url, "", True).fragment,
+ expected_frag)
+ self.assertEqual(func(url).fragment, expected_frag)
def test_mixed_types_rejected(self):
# Several functions that process either strings or ASCII encoded bytes
@@ -1199,7 +1182,14 @@ class UrlParseTestCase(unittest.TestCase):
with self.assertRaisesRegex(TypeError, "Cannot mix str"):
urllib.parse.urljoin(b"http://python.org", "http://python.org")
- def _check_result_type(self, str_type):
+ @support.subTests('result_type', [
+ urllib.parse.DefragResult,
+ urllib.parse.SplitResult,
+ urllib.parse.ParseResult,
+ ])
+ def test_result_pairs(self, result_type):
+ # Check encoding and decoding between result pairs
+ str_type = result_type
num_args = len(str_type._fields)
bytes_type = str_type._encoded_counterpart
self.assertIs(bytes_type._decoded_counterpart, str_type)
@@ -1224,16 +1214,6 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(str_result.encode(encoding, errors), bytes_args)
self.assertEqual(str_result.encode(encoding, errors), bytes_result)
- def test_result_pairs(self):
- # Check encoding and decoding between result pairs
- result_types = [
- urllib.parse.DefragResult,
- urllib.parse.SplitResult,
- urllib.parse.ParseResult,
- ]
- for result_type in result_types:
- self._check_result_type(result_type)
-
def test_parse_qs_encoding(self):
result = urllib.parse.parse_qs("key=\u0141%E9", encoding="latin-1")
self.assertEqual(result, {'key': ['\u0141\xE9']})
@@ -1265,8 +1245,7 @@ class UrlParseTestCase(unittest.TestCase):
urllib.parse.parse_qsl('&'.join(['a=a']*11), max_num_fields=10)
urllib.parse.parse_qsl('&'.join(['a=a']*10), max_num_fields=10)
- def test_parse_qs_separator(self):
- parse_qs_semicolon_cases = [
+ @support.subTests('orig,expect', [
(";", {}),
(";;", {}),
(";a=b", {'a': ['b']}),
@@ -1277,17 +1256,14 @@ class UrlParseTestCase(unittest.TestCase):
(b";a=b", {b'a': [b'b']}),
(b"a=a+b;b=b+c", {b'a': [b'a b'], b'b': [b'b c']}),
(b"a=1;a=2", {b'a': [b'1', b'2']}),
- ]
- for orig, expect in parse_qs_semicolon_cases:
- with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
- result = urllib.parse.parse_qs(orig, separator=';')
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- result_bytes = urllib.parse.parse_qs(orig, separator=b';')
- self.assertEqual(result_bytes, expect, "Error parsing %r" % orig)
-
-
- def test_parse_qsl_separator(self):
- parse_qsl_semicolon_cases = [
+ ])
+ def test_parse_qs_separator(self, orig, expect):
+ result = urllib.parse.parse_qs(orig, separator=';')
+ self.assertEqual(result, expect)
+ result_bytes = urllib.parse.parse_qs(orig, separator=b';')
+ self.assertEqual(result_bytes, expect)
+
+ @support.subTests('orig,expect', [
(";", []),
(";;", []),
(";a=b", [('a', 'b')]),
@@ -1298,13 +1274,12 @@ class UrlParseTestCase(unittest.TestCase):
(b";a=b", [(b'a', b'b')]),
(b"a=a+b;b=b+c", [(b'a', b'a b'), (b'b', b'b c')]),
(b"a=1;a=2", [(b'a', b'1'), (b'a', b'2')]),
- ]
- for orig, expect in parse_qsl_semicolon_cases:
- with self.subTest(f"Original: {orig!r}, Expected: {expect!r}"):
- result = urllib.parse.parse_qsl(orig, separator=';')
- self.assertEqual(result, expect, "Error parsing %r" % orig)
- result_bytes = urllib.parse.parse_qsl(orig, separator=b';')
- self.assertEqual(result_bytes, expect, "Error parsing %r" % orig)
+ ])
+ def test_parse_qsl_separator(self, orig, expect):
+ result = urllib.parse.parse_qsl(orig, separator=';')
+ self.assertEqual(result, expect)
+ result_bytes = urllib.parse.parse_qsl(orig, separator=b';')
+ self.assertEqual(result_bytes, expect)
def test_parse_qsl_bytes(self):
self.assertEqual(urllib.parse.parse_qsl(b'a=b'), [(b'a', b'b')])
@@ -1695,11 +1670,12 @@ class Utility_Tests(unittest.TestCase):
self.assertRaises(UnicodeError, urllib.parse._to_bytes,
'http://www.python.org/medi\u00e6val')
- def test_unwrap(self):
- for wrapped_url in ('<URL:scheme://host/path>', '<scheme://host/path>',
- 'URL:scheme://host/path', 'scheme://host/path'):
- url = urllib.parse.unwrap(wrapped_url)
- self.assertEqual(url, 'scheme://host/path')
+ @support.subTests('wrapped_url',
+ ('<URL:scheme://host/path>', '<scheme://host/path>',
+ 'URL:scheme://host/path', 'scheme://host/path'))
+ def test_unwrap(self, wrapped_url):
+ url = urllib.parse.unwrap(wrapped_url)
+ self.assertEqual(url, 'scheme://host/path')
class DeprecationTest(unittest.TestCase):
@@ -1780,5 +1756,11 @@ class DeprecationTest(unittest.TestCase):
'urllib.parse.to_bytes() is deprecated as of 3.8')
+def str_encode(s):
+ return s.encode('ascii')
+
+def tuple_encode(t):
+ return tuple(str_encode(x) for x in t)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py
index ace84ef564d..75de9ea252d 100644
--- a/Lib/test/test_userdict.py
+++ b/Lib/test/test_userdict.py
@@ -166,7 +166,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
def test_missing(self):
# Make sure UserDict doesn't have a __missing__ method
- self.assertEqual(hasattr(collections.UserDict, "__missing__"), False)
+ self.assertNotHasAttr(collections.UserDict, "__missing__")
# Test several cases:
# (D) subclass defines __missing__ method returning a value
# (E) subclass defines __missing__ method raising RuntimeError
diff --git a/Lib/test/test_uuid.py b/Lib/test/test_uuid.py
index 958be5408ce..7ddacf07a2c 100755
--- a/Lib/test/test_uuid.py
+++ b/Lib/test/test_uuid.py
@@ -14,6 +14,7 @@ from unittest import mock
from test import support
from test.support import import_helper
+from test.support.script_helper import assert_python_ok
py_uuid = import_helper.import_fresh_module('uuid', blocked=['_uuid'])
c_uuid = import_helper.import_fresh_module('uuid', fresh=['_uuid'])
@@ -1217,10 +1218,37 @@ class BaseTestUUID:
class TestUUIDWithoutExtModule(BaseTestUUID, unittest.TestCase):
uuid = py_uuid
+
@unittest.skipUnless(c_uuid, 'requires the C _uuid module')
class TestUUIDWithExtModule(BaseTestUUID, unittest.TestCase):
uuid = c_uuid
+ def check_has_stable_libuuid_extractable_node(self):
+ if not self.uuid._has_stable_extractable_node:
+ self.skipTest("libuuid cannot deduce MAC address")
+
+ @unittest.skipUnless(os.name == 'posix', 'POSIX only')
+ def test_unix_getnode_from_libuuid(self):
+ self.check_has_stable_libuuid_extractable_node()
+ script = 'import uuid; print(uuid._unix_getnode())'
+ _, n_a, _ = assert_python_ok('-c', script)
+ _, n_b, _ = assert_python_ok('-c', script)
+ n_a, n_b = n_a.decode().strip(), n_b.decode().strip()
+ self.assertTrue(n_a.isdigit())
+ self.assertTrue(n_b.isdigit())
+ self.assertEqual(n_a, n_b)
+
+ @unittest.skipUnless(os.name == 'nt', 'Windows only')
+ def test_windows_getnode_from_libuuid(self):
+ self.check_has_stable_libuuid_extractable_node()
+ script = 'import uuid; print(uuid._windll_getnode())'
+ _, n_a, _ = assert_python_ok('-c', script)
+ _, n_b, _ = assert_python_ok('-c', script)
+ n_a, n_b = n_a.decode().strip(), n_b.decode().strip()
+ self.assertTrue(n_a.isdigit())
+ self.assertTrue(n_b.isdigit())
+ self.assertEqual(n_a, n_b)
+
class BaseTestInternals:
_uuid = py_uuid
diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py
index adc86a49b06..d62f3fba2d1 100644
--- a/Lib/test/test_venv.py
+++ b/Lib/test/test_venv.py
@@ -774,7 +774,7 @@ class BasicTest(BaseTest):
with open(script_path, 'rb') as script:
for i, line in enumerate(script, 1):
error_message = f"CR LF found in line {i}"
- self.assertFalse(line.endswith(b'\r\n'), error_message)
+ self.assertNotEndsWith(line, b'\r\n', error_message)
@requireVenvCreate
def test_scm_ignore_files_git(self):
@@ -978,7 +978,7 @@ class EnsurePipTest(BaseTest):
self.assertEqual(err, "")
out = out.decode("latin-1") # Force to text, prevent decoding errors
expected_version = "pip {}".format(ensurepip.version())
- self.assertEqual(out[:len(expected_version)], expected_version)
+ self.assertStartsWith(out, expected_version)
env_dir = os.fsencode(self.env_dir).decode("latin-1")
self.assertIn(env_dir, out)
@@ -1008,7 +1008,7 @@ class EnsurePipTest(BaseTest):
err, flags=re.MULTILINE)
# Ignore warning about missing optional module:
try:
- import ssl
+ import ssl # noqa: F401
except ImportError:
err = re.sub(
"^WARNING: Disabling truststore since ssl support is missing$",
diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py
index 05710c46934..5c3b1250ceb 100644
--- a/Lib/test/test_warnings/__init__.py
+++ b/Lib/test/test_warnings/__init__.py
@@ -102,7 +102,7 @@ class PublicAPITests(BaseTest):
"""
def test_module_all_attribute(self):
- self.assertTrue(hasattr(self.module, '__all__'))
+ self.assertHasAttr(self.module, '__all__')
target_api = ["warn", "warn_explicit", "showwarning",
"formatwarning", "filterwarnings", "simplefilter",
"resetwarnings", "catch_warnings", "deprecated"]
@@ -735,7 +735,7 @@ class CWarnTests(WarnTests, unittest.TestCase):
# test.import_helper.import_fresh_module utility function
def test_accelerated(self):
self.assertIsNot(original_warnings, self.module)
- self.assertFalse(hasattr(self.module.warn, '__code__'))
+ self.assertNotHasAttr(self.module.warn, '__code__')
class PyWarnTests(WarnTests, unittest.TestCase):
module = py_warnings
@@ -744,7 +744,7 @@ class PyWarnTests(WarnTests, unittest.TestCase):
# test.import_helper.import_fresh_module utility function
def test_pure_python(self):
self.assertIsNot(original_warnings, self.module)
- self.assertTrue(hasattr(self.module.warn, '__code__'))
+ self.assertHasAttr(self.module.warn, '__code__')
class WCmdLineTests(BaseTest):
@@ -1528,12 +1528,12 @@ a=A()
# (_warnings will try to import it)
code = "f = open(%a)" % __file__
rc, out, err = assert_python_ok("-Wd", "-c", code)
- self.assertTrue(err.startswith(expected), ascii(err))
+ self.assertStartsWith(err, expected)
# import the warnings module
code = "import warnings; f = open(%a)" % __file__
rc, out, err = assert_python_ok("-Wd", "-c", code)
- self.assertTrue(err.startswith(expected), ascii(err))
+ self.assertStartsWith(err, expected)
class AsyncTests(BaseTest):
diff --git a/Lib/test/test_wave.py b/Lib/test/test_wave.py
index 5e771c8de96..6c3362857fc 100644
--- a/Lib/test/test_wave.py
+++ b/Lib/test/test_wave.py
@@ -136,32 +136,6 @@ class MiscTestCase(unittest.TestCase):
not_exported = {'WAVE_FORMAT_PCM', 'WAVE_FORMAT_EXTENSIBLE', 'KSDATAFORMAT_SUBTYPE_PCM'}
support.check__all__(self, wave, not_exported=not_exported)
- def test_read_deprecations(self):
- filename = support.findfile('pluck-pcm8.wav', subdir='audiodata')
- with wave.open(filename) as reader:
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- reader.getmark('mark')
- with self.assertWarns(DeprecationWarning):
- self.assertIsNone(reader.getmarkers())
-
- def test_write_deprecations(self):
- with io.BytesIO(b'') as tmpfile:
- with wave.open(tmpfile, 'wb') as writer:
- writer.setnchannels(1)
- writer.setsampwidth(1)
- writer.setframerate(1)
- writer.setcomptype('NONE', 'not compressed')
-
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- writer.setmark(0, 0, 'mark')
- with self.assertWarns(DeprecationWarning):
- with self.assertRaises(wave.Error):
- writer.getmark('mark')
- with self.assertWarns(DeprecationWarning):
- self.assertIsNone(writer.getmarkers())
-
class WaveLowLevelTest(unittest.TestCase):
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index 4faad6629fe..4c7c900eb56 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -432,7 +432,7 @@ class ReferencesTestCase(TestBase):
self.assertEqual(proxy.foo, 2,
"proxy does not reflect attribute modification")
del o.foo
- self.assertFalse(hasattr(proxy, 'foo'),
+ self.assertNotHasAttr(proxy, 'foo',
"proxy does not reflect attribute removal")
proxy.foo = 1
@@ -442,7 +442,7 @@ class ReferencesTestCase(TestBase):
self.assertEqual(o.foo, 2,
"object does not reflect attribute modification via proxy")
del proxy.foo
- self.assertFalse(hasattr(o, 'foo'),
+ self.assertNotHasAttr(o, 'foo',
"object does not reflect attribute removal via proxy")
def test_proxy_deletion(self):
@@ -1108,7 +1108,7 @@ class SubclassableWeakrefTestCase(TestBase):
self.assertEqual(r.slot1, "abc")
self.assertEqual(r.slot2, "def")
self.assertEqual(r.meth(), "abcdef")
- self.assertFalse(hasattr(r, "__dict__"))
+ self.assertNotHasAttr(r, "__dict__")
def test_subclass_refs_with_cycle(self):
"""Confirm https://bugs.python.org/issue3100 is fixed."""
diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py
index 76e8e5c8ab7..c1e4f9c8366 100644
--- a/Lib/test/test_weakset.py
+++ b/Lib/test/test_weakset.py
@@ -466,7 +466,7 @@ class TestWeakSet(unittest.TestCase):
self.assertIsNot(dup, s)
self.assertIs(dup.x, s.x)
self.assertIs(dup.z, s.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
dup = copy.deepcopy(s)
self.assertIsInstance(dup, cls)
@@ -476,7 +476,7 @@ class TestWeakSet(unittest.TestCase):
self.assertIsNot(dup.x, s.x)
self.assertEqual(dup.z, s.z)
self.assertIsNot(dup.z, s.z)
- self.assertFalse(hasattr(dup, 'y'))
+ self.assertNotHasAttr(dup, 'y')
if __name__ == "__main__":
diff --git a/Lib/test/test_webbrowser.py b/Lib/test/test_webbrowser.py
index 4c3ea1cd8df..6b577ae100e 100644
--- a/Lib/test/test_webbrowser.py
+++ b/Lib/test/test_webbrowser.py
@@ -6,7 +6,6 @@ import subprocess
import sys
import unittest
import webbrowser
-from functools import partial
from test import support
from test.support import import_helper
from test.support import is_apple_mobile
diff --git a/Lib/test/test_winconsoleio.py b/Lib/test/test_winconsoleio.py
index d9076e77c15..1bae884ed9a 100644
--- a/Lib/test/test_winconsoleio.py
+++ b/Lib/test/test_winconsoleio.py
@@ -17,9 +17,9 @@ ConIO = io._WindowsConsoleIO
class WindowsConsoleIOTests(unittest.TestCase):
def test_abc(self):
- self.assertTrue(issubclass(ConIO, io.RawIOBase))
- self.assertFalse(issubclass(ConIO, io.BufferedIOBase))
- self.assertFalse(issubclass(ConIO, io.TextIOBase))
+ self.assertIsSubclass(ConIO, io.RawIOBase)
+ self.assertNotIsSubclass(ConIO, io.BufferedIOBase)
+ self.assertNotIsSubclass(ConIO, io.TextIOBase)
def test_open_fd(self):
self.assertRaisesRegex(ValueError,
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py
index fd7abd1782e..f16611b29a2 100644
--- a/Lib/test/test_with.py
+++ b/Lib/test/test_with.py
@@ -679,7 +679,7 @@ class AssignmentTargetTestCase(unittest.TestCase):
class C: pass
blah = C()
with mock_contextmanager_generator() as blah.foo:
- self.assertEqual(hasattr(blah, "foo"), True)
+ self.assertHasAttr(blah, "foo")
def testMultipleComplexTargets(self):
class C:
diff --git a/Lib/test/test_wmi.py b/Lib/test/test_wmi.py
index ac7c9cb3a5a..90eb40439d4 100644
--- a/Lib/test/test_wmi.py
+++ b/Lib/test/test_wmi.py
@@ -70,8 +70,8 @@ class WmiTests(unittest.TestCase):
def test_wmi_query_multiple_rows(self):
# Multiple instances should have an extra null separator
r = wmi_exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000")
- self.assertFalse(r.startswith("\0"), r)
- self.assertFalse(r.endswith("\0"), r)
+ self.assertNotStartsWith(r, "\0")
+ self.assertNotEndsWith(r, "\0")
it = iter(r.split("\0"))
try:
while True:
diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py
index b047f7b06f8..e04a4d2c221 100644
--- a/Lib/test/test_wsgiref.py
+++ b/Lib/test/test_wsgiref.py
@@ -149,9 +149,9 @@ class IntegrationTests(TestCase):
start_response("200 OK", ('Content-Type','text/plain'))
return ["Hello, world!"]
out, err = run_amock(validator(bad_app))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(
err.splitlines()[-2],
"AssertionError: Headers (('Content-Type', 'text/plain')) must"
@@ -174,9 +174,9 @@ class IntegrationTests(TestCase):
for status, exc_message in tests:
with self.subTest(status=status):
out, err = run_amock(create_bad_app(status))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(err.splitlines()[-2], exc_message)
def test_wsgi_input(self):
@@ -185,9 +185,9 @@ class IntegrationTests(TestCase):
s("200 OK", [("Content-Type", "text/plain; charset=utf-8")])
return [b"data"]
out, err = run_amock(validator(bad_app))
- self.assertTrue(out.endswith(
+ self.assertEndsWith(out,
b"A server error occurred. Please contact the administrator."
- ))
+ )
self.assertEqual(
err.splitlines()[-2], "AssertionError"
)
@@ -200,7 +200,7 @@ class IntegrationTests(TestCase):
])
return [b"data"]
out, err = run_amock(validator(app))
- self.assertTrue(err.endswith('"GET / HTTP/1.0" 200 4\n'))
+ self.assertEndsWith(err, '"GET / HTTP/1.0" 200 4\n')
ver = sys.version.split()[0].encode('ascii')
py = python_implementation().encode('ascii')
pyver = py + b"/" + ver
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 5fe9d688410..38be2cd437f 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -225,8 +225,7 @@ class ElementTreeTest(unittest.TestCase):
self.assertTrue(ET.iselement(element), msg="not an element")
direlem = dir(element)
for attr in 'tag', 'attrib', 'text', 'tail':
- self.assertTrue(hasattr(element, attr),
- msg='no %s member' % attr)
+ self.assertHasAttr(element, attr)
self.assertIn(attr, direlem,
msg='no %s visible by dir' % attr)
@@ -251,7 +250,7 @@ class ElementTreeTest(unittest.TestCase):
# Make sure all standard element methods exist.
def check_method(method):
- self.assertTrue(hasattr(method, '__call__'),
+ self.assertHasAttr(method, '__call__',
msg="%s not callable" % method)
check_method(element.append)
@@ -2960,6 +2959,50 @@ class BadElementTest(ElementTestCase, unittest.TestCase):
del b
gc_collect()
+ def test_deepcopy_clear(self):
+ # Prevent crashes when __deepcopy__() clears the children list.
+ # See https://github.com/python/cpython/issues/133009.
+ class X(ET.Element):
+ def __deepcopy__(self, memo):
+ root.clear()
+ return self
+
+ root = ET.Element('a')
+ evil = X('x')
+ root.extend([evil, ET.Element('y')])
+ if is_python_implementation():
+ # Mutating a list over which we iterate raises an error.
+ self.assertRaises(RuntimeError, copy.deepcopy, root)
+ else:
+ c = copy.deepcopy(root)
+ # In the C implementation, we can still copy the evil element.
+ self.assertListEqual(list(c), [evil])
+
+ def test_deepcopy_grow(self):
+ # Prevent crashes when __deepcopy__() mutates the children list.
+ # See https://github.com/python/cpython/issues/133009.
+ a = ET.Element('a')
+ b = ET.Element('b')
+ c = ET.Element('c')
+
+ class X(ET.Element):
+ def __deepcopy__(self, memo):
+ root.append(a)
+ root.append(b)
+ return self
+
+ root = ET.Element('top')
+ evil1, evil2 = X('1'), X('2')
+ root.extend([evil1, c, evil2])
+ children = list(copy.deepcopy(root))
+ # mock deep copies
+ self.assertIs(children[0], evil1)
+ self.assertIs(children[2], evil2)
+ # true deep copies
+ self.assertEqual(children[1].tag, c.tag)
+ self.assertEqual([c.tag for c in children[3:]],
+ [a.tag, b.tag, a.tag, b.tag])
+
class MutationDeleteElementPath(str):
def __new__(cls, elem, *args):
diff --git a/Lib/test/test_xxlimited.py b/Lib/test/test_xxlimited.py
index 6dbfb3f4393..b52e78bc4fb 100644
--- a/Lib/test/test_xxlimited.py
+++ b/Lib/test/test_xxlimited.py
@@ -31,7 +31,7 @@ class CommonTests:
self.assertEqual(self.module.foo(1, 2), 3)
def test_str(self):
- self.assertTrue(issubclass(self.module.Str, str))
+ self.assertIsSubclass(self.module.Str, str)
self.assertIsNot(self.module.Str, str)
custom_string = self.module.Str("abcd")
diff --git a/Lib/test/test_zipapp.py b/Lib/test/test_zipapp.py
index d4766c59a10..8fb0a68deba 100644
--- a/Lib/test/test_zipapp.py
+++ b/Lib/test/test_zipapp.py
@@ -259,7 +259,7 @@ class ZipAppTest(unittest.TestCase):
(source / '__main__.py').touch()
target = io.BytesIO()
zipapp.create_archive(str(source), target, interpreter='python')
- self.assertTrue(target.getvalue().startswith(b'#!python\n'))
+ self.assertStartsWith(target.getvalue(), b'#!python\n')
def test_read_shebang(self):
# Test that we can read the shebang line correctly.
@@ -300,7 +300,7 @@ class ZipAppTest(unittest.TestCase):
zipapp.create_archive(str(source), str(target), interpreter='python')
new_target = io.BytesIO()
zipapp.create_archive(str(target), new_target, interpreter='python2.7')
- self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n'))
+ self.assertStartsWith(new_target.getvalue(), b'#!python2.7\n')
def test_read_from_pathlike_obj(self):
# Test that we can copy an archive using a path-like object
@@ -326,7 +326,7 @@ class ZipAppTest(unittest.TestCase):
new_target = io.BytesIO()
temp_archive.seek(0)
zipapp.create_archive(temp_archive, new_target, interpreter='python2.7')
- self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n'))
+ self.assertStartsWith(new_target.getvalue(), b'#!python2.7\n')
def test_remove_shebang(self):
# Test that we can remove the shebang from a file.
diff --git a/Lib/test/test_zipfile/__main__.py b/Lib/test/test_zipfile/__main__.py
index e25ac946edf..90da74ade38 100644
--- a/Lib/test/test_zipfile/__main__.py
+++ b/Lib/test/test_zipfile/__main__.py
@@ -1,6 +1,6 @@
import unittest
-from . import load_tests # noqa: F401
+from . import load_tests
if __name__ == "__main__":
diff --git a/Lib/test/test_zipfile/_path/_test_params.py b/Lib/test/test_zipfile/_path/_test_params.py
index bc95b4ebf4a..00a9eaf2f99 100644
--- a/Lib/test/test_zipfile/_path/_test_params.py
+++ b/Lib/test/test_zipfile/_path/_test_params.py
@@ -1,5 +1,5 @@
-import types
import functools
+import types
from ._itertools import always_iterable
diff --git a/Lib/test/test_zipfile/_path/test_complexity.py b/Lib/test/test_zipfile/_path/test_complexity.py
index b505dd7c376..7c108fc6ab8 100644
--- a/Lib/test/test_zipfile/_path/test_complexity.py
+++ b/Lib/test/test_zipfile/_path/test_complexity.py
@@ -8,10 +8,8 @@ import zipfile
from ._functools import compose
from ._itertools import consume
-
from ._support import import_or_skip
-
big_o = import_or_skip('big_o')
pytest = import_or_skip('pytest')
diff --git a/Lib/test/test_zipfile/_path/test_path.py b/Lib/test/test_zipfile/_path/test_path.py
index 0afabc0c668..696134023a5 100644
--- a/Lib/test/test_zipfile/_path/test_path.py
+++ b/Lib/test/test_zipfile/_path/test_path.py
@@ -1,6 +1,6 @@
+import contextlib
import io
import itertools
-import contextlib
import pathlib
import pickle
import stat
@@ -9,12 +9,11 @@ import unittest
import zipfile
import zipfile._path
-from test.support.os_helper import temp_dir, FakePath
+from test.support.os_helper import FakePath, temp_dir
from ._functools import compose
from ._itertools import Counter
-
-from ._test_params import parameterize, Invoked
+from ._test_params import Invoked, parameterize
class jaraco:
@@ -193,10 +192,10 @@ class TestPath(unittest.TestCase):
"""EncodingWarning must blame the read_text and open calls."""
assert sys.flags.warn_default_encoding
root = zipfile.Path(alpharep)
- with self.assertWarns(EncodingWarning) as wc:
+ with self.assertWarns(EncodingWarning) as wc: # noqa: F821 (astral-sh/ruff#13296)
root.joinpath("a.txt").read_text()
assert __file__ == wc.filename
- with self.assertWarns(EncodingWarning) as wc:
+ with self.assertWarns(EncodingWarning) as wc: # noqa: F821 (astral-sh/ruff#13296)
root.joinpath("a.txt").open("r").close()
assert __file__ == wc.filename
@@ -365,6 +364,17 @@ class TestPath(unittest.TestCase):
assert root.name == 'alpharep.zip' == root.filename.name
@pass_alpharep
+ def test_root_on_disk(self, alpharep):
+ """
+ The name/stem of the root should match the zipfile on disk.
+
+ This condition must hold across platforms.
+ """
+ root = zipfile.Path(self.zipfile_ondisk(alpharep))
+ assert root.name == 'alpharep.zip' == root.filename.name
+ assert root.stem == 'alpharep' == root.filename.stem
+
+ @pass_alpharep
def test_suffix(self, alpharep):
"""
The suffix of the root should be the suffix of the zipfile.
diff --git a/Lib/test/test_zipfile/_path/write-alpharep.py b/Lib/test/test_zipfile/_path/write-alpharep.py
index 48c09b53717..7418391abad 100644
--- a/Lib/test/test_zipfile/_path/write-alpharep.py
+++ b/Lib/test/test_zipfile/_path/write-alpharep.py
@@ -1,4 +1,3 @@
from . import test_path
-
__name__ == '__main__' and test_path.build_alpharep_fixture().extractall('alpharep')
diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py
index 7c8a82d821a..ada96813709 100644
--- a/Lib/test/test_zipfile/test_core.py
+++ b/Lib/test/test_zipfile/test_core.py
@@ -23,11 +23,13 @@ from test import archiver_tests
from test.support import script_helper, os_helper
from test.support import (
findfile, requires_zlib, requires_bz2, requires_lzma,
- captured_stdout, captured_stderr, requires_subprocess,
+ requires_zstd, captured_stdout, captured_stderr, requires_subprocess,
+ cpython_only
)
from test.support.os_helper import (
TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath
)
+from test.support.import_helper import ensure_lazy_imports
TESTFN2 = TESTFN + "2"
@@ -49,6 +51,13 @@ def get_files(test):
yield f
test.assertFalse(f.closed)
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("zipfile", {"typing"})
+
+
class AbstractTestsWithSourceFile:
@classmethod
def setUpClass(cls):
@@ -693,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractTestZip64InSmallFiles:
# These tests test the ZIP64 functionality without using large files,
@@ -1270,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractWriterTests:
@@ -1339,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase):
class LzmaWriterTests(AbstractWriterTests, unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdWriterTests(AbstractWriterTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class PyZipFileTests(unittest.TestCase):
def assertCompiledIn(self, name, namelist):
@@ -1971,6 +1991,25 @@ class OtherTests(unittest.TestCase):
self.assertFalse(zipfile.is_zipfile(fp))
fp.seek(0, 0)
self.assertFalse(zipfile.is_zipfile(fp))
+ # - passing non-zipfile with ZIP header elements
+ # data created using pyPNG like so:
+ # d = [(ord('P'), ord('K'), 5, 6), (ord('P'), ord('K'), 6, 6)]
+ # w = png.Writer(1,2,alpha=True,compression=0)
+ # f = open('onepix.png', 'wb')
+ # w.write(f, d)
+ # w.close()
+ data = (b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
+ b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I"
+ b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07"
+ b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82")
+ # - passing a filename
+ with open(TESTFN, "wb") as fp:
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(TESTFN))
+ # - passing a file-like object
+ fp = io.BytesIO()
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(fp))
def test_damaged_zipfile(self):
"""Check that zipfiles with missing bytes at the end raise BadZipFile."""
@@ -2669,6 +2708,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
b'\x00>\x00\x00\x00\x00\x00')
+@requires_zstd()
+class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
+ zip_with_bad_crc = (
+ b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00'
+ b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00'
+ b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00'
+ b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK'
+ b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00'
+ b'\x00\x00')
class DecryptionTests(unittest.TestCase):
"""Check that ZIP decryption works. Since the library does not
@@ -2896,6 +2946,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
# Provide the tell() method but not seek()
class Tellable:
@@ -3144,7 +3198,7 @@ class TestWithDirectory(unittest.TestCase):
with zipfile.ZipFile(TESTFN, "w") as zipf:
zipf.write(dirpath)
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("/x/"))
+ self.assertEndsWith(zinfo.filename, "/x/")
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
zipf.write(dirpath, "y")
zinfo = zipf.filelist[1]
@@ -3152,7 +3206,7 @@ class TestWithDirectory(unittest.TestCase):
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
with zipfile.ZipFile(TESTFN, "r") as zipf:
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("/x/"))
+ self.assertEndsWith(zinfo.filename, "/x/")
self.assertEqual(zinfo.external_attr, (mode << 16) | 0x10)
zinfo = zipf.filelist[1]
self.assertTrue(zinfo.filename, "y/")
@@ -3172,7 +3226,7 @@ class TestWithDirectory(unittest.TestCase):
self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10)
with zipfile.ZipFile(TESTFN, "r") as zipf:
zinfo = zipf.filelist[0]
- self.assertTrue(zinfo.filename.endswith("x/"))
+ self.assertEndsWith(zinfo.filename, "x/")
self.assertEqual(zinfo.external_attr, (0o40775 << 16) | 0x10)
target = os.path.join(TESTFN2, "target")
os.mkdir(target)
@@ -3607,7 +3661,7 @@ class EncodedMetadataTests(unittest.TestCase):
except OSError:
pass
except UnicodeEncodeError:
- self.skipTest(f'cannot encode file name {fn!r}')
+ self.skipTest(f'cannot encode file name {fn!a}')
zipfile.main(["--metadata-encoding=shift_jis", "-e", TESTFN, TESTFN2])
listing = os.listdir(TESTFN2)
diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py
index 1f288c8b45d..b5b4acf5f85 100644
--- a/Lib/test/test_zipimport.py
+++ b/Lib/test/test_zipimport.py
@@ -835,11 +835,11 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
s = io.StringIO()
print_tb(tb, 1, s)
- self.assertTrue(s.getvalue().endswith(
+ self.assertEndsWith(s.getvalue(),
' def do_raise(): raise TypeError\n'
'' if support.has_no_debug_ranges() else
' ^^^^^^^^^^^^^^^\n'
- ))
+ )
else:
raise AssertionError("This ought to be impossible")
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
index 4d97fe56f3a..c57ab51eca1 100644
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -119,6 +119,114 @@ class ChecksumTestCase(unittest.TestCase):
self.assertEqual(binascii.crc32(b'spam'), zlib.crc32(b'spam'))
+class ChecksumCombineMixin:
+ """Mixin class for testing checksum combination."""
+
+ N = 1000
+ default_iv: int
+
+ def parse_iv(self, iv):
+ """Parse an IV value.
+
+ - The default IV is returned if *iv* is None.
+ - A random IV is returned if *iv* is -1.
+ - Otherwise, *iv* is returned as is.
+ """
+ if iv is None:
+ return self.default_iv
+ if iv == -1:
+ return random.randint(1, 0x80000000)
+ return iv
+
+ def checksum(self, data, init=None):
+ """Compute the checksum of data with a given initial value.
+
+ The *init* value is parsed by ``parse_iv``.
+ """
+ iv = self.parse_iv(init)
+ return self._checksum(data, iv)
+
+ def _checksum(self, data, init):
+ raise NotImplementedError
+
+ def combine(self, a, b, blen):
+ """Combine two checksums together."""
+ raise NotImplementedError
+
+ def get_random_data(self, data_len, *, iv=None):
+ """Get a triplet (data, iv, checksum)."""
+ data = random.randbytes(data_len)
+ init = self.parse_iv(iv)
+ checksum = self.checksum(data, init)
+ return data, init, checksum
+
+ def test_combine_empty(self):
+ for _ in range(self.N):
+ a, iv, checksum = self.get_random_data(32, iv=-1)
+ res = self.combine(iv, self.checksum(a), len(a))
+ self.assertEqual(res, checksum)
+
+ def test_combine_no_iv(self):
+ for _ in range(self.N):
+ a, _, chk_a = self.get_random_data(32)
+ b, _, chk_b = self.get_random_data(64)
+ res = self.combine(chk_a, chk_b, len(b))
+ self.assertEqual(res, self.checksum(a + b))
+
+ def test_combine_no_iv_invalid_length(self):
+ a, _, chk_a = self.get_random_data(32)
+ b, _, chk_b = self.get_random_data(64)
+ checksum = self.checksum(a + b)
+ for invalid_len in [1, len(a), 48, len(b) + 1, 191]:
+ invalid_res = self.combine(chk_a, chk_b, invalid_len)
+ self.assertNotEqual(invalid_res, checksum)
+
+ self.assertRaises(TypeError, self.combine, 0, 0, "len")
+
+ def test_combine_with_iv(self):
+ for _ in range(self.N):
+ a, iv_a, chk_a_with_iv = self.get_random_data(32, iv=-1)
+ chk_a_no_iv = self.checksum(a)
+ b, iv_b, chk_b_with_iv = self.get_random_data(64, iv=-1)
+ chk_b_no_iv = self.checksum(b)
+
+ # We can represent c = COMBINE(CHK(a, iv_a), CHK(b, iv_b)) as:
+ #
+ # c = CHK(CHK(b'', iv_a) + CHK(a) + CHK(b'', iv_b) + CHK(b))
+ # = COMBINE(
+ # COMBINE(CHK(b'', iv_a), CHK(a)),
+ # COMBINE(CHK(b'', iv_b), CHK(b)),
+ # )
+ # = COMBINE(COMBINE(iv_a, CHK(a)), COMBINE(iv_b, CHK(b)))
+ tmp0 = self.combine(iv_a, chk_a_no_iv, len(a))
+ tmp1 = self.combine(iv_b, chk_b_no_iv, len(b))
+ expected = self.combine(tmp0, tmp1, len(b))
+ checksum = self.combine(chk_a_with_iv, chk_b_with_iv, len(b))
+ self.assertEqual(checksum, expected)
+
+
+class CRC32CombineTestCase(ChecksumCombineMixin, unittest.TestCase):
+
+ default_iv = 0
+
+ def _checksum(self, data, init):
+ return zlib.crc32(data, init)
+
+ def combine(self, a, b, blen):
+ return zlib.crc32_combine(a, b, blen)
+
+
+class Adler32CombineTestCase(ChecksumCombineMixin, unittest.TestCase):
+
+ default_iv = 1
+
+ def _checksum(self, data, init):
+ return zlib.adler32(data, init)
+
+ def combine(self, a, b, blen):
+ return zlib.adler32_combine(a, b, blen)
+
+
# Issue #10276 - check that inputs >=4 GiB are handled correctly.
class ChecksumBigBufferTestCase(unittest.TestCase):
diff --git a/Lib/test/test_zoneinfo/test_zoneinfo.py b/Lib/test/test_zoneinfo/test_zoneinfo.py
index b0dbd768cab..f313e394f49 100644
--- a/Lib/test/test_zoneinfo/test_zoneinfo.py
+++ b/Lib/test/test_zoneinfo/test_zoneinfo.py
@@ -237,7 +237,6 @@ class ZoneInfoTest(TzPathUserMixin, ZoneInfoTestBase):
"../zoneinfo/America/Los_Angeles", # Traverses above TZPATH
"America/../America/Los_Angeles", # Not normalized
"America/./Los_Angeles",
- "",
]
for bad_key in bad_keys:
@@ -1916,8 +1915,8 @@ class ExtensionBuiltTest(unittest.TestCase):
def test_cache_location(self):
# The pure Python version stores caches on attributes, but the C
# extension stores them in C globals (at least for now)
- self.assertFalse(hasattr(c_zoneinfo.ZoneInfo, "_weak_cache"))
- self.assertTrue(hasattr(py_zoneinfo.ZoneInfo, "_weak_cache"))
+ self.assertNotHasAttr(c_zoneinfo.ZoneInfo, "_weak_cache")
+ self.assertHasAttr(py_zoneinfo.ZoneInfo, "_weak_cache")
def test_gc_tracked(self):
import gc
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
new file mode 100644
index 00000000000..d4c28aed38e
--- /dev/null
+++ b/Lib/test/test_zstd.py
@@ -0,0 +1,2794 @@
+import array
+import gc
+import io
+import pathlib
+import random
+import re
+import os
+import unittest
+import tempfile
+import threading
+
+from test.support.import_helper import import_module
+from test.support import threading_helper
+from test.support import _1M
+
+_zstd = import_module("_zstd")
+zstd = import_module("compression.zstd")
+
+from compression.zstd import (
+ open,
+ compress,
+ decompress,
+ ZstdCompressor,
+ ZstdDecompressor,
+ ZstdDict,
+ ZstdError,
+ zstd_version,
+ zstd_version_info,
+ COMPRESSION_LEVEL_DEFAULT,
+ get_frame_info,
+ get_frame_size,
+ finalize_dict,
+ train_dict,
+ CompressionParameter,
+ DecompressionParameter,
+ Strategy,
+ ZstdFile,
+)
+
+_1K = 1024
+_130_1K = 130 * _1K
+DICT_SIZE1 = 3*_1K
+
+DAT_130K_D = None
+DAT_130K_C = None
+
+DECOMPRESSED_DAT = None
+COMPRESSED_DAT = None
+
+DECOMPRESSED_100_PLUS_32KB = None
+COMPRESSED_100_PLUS_32KB = None
+
+SKIPPABLE_FRAME = None
+
+THIS_FILE_BYTES = None
+THIS_FILE_STR = None
+COMPRESSED_THIS_FILE = None
+
+COMPRESSED_BOGUS = None
+
+SAMPLES = None
+
+TRAINED_DICT = None
+
+SUPPORT_MULTITHREADING = False
+
+C_INT_MIN = -(2**31)
+C_INT_MAX = (2**31) - 1
+
+
+def setUpModule():
+ global SUPPORT_MULTITHREADING
+ SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0)
+ # uncompressed size 130KB, more than a zstd block.
+ # with a frame epilogue, 4 bytes checksum.
+ global DAT_130K_D
+ DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)])
+
+ global DAT_130K_C
+ DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1})
+
+ global DECOMPRESSED_DAT
+ DECOMPRESSED_DAT = b'abcdefg123456' * 1000
+
+ global COMPRESSED_DAT
+ COMPRESSED_DAT = compress(DECOMPRESSED_DAT)
+
+ global DECOMPRESSED_100_PLUS_32KB
+ DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K)
+
+ global COMPRESSED_100_PLUS_32KB
+ COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB)
+
+ global SKIPPABLE_FRAME
+ SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \
+ (32*_1K).to_bytes(4, byteorder='little') + \
+ b'a' * (32*_1K)
+
+ global THIS_FILE_BYTES, THIS_FILE_STR
+ with io.open(os.path.abspath(__file__), 'rb') as f:
+ THIS_FILE_BYTES = f.read()
+ THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES)
+ THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8')
+
+ global COMPRESSED_THIS_FILE
+ COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES)
+
+ global COMPRESSED_BOGUS
+ COMPRESSED_BOGUS = DECOMPRESSED_DAT
+
+ # dict data
+ words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue',
+ b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive',
+ b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird']
+ lst = []
+ for i in range(300):
+ sample = [b'%s = %d' % (random.choice(words), random.randrange(100))
+ for j in range(20)]
+ sample = b'\n'.join(sample)
+
+ lst.append(sample)
+ global SAMPLES
+ SAMPLES = lst
+ assert len(SAMPLES) > 10
+
+ global TRAINED_DICT
+ TRAINED_DICT = train_dict(SAMPLES, 3*_1K)
+ assert len(TRAINED_DICT.dict_content) <= 3*_1K
+
+
+class FunctionsTestCase(unittest.TestCase):
+
+ def test_version(self):
+ s = ".".join((str(i) for i in zstd_version_info))
+ self.assertEqual(s, zstd_version)
+
+ def test_compressionLevel_values(self):
+ min, max = CompressionParameter.compression_level.bounds()
+ self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int)
+ self.assertIs(type(min), int)
+ self.assertIs(type(max), int)
+ self.assertLess(min, max)
+
+ def test_roundtrip_default(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ dat1 = compress(raw_dat)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_roundtrip_level(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+
+ for level in range(max(-20, level_min), level_max + 1):
+ dat1 = compress(raw_dat, level)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_get_frame_info(self):
+ # no dict
+ info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20])
+ self.assertEqual(info.decompressed_size, 32 * _1K + 100)
+ self.assertEqual(info.dictionary_id, 0)
+
+ # use dict
+ dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT)
+ info = get_frame_info(dat)
+ self.assertEqual(info.decompressed_size, 345)
+ self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id)
+
+ with self.assertRaisesRegex(ZstdError, "not less than the frame header"):
+ get_frame_info(b"aaaaaaaaaaaaaa")
+
+ def test_get_frame_size(self):
+ size = get_frame_size(COMPRESSED_100_PLUS_32KB)
+ self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB))
+
+ with self.assertRaisesRegex(ZstdError, "not less than this complete frame"):
+ get_frame_size(b"aaaaaaaaaaaaaa")
+
+ def test_decompress_2x130_1K(self):
+ decompressed_size = get_frame_info(DAT_130K_C).decompressed_size
+ self.assertEqual(decompressed_size, _130_1K)
+
+ dat = decompress(DAT_130K_C + DAT_130K_C)
+ self.assertEqual(len(dat), 2 * _130_1K)
+
+
+class CompressorTestCase(unittest.TestCase):
+
+ def test_simple_compress_bad_args(self):
+ # ZstdCompressor
+ self.assertRaises(TypeError, ZstdCompressor, [])
+ self.assertRaises(TypeError, ZstdCompressor, level=3.14)
+ self.assertRaises(TypeError, ZstdCompressor, level="abc")
+ self.assertRaises(TypeError, ZstdCompressor, options=b"abc")
+
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})
+
+ # valid range for compression level is [-(1<<17), 22]
+ msg = r'illegal compression level {}; the valid range is \[-?\d+, -?\d+\]'
+ with self.assertRaisesRegex(ValueError, msg.format(C_INT_MAX)):
+ ZstdCompressor(C_INT_MAX)
+ with self.assertRaisesRegex(ValueError, msg.format(C_INT_MIN)):
+ ZstdCompressor(C_INT_MIN)
+ msg = r'illegal compression level; the valid range is \[-?\d+, -?\d+\]'
+ with self.assertRaisesRegex(ValueError, msg):
+ ZstdCompressor(level=-(2**1000))
+ with self.assertRaisesRegex(ValueError, msg):
+ ZstdCompressor(level=2**1000)
+
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.window_log: 100})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={3333: 100})
+
+ # Method bad arguments
+ zc = ZstdCompressor()
+ self.assertRaises(TypeError, zc.compress)
+ self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar")
+ self.assertRaises(TypeError, zc.compress, "str")
+ self.assertRaises((TypeError, ValueError), zc.flush, b"foo")
+ self.assertRaises(TypeError, zc.flush, b"blah", 1)
+
+ self.assertRaises(ValueError, zc.compress, b'', -1)
+ self.assertRaises(ValueError, zc.compress, b'', 3)
+ self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0
+ self.assertRaises(ValueError, zc.flush, 3)
+
+ zc.compress(b'')
+ zc.compress(b'', zc.CONTINUE)
+ zc.compress(b'', zc.FLUSH_BLOCK)
+ zc.compress(b'', zc.FLUSH_FRAME)
+ empty = zc.flush()
+ zc.flush(zc.FLUSH_BLOCK)
+ zc.flush(zc.FLUSH_FRAME)
+
+ def test_compress_parameters(self):
+ d = {CompressionParameter.compression_level : 10,
+
+ CompressionParameter.window_log : 12,
+ CompressionParameter.hash_log : 10,
+ CompressionParameter.chain_log : 12,
+ CompressionParameter.search_log : 12,
+ CompressionParameter.min_match : 4,
+ CompressionParameter.target_length : 12,
+ CompressionParameter.strategy : Strategy.lazy,
+
+ CompressionParameter.enable_long_distance_matching : 1,
+ CompressionParameter.ldm_hash_log : 12,
+ CompressionParameter.ldm_min_match : 11,
+ CompressionParameter.ldm_bucket_size_log : 5,
+ CompressionParameter.ldm_hash_rate_log : 12,
+
+ CompressionParameter.content_size_flag : 1,
+ CompressionParameter.checksum_flag : 1,
+ CompressionParameter.dict_id_flag : 0,
+
+ CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0,
+ }
+ ZstdCompressor(options=d)
+
+ d1 = d.copy()
+ # larger than signed int
+ d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options=d1)
+ # smaller than signed int
+ d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options=d1)
+
+ # out of bounds compression level
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+ with self.assertRaises(ValueError):
+ compress(b'', level_max+1)
+ with self.assertRaises(ValueError):
+ compress(b'', level_min-1)
+ with self.assertRaises(ValueError):
+ compress(b'', 2**1000)
+ with self.assertRaises(ValueError):
+ compress(b'', -(2**1000))
+ with self.assertRaises(ValueError):
+ compress(b'', options={
+ CompressionParameter.compression_level: level_max+1})
+ with self.assertRaises(ValueError):
+ compress(b'', options={
+ CompressionParameter.compression_level: level_min-1})
+
+ # zstd lib doesn't support MT compression
+ if not SUPPORT_MULTITHREADING:
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.nb_workers:4})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.job_size:4})
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={CompressionParameter.overlap_log:4})
+
+ # out of bounds error msg
+ option = {CompressionParameter.window_log:100}
+ with self.assertRaisesRegex(
+ ValueError,
+ "compression parameter 'window_log' received an illegal value 100; "
+ r'the valid range is \[-?\d+, -?\d+\]',
+ ):
+ compress(b'', options=option)
+
+ def test_unknown_compression_parameter(self):
+ KEY = 100001234
+ option = {CompressionParameter.compression_level: 10,
+ KEY: 200000000}
+ pattern = rf"invalid compression parameter 'unknown parameter \(key {KEY}\)'"
+ with self.assertRaisesRegex(ValueError, pattern):
+ ZstdCompressor(options=option)
+
+ @unittest.skipIf(not SUPPORT_MULTITHREADING,
+ "zstd build doesn't support multi-threaded compression")
+ def test_zstd_multithread_compress(self):
+ size = 40*_1M
+ b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES))
+
+ options = {CompressionParameter.compression_level : 4,
+ CompressionParameter.nb_workers : 2}
+
+ # compress()
+ dat1 = compress(b, options=options)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, b)
+
+ # ZstdCompressor
+ c = ZstdCompressor(options=options)
+ dat1 = c.compress(b, c.CONTINUE)
+ dat2 = c.compress(b, c.FLUSH_BLOCK)
+ dat3 = c.compress(b, c.FLUSH_FRAME)
+ dat4 = decompress(dat1+dat2+dat3)
+ self.assertEqual(dat4, b * 3)
+
+ # ZstdFile
+ with ZstdFile(io.BytesIO(), 'w', options=options) as f:
+ f.write(b)
+
+ def test_compress_flushblock(self):
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK)
+ self.assertEqual(c.last_mode, c.FLUSH_BLOCK)
+ dat2 = c.flush()
+ pattern = "Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(dat1)
+
+ dat3 = decompress(dat1 + dat2)
+
+ self.assertEqual(dat3, THIS_FILE_BYTES)
+
+ def test_compress_flushframe(self):
+ # test compress & decompress
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat1)
+ self.assertEqual(nt.decompressed_size, None) # no content size
+
+ dat2 = decompress(dat1)
+
+ self.assertEqual(dat2, THIS_FILE_BYTES)
+
+ # single .FLUSH_FRAME mode has content size
+ c = ZstdCompressor()
+ dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat)
+ self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES))
+
+ def test_compress_empty(self):
+ # output empty content frame
+ self.assertNotEqual(compress(b''), b'')
+
+ c = ZstdCompressor()
+ self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
+
+ def test_set_pledged_input_size(self):
+ DAT = DECOMPRESSED_100_PLUS_32KB
+ CHUNK_SIZE = len(DAT) // 3
+
+ # wrong value
+ c = ZstdCompressor()
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(-300)
+ # overflow
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64)
+ # ZSTD_CONTENTSIZE_ERROR is invalid
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64-2)
+ # ZSTD_CONTENTSIZE_UNKNOWN should use None
+ with self.assertRaisesRegex(ValueError,
+ r'should be a positive int less than \d+'):
+ c.set_pledged_input_size(2**64-1)
+
+ # check valid values are settable
+ c.set_pledged_input_size(2**63)
+ c.set_pledged_input_size(2**64-3)
+
+ # check that zero means empty frame
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(0)
+ c.compress(b'')
+ dat = c.flush()
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, 0)
+
+
+ # wrong mode
+ c = ZstdCompressor(level=1)
+ c.compress(b'123456')
+ self.assertEqual(c.last_mode, c.CONTINUE)
+ with self.assertRaisesRegex(ValueError,
+ r'last_mode == FLUSH_FRAME'):
+ c.set_pledged_input_size(300)
+
+ # None value
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(None)
+ dat = c.compress(DAT) + c.flush()
+
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, None)
+
+ # correct value
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT))
+
+ chunks = []
+ posi = 0
+ while posi < len(DAT):
+ dat = c.compress(DAT[posi:posi+CHUNK_SIZE])
+ posi += CHUNK_SIZE
+ chunks.append(dat)
+
+ dat = c.flush()
+ chunks.append(dat)
+ chunks = b''.join(chunks)
+
+ ret = get_frame_info(chunks)
+ self.assertEqual(ret.decompressed_size, len(DAT))
+ self.assertEqual(decompress(chunks), DAT)
+
+ c.set_pledged_input_size(len(DAT)) # the second frame
+ dat = c.compress(DAT) + c.flush()
+
+ ret = get_frame_info(dat)
+ self.assertEqual(ret.decompressed_size, len(DAT))
+ self.assertEqual(decompress(dat), DAT)
+
+ # not enough data
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT)+1)
+
+ for start in range(0, len(DAT), CHUNK_SIZE):
+ end = min(start+CHUNK_SIZE, len(DAT))
+ _dat = c.compress(DAT[start:end])
+
+ with self.assertRaises(ZstdError):
+ c.flush()
+
+ # too much data
+ c = ZstdCompressor(level=1)
+ c.set_pledged_input_size(len(DAT))
+
+ for start in range(0, len(DAT), CHUNK_SIZE):
+ end = min(start+CHUNK_SIZE, len(DAT))
+ _dat = c.compress(DAT[start:end])
+
+ with self.assertRaises(ZstdError):
+ c.compress(b'extra', ZstdCompressor.FLUSH_FRAME)
+
+ # content size not set if content_size_flag == 0
+ c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0})
+ c.set_pledged_input_size(10)
+ dat1 = c.compress(b"hello")
+ dat2 = c.compress(b"world")
+ dat3 = c.flush()
+ frame_data = get_frame_info(dat1 + dat2 + dat3)
+ self.assertIsNone(frame_data.decompressed_size)
+
+
+class DecompressorTestCase(unittest.TestCase):
+
+ def test_simple_decompress_bad_args(self):
+ # ZstdDecompressor
+ self.assertRaises(TypeError, ZstdDecompressor, ())
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc')
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4})
+
+ self.assertRaises(TypeError, ZstdDecompressor, options=123)
+ self.assertRaises(TypeError, ZstdDecompressor, options='abc')
+ self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')
+
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={C_INT_MAX: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={C_INT_MIN: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={0: C_INT_MAX})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={2**1000: 100})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={-(2**1000): 100})
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(options={0: -(2**1000)})
+
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={DecompressionParameter.window_log_max: 100})
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={3333: 100})
+
+ empty = compress(b'')
+ lzd = ZstdDecompressor()
+ self.assertRaises(TypeError, lzd.decompress)
+ self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar")
+ self.assertRaises(TypeError, lzd.decompress, "str")
+ lzd.decompress(empty)
+
+ def test_decompress_parameters(self):
+ d = {DecompressionParameter.window_log_max : 15}
+ ZstdDecompressor(options=d)
+
+ d1 = d.copy()
+ # larger than signed int
+ d1[DecompressionParameter.window_log_max] = 2**1000
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(None, d1)
+ # smaller than signed int
+ d1[DecompressionParameter.window_log_max] = -(2**1000)
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor(None, d1)
+
+ d1[DecompressionParameter.window_log_max] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(None, d1)
+ d1[DecompressionParameter.window_log_max] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(None, d1)
+
+ # out of bounds error msg
+ options = {DecompressionParameter.window_log_max:100}
+ with self.assertRaisesRegex(
+ ValueError,
+ "decompression parameter 'window_log_max' received an illegal value 100; "
+ r'the valid range is \[-?\d+, -?\d+\]',
+ ):
+ decompress(b'', options=options)
+
+ # out of bounds deecompression parameter
+ options[DecompressionParameter.window_log_max] = C_INT_MAX
+ with self.assertRaises(ValueError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = C_INT_MIN
+ with self.assertRaises(ValueError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = 2**1000
+ with self.assertRaises(OverflowError):
+ decompress(b'', options=options)
+ options[DecompressionParameter.window_log_max] = -(2**1000)
+ with self.assertRaises(OverflowError):
+ decompress(b'', options=options)
+
+ def test_unknown_decompression_parameter(self):
+ KEY = 100001234
+ options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
+ KEY: 200000000}
+ pattern = rf"invalid decompression parameter 'unknown parameter \(key {KEY}\)'"
+ with self.assertRaisesRegex(ValueError, pattern):
+ ZstdDecompressor(options=options)
+
+ def test_decompress_epilogue_flags(self):
+ # DAT_130K_C has a 4 bytes checksum at frame epilogue
+
+ # full unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'')
+
+ # full limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'', 0)
+
+ # [:-4] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-4] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-3] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-3] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-1] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-1] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ def test_decompressor_arg(self):
+ zd = ZstdDict(b'12345678', is_raw=True)
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(zstd_dict={})
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(options=zd)
+
+ ZstdDecompressor()
+ ZstdDecompressor(zd, {})
+ ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25})
+
+ def test_decompressor_1(self):
+ # empty
+ d = ZstdDecompressor()
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+
+ # 130_1K full
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K full, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K, without 4 bytes checksum
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+
+ # above, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # full, unused_data
+ TRAIL = b'89234893abcd'
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C + TRAIL, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, TRAIL)
+
+ def test_decompressor_chunks_read_300(self):
+ TRAIL = b'89234893abcd'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(300)
+ if not dat:
+ break
+ else:
+ raise Exception('should not get here')
+
+ ret = d.decompress(dat)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+ def test_decompressor_chunks_read_3(self):
+ TRAIL = b'89234893'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(3)
+ if not dat:
+ break
+ else:
+ dat = b''
+
+ ret = d.decompress(dat, 1)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+
+ def test_decompress_empty(self):
+ with self.assertRaises(ZstdError):
+ decompress(b'')
+
+ d = ZstdDecompressor()
+ self.assertEqual(d.decompress(b''), b'')
+ self.assertFalse(d.eof)
+
+ def test_decompress_empty_content_frame(self):
+ DAT = compress(b'')
+ # decompress
+ self.assertGreaterEqual(len(DAT), 4)
+ self.assertEqual(decompress(DAT), b'')
+
+ with self.assertRaises(ZstdError):
+ decompress(DAT[:-1])
+
+ # ZstdDecompressor
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT)
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT[:-1])
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+class DecompressorFlagsTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ options = {CompressionParameter.checksum_flag:1}
+ c = ZstdCompressor(options=options)
+
+ cls.DECOMPRESSED_42 = b'a'*42
+ cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME)
+
+ cls.DECOMPRESSED_60 = b'a'*60
+ cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME)
+
+ cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60
+ cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60
+
+ cls._130_1K = 130*_1K
+
+ c = ZstdCompressor()
+ cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush()
+ cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush()
+ cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60
+
+ cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|'
+
+ def test_function_decompress(self):
+
+ self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K)
+
+ # 1 frame
+ self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42)
+
+ pattern = r"Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:1])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-1])
+
+ # 2 frames
+ self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42_60[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.UNKNOWN_FRAME_42_60[:-1])
+
+ # 130_1K
+ self.assertEqual(decompress(DAT_130K_C), DAT_130K_D)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.FRAME_42 + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa')
+
+ # doesn't match checksum
+ checksum = DAT_130K_C[-4:]
+ if checksum[0] == 255:
+ wrong_checksum = bytes([254]) + checksum[1:]
+ else:
+ wrong_checksum = bytes([checksum[0]+1]) + checksum[1:]
+
+ dat = DAT_130K_C[:-4] + wrong_checksum
+
+ with self.assertRaisesRegex(ZstdError, "doesn't match checksum"):
+ decompress(dat)
+
+ def test_function_skippable(self):
+ self.assertEqual(decompress(SKIPPABLE_FRAME), b'')
+ self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'')
+
+ # 1 frame + 2 skippable
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ # unknown size
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_60)
+
+ # 2 frames + 1 skippable
+ self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_42_60)
+
+ # incomplete
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:1])
+
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:-1])
+
+ with self.assertRaises(ZstdError):
+ decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME)
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ def test_decompressor_1(self):
+ # empty 1
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # empty 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # 1 frame
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42 + self.TRAIL)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # 1 frame, 32_1K
+ temp = compress(b'a'*(32*_1K))
+ d = ZstdDecompressor()
+ dat = d.decompress(temp, 32*_1K)
+
+ self.assertEqual(dat, b'a'*(32*_1K))
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, 32_1K+100, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes
+
+ self.assertEqual(len(dat), 100)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ dat = d.decompress(b'') # 32_1K
+
+ self.assertEqual(len(dat), 32*_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # incomplete 1
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_60[:1])
+
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 3
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-1])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ # incomplete 4
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4], 60)
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # Unknown frame descriptor
+ d = ZstdDecompressor()
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ d.decompress(b'aaaaaaaaa')
+
+ def test_decompressor_skippable(self):
+ # 1 skippable
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, max_length=0
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME, 0)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1])
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1], 0)
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+
+
+class ZstdDictTestCase(unittest.TestCase):
+
+ def test_is_raw(self):
+ # must be passed as a keyword argument
+ with self.assertRaises(TypeError):
+ ZstdDict(bytes(8), True)
+
+ # content < 8
+ b = b'1234567'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # content == 8
+ b = b'12345678'
+ zd = ZstdDict(b, is_raw=True)
+ self.assertEqual(zd.dict_id, 0)
+
+ temp = compress(b'aaa12345678', level=3, zstd_dict=zd)
+ self.assertEqual(b'aaa12345678', decompress(temp, zd))
+
+ # is_raw == False
+ b = b'12345678abcd'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # read only attributes
+ with self.assertRaises(AttributeError):
+ zd.dict_content = b
+
+ with self.assertRaises(AttributeError):
+ zd.dict_id = 10000
+
+ # ZstdDict arguments
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+ self.assertNotEqual(zd.dict_id, 0)
+
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True)
+ self.assertNotEqual(zd.dict_id, 0) # note this assertion
+
+ with self.assertRaises(TypeError):
+ ZstdDict("12345678abcdef", is_raw=True)
+ with self.assertRaises(TypeError):
+ ZstdDict(TRAINED_DICT)
+
+ # invalid parameter
+ with self.assertRaises(TypeError):
+ ZstdDict(desk333=345)
+
+ def test_invalid_dict(self):
+ DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little')
+ dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz'
+
+ # corrupted
+ zd = ZstdDict(dict_content, is_raw=False)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?content\.$'):
+ ZstdCompressor(zstd_dict=zd.as_digested_dict)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?content\.$'):
+ ZstdDecompressor(zd)
+
+ # wrong type
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=[zd, 1])
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 1.0))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd,))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 3))
+ with self.assertRaises(OverflowError):
+ ZstdCompressor(zstd_dict=(zd, 2**1000))
+ with self.assertRaises(OverflowError):
+ ZstdCompressor(zstd_dict=(zd, -2**1000))
+
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor(zstd_dict=[zd, 1])
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor(zstd_dict=(zd, 1.0))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd,))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
+ ZstdDecompressor((zd, 3))
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor((zd, 2**1000))
+ with self.assertRaises(OverflowError):
+ ZstdDecompressor((zd, -2**1000))
+
+ def test_train_dict(self):
+ TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
+ ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+
+ self.assertNotEqual(TRAINED_DICT.dict_id, 0)
+ self.assertGreater(len(TRAINED_DICT.dict_content), 0)
+ self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1)
+ self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT)))
+
+ # compress/decompress
+ c = ZstdCompressor(zstd_dict=TRAINED_DICT)
+ for sample in SAMPLES:
+ dat1 = compress(sample, zstd_dict=TRAINED_DICT)
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ def test_finalize_dict(self):
+ DICT_SIZE2 = 200*_1K
+ C_LEVEL = 6
+
+ try:
+ dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL)
+ except NotImplementedError:
+ # < v1.4.5 at compile-time, >= v.1.4.5 at run-time
+ return
+
+ self.assertNotEqual(dic2.dict_id, 0)
+ self.assertGreater(len(dic2.dict_content), 0)
+ self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2)
+
+ # compress/decompress
+ c = ZstdCompressor(C_LEVEL, zstd_dict=dic2)
+ for sample in SAMPLES:
+ dat1 = compress(sample, C_LEVEL, zstd_dict=dic2)
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ # dict mismatch
+ self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id)
+
+ dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT)
+ with self.assertRaises(ZstdError):
+ decompress(dat1, dic2)
+
+ def test_train_dict_arguments(self):
+ with self.assertRaises(ValueError):
+ train_dict([], 100*_1K)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, -100)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, 0)
+
+ def test_finalize_dict_arguments(self):
+ with self.assertRaises(TypeError):
+ finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, [], 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, -100, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, 0, 2)
+
+ def test_train_dict_c(self):
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd.train_dict({}, (), 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(bytearray(), (), 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', 99, 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', [], 100)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', (), 100.1)
+ with self.assertRaises(TypeError):
+ _zstd.train_dict(b'', (99.1,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'abc', (4, -1), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'abc', (2,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (99,), 100)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (2**1000,), 100)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (-2**1000,), 100)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (), 0)
+ with self.assertRaises(ValueError):
+ _zstd.train_dict(b'', (), -1)
+
+ with self.assertRaises(ZstdError):
+ _zstd.train_dict(b'', (), 1)
+
+ def test_finalize_dict_c(self):
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(1, 2, 3, 4, 5)
+
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict({}, b'', (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
+ with self.assertRaises(TypeError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
+
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
+ with self.assertRaises(ValueError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
+
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
+ with self.assertRaises(OverflowError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
+
+ with self.assertRaises(ZstdError):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
+
+ def test_train_buffer_protocol_samples(self):
+ def _nbytes(dat):
+ if isinstance(dat, (bytes, bytearray)):
+ return len(dat)
+ return memoryview(dat).nbytes
+
+ # prepare samples
+ chunk_lst = []
+ wrong_size_lst = []
+ correct_size_lst = []
+ for _ in range(300):
+ arr = array.array('Q', [random.randint(0, 20) for i in range(20)])
+ chunk_lst.append(arr)
+ correct_size_lst.append(_nbytes(arr))
+ wrong_size_lst.append(len(arr))
+ concatenation = b''.join(chunk_lst)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd.train_dict(concatenation, tuple(wrong_size_lst), 100*_1K)
+
+ # correct size list
+ _zstd.train_dict(concatenation, tuple(correct_size_lst), 3*_1K)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd.finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(wrong_size_lst), 300*_1K, 5)
+
+ # correct size list
+ _zstd.finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(correct_size_lst), 300*_1K, 5)
+
+ def test_as_prefix(self):
+ # V1
+ V1 = THIS_FILE_BYTES
+ zd = ZstdDict(V1, is_raw=True)
+
+ # V2
+ mid = len(V1) // 2
+ V2 = V1[:mid] + \
+ (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \
+ V1[mid+1:]
+
+ # compress
+ dat = compress(V2, zstd_dict=zd.as_prefix)
+ self.assertEqual(get_frame_info(dat).dictionary_id, 0)
+
+ # decompress
+ self.assertEqual(decompress(dat, zd.as_prefix), V2)
+
+ # use wrong prefix
+ zd2 = ZstdDict(SAMPLES[0], is_raw=True)
+ try:
+ decompressed = decompress(dat, zd2.as_prefix)
+ except ZstdError: # expected
+ pass
+ else:
+ self.assertNotEqual(decompressed, V2)
+
+ # read only attribute
+ with self.assertRaises(AttributeError):
+ zd.as_prefix = b'1234'
+
+ def test_as_digested_dict(self):
+ zd = TRAINED_DICT
+
+ # test .as_digested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict)
+ self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_digested_dict = b'1234'
+
+ # test .as_undigested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict)
+ self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_undigested_dict = b'1234'
+
+ def test_advanced_compression_parameters(self):
+ options = {CompressionParameter.compression_level: 6,
+ CompressionParameter.window_log: 20,
+ CompressionParameter.enable_long_distance_matching: 1}
+
+ # automatically select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ # explicitly select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ def test_len(self):
+ self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content))
+ self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT))
+
+class FileTestCase(unittest.TestCase):
+ def setUp(self):
+ self.DECOMPRESSED_42 = b'a'*42
+ self.FRAME_42 = compress(self.DECOMPRESSED_42)
+
+ def test_init(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "x") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "a") as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "w", level=12) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ def test_init_with_PathLike_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2)
+
+ os.remove(filename)
+
+ def test_init_with_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename) as f:
+ pass
+ with ZstdFile(filename, "w") as f:
+ pass
+ with ZstdFile(filename, "a") as f:
+ pass
+
+ os.remove(filename)
+
+ def test_init_mode(self):
+ bi = io.BytesIO()
+
+ with ZstdFile(bi, "r"):
+ pass
+ with ZstdFile(bi, "rb"):
+ pass
+ with ZstdFile(bi, "w"):
+ pass
+ with ZstdFile(bi, "wb"):
+ pass
+ with ZstdFile(bi, "a"):
+ pass
+ with ZstdFile(bi, "ab"):
+ pass
+
+ def test_init_with_x_mode(self):
+ with tempfile.NamedTemporaryFile() as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb"):
+ with ZstdFile(filename, mode):
+ pass
+ with self.assertRaises(FileExistsError):
+ with ZstdFile(filename, mode):
+ pass
+ os.remove(filename)
+
+ def test_init_bad_mode(self):
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x"))
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")
+
+ with self.assertRaisesRegex(TypeError,
+ r"not be a CompressionParameter"):
+ ZstdFile(io.BytesIO(), 'rb',
+ options={CompressionParameter.compression_level:5})
+ with self.assertRaisesRegex(TypeError,
+ r"not be a DecompressionParameter"):
+ ZstdFile(io.BytesIO(), 'wb',
+ options={DecompressionParameter.window_log_max:21})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12)
+
+ def test_init_bad_check(self):
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(), "w", level='asd')
+ # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(), "w", options={999:9999})
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)
+
+ with self.assertRaises(OverflowError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:2**31})
+
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={444:333})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456')
+
+ def test_init_close_fp(self):
+ # get a temp file name
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ tmp_f.write(DAT_130K_C)
+ filename = tmp_f.name
+
+ with self.assertRaises(TypeError):
+ ZstdFile(filename, options={'a':'b'})
+
+ # for PyPy
+ gc.collect()
+
+ os.remove(filename)
+
+ def test_close(self):
+ with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src:
+ f = ZstdFile(src)
+ f.close()
+ # ZstdFile.close() should not close the underlying file object.
+ self.assertFalse(src.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+ self.assertFalse(src.closed)
+
+ # Test with a real file on disk, opened directly by ZstdFile.
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ fp = f._fp
+ f.close()
+ # Here, ZstdFile.close() *should* close the underlying file object.
+ self.assertTrue(fp.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+
+ os.remove(filename)
+
+ def test_closed(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.closed)
+ f.read()
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ def test_fileno(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertRaises(io.UnsupportedOperation, f.fileno)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.fileno(), f._fp.fileno())
+ self.assertIsInstance(f.fileno(), int)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ os.remove(filename)
+
+ # 3, no .fileno() method
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'fileno'):
+ f.fileno()
+
+ def test_name(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ with self.assertRaises(AttributeError):
+ f.name
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.name, f._fp.name)
+ self.assertIsInstance(f.name, str)
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ os.remove(filename)
+
+ # 3, no .filename property
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'name'):
+ f.name
+
+ def test_seekable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.seekable())
+ f.read()
+ self.assertTrue(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ src = io.BytesIO(COMPRESSED_100_PLUS_32KB)
+ src.seekable = lambda: False
+ f = ZstdFile(src)
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ def test_readable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.readable())
+ f.read()
+ self.assertTrue(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ def test_writable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.writable())
+ f.read()
+ self.assertFalse(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertTrue(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ def test_read_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertEqual(f.read(0), b"")
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:20}) as f:
+ self.assertEqual(f.read(0), b"")
+
+ # empty file
+ with ZstdFile(io.BytesIO(b'')) as f:
+ self.assertEqual(f.read(0), b"")
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ with ZstdFile(io.BytesIO(b'')) as f:
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ def test_read_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ chunks = []
+ while True:
+ result = f.read(10)
+ if not result:
+ break
+ self.assertLessEqual(len(result), 10)
+ chunks.append(result)
+ self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_read_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT)
+
+ def test_read_incomplete(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f:
+ self.assertRaises(EOFError, f.read)
+
+ # Trailing data isn't a valid compressed stream
+ with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ def test_read_truncated(self):
+ # Drop stream epilogue: 4 bytes checksum
+ truncated = DAT_130K_C[:-4]
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ # this is an important test, make sure it doesn't raise EOFError.
+ self.assertEqual(f.read(130*_1K), DAT_130K_D)
+ with self.assertRaises(EOFError):
+ f.read(1)
+
+ # Incomplete header
+ for i in range(1, 20):
+ with ZstdFile(io.BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
+ def test_read_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_DAT))
+ f.close()
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertRaises(TypeError, f.read, float())
+
+ def test_read_bad_data(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ def test_read_exception(self):
+ class C:
+ def read(self, size=-1):
+ raise OSError
+ with ZstdFile(C()) as f:
+ with self.assertRaises(OSError):
+ f.read(10)
+
+ def test_read1(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DAT_130K_D)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read1(0), b"")
+
+ def test_read1_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ blocks = []
+ while True:
+ result = f.read1(10)
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(TypeError, f.read1, None)
+
+ def test_readinto(self):
+ arr = array.array("I", range(100))
+ self.assertEqual(len(arr), 100)
+ self.assertEqual(len(arr) * arr.itemsize, 400)
+ ba = bytearray(300)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ # 0 length output buffer
+ self.assertEqual(f.readinto(ba[0:0]), 0)
+
+ # use correct length for buffer protocol object
+ self.assertEqual(f.readinto(arr), 400)
+ self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400])
+
+ # normal readinto
+ self.assertEqual(f.readinto(ba), 300)
+ self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700])
+
+ def test_peek(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek()
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek(10)
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+
+ def test_peek_bad_args(self):
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.peek)
+
+ def test_iterator(self):
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ compressed = compress(THIS_FILE_BYTES)
+
+ # iter
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(list(iter(f)), lines)
+
+ # readline
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ for line in lines:
+ self.assertEqual(f.readline(), line)
+ self.assertEqual(f.readline(), b'')
+ self.assertEqual(f.readline(), b'')
+
+ # readlines
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(f.readlines(), lines)
+
+ def test_decompress_limited(self):
+ _ZSTD_DStreamInSize = 128*_1K + 3
+
+ bomb = compress(b'\0' * int(2e6), level=10)
+ self.assertLess(len(bomb), _ZSTD_DStreamInSize)
+
+ decomp = ZstdFile(io.BytesIO(bomb))
+ self.assertEqual(decomp.read(1), b'\0')
+
+ # BufferedReader uses 128 KiB buffer in __init__.py
+ max_decomp = 128*_1K
+ self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp,
+ "Excessive amount of data was decompressed")
+
+ def test_write(self):
+ raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor()
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", level=12) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(12)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1})
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ options = {CompressionParameter.compression_level:-5,
+ CompressionParameter.checksum_flag:1}
+ with ZstdFile(dst, "w",
+ options=options) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options=options)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_empty_frame(self):
+ # .FLUSH_FRAME generates an empty content frame
+ c = ZstdCompressor()
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+
+ # don't generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ pass
+ self.assertEqual(bo.getvalue(), b'')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_FRAME)
+ self.assertEqual(bo.getvalue(), b'')
+
+ # if .write(b''), generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'')
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ # has an empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ def test_write_empty_block(self):
+ # If no internal data, .FLUSH_BLOCK return b''.
+ c = ZstdCompressor()
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK),
+ b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+
+ # mode = .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ f.flush(f.FLUSH_BLOCK)
+ fp_pos = f._fp.tell()
+ self.assertNotEqual(fp_pos, 0)
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), fp_pos)
+
+ # mode != .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+ f.write(b'')
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+
+ def test_write_101(self):
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ for start in range(0, len(THIS_FILE_BYTES), 101):
+ f.write(THIS_FILE_BYTES[start:start+101])
+
+ comp = ZstdCompressor()
+ expected = comp.compress(THIS_FILE_BYTES) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_append(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ part1 = THIS_FILE_BYTES[:_1K]
+ part2 = THIS_FILE_BYTES[_1K:1536]
+ part3 = THIS_FILE_BYTES[1536:]
+ expected = b"".join(comp(x) for x in (part1, part2, part3))
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(part1)
+ with ZstdFile(dst, "a") as f:
+ f.write(part2)
+ with ZstdFile(dst, "a") as f:
+ f.write(part3)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_bad_args(self):
+ f = ZstdFile(io.BytesIO(), "w")
+ f.close()
+ self.assertRaises(ValueError, f.write, b"foo")
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f:
+ self.assertRaises(ValueError, f.write, b"bar")
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(TypeError, f.write, None)
+ self.assertRaises(TypeError, f.write, "text")
+ self.assertRaises(TypeError, f.write, 789)
+
+ def test_writelines(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.writelines(lines)
+ expected = comp(THIS_FILE_BYTES)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_seek_forward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(555)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:])
+
+ def test_seek_forward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:])
+
+ def test_seek_forward_relative_to_current(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(100)
+ f.seek(1236, 1)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:])
+
+ def test_seek_forward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-555, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:])
+
+ def test_seek_backward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(1001)
+ f.seek(211)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:])
+
+ def test_seek_backward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333)
+ f.seek(737)
+ self.assertEqual(f.read(),
+ DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_backward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-150, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:])
+
+ def test_seek_past_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001)
+ self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB))
+ self.assertEqual(f.read(), b"")
+
+ def test_seek_past_start(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-88)
+ self.assertEqual(f.tell(), 0)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(ValueError, f.seek, 0, 3)
+ # io.BufferedReader raises TypeError instead of ValueError
+ self.assertRaises((TypeError, ValueError), f.seek, 9, ())
+ self.assertRaises(TypeError, f.seek, None)
+ self.assertRaises(TypeError, f.seek, b"derp")
+
+ def test_seek_not_seekable(self):
+ class C(io.BytesIO):
+ def seekable(self):
+ return False
+ obj = C(COMPRESSED_100_PLUS_32KB)
+ with ZstdFile(obj, 'r') as f:
+ d = f.read(1)
+ self.assertFalse(f.seekable())
+ with self.assertRaisesRegex(io.UnsupportedOperation,
+ 'File or stream is not seekable'):
+ f.seek(0)
+ d += f.read()
+ self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_tell(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ pos = 0
+ while True:
+ self.assertEqual(f.tell(), pos)
+ result = f.read(random.randint(171, 189))
+ if not result:
+ break
+ pos += len(result)
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+ with ZstdFile(io.BytesIO(), "w") as f:
+ for pos in range(0, len(DAT_130K_D), 143):
+ self.assertEqual(f.tell(), pos)
+ f.write(DAT_130K_D[pos:pos+143])
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+
+ def test_tell_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.tell)
+
+ def test_file_dict(self):
+ # default
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_file_prefix(self):
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_UnsupportedOperation(self):
+ # 1
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.write(b'1234')
+
+ # 2
+ class T:
+ def read(self, size):
+ return b'a' * size
+
+ with self.assertRaises(TypeError): # on creation
+ with ZstdFile(T(), 'w') as f:
+ pass
+
+ # 3
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.read(100)
+ with self.assertRaises(io.UnsupportedOperation):
+ f.seek(100)
+ self.assertEqual(f.closed, True)
+ with self.assertRaises(ValueError):
+ f.readable()
+ with self.assertRaises(ValueError):
+ f.tell()
+ with self.assertRaises(ValueError):
+ f.read(100)
+
+ def test_read_readinto_readinto1(self):
+ lst = []
+ with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f:
+ while True:
+ method = random.randint(0, 2)
+ size = random.randint(0, 300)
+
+ if method == 0:
+ dat = f.read(size)
+ if not dat and size:
+ break
+ lst.append(dat)
+ elif method == 1:
+ ba = bytearray(size)
+ read_size = f.readinto(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ elif method == 2:
+ ba = bytearray(size)
+ read_size = f.readinto1(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5)
+
+ def test_zstdfile_flush(self):
+ # closed
+ f = ZstdFile(io.BytesIO(), 'w')
+ f.close()
+ with self.assertRaises(ValueError):
+ f.flush()
+
+ # read
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ # does nothing for read-only stream
+ f.flush()
+
+ # write
+ DAT = b'abcd'
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertEqual(bi.tell(), 0) # not enough for a block
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertGreater(bi.tell(), 0) # flushed
+
+ # write, no .flush() method
+ class C:
+ def write(self, b):
+ return len(b)
+ with ZstdFile(C(), 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+
+ def test_zstdfile_flush_mode(self):
+ self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK)
+ self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME)
+ with self.assertRaises(AttributeError):
+ ZstdFile.CONTINUE
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ # flush block
+ self.assertEqual(f.write(b'123'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_BLOCK))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush())
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ # flush frame
+ self.assertEqual(f.write(b'456'), 3)
+ self.assertIsNone(f.flush(mode=f.FLUSH_FRAME))
+ # flush frame
+ self.assertEqual(f.write(b'789'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ self.assertEqual(decompress(bo.getvalue()), b'123456789')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'):
+ f.flush(ZstdCompressor.CONTINUE)
+ with self.assertRaises(ValueError):
+ f.flush(-1)
+ with self.assertRaises(ValueError):
+ f.flush(123456)
+ with self.assertRaises(TypeError):
+ f.flush(node=ZstdCompressor.CONTINUE)
+ with self.assertRaises((TypeError, ValueError)):
+ f.flush('FLUSH_FRAME')
+ with self.assertRaises(TypeError):
+ f.flush(b'456', f.FLUSH_BLOCK)
+
+ def test_zstdfile_truncate(self):
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.truncate(200)
+
+ def test_zstdfile_iter_issue45475(self):
+ lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))]
+ self.assertGreater(len(lines), 0)
+
+ def test_append_new_file(self):
+ with tempfile.NamedTemporaryFile(delete=True) as tmp_f:
+ filename = tmp_f.name
+
+ with ZstdFile(filename, 'a') as f:
+ pass
+ self.assertTrue(os.path.isfile(filename))
+
+ os.remove(filename)
+
+class OpenTestCase(unittest.TestCase):
+
+ def test_binary_modes(self):
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with io.BytesIO() as bio:
+ with open(bio, "wb") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+ with open(bio, "ab") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2)
+
+ def test_text_modes(self):
+ # empty input
+ with self.assertRaises(EOFError):
+ with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader:
+ for _ in reader:
+ pass
+
+ # read
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+ with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f:
+ self.assertEqual(f.read(), uncompressed)
+
+ with io.BytesIO() as bio:
+ # write
+ with open(bio, "wt", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ # append
+ with open(bio, "at", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2)
+
+ def test_bad_params(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ with self.assertRaises(ValueError):
+ open(TESTFN, "")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rbt")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", newline="\n")
+
+ os.remove(TESTFN)
+
+ def test_option(self):
+ options = {DecompressionParameter.window_log_max:25}
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ options = {CompressionParameter.compression_level:12}
+ with io.BytesIO() as bio:
+ with open(bio, "wb", options=options) as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_encoding(self):
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-16-le") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-16-le")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-16-le") as f:
+ self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed)
+
+ def test_encoding_error_handler(self):
+ with io.BytesIO(compress(b"foo\xffbar")) as bio:
+ with open(bio, "rt", encoding="ascii", errors="ignore") as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ text = THIS_FILE_STR.replace(os.linesep, "\n")
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-8", newline="\n") as f:
+ f.write(text)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-8", newline="\r") as f:
+ self.assertEqual(f.readlines(), [text])
+
+ def test_x_mode(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb", "xt"):
+ os.remove(TESTFN)
+
+ if mode == "xt":
+ encoding = "utf-8"
+ else:
+ encoding = None
+ with open(TESTFN, mode, encoding=encoding):
+ pass
+ with self.assertRaises(FileExistsError):
+ with open(TESTFN, mode):
+ pass
+
+ os.remove(TESTFN)
+
+ def test_open_dict(self):
+ # default
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # invalid dictionary
+ bi = io.BytesIO()
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict={1:2, 2:3})
+
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict=b'1234567890')
+
+ def test_open_prefix(self):
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_buffer_protocol(self):
+ # don't use len() for buffer protocol objects
+ arr = array.array("i", range(1000))
+ LENGTH = len(arr) * arr.itemsize
+
+ with open(io.BytesIO(), "wb") as f:
+ self.assertEqual(f.write(arr), LENGTH)
+ self.assertEqual(f.tell(), LENGTH)
+
+class FreeThreadingMethodTests(unittest.TestCase):
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_locking(self):
+ input = b'a'* (16*_1K)
+ num_threads = 8
+
+ comp = ZstdCompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ parts.append(res)
+ rest1 = comp.flush()
+ expected = b''.join(parts) + rest1
+
+ comp = ZstdCompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ rest2 = comp.flush()
+ self.assertEqual(rest1, rest2)
+ actual = b''.join(output) + rest2
+ self.assertEqual(expected, actual)
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_locking(self):
+ input = compress(b'a'* (16*_1K))
+ num_threads = 8
+ # to ensure we decompress over multiple calls, set maxsize
+ window_size = _1K * 16//num_threads
+
+ decomp = ZstdDecompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = decomp.decompress(input, window_size)
+ if res:
+ parts.append(res)
+ expected = b''.join(parts)
+
+ comp = ZstdDecompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, window_size)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ actual = b''.join(output)
+ self.assertEqual(expected, actual)
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ level = threading.get_ident() % 4
+ # sync threads to increase chance of contention on
+ # capsule storing dictionary levels
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdCompressor(level=level,
+ zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_shared_dict(self):
+ num_threads = 8
+
+ def run_method(b):
+ # sync threads to increase chance of contention on
+ # decompression dictionary
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict)
+ b.wait()
+ ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix)
+ threads = []
+
+ b = threading.Barrier(num_threads)
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(b,))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/textwrap.py b/Lib/textwrap.py
index 00465f67d09..5ae439f5cd3 100644
--- a/Lib/textwrap.py
+++ b/Lib/textwrap.py
@@ -86,7 +86,7 @@ class TextWrapper:
-(?: (?<=%(lt)s{2}-) | (?<=%(lt)s-%(lt)s-))
(?= %(lt)s -? %(lt)s)
| # end of word
- (?=%(ws)s|\Z)
+ (?=%(ws)s|\z)
| # em-dash
(?<=%(wp)s) (?=-{2,}\w)
)
@@ -107,7 +107,7 @@ class TextWrapper:
sentence_end_re = re.compile(r'[a-z]' # lowercase letter
r'[\.\!\?]' # sentence-ending punct.
r'[\"\']?' # optional end-of-quote
- r'\Z') # end of chunk
+ r'\z') # end of chunk
def __init__(self,
width=70,
diff --git a/Lib/threading.py b/Lib/threading.py
index 39a1a7f4cdf..b6c451d1fba 100644
--- a/Lib/threading.py
+++ b/Lib/threading.py
@@ -123,7 +123,7 @@ def gettrace():
Lock = _LockType
-def RLock(*args, **kwargs):
+def RLock():
"""Factory function that returns a new reentrant lock.
A reentrant lock must be released by the thread that acquired it. Once a
@@ -132,16 +132,9 @@ def RLock(*args, **kwargs):
acquired it.
"""
- if args or kwargs:
- import warnings
- warnings.warn(
- 'Passing arguments to RLock is deprecated and will be removed in 3.15',
- DeprecationWarning,
- stacklevel=2,
- )
if _CRLock is None:
- return _PyRLock(*args, **kwargs)
- return _CRLock(*args, **kwargs)
+ return _PyRLock()
+ return _CRLock()
class _RLock:
"""This class implements reentrant lock objects.
@@ -165,7 +158,7 @@ class _RLock:
except KeyError:
pass
return "<%s %s.%s object owner=%r count=%d at %s>" % (
- "locked" if self._block.locked() else "unlocked",
+ "locked" if self.locked() else "unlocked",
self.__class__.__module__,
self.__class__.__qualname__,
owner,
@@ -244,7 +237,7 @@ class _RLock:
def locked(self):
"""Return whether this object is locked."""
- return self._count > 0
+ return self._block.locked()
# Internal methods used by condition variables
@@ -951,6 +944,8 @@ class Thread:
# This thread is alive.
self._ident = new_ident
assert self._os_thread_handle.ident == new_ident
+ if _HAVE_THREAD_NATIVE_ID:
+ self._set_native_id()
else:
# Otherwise, the thread is dead, Jim. _PyThread_AfterFork()
# already marked our handle done.
diff --git a/Lib/tokenize.py b/Lib/tokenize.py
index edb1ed8bdb8..7e71755068e 100644
--- a/Lib/tokenize.py
+++ b/Lib/tokenize.py
@@ -86,7 +86,7 @@ def _all_string_prefixes():
# The valid string prefixes. Only contain the lower case versions,
# and don't contain any permutations (include 'fr', but not
# 'rf'). The various permutations will be generated.
- _valid_string_prefixes = ['b', 'r', 'u', 'f', 'br', 'fr']
+ _valid_string_prefixes = ['b', 'r', 'u', 'f', 't', 'br', 'fr', 'tr']
# if we add binary f-strings, add: ['fb', 'fbr']
result = {''}
for prefix in _valid_string_prefixes:
@@ -132,7 +132,7 @@ ContStr = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*" +
group("'", r'\\\r?\n'),
StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*' +
group('"', r'\\\r?\n'))
-PseudoExtras = group(r'\\\r?\n|\Z', Comment, Triple)
+PseudoExtras = group(r'\\\r?\n|\z', Comment, Triple)
PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
# For a given string prefix plus quotes, endpats maps it to a regex
@@ -274,7 +274,7 @@ class Untokenizer:
toks_append = self.tokens.append
startline = token[0] in (NEWLINE, NL)
prevstring = False
- in_fstring = 0
+ in_fstring_or_tstring = 0
for tok in _itertools.chain([token], iterable):
toknum, tokval = tok[:2]
@@ -293,10 +293,10 @@ class Untokenizer:
else:
prevstring = False
- if toknum == FSTRING_START:
- in_fstring += 1
- elif toknum == FSTRING_END:
- in_fstring -= 1
+ if toknum in {FSTRING_START, TSTRING_START}:
+ in_fstring_or_tstring += 1
+ elif toknum in {FSTRING_END, TSTRING_END}:
+ in_fstring_or_tstring -= 1
if toknum == INDENT:
indents.append(tokval)
continue
@@ -311,8 +311,8 @@ class Untokenizer:
elif toknum in {FSTRING_MIDDLE, TSTRING_MIDDLE}:
tokval = self.escape_brackets(tokval)
- # Insert a space between two consecutive brackets if we are in an f-string
- if tokval in {"{", "}"} and self.tokens and self.tokens[-1] == tokval and in_fstring:
+ # Insert a space between two consecutive brackets if we are in an f-string or t-string
+ if tokval in {"{", "}"} and self.tokens and self.tokens[-1] == tokval and in_fstring_or_tstring:
tokval = ' ' + tokval
# Insert a space between two consecutive f-strings
@@ -518,7 +518,7 @@ def _main(args=None):
sys.exit(1)
# Parse the arguments and options
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(dest='filename', nargs='?',
metavar='filename.py',
help='the file to tokenize; defaults to stdin')
diff --git a/Lib/tomllib/_parser.py b/Lib/tomllib/_parser.py
index da56af3f34d..3ee47aa9e0a 100644
--- a/Lib/tomllib/_parser.py
+++ b/Lib/tomllib/_parser.py
@@ -214,7 +214,7 @@ class Flags:
EXPLICIT_NEST = 1
def __init__(self) -> None:
- self._flags: dict[str, dict] = {}
+ self._flags: dict[str, dict[Any, Any]] = {}
self._pending_flags: set[tuple[Key, int]] = set()
def add_pending(self, key: Key, flag: int) -> None:
@@ -272,7 +272,7 @@ class NestedDict:
key: Key,
*,
access_lists: bool = True,
- ) -> dict:
+ ) -> dict[str, Any]:
cont: Any = self.dict
for k in key:
if k not in cont:
@@ -486,9 +486,9 @@ def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
return parse_basic_str(src, pos, multiline=False)
-def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]:
+def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list[Any]]:
pos += 1
- array: list = []
+ array: list[Any] = []
pos = skip_comments_and_array_ws(src, pos)
if src.startswith("]", pos):
@@ -510,7 +510,7 @@ def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]
return pos + 1, array
-def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]:
+def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict[str, Any]]:
pos += 1
nested_dict = NestedDict()
flags = Flags()
diff --git a/Lib/tomllib/_re.py b/Lib/tomllib/_re.py
index 1ca6bef77a0..eb8beb19747 100644
--- a/Lib/tomllib/_re.py
+++ b/Lib/tomllib/_re.py
@@ -52,7 +52,7 @@ RE_DATETIME = re.compile(
)
-def match_to_datetime(match: re.Match) -> datetime | date:
+def match_to_datetime(match: re.Match[str]) -> datetime | date:
"""Convert a `RE_DATETIME` match to `datetime.datetime` or `datetime.date`.
Raises ValueError if the match does not correspond to a valid date
@@ -101,13 +101,13 @@ def cached_tz(hour_str: str, minute_str: str, sign_str: str) -> timezone:
)
-def match_to_localtime(match: re.Match) -> time:
+def match_to_localtime(match: re.Match[str]) -> time:
hour_str, minute_str, sec_str, micros_str = match.groups()
micros = int(micros_str.ljust(6, "0")) if micros_str else 0
return time(int(hour_str), int(minute_str), int(sec_str), micros)
-def match_to_number(match: re.Match, parse_float: ParseFloat) -> Any:
+def match_to_number(match: re.Match[str], parse_float: ParseFloat) -> Any:
if match.group("floatpart"):
return parse_float(match.group())
return int(match.group(), 0)
diff --git a/Lib/tomllib/mypy.ini b/Lib/tomllib/mypy.ini
index 0297d19e2c8..1761dce4556 100644
--- a/Lib/tomllib/mypy.ini
+++ b/Lib/tomllib/mypy.ini
@@ -15,5 +15,3 @@ strict = True
strict_bytes = True
local_partial_types = True
warn_unreachable = True
-# TODO(@sobolevn): remove this setting and refactor any found problems
-disallow_any_generics = False \ No newline at end of file
diff --git a/Lib/trace.py b/Lib/trace.py
index a87bc6d61a8..cf8817f4383 100644
--- a/Lib/trace.py
+++ b/Lib/trace.py
@@ -604,7 +604,7 @@ class Trace:
def main():
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('--version', action='version', version='trace 2.0')
grp = parser.add_argument_group('Main options',
diff --git a/Lib/traceback.py b/Lib/traceback.py
index 16ba7fc2ee8..a1f175dbbaa 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -10,9 +10,9 @@ import codeop
import keyword
import tokenize
import io
-from contextlib import suppress
import _colorize
-from _colorize import ANSIColors
+
+from contextlib import suppress
__all__ = ['extract_stack', 'extract_tb', 'format_exception',
'format_exception_only', 'format_list', 'format_stack',
@@ -187,15 +187,13 @@ def _format_final_exc_line(etype, value, *, insert_final_newline=True, colorize=
valuestr = _safe_string(value, 'exception')
end_char = "\n" if insert_final_newline else ""
if colorize:
- if value is None or not valuestr:
- line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}{end_char}"
- else:
- line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}: {ANSIColors.MAGENTA}{valuestr}{ANSIColors.RESET}{end_char}"
+ theme = _colorize.get_theme(force_color=True).traceback
else:
- if value is None or not valuestr:
- line = f"{etype}{end_char}"
- else:
- line = f"{etype}: {valuestr}{end_char}"
+ theme = _colorize.get_theme(force_no_color=True).traceback
+ if value is None or not valuestr:
+ line = f"{theme.type}{etype}{theme.reset}{end_char}"
+ else:
+ line = f"{theme.type}{etype}{theme.reset}: {theme.message}{valuestr}{theme.reset}{end_char}"
return line
@@ -539,21 +537,22 @@ class StackSummary(list):
if frame_summary.filename.startswith("<stdin>-"):
filename = "<stdin>"
if colorize:
- row.append(' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format(
- ANSIColors.MAGENTA,
- filename,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- frame_summary.lineno,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- frame_summary.name,
- ANSIColors.RESET,
- )
- )
+ theme = _colorize.get_theme(force_color=True).traceback
else:
- row.append(' File "{}", line {}, in {}\n'.format(
- filename, frame_summary.lineno, frame_summary.name))
+ theme = _colorize.get_theme(force_no_color=True).traceback
+ row.append(
+ ' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format(
+ theme.filename,
+ filename,
+ theme.reset,
+ theme.line_no,
+ frame_summary.lineno,
+ theme.reset,
+ theme.frame,
+ frame_summary.name,
+ theme.reset,
+ )
+ )
if frame_summary._dedented_lines and frame_summary._dedented_lines.strip():
if (
frame_summary.colno is None or
@@ -672,11 +671,11 @@ class StackSummary(list):
for color, group in itertools.groupby(itertools.zip_longest(line, carets, fillvalue=""), key=lambda x: x[1]):
caret_group = list(group)
if color == "^":
- colorized_line_parts.append(ANSIColors.BOLD_RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET)
- colorized_carets_parts.append(ANSIColors.BOLD_RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET)
+ colorized_line_parts.append(theme.error_highlight + "".join(char for char, _ in caret_group) + theme.reset)
+ colorized_carets_parts.append(theme.error_highlight + "".join(caret for _, caret in caret_group) + theme.reset)
elif color == "~":
- colorized_line_parts.append(ANSIColors.RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET)
- colorized_carets_parts.append(ANSIColors.RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET)
+ colorized_line_parts.append(theme.error_range + "".join(char for char, _ in caret_group) + theme.reset)
+ colorized_carets_parts.append(theme.error_range + "".join(caret for _, caret in caret_group) + theme.reset)
else:
colorized_line_parts.append("".join(char for char, _ in caret_group))
colorized_carets_parts.append("".join(caret for _, caret in caret_group))
@@ -1378,20 +1377,20 @@ class TracebackException:
"""Format SyntaxError exceptions (internal helper)."""
# Show exactly where the problem was found.
colorize = kwargs.get("colorize", False)
+ if colorize:
+ theme = _colorize.get_theme(force_color=True).traceback
+ else:
+ theme = _colorize.get_theme(force_no_color=True).traceback
filename_suffix = ''
if self.lineno is not None:
- if colorize:
- yield ' File {}"{}"{}, line {}{}{}\n'.format(
- ANSIColors.MAGENTA,
- self.filename or "<string>",
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- self.lineno,
- ANSIColors.RESET,
- )
- else:
- yield ' File "{}", line {}\n'.format(
- self.filename or "<string>", self.lineno)
+ yield ' File {}"{}"{}, line {}{}{}\n'.format(
+ theme.filename,
+ self.filename or "<string>",
+ theme.reset,
+ theme.line_no,
+ self.lineno,
+ theme.reset,
+ )
elif self.filename is not None:
filename_suffix = ' ({})'.format(self.filename)
@@ -1441,11 +1440,11 @@ class TracebackException:
# colorize from colno to end_colno
ltext = (
ltext[:colno] +
- ANSIColors.BOLD_RED + ltext[colno:end_colno] + ANSIColors.RESET +
+ theme.error_highlight + ltext[colno:end_colno] + theme.reset +
ltext[end_colno:]
)
- start_color = ANSIColors.BOLD_RED
- end_color = ANSIColors.RESET
+ start_color = theme.error_highlight
+ end_color = theme.reset
yield ' {}\n'.format(ltext)
yield ' {}{}{}{}\n'.format(
"".join(caretspace),
@@ -1456,17 +1455,15 @@ class TracebackException:
else:
yield ' {}\n'.format(ltext)
msg = self.msg or "<no detail available>"
- if colorize:
- yield "{}{}{}: {}{}{}{}\n".format(
- ANSIColors.BOLD_MAGENTA,
- stype,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- msg,
- ANSIColors.RESET,
- filename_suffix)
- else:
- yield "{}: {}{}\n".format(stype, msg, filename_suffix)
+ yield "{}{}{}: {}{}{}{}\n".format(
+ theme.type,
+ stype,
+ theme.reset,
+ theme.message,
+ msg,
+ theme.reset,
+ filename_suffix,
+ )
def format(self, *, chain=True, _ctx=None, **kwargs):
"""Format the exception.
@@ -1598,7 +1595,11 @@ def _compute_suggestion_error(exc_value, tb, wrong_name):
if isinstance(exc_value, AttributeError):
obj = exc_value.obj
try:
- d = dir(obj)
+ try:
+ d = dir(obj)
+ except TypeError: # Attributes are unsortable, e.g. int and str
+ d = list(obj.__class__.__dict__.keys()) + list(obj.__dict__.keys())
+ d = sorted([x for x in d if isinstance(x, str)])
hide_underscored = (wrong_name[:1] != '_')
if hide_underscored and tb is not None:
while tb.tb_next is not None:
@@ -1613,7 +1614,11 @@ def _compute_suggestion_error(exc_value, tb, wrong_name):
elif isinstance(exc_value, ImportError):
try:
mod = __import__(exc_value.name)
- d = dir(mod)
+ try:
+ d = dir(mod)
+ except TypeError: # Attributes are unsortable, e.g. int and str
+ d = list(mod.__dict__.keys())
+ d = sorted([x for x in d if isinstance(x, str)])
if wrong_name[:1] != '_':
d = [x for x in d if x[:1] != '_']
except Exception:
@@ -1631,6 +1636,7 @@ def _compute_suggestion_error(exc_value, tb, wrong_name):
+ list(frame.f_globals)
+ list(frame.f_builtins)
)
+ d = [x for x in d if isinstance(x, str)]
# Check first if we are in a method and the instance
# has the wrong name as attribute
diff --git a/Lib/types.py b/Lib/types.py
index 6efac339434..cf0549315a7 100644
--- a/Lib/types.py
+++ b/Lib/types.py
@@ -250,7 +250,6 @@ class DynamicClassAttribute:
class _GeneratorWrapper:
- # TODO: Implement this in C.
def __init__(self, gen):
self.__wrapped = gen
self.__isgen = gen.__class__ is GeneratorType
@@ -305,7 +304,6 @@ def coroutine(func):
# Check if 'func' is a generator function.
# (0x20 == CO_GENERATOR)
if co_flags & 0x20:
- # TODO: Implement this in C.
co = func.__code__
# 0x100 == CO_ITERABLE_COROUTINE
func.__code__ = co.replace(co_flags=co.co_flags | 0x100)
diff --git a/Lib/typing.py b/Lib/typing.py
index f70dcd0b5b7..ed1dd4fc641 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -956,12 +956,8 @@ def evaluate_forward_ref(
"""Evaluate a forward reference as a type hint.
This is similar to calling the ForwardRef.evaluate() method,
- but unlike that method, evaluate_forward_ref() also:
-
- * Recursively evaluates forward references nested within the type hint.
- * Rejects certain objects that are not valid type hints.
- * Replaces type hints that evaluate to None with types.NoneType.
- * Supports the *FORWARDREF* and *STRING* formats.
+ but unlike that method, evaluate_forward_ref() also
+ recursively evaluates forward references nested within the type hint.
*forward_ref* must be an instance of ForwardRef. *owner*, if given,
should be the object that holds the annotations that the forward reference
@@ -981,23 +977,24 @@ def evaluate_forward_ref(
if forward_ref.__forward_arg__ in _recursive_guard:
return forward_ref
- try:
- value = forward_ref.evaluate(globals=globals, locals=locals,
- type_params=type_params, owner=owner)
- except NameError:
- if format == _lazy_annotationlib.Format.FORWARDREF:
- return forward_ref
- else:
- raise
-
- type_ = _type_check(
- value,
- "Forward references must evaluate to types.",
- is_argument=forward_ref.__forward_is_argument__,
- allow_special_forms=forward_ref.__forward_is_class__,
- )
+ if format is None:
+ format = _lazy_annotationlib.Format.VALUE
+ value = forward_ref.evaluate(globals=globals, locals=locals,
+ type_params=type_params, owner=owner, format=format)
+
+ if (isinstance(value, _lazy_annotationlib.ForwardRef)
+ and format == _lazy_annotationlib.Format.FORWARDREF):
+ return value
+
+ if isinstance(value, str):
+ value = _make_forward_ref(value, module=forward_ref.__forward_module__,
+ owner=owner or forward_ref.__owner__,
+ is_argument=forward_ref.__forward_is_argument__,
+ is_class=forward_ref.__forward_is_class__)
+ if owner is None:
+ owner = forward_ref.__owner__
return _eval_type(
- type_,
+ value,
globals,
locals,
type_params,
@@ -1649,6 +1646,9 @@ class _UnionGenericAliasMeta(type):
return True
return NotImplemented
+ def __hash__(self):
+ return hash(Union)
+
class _UnionGenericAlias(metaclass=_UnionGenericAliasMeta):
"""Compatibility hack.
@@ -2335,12 +2335,12 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
# This only affects ForwardRefs.
base_globals, base_locals = base_locals, base_globals
for name, value in ann.items():
- if value is None:
- value = type(None)
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
value = _eval_type(value, base_globals, base_locals, base.__type_params__,
format=format, owner=obj)
+ if value is None:
+ value = type(None)
hints[name] = value
if include_extras or format == Format.STRING:
return hints
@@ -2374,8 +2374,6 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
localns = globalns
type_params = getattr(obj, "__type_params__", ())
for name, value in hints.items():
- if value is None:
- value = type(None)
if isinstance(value, str):
# class-level forward refs were handled above, this must be either
# a module-level annotation or a function argument annotation
@@ -2384,7 +2382,10 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
- hints[name] = _eval_type(value, globalns, localns, type_params, format=format, owner=obj)
+ value = _eval_type(value, globalns, localns, type_params, format=format, owner=obj)
+ if value is None:
+ value = type(None)
+ hints[name] = value
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
@@ -2906,7 +2907,7 @@ class NamedTupleMeta(type):
types = ns["__annotations__"]
field_names = list(types)
annotate = _make_eager_annotate(types)
- elif (original_annotate := _lazy_annotationlib.get_annotate_function(ns)) is not None:
+ elif (original_annotate := _lazy_annotationlib.get_annotate_from_class_namespace(ns)) is not None:
types = _lazy_annotationlib.call_annotate_function(
original_annotate, _lazy_annotationlib.Format.FORWARDREF)
field_names = list(types)
@@ -2968,7 +2969,7 @@ class NamedTupleMeta(type):
return nm_tpl
-def NamedTuple(typename, fields=_sentinel, /, **kwargs):
+def NamedTuple(typename, fields, /):
"""Typed version of namedtuple.
Usage::
@@ -2988,48 +2989,9 @@ def NamedTuple(typename, fields=_sentinel, /, **kwargs):
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
"""
- if fields is _sentinel:
- if kwargs:
- deprecated_thing = "Creating NamedTuple classes using keyword arguments"
- deprecation_msg = (
- "{name} is deprecated and will be disallowed in Python {remove}. "
- "Use the class-based or functional syntax instead."
- )
- else:
- deprecated_thing = "Failing to pass a value for the 'fields' parameter"
- example = f"`{typename} = NamedTuple({typename!r}, [])`"
- deprecation_msg = (
- "{name} is deprecated and will be disallowed in Python {remove}. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. "
- ) + example + "."
- elif fields is None:
- if kwargs:
- raise TypeError(
- "Cannot pass `None` as the 'fields' parameter "
- "and also specify fields using keyword arguments"
- )
- else:
- deprecated_thing = "Passing `None` as the 'fields' parameter"
- example = f"`{typename} = NamedTuple({typename!r}, [])`"
- deprecation_msg = (
- "{name} is deprecated and will be disallowed in Python {remove}. "
- "To create a NamedTuple class with 0 fields "
- "using the functional syntax, "
- "pass an empty list, e.g. "
- ) + example + "."
- elif kwargs:
- raise TypeError("Either list of fields or keywords"
- " can be provided to NamedTuple, not both")
- if fields is _sentinel or fields is None:
- import warnings
- warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15))
- fields = kwargs.items()
types = {n: _type_check(t, f"field {n} annotation must be a type")
for n, t in fields}
field_names = [n for n, _ in fields]
-
nt = _make_nmtuple(typename, field_names, _make_eager_annotate(types), module=_caller())
nt.__orig_bases__ = (NamedTuple,)
return nt
@@ -3084,15 +3046,17 @@ class _TypedDictMeta(type):
else:
generic_base = ()
+ ns_annotations = ns.pop('__annotations__', None)
+
tp_dict = type.__new__(_TypedDictMeta, name, (*generic_base, dict), ns)
if not hasattr(tp_dict, '__orig_bases__'):
tp_dict.__orig_bases__ = bases
- if "__annotations__" in ns:
+ if ns_annotations is not None:
own_annotate = None
- own_annotations = ns["__annotations__"]
- elif (own_annotate := _lazy_annotationlib.get_annotate_function(ns)) is not None:
+ own_annotations = ns_annotations
+ elif (own_annotate := _lazy_annotationlib.get_annotate_from_class_namespace(ns)) is not None:
own_annotations = _lazy_annotationlib.call_annotate_function(
own_annotate, _lazy_annotationlib.Format.FORWARDREF, owner=tp_dict
)
@@ -3162,7 +3126,7 @@ class _TypedDictMeta(type):
if base_annotate is None:
continue
base_annos = _lazy_annotationlib.call_annotate_function(
- base.__annotate__, format, owner=base)
+ base_annotate, format, owner=base)
annos.update(base_annos)
if own_annotate is not None:
own = _lazy_annotationlib.call_annotate_function(
@@ -3198,7 +3162,7 @@ class _TypedDictMeta(type):
__instancecheck__ = __subclasscheck__
-def TypedDict(typename, fields=_sentinel, /, *, total=True):
+def TypedDict(typename, fields, /, *, total=True):
"""A simple typed namespace. At runtime it is equivalent to a plain dict.
TypedDict creates a dictionary type such that a type checker will expect all
@@ -3253,24 +3217,6 @@ def TypedDict(typename, fields=_sentinel, /, *, total=True):
username: str # the "username" key can be changed
"""
- if fields is _sentinel or fields is None:
- import warnings
-
- if fields is _sentinel:
- deprecated_thing = "Failing to pass a value for the 'fields' parameter"
- else:
- deprecated_thing = "Passing `None` as the 'fields' parameter"
-
- example = f"`{typename} = TypedDict({typename!r}, {{{{}}}})`"
- deprecation_msg = (
- "{name} is deprecated and will be disallowed in Python {remove}. "
- "To create a TypedDict class with 0 fields "
- "using the functional syntax, "
- "pass an empty dictionary, e.g. "
- ) + example + "."
- warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15))
- fields = {}
-
ns = {'__annotations__': dict(fields)}
module = _caller()
if module is not None:
@@ -3477,7 +3423,7 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def readlines(self, hint: int = -1) -> List[AnyStr]:
+ def readlines(self, hint: int = -1) -> list[AnyStr]:
pass
@abstractmethod
@@ -3493,7 +3439,7 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def truncate(self, size: int = None) -> int:
+ def truncate(self, size: int | None = None) -> int:
pass
@abstractmethod
@@ -3505,11 +3451,11 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def writelines(self, lines: List[AnyStr]) -> None:
+ def writelines(self, lines: list[AnyStr]) -> None:
pass
@abstractmethod
- def __enter__(self) -> 'IO[AnyStr]':
+ def __enter__(self) -> IO[AnyStr]:
pass
@abstractmethod
@@ -3523,11 +3469,11 @@ class BinaryIO(IO[bytes]):
__slots__ = ()
@abstractmethod
- def write(self, s: Union[bytes, bytearray]) -> int:
+ def write(self, s: bytes | bytearray) -> int:
pass
@abstractmethod
- def __enter__(self) -> 'BinaryIO':
+ def __enter__(self) -> BinaryIO:
pass
@@ -3548,7 +3494,7 @@ class TextIO(IO[str]):
@property
@abstractmethod
- def errors(self) -> Optional[str]:
+ def errors(self) -> str | None:
pass
@property
@@ -3562,7 +3508,7 @@ class TextIO(IO[str]):
pass
@abstractmethod
- def __enter__(self) -> 'TextIO':
+ def __enter__(self) -> TextIO:
pass
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 884fc1b21f6..db10de68e4a 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -149,9 +149,7 @@ def doModuleCleanups():
except Exception as exc:
exceptions.append(exc)
if exceptions:
- # Swallows all but first exception. If a multi-exception handler
- # gets written we should use that here instead.
- raise exceptions[0]
+ raise ExceptionGroup('module cleanup failed', exceptions)
def skip(reason):
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index c3869de3f6f..6fd949581f3 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -197,7 +197,7 @@ class TestProgram(object):
return parser
def _getMainArgParser(self, parent):
- parser = argparse.ArgumentParser(parents=[parent])
+ parser = argparse.ArgumentParser(parents=[parent], color=True)
parser.prog = self.progName
parser.print_help = self._print_help
@@ -208,7 +208,7 @@ class TestProgram(object):
return parser
def _getDiscoveryArgParser(self, parent):
- parser = argparse.ArgumentParser(parents=[parent])
+ parser = argparse.ArgumentParser(parents=[parent], color=True)
parser.prog = '%s discover' % self.progName
parser.epilog = ('For test discovery all test modules must be '
'importable from the top level directory of the '
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 55cb4b1f6af..e1dbfdacf56 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -569,6 +569,11 @@ class NonCallableMock(Base):
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs
+ def _mock_extend_spec_methods(self, spec_methods):
+ methods = self.__dict__.get('_mock_methods') or []
+ methods.extend(spec_methods)
+ self.__dict__['_mock_methods'] = methods
+
def __get_return_value(self):
ret = self._mock_return_value
if self._mock_delegate is not None:
@@ -981,7 +986,7 @@ class NonCallableMock(Base):
def assert_called_once_with(self, /, *args, **kwargs):
- """assert that the mock was called exactly once and that that call was
+ """assert that the mock was called exactly once and that call was
with the specified arguments."""
if not self.call_count == 1:
msg = ("Expected '%s' to be called once. Called %s times.%s"
@@ -2766,14 +2771,16 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
raise InvalidSpecError(f'Cannot autospec a Mock object. '
f'[object={spec!r}]')
is_async_func = _is_async_func(spec)
+ _kwargs = {'spec': spec}
entries = [(entry, _missing) for entry in dir(spec)]
if is_type and instance and is_dataclass(spec):
+ is_dataclass_spec = True
dataclass_fields = fields(spec)
entries.extend((f.name, f.type) for f in dataclass_fields)
- _kwargs = {'spec': [f.name for f in dataclass_fields]}
+ dataclass_spec_list = [f.name for f in dataclass_fields]
else:
- _kwargs = {'spec': spec}
+ is_dataclass_spec = False
if spec_set:
_kwargs = {'spec_set': spec}
@@ -2810,6 +2817,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
mock = Klass(parent=_parent, _new_parent=_parent, _new_name=_new_name,
name=_name, **_kwargs)
+ if is_dataclass_spec:
+ mock._mock_extend_spec_methods(dataclass_spec_list)
if isinstance(spec, FunctionTypes):
# should only happen at the top level because we don't
diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py
index eb0234a2617..5f22d91aebd 100644
--- a/Lib/unittest/runner.py
+++ b/Lib/unittest/runner.py
@@ -4,7 +4,7 @@ import sys
import time
import warnings
-from _colorize import get_colors
+from _colorize import get_theme
from . import result
from .case import _SubTest
@@ -45,7 +45,7 @@ class TextTestResult(result.TestResult):
self.showAll = verbosity > 1
self.dots = verbosity == 1
self.descriptions = descriptions
- self._ansi = get_colors(file=stream)
+ self._theme = get_theme(tty_file=stream).unittest
self._newline = True
self.durations = durations
@@ -79,101 +79,99 @@ class TextTestResult(result.TestResult):
def addSubTest(self, test, subtest, err):
if err is not None:
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
if issubclass(err[0], subtest.failureException):
- self._write_status(subtest, f"{red}FAIL{reset}")
+ self._write_status(subtest, f"{t.fail}FAIL{t.reset}")
else:
- self._write_status(subtest, f"{red}ERROR{reset}")
+ self._write_status(subtest, f"{t.fail}ERROR{t.reset}")
elif self.dots:
if issubclass(err[0], subtest.failureException):
- self.stream.write(f"{red}F{reset}")
+ self.stream.write(f"{t.fail}F{t.reset}")
else:
- self.stream.write(f"{red}E{reset}")
+ self.stream.write(f"{t.fail}E{t.reset}")
self.stream.flush()
super(TextTestResult, self).addSubTest(test, subtest, err)
def addSuccess(self, test):
super(TextTestResult, self).addSuccess(test)
- green, reset = self._ansi.GREEN, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{green}ok{reset}")
+ self._write_status(test, f"{t.passed}ok{t.reset}")
elif self.dots:
- self.stream.write(f"{green}.{reset}")
+ self.stream.write(f"{t.passed}.{t.reset}")
self.stream.flush()
def addError(self, test, err):
super(TextTestResult, self).addError(test, err)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{red}ERROR{reset}")
+ self._write_status(test, f"{t.fail}ERROR{t.reset}")
elif self.dots:
- self.stream.write(f"{red}E{reset}")
+ self.stream.write(f"{t.fail}E{t.reset}")
self.stream.flush()
def addFailure(self, test, err):
super(TextTestResult, self).addFailure(test, err)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{red}FAIL{reset}")
+ self._write_status(test, f"{t.fail}FAIL{t.reset}")
elif self.dots:
- self.stream.write(f"{red}F{reset}")
+ self.stream.write(f"{t.fail}F{t.reset}")
self.stream.flush()
def addSkip(self, test, reason):
super(TextTestResult, self).addSkip(test, reason)
- yellow, reset = self._ansi.YELLOW, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{yellow}skipped{reset} {reason!r}")
+ self._write_status(test, f"{t.warn}skipped{t.reset} {reason!r}")
elif self.dots:
- self.stream.write(f"{yellow}s{reset}")
+ self.stream.write(f"{t.warn}s{t.reset}")
self.stream.flush()
def addExpectedFailure(self, test, err):
super(TextTestResult, self).addExpectedFailure(test, err)
- yellow, reset = self._ansi.YELLOW, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self.stream.writeln(f"{yellow}expected failure{reset}")
+ self.stream.writeln(f"{t.warn}expected failure{t.reset}")
self.stream.flush()
elif self.dots:
- self.stream.write(f"{yellow}x{reset}")
+ self.stream.write(f"{t.warn}x{t.reset}")
self.stream.flush()
def addUnexpectedSuccess(self, test):
super(TextTestResult, self).addUnexpectedSuccess(test)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self.stream.writeln(f"{red}unexpected success{reset}")
+ self.stream.writeln(f"{t.fail}unexpected success{t.reset}")
self.stream.flush()
elif self.dots:
- self.stream.write(f"{red}u{reset}")
+ self.stream.write(f"{t.fail}u{t.reset}")
self.stream.flush()
def printErrors(self):
- bold_red = self._ansi.BOLD_RED
- red = self._ansi.RED
- reset = self._ansi.RESET
+ t = self._theme
if self.dots or self.showAll:
self.stream.writeln()
self.stream.flush()
- self.printErrorList(f"{red}ERROR{reset}", self.errors)
- self.printErrorList(f"{red}FAIL{reset}", self.failures)
+ self.printErrorList(f"{t.fail}ERROR{t.reset}", self.errors)
+ self.printErrorList(f"{t.fail}FAIL{t.reset}", self.failures)
unexpectedSuccesses = getattr(self, "unexpectedSuccesses", ())
if unexpectedSuccesses:
self.stream.writeln(self.separator1)
for test in unexpectedSuccesses:
self.stream.writeln(
- f"{red}UNEXPECTED SUCCESS{bold_red}: "
- f"{self.getDescription(test)}{reset}"
+ f"{t.fail}UNEXPECTED SUCCESS{t.fail_info}: "
+ f"{self.getDescription(test)}{t.reset}"
)
self.stream.flush()
def printErrorList(self, flavour, errors):
- bold_red, reset = self._ansi.BOLD_RED, self._ansi.RESET
+ t = self._theme
for test, err in errors:
self.stream.writeln(self.separator1)
self.stream.writeln(
- f"{flavour}{bold_red}: {self.getDescription(test)}{reset}"
+ f"{flavour}{t.fail_info}: {self.getDescription(test)}{t.reset}"
)
self.stream.writeln(self.separator2)
self.stream.writeln("%s" % err)
@@ -286,31 +284,26 @@ class TextTestRunner(object):
expected_fails, unexpected_successes, skipped = results
infos = []
- ansi = get_colors(file=self.stream)
- bold_red = ansi.BOLD_RED
- green = ansi.GREEN
- red = ansi.RED
- reset = ansi.RESET
- yellow = ansi.YELLOW
+ t = get_theme(tty_file=self.stream).unittest
if not result.wasSuccessful():
- self.stream.write(f"{bold_red}FAILED{reset}")
+ self.stream.write(f"{t.fail_info}FAILED{t.reset}")
failed, errored = len(result.failures), len(result.errors)
if failed:
- infos.append(f"{bold_red}failures={failed}{reset}")
+ infos.append(f"{t.fail_info}failures={failed}{t.reset}")
if errored:
- infos.append(f"{bold_red}errors={errored}{reset}")
+ infos.append(f"{t.fail_info}errors={errored}{t.reset}")
elif run == 0 and not skipped:
- self.stream.write(f"{yellow}NO TESTS RAN{reset}")
+ self.stream.write(f"{t.warn}NO TESTS RAN{t.reset}")
else:
- self.stream.write(f"{green}OK{reset}")
+ self.stream.write(f"{t.passed}OK{t.reset}")
if skipped:
- infos.append(f"{yellow}skipped={skipped}{reset}")
+ infos.append(f"{t.warn}skipped={skipped}{t.reset}")
if expected_fails:
- infos.append(f"{yellow}expected failures={expected_fails}{reset}")
+ infos.append(f"{t.warn}expected failures={expected_fails}{t.reset}")
if unexpected_successes:
infos.append(
- f"{red}unexpected successes={unexpected_successes}{reset}"
+ f"{t.fail}unexpected successes={unexpected_successes}{t.reset}"
)
if infos:
self.stream.writeln(" (%s)" % (", ".join(infos),))
diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py
index 6f45b6fe5f6..ae9ca2d615d 100644
--- a/Lib/unittest/suite.py
+++ b/Lib/unittest/suite.py
@@ -223,6 +223,11 @@ class TestSuite(BaseTestSuite):
if result._moduleSetUpFailed:
try:
case.doModuleCleanups()
+ except ExceptionGroup as eg:
+ for e in eg.exceptions:
+ self._createClassOrModuleLevelException(result, e,
+ 'setUpModule',
+ currentModule)
except Exception as e:
self._createClassOrModuleLevelException(result, e,
'setUpModule',
@@ -235,15 +240,15 @@ class TestSuite(BaseTestSuite):
errorName = f'{method_name} ({parent})'
self._addClassOrModuleLevelException(result, exc, errorName, info)
- def _addClassOrModuleLevelException(self, result, exception, errorName,
+ def _addClassOrModuleLevelException(self, result, exc, errorName,
info=None):
error = _ErrorHolder(errorName)
addSkip = getattr(result, 'addSkip', None)
- if addSkip is not None and isinstance(exception, case.SkipTest):
- addSkip(error, str(exception))
+ if addSkip is not None and isinstance(exc, case.SkipTest):
+ addSkip(error, str(exc))
else:
if not info:
- result.addError(error, sys.exc_info())
+ result.addError(error, (type(exc), exc, exc.__traceback__))
else:
result.addError(error, info)
@@ -273,6 +278,13 @@ class TestSuite(BaseTestSuite):
previousModule)
try:
case.doModuleCleanups()
+ except ExceptionGroup as eg:
+ if isinstance(result, _DebugResult):
+ raise
+ for e in eg.exceptions:
+ self._createClassOrModuleLevelException(result, e,
+ 'tearDownModule',
+ previousModule)
except Exception as e:
if isinstance(result, _DebugResult):
raise
diff --git a/Lib/urllib/parse.py b/Lib/urllib/parse.py
index 9d51f4c6812..67d9bbea0d3 100644
--- a/Lib/urllib/parse.py
+++ b/Lib/urllib/parse.py
@@ -460,7 +460,7 @@ def _check_bracketed_netloc(netloc):
# https://www.rfc-editor.org/rfc/rfc3986#page-49 and https://url.spec.whatwg.org/
def _check_bracketed_host(hostname):
if hostname.startswith('v'):
- if not re.match(r"\Av[a-fA-F0-9]+\..+\Z", hostname):
+ if not re.match(r"\Av[a-fA-F0-9]+\..+\z", hostname):
raise ValueError(f"IPvFuture address is invalid")
else:
ip = ipaddress.ip_address(hostname) # Throws Value Error if not IPv6 or IPv4
diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py
index 9a6b29a90a2..41dc5d7b35d 100644
--- a/Lib/urllib/request.py
+++ b/Lib/urllib/request.py
@@ -1466,7 +1466,7 @@ class FileHandler(BaseHandler):
def open_local_file(self, req):
import email.utils
import mimetypes
- localfile = url2pathname(req.full_url, require_scheme=True)
+ localfile = url2pathname(req.full_url, require_scheme=True, resolve_host=True)
try:
stats = os.stat(localfile)
size = stats.st_size
@@ -1482,7 +1482,7 @@ class FileHandler(BaseHandler):
file_open = open_local_file
-def _is_local_authority(authority):
+def _is_local_authority(authority, resolve):
# Compare hostnames
if not authority or authority == 'localhost':
return True
@@ -1494,9 +1494,11 @@ def _is_local_authority(authority):
if authority == hostname:
return True
# Compare IP addresses
+ if not resolve:
+ return False
try:
address = socket.gethostbyname(authority)
- except (socket.gaierror, AttributeError):
+ except (socket.gaierror, AttributeError, UnicodeEncodeError):
return False
return address in FileHandler().get_names()
@@ -1641,13 +1643,16 @@ class DataHandler(BaseHandler):
return addinfourl(io.BytesIO(data), headers, url)
-# Code move from the old urllib module
+# Code moved from the old urllib module
-def url2pathname(url, *, require_scheme=False):
+def url2pathname(url, *, require_scheme=False, resolve_host=False):
"""Convert the given file URL to a local file system path.
The 'file:' scheme prefix must be omitted unless *require_scheme*
is set to true.
+
+ The URL authority may be resolved with gethostbyname() if
+ *resolve_host* is set to true.
"""
if require_scheme:
scheme, url = _splittype(url)
@@ -1655,7 +1660,7 @@ def url2pathname(url, *, require_scheme=False):
raise URLError("URL is missing a 'file:' scheme")
authority, url = _splithost(url)
if os.name == 'nt':
- if not _is_local_authority(authority):
+ if not _is_local_authority(authority, resolve_host):
# e.g. file://server/share/file.txt
url = '//' + authority + url
elif url[:3] == '///':
@@ -1669,7 +1674,7 @@ def url2pathname(url, *, require_scheme=False):
# Older URLs use a pipe after a drive letter
url = url[:1] + ':' + url[2:]
url = url.replace('/', '\\')
- elif not _is_local_authority(authority):
+ elif not _is_local_authority(authority, resolve_host):
raise URLError("file:// scheme is supported only on localhost")
encoding = sys.getfilesystemencoding()
errors = sys.getfilesystemencodeerrors()
diff --git a/Lib/uuid.py b/Lib/uuid.py
index 2c16c3f0f5a..313f2fc46cb 100644
--- a/Lib/uuid.py
+++ b/Lib/uuid.py
@@ -633,39 +633,43 @@ def _netstat_getnode():
try:
import _uuid
_generate_time_safe = getattr(_uuid, "generate_time_safe", None)
+ _has_stable_extractable_node = _uuid.has_stable_extractable_node
_UuidCreate = getattr(_uuid, "UuidCreate", None)
except ImportError:
_uuid = None
_generate_time_safe = None
+ _has_stable_extractable_node = False
_UuidCreate = None
def _unix_getnode():
"""Get the hardware address on Unix using the _uuid extension module."""
- if _generate_time_safe:
+ if _generate_time_safe and _has_stable_extractable_node:
uuid_time, _ = _generate_time_safe()
return UUID(bytes=uuid_time).node
def _windll_getnode():
"""Get the hardware address on Windows using the _uuid extension module."""
- if _UuidCreate:
+ if _UuidCreate and _has_stable_extractable_node:
uuid_bytes = _UuidCreate()
return UUID(bytes_le=uuid_bytes).node
def _random_getnode():
"""Get a random node ID."""
- # RFC 4122, $4.1.6 says "For systems with no IEEE address, a randomly or
- # pseudo-randomly generated value may be used; see Section 4.5. The
- # multicast bit must be set in such addresses, in order that they will
- # never conflict with addresses obtained from network cards."
+ # RFC 9562, §6.10-3 says that
+ #
+ # Implementations MAY elect to obtain a 48-bit cryptographic-quality
+ # random number as per Section 6.9 to use as the Node ID. [...] [and]
+ # implementations MUST set the least significant bit of the first octet
+ # of the Node ID to 1. This bit is the unicast or multicast bit, which
+ # will never be set in IEEE 802 addresses obtained from network cards.
#
# The "multicast bit" of a MAC address is defined to be "the least
# significant bit of the first octet". This works out to be the 41st bit
# counting from 1 being the least significant bit, or 1<<40.
#
# See https://en.wikipedia.org/w/index.php?title=MAC_address&oldid=1128764812#Universal_vs._local_(U/L_bit)
- import random
- return random.getrandbits(48) | (1 << 40)
+ return int.from_bytes(os.urandom(6)) | (1 << 40)
# _OS_GETTERS, when known, are targeted for a specific OS or platform.
@@ -949,7 +953,9 @@ def main():
import argparse
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- description="Generate a UUID using the selected UUID function.")
+ description="Generate a UUID using the selected UUID function.",
+ color=True,
+ )
parser.add_argument("-u", "--uuid",
choices=uuid_funcs.keys(),
default="uuid4",
diff --git a/Lib/venv/__init__.py b/Lib/venv/__init__.py
index dc4c9ef3531..dc9c5991df7 100644
--- a/Lib/venv/__init__.py
+++ b/Lib/venv/__init__.py
@@ -313,11 +313,8 @@ class EnvBuilder:
copier(context.executable, path)
if not os.path.islink(path):
os.chmod(path, 0o755)
-
- suffixes = ['python', 'python3', f'python3.{sys.version_info[1]}']
- if sys.version_info[:2] == (3, 14):
- suffixes.append('𝜋thon')
- for suffix in suffixes:
+ for suffix in ('python', 'python3',
+ f'python3.{sys.version_info[1]}'):
path = os.path.join(binpath, suffix)
if not os.path.exists(path):
# Issue 18807: make copies if
@@ -624,7 +621,9 @@ def main(args=None):
'created, you may wish to '
'activate it, e.g. by '
'sourcing an activate script '
- 'in its bin directory.')
+ 'in its bin directory.',
+ color=True,
+ )
parser.add_argument('dirs', metavar='ENV_DIR', nargs='+',
help='A directory to create the environment in.')
parser.add_argument('--system-site-packages', default=False,
diff --git a/Lib/wave.py b/Lib/wave.py
index a34af244c3e..929609fa524 100644
--- a/Lib/wave.py
+++ b/Lib/wave.py
@@ -20,10 +20,6 @@ This returns an instance of a class with the following public methods:
compression type ('not compressed' linear samples)
getparams() -- returns a namedtuple consisting of all of the
above in the above order
- getmarkers() -- returns None (for compatibility with the
- old aifc module)
- getmark(id) -- raises an error since the mark does not
- exist (for compatibility with the old aifc module)
readframes(n) -- returns at most n frames of audio
rewind() -- rewind to the beginning of the audio stream
setpos(pos) -- seek to the specified position
@@ -341,16 +337,6 @@ class Wave_read:
self.getframerate(), self.getnframes(),
self.getcomptype(), self.getcompname())
- def getmarkers(self):
- import warnings
- warnings._deprecated("Wave_read.getmarkers", remove=(3, 15))
- return None
-
- def getmark(self, id):
- import warnings
- warnings._deprecated("Wave_read.getmark", remove=(3, 15))
- raise Error('no marks')
-
def setpos(self, pos):
if pos < 0 or pos > self._nframes:
raise Error('position not in range')
@@ -551,21 +537,6 @@ class Wave_write:
return _wave_params(self._nchannels, self._sampwidth, self._framerate,
self._nframes, self._comptype, self._compname)
- def setmark(self, id, pos, name):
- import warnings
- warnings._deprecated("Wave_write.setmark", remove=(3, 15))
- raise Error('setmark() not supported')
-
- def getmark(self, id):
- import warnings
- warnings._deprecated("Wave_write.getmark", remove=(3, 15))
- raise Error('no marks')
-
- def getmarkers(self):
- import warnings
- warnings._deprecated("Wave_write.getmarkers", remove=(3, 15))
- return None
-
def tell(self):
return self._nframeswritten
diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py
index ab50ec1ee95..f2e2394089d 100644
--- a/Lib/webbrowser.py
+++ b/Lib/webbrowser.py
@@ -719,7 +719,9 @@ if sys.platform == "ios":
def parse_args(arg_list: list[str] | None):
import argparse
- parser = argparse.ArgumentParser(description="Open URL in a web browser.")
+ parser = argparse.ArgumentParser(
+ description="Open URL in a web browser.", color=True,
+ )
parser.add_argument("url", help="URL to open")
group = parser.add_mutually_exclusive_group()
diff --git a/Lib/wsgiref/handlers.py b/Lib/wsgiref/handlers.py
index cafe872c7aa..9353fb67862 100644
--- a/Lib/wsgiref/handlers.py
+++ b/Lib/wsgiref/handlers.py
@@ -69,7 +69,8 @@ def read_environ():
# Python 3's http.server.CGIHTTPRequestHandler decodes
# using the urllib.unquote default of UTF-8, amongst other
- # issues.
+ # issues. While the CGI handler is removed in 3.15, this
+ # is kept for legacy reasons.
elif (
software.startswith('simplehttp/')
and 'python/3' in software
diff --git a/Lib/xmlrpc/server.py b/Lib/xmlrpc/server.py
index 90a356fbb8e..8130c739af2 100644
--- a/Lib/xmlrpc/server.py
+++ b/Lib/xmlrpc/server.py
@@ -578,7 +578,7 @@ class SimpleXMLRPCServer(socketserver.TCPServer,
"""
allow_reuse_address = True
- allow_reuse_port = True
+ allow_reuse_port = False
# Warning: this is for debugging purposes only! Never set this to True in
# production code, as will be sending out sensitive information (exception
diff --git a/Lib/zipapp.py b/Lib/zipapp.py
index 59b444075a6..7a4ef96ea0f 100644
--- a/Lib/zipapp.py
+++ b/Lib/zipapp.py
@@ -187,7 +187,7 @@ def main(args=None):
"""
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('--output', '-o', default=None,
help="The name of the output archive. "
"Required if SOURCE is an archive.")
diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py
index b7840d0f945..18caeb3e04a 100644
--- a/Lib/zipfile/__init__.py
+++ b/Lib/zipfile/__init__.py
@@ -31,10 +31,15 @@ try:
except ImportError:
lzma = None
+try:
+ from compression import zstd # We may need its compression method
+except ImportError:
+ zstd = None
+
__all__ = ["BadZipFile", "BadZipfile", "error",
"ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA",
- "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile",
- "Path"]
+ "ZIP_ZSTANDARD", "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile",
+ "LargeZipFile", "Path"]
class BadZipFile(Exception):
pass
@@ -58,12 +63,14 @@ ZIP_STORED = 0
ZIP_DEFLATED = 8
ZIP_BZIP2 = 12
ZIP_LZMA = 14
+ZIP_ZSTANDARD = 93
# Other ZIP compression methods not supported
DEFAULT_VERSION = 20
ZIP64_VERSION = 45
BZIP2_VERSION = 46
LZMA_VERSION = 63
+ZSTANDARD_VERSION = 63
# we recognize (but not necessarily support) all features up to that version
MAX_EXTRACT_VERSION = 63
@@ -227,8 +234,19 @@ class _Extra(bytes):
def _check_zipfile(fp):
try:
- if _EndRecData(fp):
- return True # file has correct magic number
+ endrec = _EndRecData(fp)
+ if endrec:
+ if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0:
+ return True # Empty zipfiles are still zipfiles
+ elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]:
+ # Central directory is on the same disk
+ fp.seek(sum(_handle_prepended_data(endrec)))
+ if endrec[_ECD_SIZE] >= sizeCentralDir:
+ data = fp.read(sizeCentralDir) # CD is where we expect it to be
+ if len(data) == sizeCentralDir:
+ centdir = struct.unpack(structCentralDir, data) # CD is the right size
+ if centdir[_CD_SIGNATURE] == stringCentralDir:
+ return True # First central directory entry has correct magic number
except OSError:
pass
return False
@@ -251,6 +269,22 @@ def is_zipfile(filename):
pass
return result
+def _handle_prepended_data(endrec, debug=0):
+ size_cd = endrec[_ECD_SIZE] # bytes in central directory
+ offset_cd = endrec[_ECD_OFFSET] # offset of central directory
+
+ # "concat" is zero, unless zip was concatenated to another file
+ concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
+ if endrec[_ECD_SIGNATURE] == stringEndArchive64:
+ # If Zip64 extension structures are present, account for them
+ concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+
+ if debug > 2:
+ inferred = concat + offset_cd
+ print("given, inferred, offset", offset_cd, inferred, concat)
+
+ return offset_cd, concat
+
def _EndRecData64(fpin, offset, endrec):
"""
Read the ZIP64 end-of-archive records and use that to update endrec
@@ -505,6 +539,8 @@ class ZipInfo:
min_version = max(BZIP2_VERSION, min_version)
elif self.compress_type == ZIP_LZMA:
min_version = max(LZMA_VERSION, min_version)
+ elif self.compress_type == ZIP_ZSTANDARD:
+ min_version = max(ZSTANDARD_VERSION, min_version)
self.extract_version = max(min_version, self.extract_version)
self.create_version = max(min_version, self.create_version)
@@ -766,6 +802,7 @@ compressor_names = {
14: 'lzma',
18: 'terse',
19: 'lz77',
+ 93: 'zstd',
97: 'wavpack',
98: 'ppmd',
}
@@ -785,6 +822,10 @@ def _check_compression(compression):
if not lzma:
raise RuntimeError(
"Compression requires the (missing) lzma module")
+ elif compression == ZIP_ZSTANDARD:
+ if not zstd:
+ raise RuntimeError(
+ "Compression requires the (missing) compression.zstd module")
else:
raise NotImplementedError("That compression method is not supported")
@@ -801,6 +842,8 @@ def _get_compressor(compress_type, compresslevel=None):
# compresslevel is ignored for ZIP_LZMA
elif compress_type == ZIP_LZMA:
return LZMACompressor()
+ elif compress_type == ZIP_ZSTANDARD:
+ return zstd.ZstdCompressor(level=compresslevel)
else:
return None
@@ -815,6 +858,8 @@ def _get_decompressor(compress_type):
return bz2.BZ2Decompressor()
elif compress_type == ZIP_LZMA:
return LZMADecompressor()
+ elif compress_type == ZIP_ZSTANDARD:
+ return zstd.ZstdDecompressor()
else:
descr = compressor_names.get(compress_type)
if descr:
@@ -1334,7 +1379,8 @@ class ZipFile:
mode: The mode can be either read 'r', write 'w', exclusive create 'x',
or append 'a'.
compression: ZIP_STORED (no compression), ZIP_DEFLATED (requires zlib),
- ZIP_BZIP2 (requires bz2) or ZIP_LZMA (requires lzma).
+ ZIP_BZIP2 (requires bz2), ZIP_LZMA (requires lzma), or
+ ZIP_ZSTANDARD (requires compression.zstd).
allowZip64: if True ZipFile will create files with ZIP64 extensions when
needed, otherwise it will raise an exception when this would
be necessary.
@@ -1343,6 +1389,9 @@ class ZipFile:
When using ZIP_STORED or ZIP_LZMA this keyword has no effect.
When using ZIP_DEFLATED integers 0 through 9 are accepted.
When using ZIP_BZIP2 integers 1 through 9 are accepted.
+ When using ZIP_ZSTANDARD integers -7 though 22 are common,
+ see the CompressionParameter enum in compression.zstd for
+ details.
"""
@@ -1479,28 +1528,21 @@ class ZipFile:
raise BadZipFile("File is not a zip file")
if self.debug > 1:
print(endrec)
- size_cd = endrec[_ECD_SIZE] # bytes in central directory
- offset_cd = endrec[_ECD_OFFSET] # offset of central directory
self._comment = endrec[_ECD_COMMENT] # archive comment
- # "concat" is zero, unless zip was concatenated to another file
- concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
- if endrec[_ECD_SIGNATURE] == stringEndArchive64:
- # If Zip64 extension structures are present, account for them
- concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+ offset_cd, concat = _handle_prepended_data(endrec, self.debug)
+
+ # self.start_dir: Position of start of central directory
+ self.start_dir = offset_cd + concat
# store the offset to the beginning of data for the
# .data_offset property
self._data_offset = concat
- if self.debug > 2:
- inferred = concat + offset_cd
- print("given, inferred, offset", offset_cd, inferred, concat)
- # self.start_dir: Position of start of central directory
- self.start_dir = offset_cd + concat
if self.start_dir < 0:
raise BadZipFile("Bad offset for central directory")
fp.seek(self.start_dir, 0)
+ size_cd = endrec[_ECD_SIZE]
data = fp.read(size_cd)
fp = io.BytesIO(data)
total = 0
@@ -2075,6 +2117,8 @@ class ZipFile:
min_version = max(BZIP2_VERSION, min_version)
elif zinfo.compress_type == ZIP_LZMA:
min_version = max(LZMA_VERSION, min_version)
+ elif zinfo.compress_type == ZIP_ZSTANDARD:
+ min_version = max(ZSTANDARD_VERSION, min_version)
extract_version = max(min_version, zinfo.extract_version)
create_version = max(min_version, zinfo.create_version)
@@ -2317,7 +2361,7 @@ def main(args=None):
import argparse
description = 'A simple command-line interface for zipfile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-l', '--list', metavar='<zipfile>',
help='Show listing of a zipfile')
diff --git a/Lib/zipfile/_path/__init__.py b/Lib/zipfile/_path/__init__.py
index 5ae16ec970d..faae4c84cae 100644
--- a/Lib/zipfile/_path/__init__.py
+++ b/Lib/zipfile/_path/__init__.py
@@ -7,19 +7,19 @@ https://github.com/python/importlib_metadata/wiki/Development-Methodology
for more detail.
"""
+import functools
import io
-import posixpath
-import zipfile
import itertools
-import contextlib
import pathlib
+import posixpath
import re
import stat
import sys
+import zipfile
+from ._functools import save_method_args
from .glob import Translator
-
__all__ = ['Path']
@@ -86,13 +86,12 @@ class InitializedState:
Mix-in to save the initialization state for pickling.
"""
+ @save_method_args
def __init__(self, *args, **kwargs):
- self.__args = args
- self.__kwargs = kwargs
super().__init__(*args, **kwargs)
def __getstate__(self):
- return self.__args, self.__kwargs
+ return self._saved___init__.args, self._saved___init__.kwargs
def __setstate__(self, state):
args, kwargs = state
@@ -181,22 +180,27 @@ class FastLookup(CompleteDirs):
"""
def namelist(self):
- with contextlib.suppress(AttributeError):
- return self.__names
- self.__names = super().namelist()
- return self.__names
+ return self._namelist
+
+ @functools.cached_property
+ def _namelist(self):
+ return super().namelist()
def _name_set(self):
- with contextlib.suppress(AttributeError):
- return self.__lookup
- self.__lookup = super()._name_set()
- return self.__lookup
+ return self._name_set_prop
+
+ @functools.cached_property
+ def _name_set_prop(self):
+ return super()._name_set()
def _extract_text_encoding(encoding=None, *args, **kwargs):
# compute stack level so that the caller of the caller sees any warning.
is_pypy = sys.implementation.name == 'pypy'
- stack_level = 3 + is_pypy
+ # PyPy no longer special cased after 7.3.19 (or maybe 7.3.18)
+ # See jaraco/zipp#143
+ is_old_pypi = is_pypy and sys.pypy_version_info < (7, 3, 19)
+ stack_level = 3 + is_old_pypi
return io.text_encoding(encoding, stack_level), args, kwargs
@@ -351,7 +355,7 @@ class Path:
return io.TextIOWrapper(stream, encoding, *args, **kwargs)
def _base(self):
- return pathlib.PurePosixPath(self.at or self.root.filename)
+ return pathlib.PurePosixPath(self.at) if self.at else self.filename
@property
def name(self):
diff --git a/Lib/zipfile/_path/_functools.py b/Lib/zipfile/_path/_functools.py
new file mode 100644
index 00000000000..7390be21873
--- /dev/null
+++ b/Lib/zipfile/_path/_functools.py
@@ -0,0 +1,20 @@
+import collections
+import functools
+
+
+# from jaraco.functools 4.0.2
+def save_method_args(method):
+ """
+ Wrap a method such that when it is called, the args and kwargs are
+ saved on the method.
+ """
+ args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') # noqa: PYI024
+
+ @functools.wraps(method)
+ def wrapper(self, /, *args, **kwargs):
+ attr_name = '_saved_' + method.__name__
+ attr = args_and_kwargs(args, kwargs)
+ setattr(self, attr_name, attr)
+ return method(self, *args, **kwargs)
+
+ return wrapper
diff --git a/Lib/zipfile/_path/glob.py b/Lib/zipfile/_path/glob.py
index 4320f1c0bad..bd2839304b7 100644
--- a/Lib/zipfile/_path/glob.py
+++ b/Lib/zipfile/_path/glob.py
@@ -1,7 +1,6 @@
import os
import re
-
_default_seps = os.sep + str(os.altsep) * bool(os.altsep)
@@ -37,9 +36,9 @@ class Translator:
Apply '(?s:)' to create a non-matching group that
matches newlines (valid on Unix).
- Append '\Z' to imply fullmatch even when match is used.
+ Append '\z' to imply fullmatch even when match is used.
"""
- return rf'(?s:{pattern})\Z'
+ return rf'(?s:{pattern})\z'
def match_dirs(self, pattern):
"""
diff --git a/Lib/zoneinfo/_common.py b/Lib/zoneinfo/_common.py
index 6e05abc3239..03cc42149f9 100644
--- a/Lib/zoneinfo/_common.py
+++ b/Lib/zoneinfo/_common.py
@@ -9,9 +9,13 @@ def load_tzdata(key):
resource_name = components[-1]
try:
- return resources.files(package_name).joinpath(resource_name).open("rb")
+ path = resources.files(package_name).joinpath(resource_name)
+ # gh-85702: Prevent PermissionError on Windows
+ if path.is_dir():
+ raise IsADirectoryError
+ return path.open("rb")
except (ImportError, FileNotFoundError, UnicodeEncodeError, IsADirectoryError):
- # There are three types of exception that can be raised that all amount
+ # There are four types of exception that can be raised that all amount
# to "we cannot find this key":
#
# ImportError: If package_name doesn't exist (e.g. if tzdata is not
diff --git a/Lib/zoneinfo/_tzpath.py b/Lib/zoneinfo/_tzpath.py
index 990a5c8b6a9..5db17bea045 100644
--- a/Lib/zoneinfo/_tzpath.py
+++ b/Lib/zoneinfo/_tzpath.py
@@ -83,11 +83,6 @@ _TEST_PATH = os.path.normpath(os.path.join("_", "_"))[:-1]
def _validate_tzfile_path(path, _base=_TEST_PATH):
- if not path:
- raise ValueError(
- "ZoneInfo key must not be an empty string"
- )
-
if os.path.isabs(path):
raise ValueError(
f"ZoneInfo keys may not be absolute paths, got: {path}"
diff --git a/Lib/zoneinfo/_zoneinfo.py b/Lib/zoneinfo/_zoneinfo.py
index b77dc0ed391..3ffdb4c8371 100644
--- a/Lib/zoneinfo/_zoneinfo.py
+++ b/Lib/zoneinfo/_zoneinfo.py
@@ -75,12 +75,12 @@ class ZoneInfo(tzinfo):
return obj
@classmethod
- def from_file(cls, fobj, /, key=None):
+ def from_file(cls, file_obj, /, key=None):
obj = super().__new__(cls)
obj._key = key
obj._file_path = None
- obj._load_file(fobj)
- obj._file_repr = repr(fobj)
+ obj._load_file(file_obj)
+ obj._file_repr = repr(file_obj)
# Disable pickling for objects created from files
obj.__reduce__ = obj._file_reduce