aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_colorize.py255
-rw-r--r--Lib/_pyrepl/base_eventqueue.py12
-rw-r--r--Lib/_pyrepl/commands.py2
-rw-r--r--Lib/_pyrepl/reader.py9
-rw-r--r--Lib/_pyrepl/simple_interact.py5
-rw-r--r--Lib/_pyrepl/unix_console.py21
-rw-r--r--Lib/_pyrepl/utils.py39
-rw-r--r--Lib/_pyrepl/windows_console.py41
-rw-r--r--Lib/_threading_local.py122
-rw-r--r--Lib/annotationlib.py2
-rw-r--r--Lib/argparse.py68
-rw-r--r--Lib/ast.py29
-rw-r--r--Lib/asyncio/__main__.py8
-rw-r--r--Lib/asyncio/base_events.py2
-rw-r--r--Lib/asyncio/taskgroups.py7
-rw-r--r--Lib/asyncio/tasks.py14
-rw-r--r--Lib/asyncio/tools.py16
-rw-r--r--Lib/calendar.py2
-rw-r--r--Lib/code.py2
-rw-r--r--Lib/compileall.py4
-rw-r--r--Lib/compression/zstd/__init__.py234
-rw-r--r--Lib/compression/zstd/_zstdfile.py349
-rw-r--r--Lib/ctypes/_layout.py21
-rw-r--r--Lib/dataclasses.py53
-rw-r--r--Lib/dis.py2
-rw-r--r--Lib/doctest.py2
-rw-r--r--Lib/ensurepip/__init__.py2
-rw-r--r--Lib/getpass.py64
-rw-r--r--Lib/gzip.py4
-rw-r--r--Lib/heapq.py51
-rw-r--r--Lib/http/server.py2
-rw-r--r--Lib/inspect.py2
-rw-r--r--Lib/json/tool.py43
-rw-r--r--Lib/mimetypes.py4
-rw-r--r--Lib/pdb.py372
-rw-r--r--Lib/pickle.py4
-rw-r--r--Lib/pickletools.py4
-rw-r--r--Lib/platform.py2
-rw-r--r--Lib/py_compile.py2
-rw-r--r--Lib/random.py2
-rw-r--r--Lib/reprlib.py2
-rw-r--r--Lib/shutil.py19
-rw-r--r--Lib/sqlite3/__main__.py1
-rw-r--r--Lib/subprocess.py14
-rw-r--r--Lib/tarfile.py63
-rw-r--r--Lib/test/_code_definitions.py51
-rw-r--r--Lib/test/libregrtest/setup.py2
-rw-r--r--Lib/test/libregrtest/utils.py40
-rw-r--r--Lib/test/support/__init__.py56
-rw-r--r--Lib/test/test_annotationlib.py15
-rw-r--r--Lib/test/test_argparse.py34
-rw-r--r--Lib/test/test_asdl_parser.py8
-rw-r--r--Lib/test/test_ast/test_ast.py103
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py37
-rw-r--r--Lib/test/test_asyncio/test_tasks.py33
-rw-r--r--Lib/test/test_asyncio/test_tools.py41
-rw-r--r--Lib/test/test_base64.py8
-rw-r--r--Lib/test/test_calendar.py1
-rw-r--r--Lib/test/test_capi/test_config.py2
-rw-r--r--Lib/test/test_capi/test_misc.py2
-rw-r--r--Lib/test/test_capi/test_object.py10
-rw-r--r--Lib/test/test_capi/test_opt.py4
-rw-r--r--Lib/test/test_cmd.py8
-rw-r--r--Lib/test/test_cmd_line.py18
-rw-r--r--Lib/test/test_code.py297
-rw-r--r--Lib/test/test_crossinterp.py46
-rw-r--r--Lib/test/test_csv.py7
-rw-r--r--Lib/test/test_ctypes/test_aligned_structures.py1
-rw-r--r--Lib/test/test_ctypes/test_bitfields.py5
-rw-r--r--Lib/test/test_ctypes/test_byteswap.py2
-rw-r--r--Lib/test/test_ctypes/test_generated_structs.py11
-rw-r--r--Lib/test/test_ctypes/test_pep3118.py3
-rw-r--r--Lib/test/test_ctypes/test_structunion.py18
-rw-r--r--Lib/test/test_ctypes/test_structures.py31
-rw-r--r--Lib/test/test_ctypes/test_unaligned_structures.py2
-rw-r--r--Lib/test/test_dataclasses/__init__.py54
-rw-r--r--Lib/test/test_dis.py2
-rw-r--r--Lib/test/test_email/test_utils.py10
-rw-r--r--Lib/test/test_embed.py8
-rw-r--r--Lib/test/test_enum.py7
-rw-r--r--Lib/test/test_external_inspection.py66
-rw-r--r--Lib/test/test_functools.py9
-rw-r--r--Lib/test/test_getpass.py39
-rw-r--r--Lib/test/test_gettext.py7
-rw-r--r--Lib/test/test_heapq.py197
-rw-r--r--Lib/test/test_json/test_tool.py85
-rw-r--r--Lib/test/test_locale.py9
-rw-r--r--Lib/test/test_mimetypes.py8
-rw-r--r--Lib/test/test_minidom.py182
-rw-r--r--Lib/test/test_optparse.py7
-rw-r--r--Lib/test/test_pathlib/test_pathlib.py8
-rw-r--r--Lib/test/test_pdb.py82
-rw-r--r--Lib/test/test_peepholer.py24
-rw-r--r--Lib/test/test_pickle.py10
-rw-r--r--Lib/test/test_platform.py1
-rw-r--r--Lib/test/test_posix.py4
-rw-r--r--Lib/test/test_pprint.py7
-rw-r--r--Lib/test/test_pstats.py7
-rw-r--r--Lib/test/test_pyrepl/support.py3
-rw-r--r--Lib/test/test_pyrepl/test_eventqueue.py78
-rw-r--r--Lib/test/test_pyrepl/test_reader.py39
-rw-r--r--Lib/test/test_pyrepl/test_unix_console.py12
-rw-r--r--Lib/test/test_pyrepl/test_windows_console.py225
-rw-r--r--Lib/test/test_random.py1
-rw-r--r--Lib/test/test_remote_pdb.py412
-rw-r--r--Lib/test/test_reprlib.py17
-rw-r--r--Lib/test/test_shlex.py2
-rw-r--r--Lib/test/test_shutil.py4
-rw-r--r--Lib/test/test_socket.py9
-rw-r--r--Lib/test/test_sqlite3/test_cli.py8
-rw-r--r--Lib/test/test_string/test_string.py8
-rw-r--r--Lib/test/test_subprocess.py14
-rw-r--r--Lib/test/test_support.py1
-rw-r--r--Lib/test/test_sys.py62
-rw-r--r--Lib/test/test_tarfile.py44
-rw-r--r--Lib/test/test_threading.py6
-rw-r--r--Lib/test/test_traceback.py114
-rw-r--r--Lib/test/test_urllib.py16
-rw-r--r--Lib/test/test_zipfile/test_core.py37
-rw-r--r--Lib/test/test_zstd.py2507
-rw-r--r--Lib/tokenize.py2
-rw-r--r--Lib/trace.py2
-rw-r--r--Lib/traceback.py105
-rw-r--r--Lib/typing.py16
-rw-r--r--Lib/unittest/main.py4
-rw-r--r--Lib/unittest/runner.py89
-rw-r--r--Lib/urllib/request.py19
-rw-r--r--Lib/uuid.py4
-rw-r--r--Lib/venv/__init__.py4
-rw-r--r--Lib/webbrowser.py4
-rw-r--r--Lib/zipapp.py2
-rw-r--r--Lib/zipfile/__init__.py22
132 files changed, 6468 insertions, 1131 deletions
diff --git a/Lib/_colorize.py b/Lib/_colorize.py
index 54895488e74..4a310a40235 100644
--- a/Lib/_colorize.py
+++ b/Lib/_colorize.py
@@ -1,28 +1,17 @@
-from __future__ import annotations
import io
import os
import sys
+from collections.abc import Callable, Iterator, Mapping
+from dataclasses import dataclass, field, Field
+
COLORIZE = True
+
# types
if False:
- from typing import IO, Literal
-
- type ColorTag = Literal[
- "PROMPT",
- "KEYWORD",
- "BUILTIN",
- "COMMENT",
- "STRING",
- "NUMBER",
- "OP",
- "DEFINITION",
- "SOFT_KEYWORD",
- "RESET",
- ]
-
- theme: dict[ColorTag, str]
+ from typing import IO, Self, ClassVar
+ _theme: Theme
class ANSIColors:
@@ -86,6 +75,186 @@ for attr, code in ANSIColors.__dict__.items():
setattr(NoColors, attr, "")
+#
+# Experimental theming support (see gh-133346)
+#
+
+# - Create a theme by copying an existing `Theme` with one or more sections
+# replaced, using `default_theme.copy_with()`;
+# - create a theme section by copying an existing `ThemeSection` with one or
+# more colors replaced, using for example `default_theme.syntax.copy_with()`;
+# - create a theme from scratch by instantiating a `Theme` data class with
+# the required sections (which are also dataclass instances).
+#
+# Then call `_colorize.set_theme(your_theme)` to set it.
+#
+# Put your theme configuration in $PYTHONSTARTUP for the interactive shell,
+# or sitecustomize.py in your virtual environment or Python installation for
+# other uses. Your applications can call `_colorize.set_theme()` too.
+#
+# Note that thanks to the dataclasses providing default values for all fields,
+# creating a new theme or theme section from scratch is possible without
+# specifying all keys.
+#
+# For example, here's a theme that makes punctuation and operators less prominent:
+#
+# try:
+# from _colorize import set_theme, default_theme, Syntax, ANSIColors
+# except ImportError:
+# pass
+# else:
+# theme_with_dim_operators = default_theme.copy_with(
+# syntax=Syntax(op=ANSIColors.INTENSE_BLACK),
+# )
+# set_theme(theme_with_dim_operators)
+# del set_theme, default_theme, Syntax, ANSIColors, theme_with_dim_operators
+#
+# Guarding the import ensures that your .pythonstartup file will still work in
+# Python 3.13 and older. Deleting the variables ensures they don't remain in your
+# interactive shell's global scope.
+
+class ThemeSection(Mapping[str, str]):
+ """A mixin/base class for theme sections.
+
+ It enables dictionary access to a section, as well as implements convenience
+ methods.
+ """
+
+ # The two types below are just that: types to inform the type checker that the
+ # mixin will work in context of those fields existing
+ __dataclass_fields__: ClassVar[dict[str, Field[str]]]
+ _name_to_value: Callable[[str], str]
+
+ def __post_init__(self) -> None:
+ name_to_value = {}
+ for color_name in self.__dataclass_fields__:
+ name_to_value[color_name] = getattr(self, color_name)
+ super().__setattr__('_name_to_value', name_to_value.__getitem__)
+
+ def copy_with(self, **kwargs: str) -> Self:
+ color_state: dict[str, str] = {}
+ for color_name in self.__dataclass_fields__:
+ color_state[color_name] = getattr(self, color_name)
+ color_state.update(kwargs)
+ return type(self)(**color_state)
+
+ @classmethod
+ def no_colors(cls) -> Self:
+ color_state: dict[str, str] = {}
+ for color_name in cls.__dataclass_fields__:
+ color_state[color_name] = ""
+ return cls(**color_state)
+
+ def __getitem__(self, key: str) -> str:
+ return self._name_to_value(key)
+
+ def __len__(self) -> int:
+ return len(self.__dataclass_fields__)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self.__dataclass_fields__)
+
+
+@dataclass(frozen=True)
+class Argparse(ThemeSection):
+ usage: str = ANSIColors.BOLD_BLUE
+ prog: str = ANSIColors.BOLD_MAGENTA
+ prog_extra: str = ANSIColors.MAGENTA
+ heading: str = ANSIColors.BOLD_BLUE
+ summary_long_option: str = ANSIColors.CYAN
+ summary_short_option: str = ANSIColors.GREEN
+ summary_label: str = ANSIColors.YELLOW
+ summary_action: str = ANSIColors.GREEN
+ long_option: str = ANSIColors.BOLD_CYAN
+ short_option: str = ANSIColors.BOLD_GREEN
+ label: str = ANSIColors.BOLD_YELLOW
+ action: str = ANSIColors.BOLD_GREEN
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Syntax(ThemeSection):
+ prompt: str = ANSIColors.BOLD_MAGENTA
+ keyword: str = ANSIColors.BOLD_BLUE
+ builtin: str = ANSIColors.CYAN
+ comment: str = ANSIColors.RED
+ string: str = ANSIColors.GREEN
+ number: str = ANSIColors.YELLOW
+ op: str = ANSIColors.RESET
+ definition: str = ANSIColors.BOLD
+ soft_keyword: str = ANSIColors.BOLD_BLUE
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Traceback(ThemeSection):
+ type: str = ANSIColors.BOLD_MAGENTA
+ message: str = ANSIColors.MAGENTA
+ filename: str = ANSIColors.MAGENTA
+ line_no: str = ANSIColors.MAGENTA
+ frame: str = ANSIColors.MAGENTA
+ error_highlight: str = ANSIColors.BOLD_RED
+ error_range: str = ANSIColors.RED
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Unittest(ThemeSection):
+ passed: str = ANSIColors.GREEN
+ warn: str = ANSIColors.YELLOW
+ fail: str = ANSIColors.RED
+ fail_info: str = ANSIColors.BOLD_RED
+ reset: str = ANSIColors.RESET
+
+
+@dataclass(frozen=True)
+class Theme:
+ """A suite of themes for all sections of Python.
+
+ When adding a new one, remember to also modify `copy_with` and `no_colors`
+ below.
+ """
+ argparse: Argparse = field(default_factory=Argparse)
+ syntax: Syntax = field(default_factory=Syntax)
+ traceback: Traceback = field(default_factory=Traceback)
+ unittest: Unittest = field(default_factory=Unittest)
+
+ def copy_with(
+ self,
+ *,
+ argparse: Argparse | None = None,
+ syntax: Syntax | None = None,
+ traceback: Traceback | None = None,
+ unittest: Unittest | None = None,
+ ) -> Self:
+ """Return a new Theme based on this instance with some sections replaced.
+
+ Themes are immutable to protect against accidental modifications that
+ could lead to invalid terminal states.
+ """
+ return type(self)(
+ argparse=argparse or self.argparse,
+ syntax=syntax or self.syntax,
+ traceback=traceback or self.traceback,
+ unittest=unittest or self.unittest,
+ )
+
+ @classmethod
+ def no_colors(cls) -> Self:
+ """Return a new Theme where colors in all sections are empty strings.
+
+ This allows writing user code as if colors are always used. The color
+ fields will be ANSI color code strings when colorization is desired
+ and possible, and empty strings otherwise.
+ """
+ return cls(
+ argparse=Argparse.no_colors(),
+ syntax=Syntax.no_colors(),
+ traceback=Traceback.no_colors(),
+ unittest=Unittest.no_colors(),
+ )
+
+
def get_colors(
colorize: bool = False, *, file: IO[str] | IO[bytes] | None = None
) -> ANSIColors:
@@ -138,26 +307,40 @@ def can_colorize(*, file: IO[str] | IO[bytes] | None = None) -> bool:
return hasattr(file, "isatty") and file.isatty()
-def set_theme(t: dict[ColorTag, str] | None = None) -> None:
- global theme
+default_theme = Theme()
+theme_no_color = default_theme.no_colors()
+
+
+def get_theme(
+ *,
+ tty_file: IO[str] | IO[bytes] | None = None,
+ force_color: bool = False,
+ force_no_color: bool = False,
+) -> Theme:
+ """Returns the currently set theme, potentially in a zero-color variant.
+
+ In cases where colorizing is not possible (see `can_colorize`), the returned
+ theme contains all empty strings in all color definitions.
+ See `Theme.no_colors()` for more information.
+
+ It is recommended not to cache the result of this function for extended
+ periods of time because the user might influence theme selection by
+ the interactive shell, a debugger, or application-specific code. The
+ environment (including environment variable state and console configuration
+ on Windows) can also change in the course of the application life cycle.
+ """
+ if force_color or (not force_no_color and can_colorize(file=tty_file)):
+ return _theme
+ return theme_no_color
+
+
+def set_theme(t: Theme) -> None:
+ global _theme
- if t:
- theme = t
- return
+ if not isinstance(t, Theme):
+ raise ValueError(f"Expected Theme object, found {t}")
- colors = get_colors()
- theme = {
- "PROMPT": colors.BOLD_MAGENTA,
- "KEYWORD": colors.BOLD_BLUE,
- "BUILTIN": colors.CYAN,
- "COMMENT": colors.RED,
- "STRING": colors.GREEN,
- "NUMBER": colors.YELLOW,
- "OP": colors.RESET,
- "DEFINITION": colors.BOLD,
- "SOFT_KEYWORD": colors.BOLD_BLUE,
- "RESET": colors.RESET,
- }
+ _theme = t
-set_theme()
+set_theme(default_theme)
diff --git a/Lib/_pyrepl/base_eventqueue.py b/Lib/_pyrepl/base_eventqueue.py
index e018c4fc183..842599bd187 100644
--- a/Lib/_pyrepl/base_eventqueue.py
+++ b/Lib/_pyrepl/base_eventqueue.py
@@ -69,18 +69,14 @@ class BaseEventQueue:
trace('added event {event}', event=event)
self.events.append(event)
- def push(self, char: int | bytes | str) -> None:
+ def push(self, char: int | bytes) -> None:
"""
Processes a character by updating the buffer and handling special key mappings.
"""
+ assert isinstance(char, (int, bytes))
ord_char = char if isinstance(char, int) else ord(char)
- if ord_char > 255:
- assert isinstance(char, str)
- char = bytes(char.encode(self.encoding, "replace"))
- self.buf.extend(char)
- else:
- char = bytes(bytearray((ord_char,)))
- self.buf.append(ord_char)
+ char = ord_char.to_bytes()
+ self.buf.append(ord_char)
if char in self.keymap:
if self.keymap is self.compiled_keymap:
diff --git a/Lib/_pyrepl/commands.py b/Lib/_pyrepl/commands.py
index 2054a8e400f..2354fbb2ec2 100644
--- a/Lib/_pyrepl/commands.py
+++ b/Lib/_pyrepl/commands.py
@@ -439,7 +439,7 @@ class help(Command):
import _sitebuiltins
with self.reader.suspend():
- self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment, call-arg]
+ self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment]
class invalid_key(Command):
diff --git a/Lib/_pyrepl/reader.py b/Lib/_pyrepl/reader.py
index 65c2230dfd6..0ebd9162eca 100644
--- a/Lib/_pyrepl/reader.py
+++ b/Lib/_pyrepl/reader.py
@@ -28,7 +28,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from . import commands, console, input
-from .utils import wlen, unbracket, disp_str, gen_colors
+from .utils import wlen, unbracket, disp_str, gen_colors, THEME
from .trace import trace
@@ -491,11 +491,8 @@ class Reader:
prompt = self.ps1
if self.can_colorize:
- prompt = (
- f"{_colorize.theme["PROMPT"]}"
- f"{prompt}"
- f"{_colorize.theme["RESET"]}"
- )
+ t = THEME()
+ prompt = f"{t.prompt}{prompt}{t.reset}"
return prompt
def push_input_trans(self, itrans: input.KeymapTranslator) -> None:
diff --git a/Lib/_pyrepl/simple_interact.py b/Lib/_pyrepl/simple_interact.py
index e2274629b65..b3848833e14 100644
--- a/Lib/_pyrepl/simple_interact.py
+++ b/Lib/_pyrepl/simple_interact.py
@@ -162,3 +162,8 @@ def run_multiline_interactive_console(
except MemoryError:
console.write("\nMemoryError\n")
console.resetbuffer()
+ except SystemExit:
+ raise
+ except:
+ console.showtraceback()
+ console.resetbuffer()
diff --git a/Lib/_pyrepl/unix_console.py b/Lib/_pyrepl/unix_console.py
index 07b160d2324..d21cdd9b076 100644
--- a/Lib/_pyrepl/unix_console.py
+++ b/Lib/_pyrepl/unix_console.py
@@ -29,6 +29,7 @@ import signal
import struct
import termios
import time
+import types
import platform
from fcntl import ioctl
@@ -39,6 +40,12 @@ from .trace import trace
from .unix_eventqueue import EventQueue
from .utils import wlen
+# declare posix optional to allow None assignment on other platforms
+posix: types.ModuleType | None
+try:
+ import posix
+except ImportError:
+ posix = None
TYPE_CHECKING = False
@@ -197,6 +204,12 @@ class UnixConsole(Console):
self.event_queue = EventQueue(self.input_fd, self.encoding)
self.cursor_visible = 1
+ signal.signal(signal.SIGCONT, self._sigcont_handler)
+
+ def _sigcont_handler(self, signum, frame):
+ self.restore()
+ self.prepare()
+
def __read(self, n: int) -> bytes:
return os.read(self.input_fd, n)
@@ -550,11 +563,9 @@ class UnixConsole(Console):
@property
def input_hook(self):
- try:
- import posix
- except ImportError:
- return None
- if posix._is_inputhook_installed():
+ # avoid inline imports here so the repl doesn't get flooded
+ # with import logging from -X importtime=2
+ if posix is not None and posix._is_inputhook_installed():
return posix._inputhook
def __enable_bracketed_paste(self) -> None:
diff --git a/Lib/_pyrepl/utils.py b/Lib/_pyrepl/utils.py
index fe154aa59a0..38cf6b5a08e 100644
--- a/Lib/_pyrepl/utils.py
+++ b/Lib/_pyrepl/utils.py
@@ -23,6 +23,11 @@ IDENTIFIERS_AFTER = {"def", "class"}
BUILTINS = {str(name) for name in dir(builtins) if not name.startswith('_')}
+def THEME(**kwargs):
+ # Not cached: the user can modify the theme inside the interactive session.
+ return _colorize.get_theme(**kwargs).syntax
+
+
class Span(NamedTuple):
"""Span indexing that's inclusive on both ends."""
@@ -44,7 +49,7 @@ class Span(NamedTuple):
class ColorSpan(NamedTuple):
span: Span
- tag: _colorize.ColorTag
+ tag: str
@functools.cache
@@ -135,7 +140,7 @@ def recover_unterminated_string(
span = Span(start, end)
trace("yielding span {a} -> {b}", a=span.start, b=span.end)
- yield ColorSpan(span, "STRING")
+ yield ColorSpan(span, "string")
else:
trace(
"unhandled token error({buffer}) = {te}",
@@ -164,28 +169,28 @@ def gen_colors_from_token_stream(
| T.TSTRING_START | T.TSTRING_MIDDLE | T.TSTRING_END
):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "STRING")
+ yield ColorSpan(span, "string")
case T.COMMENT:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "COMMENT")
+ yield ColorSpan(span, "comment")
case T.NUMBER:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "NUMBER")
+ yield ColorSpan(span, "number")
case T.OP:
if token.string in "([{":
bracket_level += 1
elif token.string in ")]}":
bracket_level -= 1
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "OP")
+ yield ColorSpan(span, "op")
case T.NAME:
if is_def_name:
is_def_name = False
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "DEFINITION")
+ yield ColorSpan(span, "definition")
elif keyword.iskeyword(token.string):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "KEYWORD")
+ yield ColorSpan(span, "keyword")
if token.string in IDENTIFIERS_AFTER:
is_def_name = True
elif (
@@ -194,10 +199,10 @@ def gen_colors_from_token_stream(
and is_soft_keyword_used(prev_token, token, next_token)
):
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "SOFT_KEYWORD")
+ yield ColorSpan(span, "soft_keyword")
elif token.string in BUILTINS:
span = Span.from_token(token, line_lengths)
- yield ColorSpan(span, "BUILTIN")
+ yield ColorSpan(span, "builtin")
keyword_first_sets_match = {"False", "None", "True", "await", "lambda", "not"}
@@ -249,7 +254,10 @@ def is_soft_keyword_used(*tokens: TI | None) -> bool:
def disp_str(
- buffer: str, colors: list[ColorSpan] | None = None, start_index: int = 0
+ buffer: str,
+ colors: list[ColorSpan] | None = None,
+ start_index: int = 0,
+ force_color: bool = False,
) -> tuple[CharBuffer, CharWidths]:
r"""Decompose the input buffer into a printable variant with applied colors.
@@ -290,15 +298,16 @@ def disp_str(
# move past irrelevant spans
colors.pop(0)
+ theme = THEME(force_color=force_color)
pre_color = ""
post_color = ""
if colors and colors[0].span.start < start_index:
# looks like we're continuing a previous color (e.g. a multiline str)
- pre_color = _colorize.theme[colors[0].tag]
+ pre_color = theme[colors[0].tag]
for i, c in enumerate(buffer, start_index):
if colors and colors[0].span.start == i: # new color starts now
- pre_color = _colorize.theme[colors[0].tag]
+ pre_color = theme[colors[0].tag]
if c == "\x1a": # CTRL-Z on Windows
chars.append(c)
@@ -315,7 +324,7 @@ def disp_str(
char_widths.append(str_width(c))
if colors and colors[0].span.end == i: # current color ends now
- post_color = _colorize.theme["RESET"]
+ post_color = theme.reset
colors.pop(0)
chars[-1] = pre_color + chars[-1] + post_color
@@ -325,7 +334,7 @@ def disp_str(
if colors and colors[0].span.start < i and colors[0].span.end > i:
# even though the current color should be continued, reset it for now.
# the next call to `disp_str()` will revive it.
- chars[-1] += _colorize.theme["RESET"]
+ chars[-1] += theme.reset
return chars, char_widths
diff --git a/Lib/_pyrepl/windows_console.py b/Lib/_pyrepl/windows_console.py
index 77985e59a93..95749198b3b 100644
--- a/Lib/_pyrepl/windows_console.py
+++ b/Lib/_pyrepl/windows_console.py
@@ -24,6 +24,7 @@ import os
import sys
import ctypes
+import types
from ctypes.wintypes import (
_COORD,
WORD,
@@ -58,6 +59,12 @@ except:
self.err = err
self.descr = descr
+# declare nt optional to allow None assignment on other platforms
+nt: types.ModuleType | None
+try:
+ import nt
+except ImportError:
+ nt = None
TYPE_CHECKING = False
@@ -121,9 +128,8 @@ class _error(Exception):
def _supports_vt():
try:
- import nt
return nt._supports_virtual_terminal()
- except (ImportError, AttributeError):
+ except AttributeError:
return False
class WindowsConsole(Console):
@@ -235,11 +241,9 @@ class WindowsConsole(Console):
@property
def input_hook(self):
- try:
- import nt
- except ImportError:
- return None
- if nt._is_inputhook_installed():
+ # avoid inline imports here so the repl doesn't get flooded
+ # with import logging from -X importtime=2
+ if nt is not None and nt._is_inputhook_installed():
return nt._inputhook
def __write_changed_line(
@@ -464,7 +468,7 @@ class WindowsConsole(Console):
if key == "\r":
# Make enter unix-like
- return Event(evt="key", data="\n", raw=b"\n")
+ return Event(evt="key", data="\n")
elif key_event.wVirtualKeyCode == 8:
# Turn backspace directly into the command
key = "backspace"
@@ -476,24 +480,29 @@ class WindowsConsole(Console):
key = f"ctrl {key}"
elif key_event.dwControlKeyState & ALT_ACTIVE:
# queue the key, return the meta command
- self.event_queue.insert(Event(evt="key", data=key, raw=key))
+ self.event_queue.insert(Event(evt="key", data=key))
return Event(evt="key", data="\033") # keymap.py uses this for meta
- return Event(evt="key", data=key, raw=key)
+ return Event(evt="key", data=key)
if block:
continue
return None
elif self.__vt_support:
# If virtual terminal is enabled, scanning VT sequences
- self.event_queue.push(rec.Event.KeyEvent.uChar.UnicodeChar)
+ for char in raw_key.encode(self.event_queue.encoding, "replace"):
+ self.event_queue.push(char)
continue
if key_event.dwControlKeyState & ALT_ACTIVE:
- # queue the key, return the meta command
- self.event_queue.insert(Event(evt="key", data=key, raw=raw_key))
- return Event(evt="key", data="\033") # keymap.py uses this for meta
-
- return Event(evt="key", data=key, raw=raw_key)
+ # Do not swallow characters that have been entered via AltGr:
+ # Windows internally converts AltGr to CTRL+ALT, see
+ # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw
+ if not key_event.dwControlKeyState & CTRL_ACTIVE:
+ # queue the key, return the meta command
+ self.event_queue.insert(Event(evt="key", data=key))
+ return Event(evt="key", data="\033") # keymap.py uses this for meta
+
+ return Event(evt="key", data=key)
return self.event_queue.get()
def push_char(self, char: int | bytes) -> None:
diff --git a/Lib/_threading_local.py b/Lib/_threading_local.py
index b006d76c4e2..0b9e5d3bbf6 100644
--- a/Lib/_threading_local.py
+++ b/Lib/_threading_local.py
@@ -4,128 +4,6 @@
class. Depending on the version of Python you're using, there may be a
faster one available. You should always import the `local` class from
`threading`.)
-
-Thread-local objects support the management of thread-local data.
-If you have data that you want to be local to a thread, simply create
-a thread-local object and use its attributes:
-
- >>> mydata = local()
- >>> mydata.number = 42
- >>> mydata.number
- 42
-
-You can also access the local-object's dictionary:
-
- >>> mydata.__dict__
- {'number': 42}
- >>> mydata.__dict__.setdefault('widgets', [])
- []
- >>> mydata.widgets
- []
-
-What's important about thread-local objects is that their data are
-local to a thread. If we access the data in a different thread:
-
- >>> log = []
- >>> def f():
- ... items = sorted(mydata.__dict__.items())
- ... log.append(items)
- ... mydata.number = 11
- ... log.append(mydata.number)
-
- >>> import threading
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
- >>> log
- [[], 11]
-
-we get different data. Furthermore, changes made in the other thread
-don't affect data seen in this thread:
-
- >>> mydata.number
- 42
-
-Of course, values you get from a local object, including a __dict__
-attribute, are for whatever thread was current at the time the
-attribute was read. For that reason, you generally don't want to save
-these values across threads, as they apply only to the thread they
-came from.
-
-You can create custom local objects by subclassing the local class:
-
- >>> class MyLocal(local):
- ... number = 2
- ... def __init__(self, /, **kw):
- ... self.__dict__.update(kw)
- ... def squared(self):
- ... return self.number ** 2
-
-This can be useful to support default values, methods and
-initialization. Note that if you define an __init__ method, it will be
-called each time the local object is used in a separate thread. This
-is necessary to initialize each thread's dictionary.
-
-Now if we create a local object:
-
- >>> mydata = MyLocal(color='red')
-
-Now we have a default number:
-
- >>> mydata.number
- 2
-
-an initial color:
-
- >>> mydata.color
- 'red'
- >>> del mydata.color
-
-And a method that operates on the data:
-
- >>> mydata.squared()
- 4
-
-As before, we can access the data in a separate thread:
-
- >>> log = []
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
- >>> log
- [[('color', 'red')], 11]
-
-without affecting this thread's data:
-
- >>> mydata.number
- 2
- >>> mydata.color
- Traceback (most recent call last):
- ...
- AttributeError: 'MyLocal' object has no attribute 'color'
-
-Note that subclasses can define slots, but they are not thread
-local. They are shared across threads:
-
- >>> class MyLocal(local):
- ... __slots__ = 'number'
-
- >>> mydata = MyLocal()
- >>> mydata.number = 42
- >>> mydata.color = 'red'
-
-So, the separate thread:
-
- >>> thread = threading.Thread(target=f)
- >>> thread.start()
- >>> thread.join()
-
-affects what we see:
-
- >>> mydata.number
- 11
-
->>> del mydata
"""
from weakref import ref
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 5ad0893106a..c0b1d4395d1 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -868,7 +868,7 @@ def get_annotations(
# For FORWARDREF, we use __annotations__ if it exists
try:
ann = _get_dunder_annotations(obj)
- except NameError:
+ except Exception:
pass
else:
if ann is not None:
diff --git a/Lib/argparse.py b/Lib/argparse.py
index c0dcd0bbff0..f13ac82dbc5 100644
--- a/Lib/argparse.py
+++ b/Lib/argparse.py
@@ -176,13 +176,13 @@ class HelpFormatter(object):
width = shutil.get_terminal_size().columns
width -= 2
- from _colorize import ANSIColors, NoColors, can_colorize, decolor
+ from _colorize import can_colorize, decolor, get_theme
if color and can_colorize():
- self._ansi = ANSIColors()
+ self._theme = get_theme(force_color=True).argparse
self._decolor = decolor
else:
- self._ansi = NoColors
+ self._theme = get_theme(force_no_color=True).argparse
self._decolor = lambda text: text
self._prefix_chars = prefix_chars
@@ -237,14 +237,12 @@ class HelpFormatter(object):
# add the heading if the section was non-empty
if self.heading is not SUPPRESS and self.heading is not None:
- bold_blue = self.formatter._ansi.BOLD_BLUE
- reset = self.formatter._ansi.RESET
-
current_indent = self.formatter._current_indent
heading_text = _('%(heading)s:') % dict(heading=self.heading)
+ t = self.formatter._theme
heading = (
f'{" " * current_indent}'
- f'{bold_blue}{heading_text}{reset}\n'
+ f'{t.heading}{heading_text}{t.reset}\n'
)
else:
heading = ''
@@ -314,10 +312,7 @@ class HelpFormatter(object):
if part and part is not SUPPRESS])
def _format_usage(self, usage, actions, groups, prefix):
- bold_blue = self._ansi.BOLD_BLUE
- bold_magenta = self._ansi.BOLD_MAGENTA
- magenta = self._ansi.MAGENTA
- reset = self._ansi.RESET
+ t = self._theme
if prefix is None:
prefix = _('usage: ')
@@ -325,15 +320,15 @@ class HelpFormatter(object):
# if usage is specified, use that
if usage is not None:
usage = (
- magenta
+ t.prog_extra
+ usage
- % {"prog": f"{bold_magenta}{self._prog}{reset}{magenta}"}
- + reset
+ % {"prog": f"{t.prog}{self._prog}{t.reset}{t.prog_extra}"}
+ + t.reset
)
# if no optionals or positionals are available, usage is just prog
elif usage is None and not actions:
- usage = f"{bold_magenta}{self._prog}{reset}"
+ usage = f"{t.prog}{self._prog}{t.reset}"
# if optionals and positionals are available, calculate usage
elif usage is None:
@@ -411,10 +406,10 @@ class HelpFormatter(object):
usage = '\n'.join(lines)
usage = usage.removeprefix(prog)
- usage = f"{bold_magenta}{prog}{reset}{usage}"
+ usage = f"{t.prog}{prog}{t.reset}{usage}"
# prefix with 'usage:'
- return f'{bold_blue}{prefix}{reset}{usage}\n\n'
+ return f'{t.usage}{prefix}{t.reset}{usage}\n\n'
def _format_actions_usage(self, actions, groups):
return ' '.join(self._get_actions_usage_parts(actions, groups))
@@ -452,10 +447,7 @@ class HelpFormatter(object):
# collect all actions format strings
parts = []
- cyan = self._ansi.CYAN
- green = self._ansi.GREEN
- yellow = self._ansi.YELLOW
- reset = self._ansi.RESET
+ t = self._theme
for action in actions:
# suppressed arguments are marked with None
@@ -465,7 +457,11 @@ class HelpFormatter(object):
# produce all arg strings
elif not action.option_strings:
default = self._get_default_metavar_for_positional(action)
- part = green + self._format_args(action, default) + reset
+ part = (
+ t.summary_action
+ + self._format_args(action, default)
+ + t.reset
+ )
# if it's in a group, strip the outer []
if action in group_actions:
@@ -481,9 +477,9 @@ class HelpFormatter(object):
if action.nargs == 0:
part = action.format_usage()
if self._is_long_option(part):
- part = f"{cyan}{part}{reset}"
+ part = f"{t.summary_long_option}{part}{t.reset}"
elif self._is_short_option(part):
- part = f"{green}{part}{reset}"
+ part = f"{t.summary_short_option}{part}{t.reset}"
# if the Optional takes a value, format is:
# -s ARGS or --long ARGS
@@ -491,10 +487,13 @@ class HelpFormatter(object):
default = self._get_default_metavar_for_optional(action)
args_string = self._format_args(action, default)
if self._is_long_option(option_string):
- option_string = f"{cyan}{option_string}"
+ option_color = t.summary_long_option
elif self._is_short_option(option_string):
- option_string = f"{green}{option_string}"
- part = f"{option_string} {yellow}{args_string}{reset}"
+ option_color = t.summary_short_option
+ part = (
+ f"{option_color}{option_string} "
+ f"{t.summary_label}{args_string}{t.reset}"
+ )
# make it look optional if it's not required or in a group
if not action.required and action not in group_actions:
@@ -590,17 +589,14 @@ class HelpFormatter(object):
return self._join_parts(parts)
def _format_action_invocation(self, action):
- bold_green = self._ansi.BOLD_GREEN
- bold_cyan = self._ansi.BOLD_CYAN
- bold_yellow = self._ansi.BOLD_YELLOW
- reset = self._ansi.RESET
+ t = self._theme
if not action.option_strings:
default = self._get_default_metavar_for_positional(action)
return (
- bold_green
+ t.action
+ ' '.join(self._metavar_formatter(action, default)(1))
- + reset
+ + t.reset
)
else:
@@ -609,9 +605,9 @@ class HelpFormatter(object):
parts = []
for s in strings:
if self._is_long_option(s):
- parts.append(f"{bold_cyan}{s}{reset}")
+ parts.append(f"{t.long_option}{s}{t.reset}")
elif self._is_short_option(s):
- parts.append(f"{bold_green}{s}{reset}")
+ parts.append(f"{t.short_option}{s}{t.reset}")
else:
parts.append(s)
return parts
@@ -628,7 +624,7 @@ class HelpFormatter(object):
default = self._get_default_metavar_for_optional(action)
option_strings = color_option_strings(action.option_strings)
args_string = (
- f"{bold_yellow}{self._format_args(action, default)}{reset}"
+ f"{t.label}{self._format_args(action, default)}{t.reset}"
)
return ', '.join(option_strings) + ' ' + args_string
diff --git a/Lib/ast.py b/Lib/ast.py
index af4fe8ff5a8..b9791bf52d3 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -630,7 +630,7 @@ def main(args=None):
import argparse
import sys
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('infile', nargs='?', default='-',
help='the file to parse; defaults to stdin')
parser.add_argument('-m', '--mode', default='exec',
@@ -643,6 +643,15 @@ def main(args=None):
'column offsets')
parser.add_argument('-i', '--indent', type=int, default=3,
help='indentation of nodes (number of spaces)')
+ parser.add_argument('--feature-version',
+ type=str, default=None, metavar='VERSION',
+ help='Python version in the format 3.x '
+ '(for example, 3.10)')
+ parser.add_argument('-O', '--optimize',
+ type=int, default=-1, metavar='LEVEL',
+ help='optimization level for parser (default -1)')
+ parser.add_argument('--show-empty', default=False, action='store_true',
+ help='show empty lists and fields in dump output')
args = parser.parse_args(args)
if args.infile == '-':
@@ -652,8 +661,22 @@ def main(args=None):
name = args.infile
with open(args.infile, 'rb') as infile:
source = infile.read()
- tree = parse(source, name, args.mode, type_comments=args.no_type_comments)
- print(dump(tree, include_attributes=args.include_attributes, indent=args.indent))
+
+ # Process feature_version
+ feature_version = None
+ if args.feature_version:
+ try:
+ major, minor = map(int, args.feature_version.split('.', 1))
+ except ValueError:
+ parser.error('Invalid format for --feature-version; '
+ 'expected format 3.x (for example, 3.10)')
+
+ feature_version = (major, minor)
+
+ tree = parse(source, name, args.mode, type_comments=args.no_type_comments,
+ feature_version=feature_version, optimize=args.optimize)
+ print(dump(tree, include_attributes=args.include_attributes,
+ indent=args.indent, show_empty=args.show_empty))
if __name__ == '__main__':
main()
diff --git a/Lib/asyncio/__main__.py b/Lib/asyncio/__main__.py
index 7d980bc401a..21ca5c5f62a 100644
--- a/Lib/asyncio/__main__.py
+++ b/Lib/asyncio/__main__.py
@@ -12,7 +12,7 @@ import threading
import types
import warnings
-from _colorize import can_colorize, ANSIColors # type: ignore[import-not-found]
+from _colorize import get_theme
from _pyrepl.console import InteractiveColoredConsole
from . import futures
@@ -103,8 +103,9 @@ class REPLThread(threading.Thread):
exec(startup_code, console.locals)
ps1 = getattr(sys, "ps1", ">>> ")
- if can_colorize() and CAN_USE_PYREPL:
- ps1 = f"{ANSIColors.BOLD_MAGENTA}{ps1}{ANSIColors.RESET}"
+ if CAN_USE_PYREPL:
+ theme = get_theme().syntax
+ ps1 = f"{theme.prompt}{ps1}{theme.reset}"
console.write(f"{ps1}import asyncio\n")
if CAN_USE_PYREPL:
@@ -145,6 +146,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog="python3 -m asyncio",
description="Interactive asyncio shell and CLI tools",
+ color=True,
)
subparsers = parser.add_subparsers(help="sub-commands", dest="command")
ps = subparsers.add_parser(
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 29b872ce00e..04fb961e998 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -459,7 +459,7 @@ class BaseEventLoop(events.AbstractEventLoop):
return futures.Future(loop=self)
def create_task(self, coro, **kwargs):
- """Schedule a coroutine object.
+ """Schedule or begin executing a coroutine object.
Return a task object.
"""
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index 1633478d1c8..00e8f6d5d1a 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -179,7 +179,7 @@ class TaskGroup:
exc = None
- def create_task(self, coro, *, name=None, context=None):
+ def create_task(self, coro, **kwargs):
"""Create a new task in this group and return it.
Similar to `asyncio.create_task`.
@@ -193,10 +193,7 @@ class TaskGroup:
if self._aborting:
coro.close()
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
- if context is None:
- task = self._loop.create_task(coro, name=name)
- else:
- task = self._loop.create_task(coro, name=name, context=context)
+ task = self._loop.create_task(coro, **kwargs)
futures.future_add_to_awaited_by(task, self._parent_task)
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 825e91f5594..888615f8e5e 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -386,19 +386,13 @@ else:
Task = _CTask = _asyncio.Task
-def create_task(coro, *, name=None, context=None):
+def create_task(coro, **kwargs):
"""Schedule the execution of a coroutine object in a spawn task.
Return a Task object.
"""
loop = events.get_running_loop()
- if context is None:
- # Use legacy API if context is not needed
- task = loop.create_task(coro, name=name)
- else:
- task = loop.create_task(coro, name=name, context=context)
-
- return task
+ return loop.create_task(coro, **kwargs)
# wait() and as_completed() similar to those in PEP 3148.
@@ -1030,9 +1024,9 @@ def create_eager_task_factory(custom_task_constructor):
used. E.g. `loop.set_task_factory(asyncio.eager_task_factory)`.
"""
- def factory(loop, coro, *, name=None, context=None):
+ def factory(loop, coro, *, eager_start=True, **kwargs):
return custom_task_constructor(
- coro, loop=loop, name=name, context=context, eager_start=True)
+ coro, loop=loop, eager_start=eager_start, **kwargs)
return factory
diff --git a/Lib/asyncio/tools.py b/Lib/asyncio/tools.py
index 6c1f725e777..bf1cb5e64cb 100644
--- a/Lib/asyncio/tools.py
+++ b/Lib/asyncio/tools.py
@@ -5,7 +5,7 @@ from collections import defaultdict
from itertools import count
from enum import Enum
import sys
-from _remotedebugging import get_all_awaited_by
+from _remote_debugging import get_all_awaited_by
class NodeType(Enum):
@@ -21,13 +21,21 @@ class CycleFoundException(Exception):
# ─── indexing helpers ───────────────────────────────────────────
+def _format_stack_entry(elem: tuple[str, str, int] | str) -> str:
+ if isinstance(elem, tuple):
+ fqname, path, line_no = elem
+ return f"{fqname} {path}:{line_no}"
+
+ return elem
+
+
def _index(result):
id2name, awaits = {}, []
for _thr_id, tasks in result:
for tid, tname, awaited in tasks:
id2name[tid] = tname
for stack, parent_id in awaited:
- stack = [elem[0] if isinstance(elem, tuple) else elem for elem in stack]
+ stack = [_format_stack_entry(elem) for elem in stack]
awaits.append((parent_id, stack, tid))
return id2name, awaits
@@ -106,7 +114,7 @@ def _find_cycles(graph):
# ─── PRINT TREE FUNCTION ───────────────────────────────────────
def build_async_tree(result, task_emoji="(T)", cor_emoji=""):
"""
- Build a list of strings for pretty-print a async call tree.
+ Build a list of strings for pretty-print an async call tree.
The call tree is produced by `get_all_async_stacks()`, prefixing tasks
with `task_emoji` and coroutine frames with `cor_emoji`.
@@ -169,7 +177,7 @@ def build_task_table(result):
return table
def _print_cycle_exception(exception: CycleFoundException):
- print("ERROR: await-graph contains cycles – cannot print a tree!", file=sys.stderr)
+ print("ERROR: await-graph contains cycles - cannot print a tree!", file=sys.stderr)
print("", file=sys.stderr)
for c in exception.cycles:
inames = " → ".join(exception.id2name.get(tid, hex(tid)) for tid in c)
diff --git a/Lib/calendar.py b/Lib/calendar.py
index 01a76ff8e78..18f76d52ff8 100644
--- a/Lib/calendar.py
+++ b/Lib/calendar.py
@@ -810,7 +810,7 @@ def timegm(tuple):
def main(args=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
textgroup = parser.add_argument_group('text only arguments')
htmlgroup = parser.add_argument_group('html only arguments')
textgroup.add_argument(
diff --git a/Lib/code.py b/Lib/code.py
index 41331dfd071..b134886dc26 100644
--- a/Lib/code.py
+++ b/Lib/code.py
@@ -385,7 +385,7 @@ def interact(banner=None, readfunc=None, local=None, exitmsg=None, local_exit=Fa
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('-q', action='store_true',
help="don't print version and copyright messages")
args = parser.parse_args()
diff --git a/Lib/compileall.py b/Lib/compileall.py
index 47e2446356e..67fe370451e 100644
--- a/Lib/compileall.py
+++ b/Lib/compileall.py
@@ -317,7 +317,9 @@ def main():
import argparse
parser = argparse.ArgumentParser(
- description='Utilities to support installing Python libraries.')
+ description='Utilities to support installing Python libraries.',
+ color=True,
+ )
parser.add_argument('-l', action='store_const', const=0,
default=None, dest='maxlevels',
help="don't recurse into subdirectories")
diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py
new file mode 100644
index 00000000000..4f734eb07b0
--- /dev/null
+++ b/Lib/compression/zstd/__init__.py
@@ -0,0 +1,234 @@
+"""Python bindings to the Zstandard (zstd) compression library (RFC-8878)."""
+
+__all__ = (
+ # compression.zstd
+ "COMPRESSION_LEVEL_DEFAULT",
+ "compress",
+ "CompressionParameter",
+ "decompress",
+ "DecompressionParameter",
+ "finalize_dict",
+ "get_frame_info",
+ "Strategy",
+ "train_dict",
+
+ # compression.zstd._zstdfile
+ "open",
+ "ZstdFile",
+
+ # _zstd
+ "get_frame_size",
+ "zstd_version",
+ "zstd_version_info",
+ "ZstdCompressor",
+ "ZstdDecompressor",
+ "ZstdDict",
+ "ZstdError",
+)
+
+import _zstd
+import enum
+from _zstd import *
+from compression.zstd._zstdfile import ZstdFile, open, _nbytes
+
+COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0]
+"""The default compression level for Zstandard, currently '3'."""
+
+
+class FrameInfo:
+ """Information about a Zstandard frame."""
+ __slots__ = 'decompressed_size', 'dictionary_id'
+
+ def __init__(self, decompressed_size, dictionary_id):
+ super().__setattr__('decompressed_size', decompressed_size)
+ super().__setattr__('dictionary_id', dictionary_id)
+
+ def __repr__(self):
+ return (f'FrameInfo(decompressed_size={self.decompressed_size}, '
+ f'dictionary_id={self.dictionary_id})')
+
+ def __setattr__(self, name, _):
+ raise AttributeError(f"can't set attribute {name!r}")
+
+
+def get_frame_info(frame_buffer):
+ """Get Zstandard frame information from a frame header.
+
+ *frame_buffer* is a bytes-like object. It should start from the beginning
+ of a frame, and needs to include at least the frame header (6 to 18 bytes).
+
+ The returned FrameInfo object has two attributes.
+ 'decompressed_size' is the size in bytes of the data in the frame when
+ decompressed, or None when the decompressed size is unknown.
+ 'dictionary_id' is an int in the range (0, 2**32). The special value 0
+ means that the dictionary ID was not recorded in the frame header,
+ the frame may or may not need a dictionary to be decoded,
+ and the ID of such a dictionary is not specified.
+ """
+ return FrameInfo(*_zstd._get_frame_info(frame_buffer))
+
+
+def train_dict(samples, dict_size):
+ """Return a ZstdDict representing a trained Zstandard dictionary.
+
+ *samples* is an iterable of samples, where a sample is a bytes-like
+ object representing a file.
+
+ *dict_size* is the dictionary's maximum size, in bytes.
+ """
+ if not isinstance(dict_size, int):
+ ds_cls = type(dict_size).__qualname__
+ raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.')
+
+ samples = tuple(samples)
+ chunks = b''.join(samples)
+ chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+ if not chunks:
+ raise ValueError("samples contained no data; can't train dictionary.")
+ dict_content = _zstd._train_dict(chunks, chunk_sizes, dict_size)
+ return ZstdDict(dict_content)
+
+
+def finalize_dict(zstd_dict, /, samples, dict_size, level):
+ """Return a ZstdDict representing a finalized Zstandard dictionary.
+
+ Given a custom content as a basis for dictionary, and a set of samples,
+ finalize *zstd_dict* by adding headers and statistics according to the
+ Zstandard dictionary format.
+
+ You may compose an effective dictionary content by hand, which is used as
+ basis dictionary, and use some samples to finalize a dictionary. The basis
+ dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict.
+
+ *samples* is an iterable of samples, where a sample is a bytes-like object
+ representing a file.
+ *dict_size* is the dictionary's maximum size, in bytes.
+ *level* is the expected compression level. The statistics for each
+ compression level differ, so tuning the dictionary to the compression level
+ can provide improvements.
+ """
+
+ if not isinstance(zstd_dict, ZstdDict):
+ raise TypeError('zstd_dict argument should be a ZstdDict object.')
+ if not isinstance(dict_size, int):
+ raise TypeError('dict_size argument should be an int object.')
+ if not isinstance(level, int):
+ raise TypeError('level argument should be an int object.')
+
+ samples = tuple(samples)
+ chunks = b''.join(samples)
+ chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+ if not chunks:
+ raise ValueError("The samples are empty content, can't finalize the"
+ "dictionary.")
+ dict_content = _zstd._finalize_dict(zstd_dict.dict_content,
+ chunks, chunk_sizes,
+ dict_size, level)
+ return ZstdDict(dict_content)
+
+def compress(data, level=None, options=None, zstd_dict=None):
+ """Return Zstandard compressed *data* as bytes.
+
+ *level* is an int specifying the compression level to use, defaulting to
+ COMPRESSION_LEVEL_DEFAULT ('3').
+ *options* is a dict object that contains advanced compression
+ parameters. See CompressionParameter for more on options.
+ *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+ the function train_dict for how to train a ZstdDict on sample data.
+
+ For incremental compression, use a ZstdCompressor instead.
+ """
+ comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict)
+ return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME)
+
+def decompress(data, zstd_dict=None, options=None):
+ """Decompress one or more frames of Zstandard compressed *data*.
+
+ *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+ the function train_dict for how to train a ZstdDict on sample data.
+ *options* is a dict object that contains advanced compression
+ parameters. See DecompressionParameter for more on options.
+
+ For incremental decompression, use a ZstdDecompressor instead.
+ """
+ results = []
+ while True:
+ decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict)
+ results.append(decomp.decompress(data))
+ if not decomp.eof:
+ raise ZstdError("Compressed data ended before the "
+ "end-of-stream marker was reached")
+ data = decomp.unused_data
+ if not data:
+ break
+ return b"".join(results)
+
+
+class CompressionParameter(enum.IntEnum):
+ """Compression parameters."""
+
+ compression_level = _zstd._ZSTD_c_compressionLevel
+ window_log = _zstd._ZSTD_c_windowLog
+ hash_log = _zstd._ZSTD_c_hashLog
+ chain_log = _zstd._ZSTD_c_chainLog
+ search_log = _zstd._ZSTD_c_searchLog
+ min_match = _zstd._ZSTD_c_minMatch
+ target_length = _zstd._ZSTD_c_targetLength
+ strategy = _zstd._ZSTD_c_strategy
+
+ enable_long_distance_matching = _zstd._ZSTD_c_enableLongDistanceMatching
+ ldm_hash_log = _zstd._ZSTD_c_ldmHashLog
+ ldm_min_match = _zstd._ZSTD_c_ldmMinMatch
+ ldm_bucket_size_log = _zstd._ZSTD_c_ldmBucketSizeLog
+ ldm_hash_rate_log = _zstd._ZSTD_c_ldmHashRateLog
+
+ content_size_flag = _zstd._ZSTD_c_contentSizeFlag
+ checksum_flag = _zstd._ZSTD_c_checksumFlag
+ dict_id_flag = _zstd._ZSTD_c_dictIDFlag
+
+ nb_workers = _zstd._ZSTD_c_nbWorkers
+ job_size = _zstd._ZSTD_c_jobSize
+ overlap_log = _zstd._ZSTD_c_overlapLog
+
+ def bounds(self):
+ """Return the (lower, upper) int bounds of a compression parameter.
+
+ Both the lower and upper bounds are inclusive.
+ """
+ return _zstd._get_param_bounds(self.value, is_compress=True)
+
+
+class DecompressionParameter(enum.IntEnum):
+ """Decompression parameters."""
+
+ window_log_max = _zstd._ZSTD_d_windowLogMax
+
+ def bounds(self):
+ """Return the (lower, upper) int bounds of a decompression parameter.
+
+ Both the lower and upper bounds are inclusive.
+ """
+ return _zstd._get_param_bounds(self.value, is_compress=False)
+
+
+class Strategy(enum.IntEnum):
+ """Compression strategies, listed from fastest to strongest.
+
+ Note that new strategies might be added in the future.
+ Only the order (from fast to strong) is guaranteed,
+ the numeric value might change.
+ """
+
+ fast = _zstd._ZSTD_fast
+ dfast = _zstd._ZSTD_dfast
+ greedy = _zstd._ZSTD_greedy
+ lazy = _zstd._ZSTD_lazy
+ lazy2 = _zstd._ZSTD_lazy2
+ btlazy2 = _zstd._ZSTD_btlazy2
+ btopt = _zstd._ZSTD_btopt
+ btultra = _zstd._ZSTD_btultra
+ btultra2 = _zstd._ZSTD_btultra2
+
+
+# Check validity of the CompressionParameter & DecompressionParameter types
+_zstd._set_parameter_types(CompressionParameter, DecompressionParameter)
diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py
new file mode 100644
index 00000000000..fbc9e02a733
--- /dev/null
+++ b/Lib/compression/zstd/_zstdfile.py
@@ -0,0 +1,349 @@
+import io
+from os import PathLike
+from _zstd import (ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes,
+ ZstdError)
+from compression._common import _streams
+
+__all__ = ("ZstdFile", "open")
+
+_ZSTD_DStreamOutSize = _ZSTD_DStreamSizes[1]
+
+_MODE_CLOSED = 0
+_MODE_READ = 1
+_MODE_WRITE = 2
+
+
+def _nbytes(dat, /):
+ if isinstance(dat, (bytes, bytearray)):
+ return len(dat)
+ with memoryview(dat) as mv:
+ return mv.nbytes
+
+
+class ZstdFile(_streams.BaseStream):
+ """A file-like object providing transparent Zstandard (de)compression.
+
+ A ZstdFile can act as a wrapper for an existing file object, or refer
+ directly to a named file on disk.
+
+ ZstdFile provides a *binary* file interface. Data is read and returned as
+ bytes, and may only be written to objects that support the Buffer Protocol.
+ """
+
+ FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK
+ FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME
+
+ def __init__(self, file, /, mode="r", *,
+ level=None, options=None, zstd_dict=None):
+ """Open a Zstandard compressed file in binary mode.
+
+ *file* can be either an file-like object, or a file name to open.
+
+ *mode* can be "r" for reading (default), "w" for (over)writing, "x" for
+ creating exclusively, or "a" for appending. These can equivalently be
+ given as "rb", "wb", "xb" and "ab" respectively.
+
+ *level* is an optional int specifying the compression level to use,
+ or COMPRESSION_LEVEL_DEFAULT if not given.
+
+ *options* is an optional dict for advanced compression parameters.
+ See CompressionParameter and DecompressionParameter for the possible
+ options.
+
+ *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard
+ dictionary. See train_dict() to train ZstdDict on sample data.
+ """
+ self._fp = None
+ self._close_fp = False
+ self._mode = _MODE_CLOSED
+ self._buffer = None
+
+ if not isinstance(mode, str):
+ raise ValueError("mode must be a str")
+ if options is not None and not isinstance(options, dict):
+ raise TypeError("options must be a dict or None")
+ mode = mode.removesuffix("b") # handle rb, wb, xb, ab
+ if mode == "r":
+ if level is not None:
+ raise TypeError("level is illegal in read mode")
+ self._mode = _MODE_READ
+ elif mode in {"w", "a", "x"}:
+ if level is not None and not isinstance(level, int):
+ raise TypeError("level must be int or None")
+ self._mode = _MODE_WRITE
+ self._compressor = ZstdCompressor(level=level, options=options,
+ zstd_dict=zstd_dict)
+ self._pos = 0
+ else:
+ raise ValueError(f"Invalid mode: {mode!r}")
+
+ if isinstance(file, (str, bytes, PathLike)):
+ self._fp = io.open(file, f'{mode}b')
+ self._close_fp = True
+ elif ((mode == 'r' and hasattr(file, "read"))
+ or (mode != 'r' and hasattr(file, "write"))):
+ self._fp = file
+ else:
+ raise TypeError("file must be a file-like object "
+ "or a str, bytes, or PathLike object")
+
+ if self._mode == _MODE_READ:
+ raw = _streams.DecompressReader(
+ self._fp,
+ ZstdDecompressor,
+ trailing_error=ZstdError,
+ zstd_dict=zstd_dict,
+ options=options,
+ )
+ self._buffer = io.BufferedReader(raw)
+
+ def close(self):
+ """Flush and close the file.
+
+ May be called multiple times. Once the file has been closed,
+ any other operation on it will raise ValueError.
+ """
+ if self._fp is None:
+ return
+ try:
+ if self._mode == _MODE_READ:
+ if getattr(self, '_buffer', None):
+ self._buffer.close()
+ self._buffer = None
+ elif self._mode == _MODE_WRITE:
+ self.flush(self.FLUSH_FRAME)
+ self._compressor = None
+ finally:
+ self._mode = _MODE_CLOSED
+ try:
+ if self._close_fp:
+ self._fp.close()
+ finally:
+ self._fp = None
+ self._close_fp = False
+
+ def write(self, data, /):
+ """Write a bytes-like object *data* to the file.
+
+ Returns the number of uncompressed bytes written, which is
+ always the length of data in bytes. Note that due to buffering,
+ the file on disk may not reflect the data written until .flush()
+ or .close() is called.
+ """
+ self._check_can_write()
+
+ length = _nbytes(data)
+
+ compressed = self._compressor.compress(data)
+ self._fp.write(compressed)
+ self._pos += length
+ return length
+
+ def flush(self, mode=FLUSH_BLOCK):
+ """Flush remaining data to the underlying stream.
+
+ The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this
+ method will reduce compression ratio, use it only when necessary.
+
+ If the program is interrupted afterwards, all data can be recovered.
+ To ensure saving to disk, also need to use os.fsync(fd).
+
+ This method does nothing in reading mode.
+ """
+ if self._mode == _MODE_READ:
+ return
+ self._check_not_closed()
+ if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
+ raise ValueError("Invalid mode argument, expected either "
+ "ZstdFile.FLUSH_FRAME or "
+ "ZstdFile.FLUSH_BLOCK")
+ if self._compressor.last_mode == mode:
+ return
+ # Flush zstd block/frame, and write.
+ data = self._compressor.flush(mode)
+ self._fp.write(data)
+ if hasattr(self._fp, "flush"):
+ self._fp.flush()
+
+ def read(self, size=-1):
+ """Read up to size uncompressed bytes from the file.
+
+ If size is negative or omitted, read until EOF is reached.
+ Returns b"" if the file is already at EOF.
+ """
+ if size is None:
+ size = -1
+ self._check_can_read()
+ return self._buffer.read(size)
+
+ def read1(self, size=-1):
+ """Read up to size uncompressed bytes, while trying to avoid
+ making multiple reads from the underlying stream. Reads up to a
+ buffer's worth of data if size is negative.
+
+ Returns b"" if the file is at EOF.
+ """
+ self._check_can_read()
+ if size < 0:
+ # Note this should *not* be io.DEFAULT_BUFFER_SIZE.
+ # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing
+ # a full block is read.
+ size = _ZSTD_DStreamOutSize
+ return self._buffer.read1(size)
+
+ def readinto(self, b):
+ """Read bytes into b.
+
+ Returns the number of bytes read (0 for EOF).
+ """
+ self._check_can_read()
+ return self._buffer.readinto(b)
+
+ def readinto1(self, b):
+ """Read bytes into b, while trying to avoid making multiple reads
+ from the underlying stream.
+
+ Returns the number of bytes read (0 for EOF).
+ """
+ self._check_can_read()
+ return self._buffer.readinto1(b)
+
+ def readline(self, size=-1):
+ """Read a line of uncompressed bytes from the file.
+
+ The terminating newline (if present) is retained. If size is
+ non-negative, no more than size bytes will be read (in which
+ case the line may be incomplete). Returns b'' if already at EOF.
+ """
+ self._check_can_read()
+ return self._buffer.readline(size)
+
+ def seek(self, offset, whence=io.SEEK_SET):
+ """Change the file position.
+
+ The new position is specified by offset, relative to the
+ position indicated by whence. Possible values for whence are:
+
+ 0: start of stream (default): offset must not be negative
+ 1: current stream position
+ 2: end of stream; offset must not be positive
+
+ Returns the new file position.
+
+ Note that seeking is emulated, so depending on the arguments,
+ this operation may be extremely slow.
+ """
+ self._check_can_read()
+
+ # BufferedReader.seek() checks seekable
+ return self._buffer.seek(offset, whence)
+
+ def peek(self, size=-1):
+ """Return buffered data without advancing the file position.
+
+ Always returns at least one byte of data, unless at EOF.
+ The exact number of bytes returned is unspecified.
+ """
+ # Relies on the undocumented fact that BufferedReader.peek() always
+ # returns at least one byte (except at EOF)
+ self._check_can_read()
+ return self._buffer.peek(size)
+
+ def __next__(self):
+ if ret := self._buffer.readline():
+ return ret
+ raise StopIteration
+
+ def tell(self):
+ """Return the current file position."""
+ self._check_not_closed()
+ if self._mode == _MODE_READ:
+ return self._buffer.tell()
+ elif self._mode == _MODE_WRITE:
+ return self._pos
+
+ def fileno(self):
+ """Return the file descriptor for the underlying file."""
+ self._check_not_closed()
+ return self._fp.fileno()
+
+ @property
+ def name(self):
+ self._check_not_closed()
+ return self._fp.name
+
+ @property
+ def mode(self):
+ return 'wb' if self._mode == _MODE_WRITE else 'rb'
+
+ @property
+ def closed(self):
+ """True if this file is closed."""
+ return self._mode == _MODE_CLOSED
+
+ def seekable(self):
+ """Return whether the file supports seeking."""
+ return self.readable() and self._buffer.seekable()
+
+ def readable(self):
+ """Return whether the file was opened for reading."""
+ self._check_not_closed()
+ return self._mode == _MODE_READ
+
+ def writable(self):
+ """Return whether the file was opened for writing."""
+ self._check_not_closed()
+ return self._mode == _MODE_WRITE
+
+
+def open(file, /, mode="rb", *, level=None, options=None, zstd_dict=None,
+ encoding=None, errors=None, newline=None):
+ """Open a Zstandard compressed file in binary or text mode.
+
+ file can be either a file name (given as a str, bytes, or PathLike object),
+ in which case the named file is opened, or it can be an existing file object
+ to read from or write to.
+
+ The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a",
+ "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode.
+
+ The level, options, and zstd_dict parameters specify the settings the same
+ as ZstdFile.
+
+ When using read mode (decompression), the options parameter is a dict
+ representing advanced decompression options. The level parameter is not
+ supported in this case. When using write mode (compression), only one of
+ level, an int representing the compression level, or options, a dict
+ representing advanced compression options, may be passed. In both modes,
+ zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary.
+
+ For binary mode, this function is equivalent to the ZstdFile constructor:
+ ZstdFile(filename, mode, ...). In this case, the encoding, errors and
+ newline parameters must not be provided.
+
+ For text mode, an ZstdFile object is created, and wrapped in an
+ io.TextIOWrapper instance with the specified encoding, error handling
+ behavior, and line ending(s).
+ """
+
+ text_mode = "t" in mode
+ mode = mode.replace("t", "")
+
+ if text_mode:
+ if "b" in mode:
+ raise ValueError(f"Invalid mode: {mode!r}")
+ else:
+ if encoding is not None:
+ raise ValueError("Argument 'encoding' not supported in binary mode")
+ if errors is not None:
+ raise ValueError("Argument 'errors' not supported in binary mode")
+ if newline is not None:
+ raise ValueError("Argument 'newline' not supported in binary mode")
+
+ binary_file = ZstdFile(file, mode, level=level, options=options,
+ zstd_dict=zstd_dict)
+
+ if text_mode:
+ return io.TextIOWrapper(binary_file, encoding, errors, newline)
+ else:
+ return binary_file
diff --git a/Lib/ctypes/_layout.py b/Lib/ctypes/_layout.py
index 0719e72cfed..2048ccb6a1c 100644
--- a/Lib/ctypes/_layout.py
+++ b/Lib/ctypes/_layout.py
@@ -5,6 +5,7 @@ may change at any time.
"""
import sys
+import warnings
from _ctypes import CField, buffer_info
import ctypes
@@ -66,9 +67,26 @@ def get_layout(cls, input_fields, is_struct, base):
# For clarity, variables that count bits have `bit` in their names.
+ pack = getattr(cls, '_pack_', None)
+
layout = getattr(cls, '_layout_', None)
if layout is None:
- if sys.platform == 'win32' or getattr(cls, '_pack_', None):
+ if sys.platform == 'win32':
+ gcc_layout = False
+ elif pack:
+ if is_struct:
+ base_type_name = 'Structure'
+ else:
+ base_type_name = 'Union'
+ warnings._deprecated(
+ '_pack_ without _layout_',
+ f"Due to '_pack_', the '{cls.__name__}' {base_type_name} will "
+ + "use memory layout compatible with MSVC (Windows). "
+ + "If this is intended, set _layout_ to 'ms'. "
+ + "The implicit default is deprecated and slated to become "
+ + "an error in Python {remove}.",
+ remove=(3, 19),
+ )
gcc_layout = False
else:
gcc_layout = True
@@ -95,7 +113,6 @@ def get_layout(cls, input_fields, is_struct, base):
else:
big_endian = sys.byteorder == 'big'
- pack = getattr(cls, '_pack_', None)
if pack is not None:
try:
pack = int(pack)
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 0f7dc9ae6b8..86d29df0639 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
property,
})
+# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
+# without importing `typing` module.
+_ANY_MARKER = object()
+
class InitVar:
__slots__ = ('type', )
@@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
for item in fields:
if isinstance(item, str):
name = item
- tp = 'typing.Any'
+ tp = _ANY_MARKER
elif len(item) == 2:
name, tp, = item
elif len(item) == 3:
@@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
seen.add(name)
annotations[name] = tp
+ # We initially block the VALUE format, because inside dataclass() we'll
+ # call get_annotations(), which will try the VALUE format first. If we don't
+ # block, that means we'd always end up eagerly importing typing here, which
+ # is what we're trying to avoid.
+ value_blocked = True
+
+ def annotate_method(format):
+ def get_any():
+ match format:
+ case annotationlib.Format.STRING:
+ return 'typing.Any'
+ case annotationlib.Format.FORWARDREF:
+ typing = sys.modules.get("typing")
+ if typing is None:
+ return annotationlib.ForwardRef("Any", module="typing")
+ else:
+ return typing.Any
+ case annotationlib.Format.VALUE:
+ if value_blocked:
+ raise NotImplementedError
+ from typing import Any
+ return Any
+ case _:
+ raise NotImplementedError
+ annos = {
+ ann: get_any() if t is _ANY_MARKER else t
+ for ann, t in annotations.items()
+ }
+ if format == annotationlib.Format.STRING:
+ return annotationlib.annotations_to_string(annos)
+ else:
+ return annos
+
# Update 'ns' with the user-supplied namespace plus our calculated values.
def exec_body_callback(ns):
ns.update(namespace)
ns.update(defaults)
- ns['__annotations__'] = annotations
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
+ # For now, set annotations including the _ANY_MARKER.
+ cls.__annotate__ = annotate_method
# For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created.
@@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
cls.__module__ = module
# Apply the normal provided decorator.
- return decorator(cls, init=init, repr=repr, eq=eq, order=order,
- unsafe_hash=unsafe_hash, frozen=frozen,
- match_args=match_args, kw_only=kw_only, slots=slots,
- weakref_slot=weakref_slot)
+ cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
+ unsafe_hash=unsafe_hash, frozen=frozen,
+ match_args=match_args, kw_only=kw_only, slots=slots,
+ weakref_slot=weakref_slot)
+ # Now that the class is ready, allow the VALUE format.
+ value_blocked = False
+ return cls
def replace(obj, /, **changes):
diff --git a/Lib/dis.py b/Lib/dis.py
index cb6d077a391..d6d2c1386dd 100644
--- a/Lib/dis.py
+++ b/Lib/dis.py
@@ -1131,7 +1131,7 @@ class Bytecode:
def main(args=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('-C', '--show-caches', action='store_true',
help='show inline caches')
parser.add_argument('-O', '--show-offsets', action='store_true',
diff --git a/Lib/doctest.py b/Lib/doctest.py
index e02e73ed722..2acb6cb79f3 100644
--- a/Lib/doctest.py
+++ b/Lib/doctest.py
@@ -2870,7 +2870,7 @@ __test__ = {"_TestClass": _TestClass,
def _test():
import argparse
- parser = argparse.ArgumentParser(description="doctest runner")
+ parser = argparse.ArgumentParser(description="doctest runner", color=True)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
help='print very verbose output for all tests')
parser.add_argument('-o', '--option', action='append',
diff --git a/Lib/ensurepip/__init__.py b/Lib/ensurepip/__init__.py
index 6fc9f39b24c..aa641e94a8b 100644
--- a/Lib/ensurepip/__init__.py
+++ b/Lib/ensurepip/__init__.py
@@ -205,7 +205,7 @@ def _uninstall_helper(*, verbosity=0):
def _main(argv=None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(
"--version",
action="version",
diff --git a/Lib/getpass.py b/Lib/getpass.py
index bd0097ced94..f571425e541 100644
--- a/Lib/getpass.py
+++ b/Lib/getpass.py
@@ -1,6 +1,7 @@
"""Utilities to get a password and/or the current user name.
-getpass(prompt[, stream]) - Prompt for a password, with echo turned off.
+getpass(prompt[, stream[, echo_char]]) - Prompt for a password, with echo
+turned off and optional keyboard feedback.
getuser() - Get the user name from the environment or password database.
GetPassWarning - This UserWarning is issued when getpass() cannot prevent
@@ -25,13 +26,15 @@ __all__ = ["getpass","getuser","GetPassWarning"]
class GetPassWarning(UserWarning): pass
-def unix_getpass(prompt='Password: ', stream=None):
+def unix_getpass(prompt='Password: ', stream=None, *, echo_char=None):
"""Prompt for a password, with echo turned off.
Args:
prompt: Written on stream to ask for the input. Default: 'Password: '
stream: A writable file object to display the prompt. Defaults to
the tty. If no tty is available defaults to sys.stderr.
+ echo_char: A string used to mask input (e.g., '*'). If None, input is
+ hidden.
Returns:
The seKr3t input.
Raises:
@@ -40,6 +43,8 @@ def unix_getpass(prompt='Password: ', stream=None):
Always restores terminal settings before returning.
"""
+ _check_echo_char(echo_char)
+
passwd = None
with contextlib.ExitStack() as stack:
try:
@@ -68,12 +73,16 @@ def unix_getpass(prompt='Password: ', stream=None):
old = termios.tcgetattr(fd) # a copy to save
new = old[:]
new[3] &= ~termios.ECHO # 3 == 'lflags'
+ if echo_char:
+ new[3] &= ~termios.ICANON
tcsetattr_flags = termios.TCSAFLUSH
if hasattr(termios, 'TCSASOFT'):
tcsetattr_flags |= termios.TCSASOFT
try:
termios.tcsetattr(fd, tcsetattr_flags, new)
- passwd = _raw_input(prompt, stream, input=input)
+ passwd = _raw_input(prompt, stream, input=input,
+ echo_char=echo_char)
+
finally:
termios.tcsetattr(fd, tcsetattr_flags, old)
stream.flush() # issue7208
@@ -93,10 +102,11 @@ def unix_getpass(prompt='Password: ', stream=None):
return passwd
-def win_getpass(prompt='Password: ', stream=None):
+def win_getpass(prompt='Password: ', stream=None, *, echo_char=None):
"""Prompt for password with echo off, using Windows getwch()."""
if sys.stdin is not sys.__stdin__:
return fallback_getpass(prompt, stream)
+ _check_echo_char(echo_char)
for c in prompt:
msvcrt.putwch(c)
@@ -108,9 +118,15 @@ def win_getpass(prompt='Password: ', stream=None):
if c == '\003':
raise KeyboardInterrupt
if c == '\b':
+ if echo_char and pw:
+ msvcrt.putch('\b')
+ msvcrt.putch(' ')
+ msvcrt.putch('\b')
pw = pw[:-1]
else:
pw = pw + c
+ if echo_char:
+ msvcrt.putwch(echo_char)
msvcrt.putwch('\r')
msvcrt.putwch('\n')
return pw
@@ -126,7 +142,14 @@ def fallback_getpass(prompt='Password: ', stream=None):
return _raw_input(prompt, stream)
-def _raw_input(prompt="", stream=None, input=None):
+def _check_echo_char(echo_char):
+ # ASCII excluding control characters
+ if echo_char and not (echo_char.isprintable() and echo_char.isascii()):
+ raise ValueError("'echo_char' must be a printable ASCII string, "
+ f"got: {echo_char!r}")
+
+
+def _raw_input(prompt="", stream=None, input=None, echo_char=None):
# This doesn't save the string in the GNU readline history.
if not stream:
stream = sys.stderr
@@ -143,6 +166,8 @@ def _raw_input(prompt="", stream=None, input=None):
stream.write(prompt)
stream.flush()
# NOTE: The Python C API calls flockfile() (and unlock) during readline.
+ if echo_char:
+ return _readline_with_echo_char(stream, input, echo_char)
line = input.readline()
if not line:
raise EOFError
@@ -151,6 +176,35 @@ def _raw_input(prompt="", stream=None, input=None):
return line
+def _readline_with_echo_char(stream, input, echo_char):
+ passwd = ""
+ eof_pressed = False
+ while True:
+ char = input.read(1)
+ if char == '\n' or char == '\r':
+ break
+ elif char == '\x03':
+ raise KeyboardInterrupt
+ elif char == '\x7f' or char == '\b':
+ if passwd:
+ stream.write("\b \b")
+ stream.flush()
+ passwd = passwd[:-1]
+ elif char == '\x04':
+ if eof_pressed:
+ break
+ else:
+ eof_pressed = True
+ elif char == '\x00':
+ continue
+ else:
+ passwd += char
+ stream.write(echo_char)
+ stream.flush()
+ eof_pressed = False
+ return passwd
+
+
def getuser():
"""Get the username from the environment or password database.
diff --git a/Lib/gzip.py b/Lib/gzip.py
index b7375b25473..c00f51858de 100644
--- a/Lib/gzip.py
+++ b/Lib/gzip.py
@@ -667,7 +667,9 @@ def main():
from argparse import ArgumentParser
parser = ArgumentParser(description=
"A simple command line interface for the gzip module: act like gzip, "
- "but do not delete the input file.")
+ "but do not delete the input file.",
+ color=True,
+ )
group = parser.add_mutually_exclusive_group()
group.add_argument('--fast', action='store_true', help='compress faster')
group.add_argument('--best', action='store_true', help='compress better')
diff --git a/Lib/heapq.py b/Lib/heapq.py
index 9649da251f2..6ceb211f1ca 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -178,7 +178,7 @@ def heapify(x):
for i in reversed(range(n//2)):
_siftup(x, i)
-def _heappop_max(heap):
+def heappop_max(heap):
"""Maxheap version of a heappop."""
lastelt = heap.pop() # raises appropriate IndexError if heap is empty
if heap:
@@ -188,19 +188,32 @@ def _heappop_max(heap):
return returnitem
return lastelt
-def _heapreplace_max(heap, item):
+def heapreplace_max(heap, item):
"""Maxheap version of a heappop followed by a heappush."""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
heap[0] = item
_siftup_max(heap, 0)
return returnitem
-def _heapify_max(x):
+def heappush_max(heap, item):
+ """Maxheap version of a heappush."""
+ heap.append(item)
+ _siftdown_max(heap, 0, len(heap)-1)
+
+def heappushpop_max(heap, item):
+ """Maxheap fast version of a heappush followed by a heappop."""
+ if heap and item < heap[0]:
+ item, heap[0] = heap[0], item
+ _siftup_max(heap, 0)
+ return item
+
+def heapify_max(x):
"""Transform list into a maxheap, in-place, in O(len(x)) time."""
n = len(x)
for i in reversed(range(n//2)):
_siftup_max(x, i)
+
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
# is the index of a leaf with a possibly out-of-order value. Restore the
# heap invariant.
@@ -335,9 +348,9 @@ def merge(*iterables, key=None, reverse=False):
h_append = h.append
if reverse:
- _heapify = _heapify_max
- _heappop = _heappop_max
- _heapreplace = _heapreplace_max
+ _heapify = heapify_max
+ _heappop = heappop_max
+ _heapreplace = heapreplace_max
direction = -1
else:
_heapify = heapify
@@ -490,10 +503,10 @@ def nsmallest(n, iterable, key=None):
result = [(elem, i) for i, elem in zip(range(n), it)]
if not result:
return result
- _heapify_max(result)
+ heapify_max(result)
top = result[0][0]
order = n
- _heapreplace = _heapreplace_max
+ _heapreplace = heapreplace_max
for elem in it:
if elem < top:
_heapreplace(result, (elem, order))
@@ -507,10 +520,10 @@ def nsmallest(n, iterable, key=None):
result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
if not result:
return result
- _heapify_max(result)
+ heapify_max(result)
top = result[0][0]
order = n
- _heapreplace = _heapreplace_max
+ _heapreplace = heapreplace_max
for elem in it:
k = key(elem)
if k < top:
@@ -583,19 +596,13 @@ try:
from _heapq import *
except ImportError:
pass
-try:
- from _heapq import _heapreplace_max
-except ImportError:
- pass
-try:
- from _heapq import _heapify_max
-except ImportError:
- pass
-try:
- from _heapq import _heappop_max
-except ImportError:
- pass
+# For backwards compatibility
+_heappop_max = heappop_max
+_heapreplace_max = heapreplace_max
+_heappush_max = heappush_max
+_heappushpop_max = heappushpop_max
+_heapify_max = heapify_max
if __name__ == "__main__":
diff --git a/Lib/http/server.py b/Lib/http/server.py
index a2aad4c9be3..64f766f9bc2 100644
--- a/Lib/http/server.py
+++ b/Lib/http/server.py
@@ -1340,7 +1340,7 @@ if __name__ == '__main__':
import argparse
import contextlib
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('--cgi', action='store_true',
help='run as CGI server')
parser.add_argument('-b', '--bind', metavar='ADDRESS',
diff --git a/Lib/inspect.py b/Lib/inspect.py
index 9592559ba6d..52c9bb05b31 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -3343,7 +3343,7 @@ def _main():
import argparse
import importlib
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(
'object',
help="The object to be analysed. "
diff --git a/Lib/json/tool.py b/Lib/json/tool.py
index 585583da860..1967817add8 100644
--- a/Lib/json/tool.py
+++ b/Lib/json/tool.py
@@ -7,7 +7,7 @@ import argparse
import json
import re
import sys
-from _colorize import ANSIColors, can_colorize
+from _colorize import get_theme, can_colorize
# The string we are colorizing is valid JSON,
@@ -17,34 +17,34 @@ from _colorize import ANSIColors, can_colorize
_color_pattern = re.compile(r'''
(?P<key>"(\\.|[^"\\])*")(?=:) |
(?P<string>"(\\.|[^"\\])*") |
+ (?P<number>NaN|-?Infinity|[0-9\-+.Ee]+) |
(?P<boolean>true|false) |
(?P<null>null)
''', re.VERBOSE)
-
-_colors = {
- 'key': ANSIColors.INTENSE_BLUE,
- 'string': ANSIColors.BOLD_GREEN,
- 'boolean': ANSIColors.BOLD_CYAN,
- 'null': ANSIColors.BOLD_CYAN,
+_group_to_theme_color = {
+ "key": "definition",
+ "string": "string",
+ "number": "number",
+ "boolean": "keyword",
+ "null": "keyword",
}
-def _replace_match_callback(match):
- for key, color in _colors.items():
- if m := match.group(key):
- return f"{color}{m}{ANSIColors.RESET}"
- return match.group()
-
+def _colorize_json(json_str, theme):
+ def _replace_match_callback(match):
+ for group, color in _group_to_theme_color.items():
+ if m := match.group(group):
+ return f"{theme[color]}{m}{theme.reset}"
+ return match.group()
-def _colorize_json(json_str):
return re.sub(_color_pattern, _replace_match_callback, json_str)
def main():
description = ('A simple command line interface for json module '
'to validate and pretty-print JSON objects.')
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument('infile', nargs='?',
help='a JSON file to be validated or pretty-printed',
default='-')
@@ -100,13 +100,16 @@ def main():
else:
outfile = open(options.outfile, 'w', encoding='utf-8')
with outfile:
- for obj in objs:
- if can_colorize(file=outfile):
+ if can_colorize(file=outfile):
+ t = get_theme(tty_file=outfile).syntax
+ for obj in objs:
json_str = json.dumps(obj, **dump_args)
- outfile.write(_colorize_json(json_str))
- else:
+ outfile.write(_colorize_json(json_str, t))
+ outfile.write('\n')
+ else:
+ for obj in objs:
json.dump(obj, outfile, **dump_args)
- outfile.write('\n')
+ outfile.write('\n')
except ValueError as e:
raise SystemExit(e)
diff --git a/Lib/mimetypes.py b/Lib/mimetypes.py
index b5a1b8da263..33e86d51a0f 100644
--- a/Lib/mimetypes.py
+++ b/Lib/mimetypes.py
@@ -698,7 +698,9 @@ _default_mime_types()
def _parse_args(args):
from argparse import ArgumentParser
- parser = ArgumentParser(description='map filename extensions to MIME types')
+ parser = ArgumentParser(
+ description='map filename extensions to MIME types', color=True
+ )
parser.add_argument(
'-e', '--extension',
action='store_true',
diff --git a/Lib/pdb.py b/Lib/pdb.py
index 2aa60c75396..f89d104fcdd 100644
--- a/Lib/pdb.py
+++ b/Lib/pdb.py
@@ -77,6 +77,7 @@ import glob
import json
import token
import types
+import atexit
import codeop
import pprint
import signal
@@ -92,10 +93,12 @@ import tokenize
import itertools
import traceback
import linecache
+import selectors
+import threading
import _colorize
+import _pyrepl.utils
-from contextlib import closing
-from contextlib import contextmanager
+from contextlib import ExitStack, closing, contextmanager
from rlcompleter import Completer
from types import CodeType
from warnings import deprecated
@@ -339,7 +342,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
_last_pdb_instance = None
def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
- nosigint=False, readrc=True, mode=None, backend=None):
+ nosigint=False, readrc=True, mode=None, backend=None, colorize=False):
bdb.Bdb.__init__(self, skip=skip, backend=backend if backend else get_default_backend())
cmd.Cmd.__init__(self, completekey, stdin, stdout)
sys.audit("pdb.Pdb")
@@ -352,6 +355,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
self._wait_for_mainpyfile = False
self.tb_lineno = {}
self.mode = mode
+ self.colorize = colorize and _colorize.can_colorize(file=stdout or sys.stdout)
# Try to load readline if it exists
try:
import readline
@@ -743,12 +747,34 @@ class Pdb(bdb.Bdb, cmd.Cmd):
self.message(repr(obj))
@contextmanager
- def _enable_multiline_completion(self):
+ def _enable_multiline_input(self):
+ try:
+ import readline
+ except ImportError:
+ yield
+ return
+
+ def input_auto_indent():
+ last_index = readline.get_current_history_length()
+ last_line = readline.get_history_item(last_index)
+ if last_line:
+ if last_line.isspace():
+ # If the last line is empty, we don't need to indent
+ return
+
+ last_line = last_line.rstrip('\r\n')
+ indent = len(last_line) - len(last_line.lstrip())
+ if last_line.endswith(":"):
+ indent += 4
+ readline.insert_text(' ' * indent)
+
completenames = self.completenames
try:
self.completenames = self.complete_multiline_names
+ readline.set_startup_hook(input_auto_indent)
yield
finally:
+ readline.set_startup_hook()
self.completenames = completenames
return
@@ -857,7 +883,7 @@ class Pdb(bdb.Bdb, cmd.Cmd):
try:
if (code := codeop.compile_command(line + '\n', '<stdin>', 'single')) is None:
# Multi-line mode
- with self._enable_multiline_completion():
+ with self._enable_multiline_input():
buffer = line
continue_prompt = "... "
while (code := codeop.compile_command(buffer, '<stdin>', 'single')) is None:
@@ -879,7 +905,11 @@ class Pdb(bdb.Bdb, cmd.Cmd):
return None, None, False
else:
line = line.rstrip('\r\n')
- buffer += '\n' + line
+ if line.isspace():
+ # empty line, just continue
+ buffer += '\n'
+ else:
+ buffer += '\n' + line
self.lastcmd = buffer
except SyntaxError as e:
# Maybe it's an await expression/statement
@@ -1036,6 +1066,13 @@ class Pdb(bdb.Bdb, cmd.Cmd):
return True
return False
+ def _colorize_code(self, code):
+ if self.colorize:
+ colors = list(_pyrepl.utils.gen_colors(code))
+ chars, _ = _pyrepl.utils.disp_str(code, colors=colors, force_color=True)
+ code = "".join(chars)
+ return code
+
# interface abstraction functions
def message(self, msg, end='\n'):
@@ -2166,6 +2203,8 @@ class Pdb(bdb.Bdb, cmd.Cmd):
s += '->'
elif lineno == exc_lineno:
s += '>>'
+ if self.colorize:
+ line = self._colorize_code(line)
self.message(s + '\t' + line.rstrip())
def do_whatis(self, arg):
@@ -2365,8 +2404,14 @@ class Pdb(bdb.Bdb, cmd.Cmd):
prefix = '> '
else:
prefix = ' '
- self.message(prefix +
- self.format_stack_entry(frame_lineno, prompt_prefix))
+ stack_entry = self.format_stack_entry(frame_lineno, prompt_prefix)
+ if self.colorize:
+ lines = stack_entry.split(prompt_prefix, 1)
+ if len(lines) > 1:
+ # We have some code to display
+ lines[1] = self._colorize_code(lines[1])
+ stack_entry = prompt_prefix.join(lines)
+ self.message(prefix + stack_entry)
# Provide help
@@ -2604,7 +2649,7 @@ def set_trace(*, header=None, commands=None):
if Pdb._last_pdb_instance is not None:
pdb = Pdb._last_pdb_instance
else:
- pdb = Pdb(mode='inline', backend='monitoring')
+ pdb = Pdb(mode='inline', backend='monitoring', colorize=True)
if header is not None:
pdb.message(header)
pdb.set_trace(sys._getframe().f_back, commands=commands)
@@ -2619,7 +2664,7 @@ async def set_trace_async(*, header=None, commands=None):
if Pdb._last_pdb_instance is not None:
pdb = Pdb._last_pdb_instance
else:
- pdb = Pdb(mode='inline', backend='monitoring')
+ pdb = Pdb(mode='inline', backend='monitoring', colorize=True)
if header is not None:
pdb.message(header)
await pdb.set_trace_async(sys._getframe().f_back, commands=commands)
@@ -2627,13 +2672,26 @@ async def set_trace_async(*, header=None, commands=None):
# Remote PDB
class _PdbServer(Pdb):
- def __init__(self, sockfile, owns_sockfile=True, **kwargs):
+ def __init__(
+ self,
+ sockfile,
+ signal_server=None,
+ owns_sockfile=True,
+ colorize=False,
+ **kwargs,
+ ):
self._owns_sockfile = owns_sockfile
self._interact_state = None
self._sockfile = sockfile
self._command_name_cache = []
self._write_failed = False
- super().__init__(**kwargs)
+ if signal_server:
+ # Only started by the top level _PdbServer, not recursive ones.
+ self._start_signal_listener(signal_server)
+ # Override the `colorize` attribute set by the parent constructor,
+ # because it checks the server's stdout, rather than the client's.
+ super().__init__(colorize=False, **kwargs)
+ self.colorize = colorize
@staticmethod
def protocol_version():
@@ -2688,15 +2746,49 @@ class _PdbServer(Pdb):
f"PDB message doesn't follow the schema! {msg}"
)
+ @classmethod
+ def _start_signal_listener(cls, address):
+ def listener(sock):
+ with closing(sock):
+ # Check if the interpreter is finalizing every quarter of a second.
+ # Clean up and exit if so.
+ sock.settimeout(0.25)
+ sock.shutdown(socket.SHUT_WR)
+ while not shut_down.is_set():
+ try:
+ data = sock.recv(1024)
+ except socket.timeout:
+ continue
+ if data == b"":
+ return # EOF
+ signal.raise_signal(signal.SIGINT)
+
+ def stop_thread():
+ shut_down.set()
+ thread.join()
+
+ # Use a daemon thread so that we don't detach until after all non-daemon
+ # threads are done. Use an atexit handler to stop gracefully at that point,
+ # so that our thread is stopped before the interpreter is torn down.
+ shut_down = threading.Event()
+ thread = threading.Thread(
+ target=listener,
+ args=[socket.create_connection(address, timeout=5)],
+ daemon=True,
+ )
+ atexit.register(stop_thread)
+ thread.start()
+
def _send(self, **kwargs):
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
self._sockfile.write(json_payload.encode() + b"\n")
self._sockfile.flush()
- except OSError:
- # This means that the client has abruptly disconnected, but we'll
- # handle that the next time we try to read from the client instead
+ except (OSError, ValueError):
+ # We get an OSError if the network connection has dropped, and a
+ # ValueError if detach() if the sockfile has been closed. We'll
+ # handle this the next time we try to read from the client instead
# of trying to handle it from everywhere _send() may be called.
# Track this with a flag rather than assuming readline() will ever
# return an empty string because the socket may be half-closed.
@@ -2887,7 +2979,11 @@ class _PdbServer(Pdb):
@typing.override
def _create_recursive_debugger(self):
- return _PdbServer(self._sockfile, owns_sockfile=False)
+ return _PdbServer(
+ self._sockfile,
+ owns_sockfile=False,
+ colorize=self.colorize,
+ )
@typing.override
def _prompt_for_confirmation(self, prompt, default):
@@ -2924,10 +3020,15 @@ class _PdbServer(Pdb):
class _PdbClient:
- def __init__(self, pid, sockfile, interrupt_script):
+ def __init__(self, pid, server_socket, interrupt_sock):
self.pid = pid
- self.sockfile = sockfile
- self.interrupt_script = interrupt_script
+ self.read_buf = b""
+ self.signal_read = None
+ self.signal_write = None
+ self.sigint_received = False
+ self.raise_on_sigint = False
+ self.server_socket = server_socket
+ self.interrupt_sock = interrupt_sock
self.pdb_instance = Pdb()
self.pdb_commands = set()
self.completion_matches = []
@@ -2969,8 +3070,7 @@ class _PdbClient:
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
- self.sockfile.write(json_payload.encode() + b"\n")
- self.sockfile.flush()
+ self.server_socket.sendall(json_payload.encode() + b"\n")
except OSError:
# This means that the client has abruptly disconnected, but we'll
# handle that the next time we try to read from the client instead
@@ -2979,10 +3079,44 @@ class _PdbClient:
# return an empty string because the socket may be half-closed.
self.write_failed = True
- def read_command(self, prompt):
- self.multiline_block = False
- reply = input(prompt)
+ def _readline(self):
+ if self.sigint_received:
+ # There's a pending unhandled SIGINT. Handle it now.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+
+ # Wait for either a SIGINT or a line or EOF from the PDB server.
+ selector = selectors.DefaultSelector()
+ selector.register(self.signal_read, selectors.EVENT_READ)
+ selector.register(self.server_socket, selectors.EVENT_READ)
+
+ while b"\n" not in self.read_buf:
+ for key, _ in selector.select():
+ if key.fileobj == self.signal_read:
+ self.signal_read.recv(1024)
+ if self.sigint_received:
+ # If not, we're reading wakeup events for sigints that
+ # we've previously handled, and can ignore them.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+ elif key.fileobj == self.server_socket:
+ data = self.server_socket.recv(16 * 1024)
+ self.read_buf += data
+ if not data and b"\n" not in self.read_buf:
+ # EOF without a full final line. Drop the partial line.
+ self.read_buf = b""
+ return b""
+
+ ret, sep, self.read_buf = self.read_buf.partition(b"\n")
+ return ret + sep
+
+ def read_input(self, prompt, multiline_block):
+ self.multiline_block = multiline_block
+ with self._sigint_raises_keyboard_interrupt():
+ return input(prompt)
+ def read_command(self, prompt):
+ reply = self.read_input(prompt, multiline_block=False)
if self.state == "dumb":
# No logic applied whatsoever, just pass the raw reply back.
return reply
@@ -3005,10 +3139,9 @@ class _PdbClient:
return prefix + reply
# Otherwise, valid first line of a multi-line statement
- self.multiline_block = True
- continue_prompt = "...".ljust(len(prompt))
+ more_prompt = "...".ljust(len(prompt))
while codeop.compile_command(reply, "<stdin>", "single") is None:
- reply += "\n" + input(continue_prompt)
+ reply += "\n" + self.read_input(more_prompt, multiline_block=True)
return prefix + reply
@@ -3033,11 +3166,70 @@ class _PdbClient:
finally:
readline.set_completer(old_completer)
+ @contextmanager
+ def _sigint_handler(self):
+ # Signal handling strategy:
+ # - When we call input() we want a SIGINT to raise KeyboardInterrupt
+ # - Otherwise we want to write to the wakeup FD and set a flag.
+ # We'll break out of select() when the wakeup FD is written to,
+ # and we'll check the flag whenever we're about to accept input.
+ def handler(signum, frame):
+ self.sigint_received = True
+ if self.raise_on_sigint:
+ # One-shot; don't raise again until the flag is set again.
+ self.raise_on_sigint = False
+ self.sigint_received = False
+ raise KeyboardInterrupt
+
+ sentinel = object()
+ old_handler = sentinel
+ old_wakeup_fd = sentinel
+
+ self.signal_read, self.signal_write = socket.socketpair()
+ with (closing(self.signal_read), closing(self.signal_write)):
+ self.signal_read.setblocking(False)
+ self.signal_write.setblocking(False)
+
+ try:
+ old_handler = signal.signal(signal.SIGINT, handler)
+
+ try:
+ old_wakeup_fd = signal.set_wakeup_fd(
+ self.signal_write.fileno(),
+ warn_on_full_buffer=False,
+ )
+ yield
+ finally:
+ # Restore the old wakeup fd if we installed a new one
+ if old_wakeup_fd is not sentinel:
+ signal.set_wakeup_fd(old_wakeup_fd)
+ finally:
+ self.signal_read = self.signal_write = None
+ if old_handler is not sentinel:
+ # Restore the old handler if we installed a new one
+ signal.signal(signal.SIGINT, old_handler)
+
+ @contextmanager
+ def _sigint_raises_keyboard_interrupt(self):
+ if self.sigint_received:
+ # There's a pending unhandled SIGINT. Handle it now.
+ self.sigint_received = False
+ raise KeyboardInterrupt
+
+ try:
+ self.raise_on_sigint = True
+ yield
+ finally:
+ self.raise_on_sigint = False
+
def cmdloop(self):
- with self.readline_completion(self.complete):
+ with (
+ self._sigint_handler(),
+ self.readline_completion(self.complete),
+ ):
while not self.write_failed:
try:
- if not (payload_bytes := self.sockfile.readline()):
+ if not (payload_bytes := self._readline()):
break
except KeyboardInterrupt:
self.send_interrupt()
@@ -3055,11 +3247,17 @@ class _PdbClient:
self.process_payload(payload)
def send_interrupt(self):
- print(
- "\n*** Program will stop at the next bytecode instruction."
- " (Use 'cont' to resume)."
- )
- sys.remote_exec(self.pid, self.interrupt_script)
+ if self.interrupt_sock is not None:
+ # Write to a socket that the PDB server listens on. This triggers
+ # the remote to raise a SIGINT for itself. We do this because
+ # Windows doesn't allow triggering SIGINT remotely.
+ # See https://stackoverflow.com/a/35792192 for many more details.
+ self.interrupt_sock.sendall(signal.SIGINT.to_bytes())
+ else:
+ # On Unix we can just send a SIGINT to the remote process.
+ # This is preferable to using the signal thread approach that we
+ # use on Windows because it can interrupt IO in the main thread.
+ os.kill(self.pid, signal.SIGINT)
def process_payload(self, payload):
match payload:
@@ -3129,7 +3327,7 @@ class _PdbClient:
if self.write_failed:
return None
- payload = self.sockfile.readline()
+ payload = self._readline()
if not payload:
return None
@@ -3146,11 +3344,31 @@ class _PdbClient:
return None
-def _connect(host, port, frame, commands, version):
+def _connect(
+ *,
+ host,
+ port,
+ frame,
+ commands,
+ version,
+ signal_raising_thread,
+ colorize,
+):
with closing(socket.create_connection((host, port))) as conn:
sockfile = conn.makefile("rwb")
- remote_pdb = _PdbServer(sockfile)
+ # The client requests this thread on Windows but not on Unix.
+ # Most tests don't request this thread, to keep them simpler.
+ if signal_raising_thread:
+ signal_server = (host, port)
+ else:
+ signal_server = None
+
+ remote_pdb = _PdbServer(
+ sockfile,
+ signal_server=signal_server,
+ colorize=colorize,
+ )
weakref.finalize(remote_pdb, sockfile.close)
if Pdb._last_pdb_instance is not None:
@@ -3171,43 +3389,50 @@ def _connect(host, port, frame, commands, version):
def attach(pid, commands=()):
"""Attach to a running process with the given PID."""
- with closing(socket.create_server(("localhost", 0))) as server:
+ with ExitStack() as stack:
+ server = stack.enter_context(
+ closing(socket.create_server(("localhost", 0)))
+ )
port = server.getsockname()[1]
- with tempfile.NamedTemporaryFile("w", delete_on_close=False) as connect_script:
- connect_script.write(
- textwrap.dedent(
- f"""
- import pdb, sys
- pdb._connect(
- host="localhost",
- port={port},
- frame=sys._getframe(1),
- commands={json.dumps("\n".join(commands))},
- version={_PdbServer.protocol_version()},
- )
- """
+ connect_script = stack.enter_context(
+ tempfile.NamedTemporaryFile("w", delete_on_close=False)
+ )
+
+ use_signal_thread = sys.platform == "win32"
+ colorize = _colorize.can_colorize()
+
+ connect_script.write(
+ textwrap.dedent(
+ f"""
+ import pdb, sys
+ pdb._connect(
+ host="localhost",
+ port={port},
+ frame=sys._getframe(1),
+ commands={json.dumps("\n".join(commands))},
+ version={_PdbServer.protocol_version()},
+ signal_raising_thread={use_signal_thread!r},
+ colorize={colorize!r},
)
+ """
)
- connect_script.close()
- sys.remote_exec(pid, connect_script.name)
-
- # TODO Add a timeout? Or don't bother since the user can ^C?
- client_sock, _ = server.accept()
+ )
+ connect_script.close()
+ sys.remote_exec(pid, connect_script.name)
- with closing(client_sock):
- sockfile = client_sock.makefile("rwb")
+ # TODO Add a timeout? Or don't bother since the user can ^C?
+ client_sock, _ = server.accept()
+ stack.enter_context(closing(client_sock))
- with closing(sockfile):
- with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
- interrupt_script.write(
- 'import pdb, sys\n'
- 'if inst := pdb.Pdb._last_pdb_instance:\n'
- ' inst.set_trace(sys._getframe(1))\n'
- )
- interrupt_script.close()
+ if use_signal_thread:
+ interrupt_sock, _ = server.accept()
+ stack.enter_context(closing(interrupt_sock))
+ interrupt_sock.setblocking(False)
+ else:
+ interrupt_sock = None
- _PdbClient(pid, sockfile, interrupt_script.name).cmdloop()
+ _PdbClient(pid, client_sock, interrupt_sock).cmdloop()
# Post-Mortem interface
@@ -3279,10 +3504,13 @@ To let the script run up to a given line X in the debugged file, use
def main():
import argparse
- parser = argparse.ArgumentParser(usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]",
- description=_usage,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- allow_abbrev=False)
+ parser = argparse.ArgumentParser(
+ usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]",
+ description=_usage,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ allow_abbrev=False,
+ color=True,
+ )
# We need to maunally get the script from args, because the first positional
# arguments could be either the script we need to debug, or the argument
@@ -3345,7 +3573,7 @@ def main():
# modified by the script being debugged. It's a bad idea when it was
# changed by the user from the command line. There is a "restart" command
# which allows explicit specification of command line arguments.
- pdb = Pdb(mode='cli', backend='monitoring')
+ pdb = Pdb(mode='cli', backend='monitoring', colorize=True)
pdb.rcLines.extend(opts.commands)
while True:
try:
diff --git a/Lib/pickle.py b/Lib/pickle.py
index 4fa3632d1a7..beaefae0479 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -1911,7 +1911,9 @@ def _main(args=None):
import argparse
import pprint
parser = argparse.ArgumentParser(
- description='display contents of the pickle files')
+ description='display contents of the pickle files',
+ color=True,
+ )
parser.add_argument(
'pickle_file',
nargs='+', help='the pickle file')
diff --git a/Lib/pickletools.py b/Lib/pickletools.py
index 53f25ea4e46..bcddfb722bd 100644
--- a/Lib/pickletools.py
+++ b/Lib/pickletools.py
@@ -2842,7 +2842,9 @@ __test__ = {'disassembler_test': _dis_test,
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
- description='disassemble one or more pickle files')
+ description='disassemble one or more pickle files',
+ color=True,
+ )
parser.add_argument(
'pickle_file',
nargs='+', help='the pickle file')
diff --git a/Lib/platform.py b/Lib/platform.py
index 507552f360b..55e211212d4 100644
--- a/Lib/platform.py
+++ b/Lib/platform.py
@@ -1467,7 +1467,7 @@ def invalidate_caches():
def _parse_args(args: list[str] | None):
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument("args", nargs="*", choices=["nonaliased", "terse"])
parser.add_argument(
"--terse",
diff --git a/Lib/py_compile.py b/Lib/py_compile.py
index 388614e51b1..43d8ec90ffb 100644
--- a/Lib/py_compile.py
+++ b/Lib/py_compile.py
@@ -177,7 +177,7 @@ def main():
import argparse
description = 'A simple command-line interface for py_compile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument(
'-q', '--quiet',
action='store_true',
diff --git a/Lib/random.py b/Lib/random.py
index 5e5d0c4c694..86d562f0b8a 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -1011,7 +1011,7 @@ if hasattr(_os, "fork"):
def _parse_args(arg_list: list[str] | None):
import argparse
parser = argparse.ArgumentParser(
- formatter_class=argparse.RawTextHelpFormatter)
+ formatter_class=argparse.RawTextHelpFormatter, color=True)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-c", "--choice", nargs="+",
diff --git a/Lib/reprlib.py b/Lib/reprlib.py
index 19dbe3a07eb..441d1be4bde 100644
--- a/Lib/reprlib.py
+++ b/Lib/reprlib.py
@@ -28,7 +28,7 @@ def recursive_repr(fillvalue='...'):
wrapper.__doc__ = getattr(user_function, '__doc__')
wrapper.__name__ = getattr(user_function, '__name__')
wrapper.__qualname__ = getattr(user_function, '__qualname__')
- wrapper.__annotations__ = getattr(user_function, '__annotations__', {})
+ wrapper.__annotate__ = getattr(user_function, '__annotate__', None)
wrapper.__type_params__ = getattr(user_function, '__type_params__', ())
wrapper.__wrapped__ = user_function
return wrapper
diff --git a/Lib/shutil.py b/Lib/shutil.py
index 510ae8c6f22..ca0a2ea2f7f 100644
--- a/Lib/shutil.py
+++ b/Lib/shutil.py
@@ -32,6 +32,13 @@ try:
except ImportError:
_LZMA_SUPPORTED = False
+try:
+ from compression import zstd
+ del zstd
+ _ZSTD_SUPPORTED = True
+except ImportError:
+ _ZSTD_SUPPORTED = False
+
_WINDOWS = os.name == 'nt'
posix = nt = None
if os.name == 'posix':
@@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
tar_compression = 'bz2'
elif _LZMA_SUPPORTED and compress == 'xz':
tar_compression = 'xz'
+ elif _ZSTD_SUPPORTED and compress == 'zst':
+ tar_compression = 'zst'
else:
raise ValueError("bad value for 'compress', or compression format not "
"supported : {0}".format(compress))
@@ -1134,6 +1143,10 @@ if _LZMA_SUPPORTED:
_ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')],
"xz'ed tar-file")
+if _ZSTD_SUPPORTED:
+ _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')],
+ "zstd'ed tar-file")
+
def get_archive_formats():
"""Returns a list of supported formats for archiving and unarchiving.
@@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0,
'base_name' is the name of the file to create, minus any format-specific
extension; 'format' is the archive format: one of "zip", "tar", "gztar",
- "bztar", or "xztar". Or any other registered format.
+ "bztar", "zstdtar", or "xztar". Or any other registered format.
'root_dir' is a directory that will be the root directory of the
archive; ie. we typically chdir into 'root_dir' before creating the
@@ -1359,6 +1372,10 @@ if _LZMA_SUPPORTED:
_UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [],
"xz'ed tar-file")
+if _ZSTD_SUPPORTED:
+ _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [],
+ "zstd'ed tar-file")
+
def _find_unpack_format(filename):
for name, info in _UNPACK_FORMATS.items():
for extension in info[0]:
diff --git a/Lib/sqlite3/__main__.py b/Lib/sqlite3/__main__.py
index 79a6209468d..002f1986cdd 100644
--- a/Lib/sqlite3/__main__.py
+++ b/Lib/sqlite3/__main__.py
@@ -65,6 +65,7 @@ class SqliteInteractiveConsole(InteractiveConsole):
def main(*args):
parser = ArgumentParser(
description="Python sqlite3 CLI",
+ color=True,
)
parser.add_argument(
"filename", type=str, default=":memory:", nargs="?",
diff --git a/Lib/subprocess.py b/Lib/subprocess.py
index da5f5729e09..54c2eb515b6 100644
--- a/Lib/subprocess.py
+++ b/Lib/subprocess.py
@@ -1235,8 +1235,11 @@ class Popen:
finally:
self._communication_started = True
-
- sts = self.wait(timeout=self._remaining_time(endtime))
+ try:
+ sts = self.wait(timeout=self._remaining_time(endtime))
+ except TimeoutExpired as exc:
+ exc.timeout = timeout
+ raise
return (stdout, stderr)
@@ -2145,8 +2148,11 @@ class Popen:
selector.unregister(key.fileobj)
key.fileobj.close()
self._fileobj2output[key.fileobj].append(data)
-
- self.wait(timeout=self._remaining_time(endtime))
+ try:
+ self.wait(timeout=self._remaining_time(endtime))
+ except TimeoutExpired as exc:
+ exc.timeout = orig_timeout
+ raise
# All data exchanged. Translate lists into strings.
if stdout is not None:
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
index 82c5f6704cb..c0f5a609b9f 100644
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -399,7 +399,17 @@ class _Stream:
self.exception = lzma.LZMAError
else:
self.cmp = lzma.LZMACompressor(preset=preset)
-
+ elif comptype == "zst":
+ try:
+ from compression import zstd
+ except ImportError:
+ raise CompressionError("compression.zstd module is not available") from None
+ if mode == "r":
+ self.dbuf = b""
+ self.cmp = zstd.ZstdDecompressor()
+ self.exception = zstd.ZstdError
+ else:
+ self.cmp = zstd.ZstdCompressor()
elif comptype != "tar":
raise CompressionError("unknown compression type %r" % comptype)
@@ -591,6 +601,8 @@ class _StreamProxy(object):
return "bz2"
elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")):
return "xz"
+ elif self.buf.startswith(b"\x28\xb5\x2f\xfd"):
+ return "zst"
else:
return "tar"
@@ -1817,11 +1829,13 @@ class TarFile(object):
'r:gz' open for reading with gzip compression
'r:bz2' open for reading with bzip2 compression
'r:xz' open for reading with lzma compression
+ 'r:zst' open for reading with zstd compression
'a' or 'a:' open for appending, creating the file if necessary
'w' or 'w:' open for writing without compression
'w:gz' open for writing with gzip compression
'w:bz2' open for writing with bzip2 compression
'w:xz' open for writing with lzma compression
+ 'w:zst' open for writing with zstd compression
'x' or 'x:' create a tarfile exclusively without compression, raise
an exception if the file is already created
@@ -1831,16 +1845,20 @@ class TarFile(object):
if the file is already created
'x:xz' create an lzma compressed tarfile, raise an exception
if the file is already created
+ 'x:zst' create a zstd compressed tarfile, raise an exception
+ if the file is already created
'r|*' open a stream of tar blocks with transparent compression
'r|' open an uncompressed stream of tar blocks for reading
'r|gz' open a gzip compressed stream of tar blocks
'r|bz2' open a bzip2 compressed stream of tar blocks
'r|xz' open an lzma compressed stream of tar blocks
+ 'r|zst' open a zstd compressed stream of tar blocks
'w|' open an uncompressed stream for writing
'w|gz' open a gzip compressed stream for writing
'w|bz2' open a bzip2 compressed stream for writing
'w|xz' open an lzma compressed stream for writing
+ 'w|zst' open a zstd compressed stream for writing
"""
if not name and not fileobj:
@@ -2006,12 +2024,48 @@ class TarFile(object):
t._extfileobj = False
return t
+ @classmethod
+ def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None,
+ zstd_dict=None, **kwargs):
+ """Open zstd compressed tar archive name for reading or writing.
+ Appending is not allowed.
+ """
+ if mode not in ("r", "w", "x"):
+ raise ValueError("mode must be 'r', 'w' or 'x'")
+
+ try:
+ from compression.zstd import ZstdFile, ZstdError
+ except ImportError:
+ raise CompressionError("compression.zstd module is not available") from None
+
+ fileobj = ZstdFile(
+ fileobj or name,
+ mode,
+ level=level,
+ options=options,
+ zstd_dict=zstd_dict
+ )
+
+ try:
+ t = cls.taropen(name, mode, fileobj, **kwargs)
+ except (ZstdError, EOFError) as e:
+ fileobj.close()
+ if mode == 'r':
+ raise ReadError("not a zstd file") from e
+ raise
+ except Exception:
+ fileobj.close()
+ raise
+ t._extfileobj = False
+ return t
+
# All *open() methods are registered here.
OPEN_METH = {
"tar": "taropen", # uncompressed tar
"gz": "gzopen", # gzip compressed tar
"bz2": "bz2open", # bzip2 compressed tar
- "xz": "xzopen" # lzma compressed tar
+ "xz": "xzopen", # lzma compressed tar
+ "zst": "zstopen" # zstd compressed tar
}
#--------------------------------------------------------------------------
@@ -2883,7 +2937,7 @@ def main():
import argparse
description = 'A simple command-line interface for tarfile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
help='Verbose output')
parser.add_argument('--filter', metavar='<filtername>',
@@ -2963,6 +3017,9 @@ def main():
'.tbz': 'bz2',
'.tbz2': 'bz2',
'.tb2': 'bz2',
+ # zstd
+ '.zst': 'zst',
+ '.tzst': 'zst',
}
tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w'
tar_files = args.create
diff --git a/Lib/test/_code_definitions.py b/Lib/test/_code_definitions.py
index 06cf6a10231..c3daa0dccf5 100644
--- a/Lib/test/_code_definitions.py
+++ b/Lib/test/_code_definitions.py
@@ -12,6 +12,50 @@ def spam_minimal():
return
+def spam_with_builtins():
+ x = 42
+ values = (42,)
+ checks = tuple(callable(v) for v in values)
+ res = callable(values), tuple(values), list(values), checks
+ print(res)
+
+
+def spam_with_globals_and_builtins():
+ func1 = spam
+ func2 = spam_minimal
+ funcs = (func1, func2)
+ checks = tuple(callable(f) for f in funcs)
+ res = callable(funcs), tuple(funcs), list(funcs), checks
+ print(res)
+
+
+def spam_args_attrs_and_builtins(a, b, /, c, d, *args, e, f, **kwargs):
+ if args.__len__() > 2:
+ return None
+ return a, b, c, d, e, f, args, kwargs
+
+
+def spam_returns_arg(x):
+ return x
+
+
+def spam_with_inner_not_closure():
+ def eggs():
+ pass
+ eggs()
+
+
+def spam_with_inner_closure():
+ x = 42
+ def eggs():
+ print(x)
+ eggs()
+
+
+def spam_annotated(a: int, b: str, c: object) -> tuple:
+ return a, b, c
+
+
def spam_full(a, b, /, c, d:int=1, *args, e, f:object=None, **kwargs) -> tuple:
# arg defaults, kwarg defaults
# annotations
@@ -98,6 +142,13 @@ ham_C_closure, *_ = eggs_closure_C(2)
TOP_FUNCTIONS = [
# shallow
spam_minimal,
+ spam_with_builtins,
+ spam_with_globals_and_builtins,
+ spam_args_attrs_and_builtins,
+ spam_returns_arg,
+ spam_with_inner_not_closure,
+ spam_with_inner_closure,
+ spam_annotated,
spam_full,
spam,
# outer func
diff --git a/Lib/test/libregrtest/setup.py b/Lib/test/libregrtest/setup.py
index c0346aa934d..c3d1f60a400 100644
--- a/Lib/test/libregrtest/setup.py
+++ b/Lib/test/libregrtest/setup.py
@@ -40,7 +40,7 @@ def setup_process() -> None:
faulthandler.enable(all_threads=True, file=stderr_fd)
# Display the Python traceback on SIGALRM or SIGUSR1 signal
- signals = []
+ signals: list[signal.Signals] = []
if hasattr(signal, 'SIGALRM'):
signals.append(signal.SIGALRM)
if hasattr(signal, 'SIGUSR1'):
diff --git a/Lib/test/libregrtest/utils.py b/Lib/test/libregrtest/utils.py
index c4a1506c9a7..63a2e427d18 100644
--- a/Lib/test/libregrtest/utils.py
+++ b/Lib/test/libregrtest/utils.py
@@ -335,43 +335,11 @@ def get_build_info():
build.append('with_assert')
# --enable-experimental-jit
- tier2 = re.search('-D_Py_TIER2=([0-9]+)', cflags)
- if tier2:
- tier2 = int(tier2.group(1))
-
- if not sys.flags.ignore_environment:
- PYTHON_JIT = os.environ.get('PYTHON_JIT', None)
- if PYTHON_JIT:
- PYTHON_JIT = (PYTHON_JIT != '0')
- else:
- PYTHON_JIT = None
-
- if tier2 == 1: # =yes
- if PYTHON_JIT == False:
- jit = 'JIT=off'
- else:
- jit = 'JIT'
- elif tier2 == 3: # =yes-off
- if PYTHON_JIT:
- jit = 'JIT'
+ if sys._jit.is_available():
+ if sys._jit.is_enabled():
+ build.append("JIT")
else:
- jit = 'JIT=off'
- elif tier2 == 4: # =interpreter
- if PYTHON_JIT == False:
- jit = 'JIT-interpreter=off'
- else:
- jit = 'JIT-interpreter'
- elif tier2 == 6: # =interpreter-off (Secret option!)
- if PYTHON_JIT:
- jit = 'JIT-interpreter'
- else:
- jit = 'JIT-interpreter=off'
- elif '-D_Py_JIT' in cflags:
- jit = 'JIT'
- else:
- jit = None
- if jit:
- build.append(jit)
+ build.append("JIT (disabled)")
# --enable-framework=name
framework = sysconfig.get_config_var('PYTHONFRAMEWORK')
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 24984ad81ff..c74c3a31909 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -33,7 +33,7 @@ __all__ = [
"is_resource_enabled", "requires", "requires_freebsd_version",
"requires_gil_enabled", "requires_linux_version", "requires_mac_ver",
"check_syntax_error",
- "requires_gzip", "requires_bz2", "requires_lzma",
+ "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd",
"bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
"requires_IEEE_754", "requires_zlib",
"has_fork_support", "requires_fork",
@@ -527,6 +527,13 @@ def requires_lzma(reason='requires lzma'):
lzma = None
return unittest.skipUnless(lzma, reason)
+def requires_zstd(reason='requires zstd'):
+ try:
+ from compression import zstd
+ except ImportError:
+ zstd = None
+ return unittest.skipUnless(zstd, reason)
+
def has_no_debug_ranges():
try:
import _testcapi
@@ -2648,13 +2655,9 @@ skip_on_s390x = unittest.skipIf(is_s390x, 'skipped on s390x')
Py_TRACE_REFS = hasattr(sys, 'getobjects')
-try:
- from _testinternalcapi import jit_enabled
-except ImportError:
- requires_jit_enabled = requires_jit_disabled = unittest.skip("requires _testinternalcapi")
-else:
- requires_jit_enabled = unittest.skipUnless(jit_enabled(), "requires JIT enabled")
- requires_jit_disabled = unittest.skipIf(jit_enabled(), "requires JIT disabled")
+_JIT_ENABLED = sys._jit.is_enabled()
+requires_jit_enabled = unittest.skipUnless(_JIT_ENABLED, "requires JIT enabled")
+requires_jit_disabled = unittest.skipIf(_JIT_ENABLED, "requires JIT disabled")
_BASE_COPY_SRC_DIR_IGNORED_NAMES = frozenset({
@@ -2855,36 +2858,59 @@ def iter_slot_wrappers(cls):
@contextlib.contextmanager
-def no_color():
+def force_color(color: bool):
import _colorize
from .os_helper import EnvironmentVarGuard
with (
- swap_attr(_colorize, "can_colorize", lambda file=None: False),
+ swap_attr(_colorize, "can_colorize", lambda file=None: color),
EnvironmentVarGuard() as env,
):
env.unset("FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS")
- env.set("NO_COLOR", "1")
+ env.set("FORCE_COLOR" if color else "NO_COLOR", "1")
yield
+def force_colorized(func):
+ """Force the terminal to be colorized."""
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ with force_color(True):
+ return func(*args, **kwargs)
+ return wrapper
+
+
def force_not_colorized(func):
- """Force the terminal not to be colorized."""
+ """Force the terminal NOT to be colorized."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
- with no_color():
+ with force_color(False):
return func(*args, **kwargs)
return wrapper
+def force_colorized_test_class(cls):
+ """Force the terminal to be colorized for the entire test class."""
+ original_setUpClass = cls.setUpClass
+
+ @classmethod
+ @functools.wraps(cls.setUpClass)
+ def new_setUpClass(cls):
+ cls.enterClassContext(force_color(True))
+ original_setUpClass()
+
+ cls.setUpClass = new_setUpClass
+ return cls
+
+
def force_not_colorized_test_class(cls):
- """Force the terminal not to be colorized for the entire test class."""
+ """Force the terminal NOT to be colorized for the entire test class."""
original_setUpClass = cls.setUpClass
@classmethod
@functools.wraps(cls.setUpClass)
def new_setUpClass(cls):
- cls.enterClassContext(no_color())
+ cls.enterClassContext(force_color(False))
original_setUpClass()
cls.setUpClass = new_setUpClass
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index 13c6a2a584b..c3c245ddaf8 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -1053,6 +1053,21 @@ class TestGetAnnotations(unittest.TestCase):
},
)
+ def test_partial_evaluation_error(self):
+ def f(x: range[1]):
+ pass
+ with self.assertRaisesRegex(
+ TypeError, "type 'range' is not subscriptable"
+ ):
+ f.__annotations__
+
+ self.assertEqual(
+ get_annotations(f, format=Format.FORWARDREF),
+ {
+ "x": support.EqualToForwardRef("range[1]", owner=f),
+ },
+ )
+
def test_partial_evaluation_cell(self):
obj = object()
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index c5a1f31aa52..5a6be1180c1 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -7058,7 +7058,7 @@ class TestColorized(TestCase):
super().setUp()
# Ensure color even if ran with NO_COLOR=1
_colorize.can_colorize = lambda *args, **kwargs: True
- self.ansi = _colorize.ANSIColors()
+ self.theme = _colorize.get_theme(force_color=True).argparse
def test_argparse_color(self):
# Arrange: create a parser with a bit of everything
@@ -7120,13 +7120,17 @@ class TestColorized(TestCase):
sub2 = subparsers.add_parser("sub2", deprecated=True, help="sub2 help")
sub2.add_argument("--baz", choices=("X", "Y", "Z"), help="baz help")
- heading = self.ansi.BOLD_BLUE
- label, label_b = self.ansi.YELLOW, self.ansi.BOLD_YELLOW
- long, long_b = self.ansi.CYAN, self.ansi.BOLD_CYAN
- pos, pos_b = short, short_b = self.ansi.GREEN, self.ansi.BOLD_GREEN
- sub = self.ansi.BOLD_GREEN
- prog = self.ansi.BOLD_MAGENTA
- reset = self.ansi.RESET
+ prog = self.theme.prog
+ heading = self.theme.heading
+ long = self.theme.summary_long_option
+ short = self.theme.summary_short_option
+ label = self.theme.summary_label
+ pos = self.theme.summary_action
+ long_b = self.theme.long_option
+ short_b = self.theme.short_option
+ label_b = self.theme.label
+ pos_b = self.theme.action
+ reset = self.theme.reset
# Act
help_text = parser.format_help()
@@ -7171,9 +7175,9 @@ class TestColorized(TestCase):
{heading}subcommands:{reset}
valid subcommands
- {sub}{{sub1,sub2}}{reset} additional help
- {sub}sub1{reset} sub1 help
- {sub}sub2{reset} sub2 help
+ {pos_b}{{sub1,sub2}}{reset} additional help
+ {pos_b}sub1{reset} sub1 help
+ {pos_b}sub2{reset} sub2 help
"""
),
)
@@ -7187,10 +7191,10 @@ class TestColorized(TestCase):
prog="PROG",
usage="[prefix] %(prog)s [suffix]",
)
- heading = self.ansi.BOLD_BLUE
- prog = self.ansi.BOLD_MAGENTA
- reset = self.ansi.RESET
- usage = self.ansi.MAGENTA
+ heading = self.theme.heading
+ prog = self.theme.prog
+ reset = self.theme.reset
+ usage = self.theme.prog_extra
# Act
help_text = parser.format_help()
diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py
index 2c198a6b8b2..b9df6568123 100644
--- a/Lib/test/test_asdl_parser.py
+++ b/Lib/test/test_asdl_parser.py
@@ -62,17 +62,17 @@ class TestAsdlParser(unittest.TestCase):
alias = self.types['alias']
self.assertEqual(
str(alias),
- 'Product([Field(identifier, name), Field(identifier, asname, opt=True)], '
+ 'Product([Field(identifier, name), Field(identifier, asname, quantifiers=[OPTIONAL])], '
'[Field(int, lineno), Field(int, col_offset), '
- 'Field(int, end_lineno, opt=True), Field(int, end_col_offset, opt=True)])')
+ 'Field(int, end_lineno, quantifiers=[OPTIONAL]), Field(int, end_col_offset, quantifiers=[OPTIONAL])])')
def test_attributes(self):
stmt = self.types['stmt']
self.assertEqual(len(stmt.attributes), 4)
self.assertEqual(repr(stmt.attributes[0]), 'Field(int, lineno)')
self.assertEqual(repr(stmt.attributes[1]), 'Field(int, col_offset)')
- self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, opt=True)')
- self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, opt=True)')
+ self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, quantifiers=[OPTIONAL])')
+ self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, quantifiers=[OPTIONAL])')
def test_constructor_fields(self):
ehandler = self.types['excepthandler']
diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py
index 6a9b7812ef6..09cf3186e05 100644
--- a/Lib/test/test_ast/test_ast.py
+++ b/Lib/test/test_ast/test_ast.py
@@ -26,6 +26,7 @@ from test import support
from test.support import os_helper
from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow
from test.support.ast_helper import ASTTestMixin
+from test.support.import_helper import ensure_lazy_imports
from test.test_ast.utils import to_tuple
from test.test_ast.snippets import (
eval_tests, eval_results, exec_tests, exec_results, single_tests, single_results
@@ -47,6 +48,12 @@ def ast_repr_update_snapshots() -> None:
AST_REPR_DATA_FILE.write_text("\n".join(data))
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("ast", {"contextlib", "enum", "inspect", "re", "collections", "argparse"})
+
+
class AST_Tests(unittest.TestCase):
maxDiff = None
@@ -3272,6 +3279,9 @@ class CommandLineTests(unittest.TestCase):
('--no-type-comments', '--no-type-comments'),
('-a', '--include-attributes'),
('-i=4', '--indent=4'),
+ ('--feature-version=3.13', '--feature-version=3.13'),
+ ('-O=-1', '--optimize=-1'),
+ ('--show-empty', '--show-empty'),
)
self.set_source('''
print(1, 2, 3)
@@ -3286,6 +3296,7 @@ class CommandLineTests(unittest.TestCase):
with self.subTest(flags=args):
self.invoke_ast(*args)
+ @support.force_not_colorized
def test_help_message(self):
for flag in ('-h', '--help', '--unknown'):
with self.subTest(flag=flag):
@@ -3389,7 +3400,7 @@ class CommandLineTests(unittest.TestCase):
self.check_output(source, expect, flag)
def test_indent_flag(self):
- # test 'python -m ast -i/--indent'
+ # test 'python -m ast -i/--indent 0'
source = 'pass'
expect = '''
Module(
@@ -3400,6 +3411,96 @@ class CommandLineTests(unittest.TestCase):
with self.subTest(flag=flag):
self.check_output(source, expect, flag)
+ def test_feature_version_flag(self):
+ # test 'python -m ast --feature-version 3.9/3.10'
+ source = '''
+ match x:
+ case 1:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='x', ctx=Load()),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=1)),
+ body=[
+ Pass()])])])
+ '''
+ self.check_output(source, expect, '--feature-version=3.10')
+ with self.assertRaises(SyntaxError):
+ self.invoke_ast('--feature-version=3.9')
+
+ def test_no_optimize_flag(self):
+ # test 'python -m ast -O/--optimize -1/0'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a', ctx=Load()),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=BinOp(
+ left=Constant(value=1),
+ op=Add(),
+ right=Constant(value=2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=-1', '--optimize=-1', '-O=0', '--optimize=0'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_optimize_flag(self):
+ # test 'python -m ast -O/--optimize 1/2'
+ source = '''
+ match a:
+ case 1+2j:
+ pass
+ '''
+ expect = '''
+ Module(
+ body=[
+ Match(
+ subject=Name(id='a', ctx=Load()),
+ cases=[
+ match_case(
+ pattern=MatchValue(
+ value=Constant(value=(1+2j))),
+ body=[
+ Pass()])])])
+ '''
+ for flag in ('-O=1', '--optimize=1', '-O=2', '--optimize=2'):
+ with self.subTest(flag=flag):
+ self.check_output(source, expect, flag)
+
+ def test_show_empty_flag(self):
+ # test 'python -m ast --show-empty'
+ source = 'print(1, 2, 3)'
+ expect = '''
+ Module(
+ body=[
+ Expr(
+ value=Call(
+ func=Name(id='print', ctx=Load()),
+ args=[
+ Constant(value=1),
+ Constant(value=2),
+ Constant(value=3)],
+ keywords=[]))],
+ type_ignores=[])
+ '''
+ self.check_output(source, expect, '--show-empty')
+
class ASTOptimiziationTests(unittest.TestCase):
def wrap_expr(self, expr):
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
index a2fb1022ae4..9f3b6f9acef 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -263,6 +263,24 @@ class EagerTaskFactoryLoopTests:
self.run_coro(run())
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=False, name="example"
+ )
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ self.run_coro(main())
+
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
@@ -505,5 +523,24 @@ class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
asyncio.current_task = asyncio.tasks.current_task = self._current_task
return super().tearDown()
+
+class DefaultTaskFactoryEagerStart(test_utils.TestCase):
+ def test_eager_start_true_with_default_factory(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = asyncio.current_task().get_name()
+
+ async def main():
+ t = asyncio.get_running_loop().create_task(
+ asyncfn(), eager_start=True, name="example"
+ )
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 8d7f1733454..44498ef790e 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -89,8 +89,8 @@ class BaseTaskTests:
Future = None
all_tasks = None
- def new_task(self, loop, coro, name='TestTask', context=None):
- return self.__class__.Task(coro, loop=loop, name=name, context=context)
+ def new_task(self, loop, coro, name='TestTask', context=None, eager_start=None):
+ return self.__class__.Task(coro, loop=loop, name=name, context=context, eager_start=eager_start)
def new_future(self, loop):
return self.__class__.Future(loop=loop)
@@ -2686,6 +2686,35 @@ class BaseTaskTests:
self.assertEqual([None, 1, 2], ret)
+ def test_eager_start_true(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=True, name="example")
+ self.assertTrue(t.done())
+ self.assertEqual(name, "example")
+ await t
+
+ def test_eager_start_false(self):
+ name = None
+
+ async def asyncfn():
+ nonlocal name
+ name = self.current_task().get_name()
+
+ async def main():
+ t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=False, name="example")
+ self.assertFalse(t.done())
+ self.assertIsNone(name)
+ await t
+ self.assertEqual(name, "example")
+
+ asyncio.run(main(), loop_factory=asyncio.EventLoop)
+
def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
diff --git a/Lib/test/test_asyncio/test_tools.py b/Lib/test/test_asyncio/test_tools.py
index 2caf56172c9..0413e236c27 100644
--- a/Lib/test/test_asyncio/test_tools.py
+++ b/Lib/test/test_asyncio/test_tools.py
@@ -18,10 +18,18 @@ TEST_INPUTS_TREE = [
3,
"timer",
[
- [["awaiter3", "awaiter2", "awaiter"], 4],
- [["awaiter1_3", "awaiter1_2", "awaiter1"], 5],
- [["awaiter1_3", "awaiter1_2", "awaiter1"], 6],
- [["awaiter3", "awaiter2", "awaiter"], 7],
+ [[("awaiter3", "/path/to/app.py", 130),
+ ("awaiter2", "/path/to/app.py", 120),
+ ("awaiter", "/path/to/app.py", 110)], 4],
+ [[("awaiterB3", "/path/to/app.py", 190),
+ ("awaiterB2", "/path/to/app.py", 180),
+ ("awaiterB", "/path/to/app.py", 170)], 5],
+ [[("awaiterB3", "/path/to/app.py", 190),
+ ("awaiterB2", "/path/to/app.py", 180),
+ ("awaiterB", "/path/to/app.py", 170)], 6],
+ [[("awaiter3", "/path/to/app.py", 130),
+ ("awaiter2", "/path/to/app.py", 120),
+ ("awaiter", "/path/to/app.py", 110)], 7],
],
),
(
@@ -91,14 +99,14 @@ TEST_INPUTS_TREE = [
" │ └── __aexit__",
" │ └── _aexit",
" │ ├── (T) child1_1",
- " │ │ └── awaiter",
- " │ │ └── awaiter2",
- " │ │ └── awaiter3",
+ " │ │ └── awaiter /path/to/app.py:110",
+ " │ │ └── awaiter2 /path/to/app.py:120",
+ " │ │ └── awaiter3 /path/to/app.py:130",
" │ │ └── (T) timer",
" │ └── (T) child2_1",
- " │ └── awaiter1",
- " │ └── awaiter1_2",
- " │ └── awaiter1_3",
+ " │ └── awaiterB /path/to/app.py:170",
+ " │ └── awaiterB2 /path/to/app.py:180",
+ " │ └── awaiterB3 /path/to/app.py:190",
" │ └── (T) timer",
" └── (T) root2",
" └── bloch",
@@ -106,14 +114,14 @@ TEST_INPUTS_TREE = [
" └── __aexit__",
" └── _aexit",
" ├── (T) child1_2",
- " │ └── awaiter",
- " │ └── awaiter2",
- " │ └── awaiter3",
+ " │ └── awaiter /path/to/app.py:110",
+ " │ └── awaiter2 /path/to/app.py:120",
+ " │ └── awaiter3 /path/to/app.py:130",
" │ └── (T) timer",
" └── (T) child2_2",
- " └── awaiter1",
- " └── awaiter1_2",
- " └── awaiter1_3",
+ " └── awaiterB /path/to/app.py:170",
+ " └── awaiterB2 /path/to/app.py:180",
+ " └── awaiterB3 /path/to/app.py:190",
" └── (T) timer",
]
]
@@ -589,7 +597,6 @@ TEST_INPUTS_TABLE = [
class TestAsyncioToolsTree(unittest.TestCase):
-
def test_asyncio_utils(self):
for input_, tree in TEST_INPUTS_TREE:
with self.subTest(input_):
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index 409c8c109e8..9efebc43d91 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -3,8 +3,16 @@ import base64
import binascii
import os
from array import array
+from test.support import cpython_only
from test.support import os_helper
from test.support import script_helper
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "getopt"})
class LegacyBase64TestCase(unittest.TestCase):
diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py
index 073df310bb4..cbfee604b7a 100644
--- a/Lib/test/test_calendar.py
+++ b/Lib/test/test_calendar.py
@@ -987,6 +987,7 @@ class CommandLineTestCase(unittest.TestCase):
self.assertCLIFails(*args)
self.assertCmdFails(*args)
+ @support.force_not_colorized
def test_help(self):
stdout = self.run_cmd_ok('-h')
self.assertIn(b'usage:', stdout)
diff --git a/Lib/test/test_capi/test_config.py b/Lib/test/test_capi/test_config.py
index bf351c4defa..a2d70dd3af4 100644
--- a/Lib/test/test_capi/test_config.py
+++ b/Lib/test/test_capi/test_config.py
@@ -57,7 +57,7 @@ class CAPITests(unittest.TestCase):
("home", str | None, None),
("thread_inherit_context", int, None),
("context_aware_warnings", int, None),
- ("import_time", bool, None),
+ ("import_time", int, None),
("inspect", bool, None),
("install_signal_handlers", bool, None),
("int_max_str_digits", int, None),
diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py
index 98dc3b42ef0..a597f23a992 100644
--- a/Lib/test/test_capi/test_misc.py
+++ b/Lib/test/test_capi/test_misc.py
@@ -306,7 +306,7 @@ class CAPITest(unittest.TestCase):
CURRENT_THREAD_REGEX +
r' File .*, line 6 in <module>\n'
r'\n'
- r'Extension modules: _testcapi, _testinternalcapi \(total: 2\)\n')
+ r'Extension modules: _testcapi \(total: 1\)\n')
else:
# Python built with NDEBUG macro defined:
# test _Py_CheckFunctionResult() instead.
diff --git a/Lib/test/test_capi/test_object.py b/Lib/test/test_capi/test_object.py
index 54a01ac7c4a..127862546b1 100644
--- a/Lib/test/test_capi/test_object.py
+++ b/Lib/test/test_capi/test_object.py
@@ -174,6 +174,16 @@ class EnableDeferredRefcountingTest(unittest.TestCase):
self.assertTrue(_testinternalcapi.has_deferred_refcount(silly_list))
+class IsUniquelyReferencedTest(unittest.TestCase):
+ """Test PyUnstable_Object_IsUniquelyReferenced"""
+ def test_is_uniquely_referenced(self):
+ self.assertTrue(_testcapi.is_uniquely_referenced(object()))
+ self.assertTrue(_testcapi.is_uniquely_referenced([]))
+ # Immortals
+ self.assertFalse(_testcapi.is_uniquely_referenced("spanish inquisition"))
+ self.assertFalse(_testcapi.is_uniquely_referenced(42))
+ # CRASHES is_uniquely_referenced(NULL)
+
class CAPITest(unittest.TestCase):
def check_negative_refcount(self, code):
# bpo-35059: Check that Py_DECREF() reports the correct filename
diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index 7e0c60d5522..ba7bcb4540a 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -1919,9 +1919,11 @@ class TestUopsOptimization(unittest.TestCase):
_, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
uops = get_opnames(ex)
+ self.assertNotIn("_GUARD_NOS_NULL", uops)
+ self.assertNotIn("_GUARD_CALLABLE_LEN", uops)
+ self.assertIn("_CALL_LEN", uops)
self.assertNotIn("_GUARD_NOS_INT", uops)
self.assertNotIn("_GUARD_TOS_INT", uops)
- self.assertIn("_CALL_LEN", uops)
def test_binary_op_subscr_tuple_int(self):
def testfunc(n):
diff --git a/Lib/test/test_cmd.py b/Lib/test/test_cmd.py
index 0ae44f3987d..dbfec42fc21 100644
--- a/Lib/test/test_cmd.py
+++ b/Lib/test/test_cmd.py
@@ -11,9 +11,15 @@ import unittest
import io
import textwrap
from test import support
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.pty_helper import run_pty
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("cmd", {"inspect", "string"})
+
+
class samplecmdclass(cmd.Cmd):
"""
Instance the sampleclass:
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 36f87e259e7..1b40e0d05fe 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -1158,6 +1158,24 @@ class CmdLineTest(unittest.TestCase):
res = assert_python_ok('-c', code, PYTHON_CPU_COUNT='default')
self.assertEqual(self.res2int(res), (os.cpu_count(), os.process_cpu_count()))
+ def test_import_time(self):
+ # os is not imported at startup
+ code = 'import os; import os'
+
+ for case in 'importtime', 'importtime=1', 'importtime=true':
+ res = assert_python_ok('-X', case, '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertNotRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ res = assert_python_ok('-X', 'importtime=2', '-c', code)
+ res_err = res.err.decode('utf-8')
+ self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os')
+ self.assertRegex(res_err, r'import time: cached\s* \| cached\s* \| os')
+
+ assert_python_failure('-X', 'importtime=-1', '-c', code)
+ assert_python_failure('-X', 'importtime=3', '-c', code)
+
def res2int(self, res):
out = res.out.strip().decode("utf-8")
return tuple(int(i) for i in out.split())
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index 7cf09ee7847..b646042a3b8 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -674,6 +674,44 @@ class CodeTest(unittest.TestCase):
import test._code_definitions as defs
funcs = {
defs.spam_minimal: {},
+ defs.spam_with_builtins: {
+ 'x': CO_FAST_LOCAL,
+ 'values': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_with_globals_and_builtins: {
+ 'func1': CO_FAST_LOCAL,
+ 'func2': CO_FAST_LOCAL,
+ 'funcs': CO_FAST_LOCAL,
+ 'checks': CO_FAST_LOCAL,
+ 'res': CO_FAST_LOCAL,
+ },
+ defs.spam_args_attrs_and_builtins: {
+ 'a': POSONLY,
+ 'b': POSONLY,
+ 'c': POSORKW,
+ 'd': POSORKW,
+ 'e': KWONLY,
+ 'f': KWONLY,
+ 'args': VARARGS,
+ 'kwargs': VARKWARGS,
+ },
+ defs.spam_returns_arg: {
+ 'x': POSORKW,
+ },
+ defs.spam_with_inner_not_closure: {
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_with_inner_closure: {
+ 'x': CO_FAST_CELL,
+ 'eggs': CO_FAST_LOCAL,
+ },
+ defs.spam_annotated: {
+ 'a': POSORKW,
+ 'b': POSORKW,
+ 'c': POSORKW,
+ },
defs.spam_full: {
'a': POSONLY,
'b': POSONLY,
@@ -777,6 +815,265 @@ class CodeTest(unittest.TestCase):
kinds = _testinternalcapi.get_co_localskinds(func.__code__)
self.assertEqual(kinds, expected)
+ @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi")
+ def test_var_counts(self):
+ self.maxDiff = None
+ def new_var_counts(*,
+ posonly=0,
+ posorkw=0,
+ kwonly=0,
+ varargs=0,
+ varkwargs=0,
+ purelocals=0,
+ argcells=0,
+ othercells=0,
+ freevars=0,
+ globalvars=0,
+ attrs=0,
+ unknown=0,
+ ):
+ nargvars = posonly + posorkw + kwonly + varargs + varkwargs
+ nlocals = nargvars + purelocals + othercells
+ if isinstance(globalvars, int):
+ globalvars = {
+ 'total': globalvars,
+ 'numglobal': 0,
+ 'numbuiltin': 0,
+ 'numunknown': globalvars,
+ }
+ else:
+ g_numunknown = 0
+ if isinstance(globalvars, dict):
+ numglobal = globalvars['numglobal']
+ numbuiltin = globalvars['numbuiltin']
+ size = 2
+ if 'numunknown' in globalvars:
+ g_numunknown = globalvars['numunknown']
+ size += 1
+ assert len(globalvars) == size, globalvars
+ else:
+ assert not isinstance(globalvars, str), repr(globalvars)
+ try:
+ numglobal, numbuiltin = globalvars
+ except ValueError:
+ numglobal, numbuiltin, g_numunknown = globalvars
+ globalvars = {
+ 'total': numglobal + numbuiltin + g_numunknown,
+ 'numglobal': numglobal,
+ 'numbuiltin': numbuiltin,
+ 'numunknown': g_numunknown,
+ }
+ unbound = globalvars['total'] + attrs + unknown
+ return {
+ 'total': nlocals + freevars + unbound,
+ 'locals': {
+ 'total': nlocals,
+ 'args': {
+ 'total': nargvars,
+ 'numposonly': posonly,
+ 'numposorkw': posorkw,
+ 'numkwonly': kwonly,
+ 'varargs': varargs,
+ 'varkwargs': varkwargs,
+ },
+ 'numpure': purelocals,
+ 'cells': {
+ 'total': argcells + othercells,
+ 'numargs': argcells,
+ 'numothers': othercells,
+ },
+ 'hidden': {
+ 'total': 0,
+ 'numpure': 0,
+ 'numcells': 0,
+ },
+ },
+ 'numfree': freevars,
+ 'unbound': {
+ 'total': unbound,
+ 'globals': globalvars,
+ 'numattrs': attrs,
+ 'numunknown': unknown,
+ },
+ }
+
+ import test._code_definitions as defs
+ funcs = {
+ defs.spam_minimal: new_var_counts(),
+ defs.spam_with_builtins: new_var_counts(
+ purelocals=4,
+ globalvars=4,
+ ),
+ defs.spam_with_globals_and_builtins: new_var_counts(
+ purelocals=5,
+ globalvars=6,
+ ),
+ defs.spam_args_attrs_and_builtins: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ attrs=1,
+ ),
+ defs.spam_returns_arg: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_with_inner_not_closure: new_var_counts(
+ purelocals=1,
+ ),
+ defs.spam_with_inner_closure: new_var_counts(
+ othercells=1,
+ purelocals=1,
+ ),
+ defs.spam_annotated: new_var_counts(
+ posorkw=3,
+ ),
+ defs.spam_full: new_var_counts(
+ posonly=2,
+ posorkw=2,
+ kwonly=2,
+ varargs=1,
+ varkwargs=1,
+ purelocals=4,
+ globalvars=3,
+ attrs=1,
+ ),
+ defs.spam: new_var_counts(
+ posorkw=1,
+ ),
+ defs.spam_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_NN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.spam_NC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CN: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.spam_CC: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ ),
+ defs.eggs_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.eggs_closure: new_var_counts(
+ posorkw=1,
+ freevars=2,
+ ),
+ defs.eggs_nested_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ ),
+ defs.eggs_nested_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_N: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ freevars=2,
+ ),
+ defs.eggs_closure_C: new_var_counts(
+ posorkw=1,
+ purelocals=1,
+ argcells=1,
+ othercells=1,
+ freevars=2,
+ ),
+ defs.ham_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_closure: new_var_counts(
+ posorkw=1,
+ freevars=3,
+ ),
+ defs.ham_C_nested: new_var_counts(
+ posorkw=1,
+ ),
+ defs.ham_C_closure: new_var_counts(
+ posorkw=1,
+ freevars=4,
+ ),
+ }
+ assert len(funcs) == len(defs.FUNCTIONS), (len(funcs), len(defs.FUNCTIONS))
+ for func in defs.FUNCTIONS:
+ with self.subTest(func):
+ expected = funcs[func]
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ def func_with_globals_and_builtins():
+ mod1 = _testinternalcapi
+ mod2 = dis
+ mods = (mod1, mod2)
+ checks = tuple(callable(m) for m in mods)
+ return callable(mod2), tuple(mods), list(mods), checks
+
+ func = func_with_globals_and_builtins
+ with self.subTest(f'{func} code'):
+ expected = new_var_counts(
+ purelocals=4,
+ globalvars=5,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func.__code__)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} with own globals and builtins'):
+ expected = new_var_counts(
+ purelocals=4,
+ globalvars=(2, 3),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func)
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without globals'):
+ expected = new_var_counts(
+ purelocals=4,
+ globalvars=(0, 3, 2),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without both'):
+ expected = new_var_counts(
+ purelocals=4,
+ globalvars=5,
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, globalsns={},
+ builtinsns={})
+ self.assertEqual(counts, expected)
+
+ with self.subTest(f'{func} without builtins'):
+ expected = new_var_counts(
+ purelocals=4,
+ globalvars=(2, 0, 3),
+ )
+ counts = _testinternalcapi.get_code_var_counts(func, builtinsns={})
+ self.assertEqual(counts, expected)
+
def isinterned(s):
return s is sys.intern(('_' + s + '_')[1:-1])
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index 32d6fd4e94b..5ac0080db43 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -725,6 +725,39 @@ class MarshalTests(_GetXIDataTests):
])
+class CodeTests(_GetXIDataTests):
+
+ MODE = 'code'
+
+ def test_function_code(self):
+ self.assert_roundtrip_equal_not_identical([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
+ def test_functions(self):
+ self.assert_not_shareable([
+ *defs.FUNCTIONS,
+ *defs.FUNCTION_LIKE,
+ ])
+
+ def test_other_objects(self):
+ self.assert_not_shareable([
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ 9999,
+ 'spam',
+ b'spam',
+ (),
+ [],
+ {},
+ object(),
+ ])
+
+
class ShareableTypeTests(_GetXIDataTests):
MODE = 'xidata'
@@ -817,6 +850,13 @@ class ShareableTypeTests(_GetXIDataTests):
object(),
])
+ def test_code(self):
+ # types.CodeType
+ self.assert_not_shareable([
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ ])
+
def test_function_object(self):
for func in defs.FUNCTIONS:
assert type(func) is types.FunctionType, func
@@ -935,12 +975,6 @@ class ShareableTypeTests(_GetXIDataTests):
self.assert_not_shareable([
types.MappingProxyType({}),
types.SimpleNamespace(),
- # types.CodeType
- defs.spam_minimal.__code__,
- defs.spam_full.__code__,
- defs.spam_CC.__code__,
- defs.eggs_closure_C.__code__,
- defs.ham_C_closure.__code__,
# types.CellType
types.CellType(),
# types.FrameType
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 4af8f7f480e..9aace57633b 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -10,7 +10,8 @@ import csv
import gc
import pickle
from test import support
-from test.support import import_helper, check_disallow_instantiation
+from test.support import cpython_only, import_helper, check_disallow_instantiation
+from test.support.import_helper import ensure_lazy_imports
from itertools import permutations
from textwrap import dedent
from collections import OrderedDict
@@ -1565,6 +1566,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, csv, ('csv', '_csv'))
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("csv", {"re"})
+
def test_subclassable(self):
# issue 44089
class Foo(csv.Error): ...
diff --git a/Lib/test/test_ctypes/test_aligned_structures.py b/Lib/test/test_ctypes/test_aligned_structures.py
index 0c563ab8055..50b4d729b9d 100644
--- a/Lib/test/test_ctypes/test_aligned_structures.py
+++ b/Lib/test/test_ctypes/test_aligned_structures.py
@@ -316,6 +316,7 @@ class TestAlignedStructures(unittest.TestCase, StructCheckMixin):
class Main(sbase):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", c_ubyte),
("b", Inner),
diff --git a/Lib/test/test_ctypes/test_bitfields.py b/Lib/test/test_ctypes/test_bitfields.py
index dc81e752567..518f838219e 100644
--- a/Lib/test/test_ctypes/test_bitfields.py
+++ b/Lib/test/test_ctypes/test_bitfields.py
@@ -430,6 +430,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_84039(self):
class Bad(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -443,9 +444,9 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
("b1", c_uint16, 12),
]
-
class GoodA(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a0", c_uint8, 1),
("a1", c_uint8, 1),
@@ -460,6 +461,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
class Good(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("a", GoodA),
("b0", c_uint16, 4),
@@ -475,6 +477,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin):
def test_gh_73939(self):
class MyStructure(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [
("P", c_uint16),
("L", c_uint16, 9),
diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py
index 9f9904282e4..ea5951603f9 100644
--- a/Lib/test/test_ctypes/test_byteswap.py
+++ b/Lib/test/test_ctypes/test_byteswap.py
@@ -269,6 +269,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(base):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
@@ -296,6 +297,7 @@ class Test(unittest.TestCase, StructCheckMixin):
class S(Structure):
_pack_ = 1
+ _layout_ = "ms"
_fields_ = [("b", c_byte),
("h", c_short),
diff --git a/Lib/test/test_ctypes/test_generated_structs.py b/Lib/test/test_ctypes/test_generated_structs.py
index 9a8102219d8..aa448fad5bb 100644
--- a/Lib/test/test_ctypes/test_generated_structs.py
+++ b/Lib/test/test_ctypes/test_generated_structs.py
@@ -125,18 +125,21 @@ class Nested(Structure):
class Packed1(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 1
+ _layout_ = 'ms'
@register()
class Packed2(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 2
+ _layout_ = 'ms'
@register()
class Packed3(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 4
+ _layout_ = 'ms'
@register()
@@ -155,6 +158,7 @@ class Packed4(Structure):
_fields_ = [('a', c_int8), ('b', c_int64)]
_pack_ = 8
+ _layout_ = 'ms'
@register()
class X86_32EdgeCase(Structure):
@@ -366,6 +370,7 @@ class Example_gh_95496(Structure):
@register()
class Example_gh_84039_bad(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -380,6 +385,7 @@ class Example_gh_84039_bad(Structure):
@register()
class Example_gh_84039_good_a(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a0", c_uint8, 1),
("a1", c_uint8, 1),
("a2", c_uint8, 1),
@@ -392,6 +398,7 @@ class Example_gh_84039_good_a(Structure):
@register()
class Example_gh_84039_good(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", Example_gh_84039_good_a),
("b0", c_uint16, 4),
("b1", c_uint16, 12)]
@@ -399,6 +406,7 @@ class Example_gh_84039_good(Structure):
@register()
class Example_gh_73939(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("P", c_uint16),
("L", c_uint16, 9),
("Pro", c_uint16, 1),
@@ -419,6 +427,7 @@ class Example_gh_86098(Structure):
@register()
class Example_gh_86098_pack(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("a", c_uint8, 8),
("b", c_uint8, 8),
("c", c_uint32, 16)]
@@ -528,7 +537,7 @@ def dump_ctype(tp, struct_or_union_tag='', variable_name='', semi=''):
pushes.append(f'#pragma pack(push, {pack})')
pops.append(f'#pragma pack(pop)')
layout = getattr(tp, '_layout_', None)
- if layout == 'ms' or pack:
+ if layout == 'ms':
# The 'ms_struct' attribute only works on x86 and PowerPC
requires.add(
'defined(MS_WIN32) || ('
diff --git a/Lib/test/test_ctypes/test_pep3118.py b/Lib/test/test_ctypes/test_pep3118.py
index 06b2ccecade..11a0744f5a8 100644
--- a/Lib/test/test_ctypes/test_pep3118.py
+++ b/Lib/test/test_ctypes/test_pep3118.py
@@ -81,6 +81,7 @@ class Point(Structure):
class PackedPoint(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_long), ("y", c_long)]
class PointMidPad(Structure):
@@ -88,6 +89,7 @@ class PointMidPad(Structure):
class PackedPointMidPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_byte), ("y", c_uint64)]
class PointEndPad(Structure):
@@ -95,6 +97,7 @@ class PointEndPad(Structure):
class PackedPointEndPad(Structure):
_pack_ = 2
+ _layout_ = 'ms'
_fields_ = [("x", c_uint64), ("y", c_byte)]
class Point2(Structure):
diff --git a/Lib/test/test_ctypes/test_structunion.py b/Lib/test/test_ctypes/test_structunion.py
index 8d8b7e5e995..5b21d48d99c 100644
--- a/Lib/test/test_ctypes/test_structunion.py
+++ b/Lib/test/test_ctypes/test_structunion.py
@@ -11,6 +11,8 @@ from ._support import (_CData, PyCStructType, UnionType,
Py_TPFLAGS_DISALLOW_INSTANTIATION,
Py_TPFLAGS_IMMUTABLETYPE)
from struct import calcsize
+import contextlib
+from test.support import MS_WINDOWS
class StructUnionTestBase:
@@ -335,6 +337,22 @@ class StructUnionTestBase:
self.assertIn("from_address", dir(type(self.cls)))
self.assertIn("in_dll", dir(type(self.cls)))
+ def test_pack_layout_switch(self):
+ # Setting _pack_ implicitly sets default layout to MSVC;
+ # this is deprecated on non-Windows platforms.
+ if MS_WINDOWS:
+ warn_context = contextlib.nullcontext()
+ else:
+ warn_context = self.assertWarns(DeprecationWarning)
+ with warn_context:
+ class X(self.cls):
+ _pack_ = 1
+ # _layout_ missing
+ _fields_ = [('a', c_int8, 1), ('b', c_int16, 2)]
+
+ # Check MSVC layout (bitfields of different types aren't combined)
+ self.check_sizeof(X, struct_size=3, union_size=2)
+
class StructureTestCase(unittest.TestCase, StructUnionTestBase):
cls = Structure
diff --git a/Lib/test/test_ctypes/test_structures.py b/Lib/test/test_ctypes/test_structures.py
index 221319642e8..92d4851d739 100644
--- a/Lib/test/test_ctypes/test_structures.py
+++ b/Lib/test/test_ctypes/test_structures.py
@@ -25,6 +25,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 1
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 9)
@@ -34,6 +35,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 2
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), 10)
self.assertEqual(X.b.offset, 2)
@@ -45,6 +47,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 4
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(4, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(4, longlong_align))
@@ -53,27 +56,33 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin):
_fields_ = [("a", c_byte),
("b", c_longlong)]
_pack_ = 8
+ _layout_ = 'ms'
self.check_struct(X)
self.assertEqual(sizeof(X), min(8, longlong_align) + longlong_size)
self.assertEqual(X.b.offset, min(8, longlong_align))
-
- d = {"_fields_": [("a", "b"),
- ("b", "q")],
- "_pack_": -1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", "b"), ("b", "q")]
+ _pack_ = -1
+ _layout_ = "ms"
@support.cpython_only
def test_packed_c_limits(self):
# Issue 15989
import _testcapi
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.INT_MAX + 1}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
- d = {"_fields_": [("a", c_byte)],
- "_pack_": _testcapi.UINT_MAX + 2}
- self.assertRaises(ValueError, type(Structure), "X", (Structure,), d)
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.INT_MAX + 1
+ _layout_ = "ms"
+
+ with self.assertRaises(ValueError):
+ class X(Structure):
+ _fields_ = [("a", c_byte)]
+ _pack_ = _testcapi.UINT_MAX + 2
+ _layout_ = "ms"
def test_initializers(self):
class Person(Structure):
diff --git a/Lib/test/test_ctypes/test_unaligned_structures.py b/Lib/test/test_ctypes/test_unaligned_structures.py
index 58a00597ef5..b5fb4c0df77 100644
--- a/Lib/test/test_ctypes/test_unaligned_structures.py
+++ b/Lib/test/test_ctypes/test_unaligned_structures.py
@@ -19,10 +19,12 @@ for typ in [c_short, c_int, c_long, c_longlong,
c_ushort, c_uint, c_ulong, c_ulonglong]:
class X(Structure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
class Y(SwappedStructure):
_pack_ = 1
+ _layout_ = 'ms'
_fields_ = [("pad", c_byte),
("value", typ)]
structures.append(X)
diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py
index 99fefb57fd0..ac78f8327b8 100644
--- a/Lib/test/test_dataclasses/__init__.py
+++ b/Lib/test/test_dataclasses/__init__.py
@@ -5,6 +5,7 @@
from dataclasses import *
import abc
+import annotationlib
import io
import pickle
import inspect
@@ -12,6 +13,7 @@ import builtins
import types
import weakref
import traceback
+import sys
import textwrap
import unittest
from unittest.mock import Mock
@@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
from test import support
+from test.support import import_helper
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase):
@support.cpython_only
def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935
- from test.support import import_helper
# Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi')
@@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
- 'y': 'typing.Any',
- 'z': 'typing.Any'})
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
+ 'y': typing.Any,
+ 'z': typing.Any})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': int,
- 'z': 'typing.Any'})
+ 'z': typing.Any})
+
+ def test_no_types_get_annotations(self):
+ C = make_dataclass('C', ['x', ('y', int), 'z'])
+
+ self.assertEqual(
+ annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.FORWARDREF),
+ {'x': typing.Any, 'y': int, 'z': typing.Any},
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(
+ C, format=annotationlib.Format.STRING),
+ {'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
+ )
+
+ def test_no_types_no_typing_import(self):
+ with import_helper.CleanImport('typing'):
+ self.assertNotIn('typing', sys.modules)
+ C = make_dataclass('C', ['x', ('y', int)])
+
+ self.assertNotIn('typing', sys.modules)
+ self.assertEqual(
+ C.__annotate__(annotationlib.Format.FORWARDREF),
+ {
+ 'x': annotationlib.ForwardRef('Any', module='typing'),
+ 'y': int,
+ },
+ )
+ self.assertNotIn('typing', sys.modules)
+
+ for field in fields(C):
+ if field.name == "x":
+ self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
+ else:
+ self.assertEqual(field.name, "y")
+ self.assertIs(field.type, int)
def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index f2586fcee57..ae68c1dd75c 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -1336,7 +1336,7 @@ class DisTests(DisTestBase):
# Loop can trigger a quicken where the loop is located
self.code_quicken(loop_test)
got = self.get_disassembly(loop_test, adaptive=True)
- jit = import_helper.import_module("_testinternalcapi").jit_enabled()
+ jit = sys._jit.is_enabled()
expected = dis_loop_test_quickened_code.format("JIT" if jit else "NO_JIT")
self.do_disassembly_compare(got, expected)
diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py
index 4e6201e13c8..c9d09098b50 100644
--- a/Lib/test/test_email/test_utils.py
+++ b/Lib/test/test_email/test_utils.py
@@ -4,6 +4,16 @@ import test.support
import time
import unittest
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("email.utils", {"random", "socket"})
+
class DateTimeTests(unittest.TestCase):
diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py
index e06e684408c..95b2d80464c 100644
--- a/Lib/test/test_embed.py
+++ b/Lib/test/test_embed.py
@@ -585,7 +585,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'faulthandler': False,
'tracemalloc': 0,
'perf_profiling': 0,
- 'import_time': False,
+ 'import_time': 0,
'thread_inherit_context': DEFAULT_THREAD_INHERIT_CONTEXT,
'context_aware_warnings': DEFAULT_CONTEXT_AWARE_WARNINGS,
'code_debug_ranges': True,
@@ -998,7 +998,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'hash_seed': 123,
'tracemalloc': 2,
'perf_profiling': 0,
- 'import_time': True,
+ 'import_time': 2,
'code_debug_ranges': False,
'show_ref_count': True,
'malloc_stats': True,
@@ -1064,7 +1064,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
@@ -1100,7 +1100,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase):
'use_hash_seed': True,
'hash_seed': 42,
'tracemalloc': 2,
- 'import_time': True,
+ 'import_time': 1,
'code_debug_ranges': False,
'malloc_stats': True,
'inspect': True,
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 68cedc666a5..d8cb5261244 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -19,7 +19,8 @@ from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support
from test.support import ALWAYS_EQ, REPO_ROOT
-from test.support import threading_helper
+from test.support import threading_helper, cpython_only
+from test.support.import_helper import ensure_lazy_imports
from datetime import timedelta
python_version = sys.version_info[:2]
@@ -5288,6 +5289,10 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("enum", {"functools", "warnings", "inspect", "re"})
+
def test_doc_1(self):
class Single(Enum):
ONE = 1
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py
index f787190b1ae..ad3f669a030 100644
--- a/Lib/test/test_external_inspection.py
+++ b/Lib/test/test_external_inspection.py
@@ -15,13 +15,12 @@ import subprocess
PROCESS_VM_READV_SUPPORTED = False
try:
- from _remotedebugging import PROCESS_VM_READV_SUPPORTED
- from _remotedebugging import get_stack_trace
- from _remotedebugging import get_async_stack_trace
- from _remotedebugging import get_all_awaited_by
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ from _remote_debugging import get_stack_trace
+ from _remote_debugging import get_async_stack_trace
+ from _remote_debugging import get_all_awaited_by
except ImportError:
- raise unittest.SkipTest("Test only runs when _remotedebuggingmodule is available")
-
+ raise unittest.SkipTest("Test only runs when _remote_debugging is available")
def _make_test_script(script_dir, script_basename, source):
to_return = make_script(script_dir, script_basename, source)
@@ -60,8 +59,7 @@ class TestGetStackTrace(unittest.TestCase):
foo()
def foo():
- sock.sendall(b"ready")
- time.sleep(1000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
bar()
"""
@@ -97,10 +95,10 @@ class TestGetStackTrace(unittest.TestCase):
p.wait(timeout=SHORT_TIMEOUT)
expected_stack_trace = [
- ("foo", script_name, 15),
+ ("foo", script_name, 14),
("baz", script_name, 11),
("bar", script_name, 9),
- ("<module>", script_name, 17),
+ ("<module>", script_name, 16),
]
self.assertEqual(stack_trace, expected_stack_trace)
@@ -123,8 +121,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
def c5():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c4():
await asyncio.sleep(0)
@@ -196,10 +193,10 @@ class TestGetStackTrace(unittest.TestCase):
root_task = "Task-1"
expected_stack_trace = [
[
- ("c5", script_name, 11),
- ("c4", script_name, 15),
- ("c3", script_name, 18),
- ("c2", script_name, 21),
+ ("c5", script_name, 10),
+ ("c4", script_name, 14),
+ ("c3", script_name, 17),
+ ("c2", script_name, 20),
],
"c2_root",
[
@@ -215,13 +212,13 @@ class TestGetStackTrace(unittest.TestCase):
taskgroups.__file__,
ANY,
),
- ("main", script_name, 27),
+ ("main", script_name, 26),
],
"Task-1",
[],
],
[
- [("c1", script_name, 24)],
+ [("c1", script_name, 23)],
"sub_main_1",
[
[
@@ -236,7 +233,7 @@ class TestGetStackTrace(unittest.TestCase):
taskgroups.__file__,
ANY,
),
- ("main", script_name, 27),
+ ("main", script_name, 26),
],
"Task-1",
[],
@@ -244,7 +241,7 @@ class TestGetStackTrace(unittest.TestCase):
],
],
[
- [("c1", script_name, 24)],
+ [("c1", script_name, 23)],
"sub_main_2",
[
[
@@ -259,7 +256,7 @@ class TestGetStackTrace(unittest.TestCase):
taskgroups.__file__,
ANY,
),
- ("main", script_name, 27),
+ ("main", script_name, 26),
],
"Task-1",
[],
@@ -289,8 +286,7 @@ class TestGetStackTrace(unittest.TestCase):
sock.connect(('localhost', {port}))
async def gen_nested_call():
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def gen():
for num in range(2):
@@ -338,9 +334,9 @@ class TestGetStackTrace(unittest.TestCase):
expected_stack_trace = [
[
- ("gen_nested_call", script_name, 11),
- ("gen", script_name, 17),
- ("main", script_name, 20),
+ ("gen_nested_call", script_name, 10),
+ ("gen", script_name, 16),
+ ("main", script_name, 19),
],
"Task-1",
[],
@@ -367,8 +363,7 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
@@ -415,9 +410,9 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace[2].sort(key=lambda x: x[1])
expected_stack_trace = [
- [("deep", script_name, ANY), ("c1", script_name, 16)],
+ [("deep", script_name, 11), ("c1", script_name, 15)],
"Task-2",
- [[[("main", script_name, 22)], "Task-1", []]],
+ [[[("main", script_name, 21)], "Task-1", []]],
]
self.assertEqual(stack_trace, expected_stack_trace)
@@ -441,15 +436,14 @@ class TestGetStackTrace(unittest.TestCase):
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready")
- time.sleep(10000)
+ sock.sendall(b"ready"); time.sleep(10_000) # same line number
async def c1():
await asyncio.sleep(0)
await deep()
async def c2():
- await asyncio.sleep(10000)
+ await asyncio.sleep(10_000)
async def main():
await asyncio.staggered.staggered_race(
@@ -492,8 +486,8 @@ class TestGetStackTrace(unittest.TestCase):
stack_trace[2].sort(key=lambda x: x[1])
expected_stack_trace = [
[
- ("deep", script_name, ANY),
- ("c1", script_name, 16),
+ ("deep", script_name, 11),
+ ("c1", script_name, 15),
("staggered_race.<locals>.run_one_coro", staggered.__file__, ANY),
],
"Task-2",
@@ -501,7 +495,7 @@ class TestGetStackTrace(unittest.TestCase):
[
[
("staggered_race", staggered.__file__, ANY),
- ("main", script_name, 22),
+ ("main", script_name, 21),
],
"Task-1",
[],
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index e3b449f2d24..2e794b0fc95 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -23,6 +23,7 @@ from inspect import Signature
from test.support import import_helper
from test.support import threading_helper
+from test.support import cpython_only
from test.support import EqualToForwardRef
import functools
@@ -63,6 +64,14 @@ class BadTuple(tuple):
class MyDict(dict):
pass
+class TestImportTime(unittest.TestCase):
+
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports(
+ "functools", {"os", "weakref", "typing", "annotationlib", "warnings"}
+ )
+
class TestPartial:
diff --git a/Lib/test/test_getpass.py b/Lib/test/test_getpass.py
index 80dda2caaa3..ab36535a1cf 100644
--- a/Lib/test/test_getpass.py
+++ b/Lib/test/test_getpass.py
@@ -161,6 +161,45 @@ class UnixGetpassTest(unittest.TestCase):
self.assertIn('Warning', stderr.getvalue())
self.assertIn('Password:', stderr.getvalue())
+ def test_echo_char_replaces_input_with_asterisks(self):
+ mock_result = '*************'
+ with mock.patch('os.open') as os_open, \
+ mock.patch('io.FileIO'), \
+ mock.patch('io.TextIOWrapper') as textio, \
+ mock.patch('termios.tcgetattr'), \
+ mock.patch('termios.tcsetattr'), \
+ mock.patch('getpass._raw_input') as mock_input:
+ os_open.return_value = 3
+ mock_input.return_value = mock_result
+
+ result = getpass.unix_getpass(echo_char='*')
+ mock_input.assert_called_once_with('Password: ', textio(),
+ input=textio(), echo_char='*')
+ self.assertEqual(result, mock_result)
+
+ def test_raw_input_with_echo_char(self):
+ passwd = 'my1pa$$word!'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, passwd)
+ self.assertEqual('Password: ************', mock_output.getvalue())
+
+ def test_control_chars_with_echo_char(self):
+ passwd = 'pass\twd\b'
+ expect_result = 'pass\tw'
+ mock_input = StringIO(f'{passwd}\n')
+ mock_output = StringIO()
+ with mock.patch('sys.stdin', mock_input), \
+ mock.patch('sys.stdout', mock_output):
+ result = getpass._raw_input('Password: ', mock_output, mock_input,
+ '*')
+ self.assertEqual(result, expect_result)
+ self.assertEqual('Password: *******\x08 \x08', mock_output.getvalue())
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 585ed08ea14..33b7d75e3ff 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -6,7 +6,8 @@ import unittest.mock
from functools import partial
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
+from test.support.import_helper import ensure_lazy_imports
# TODO:
@@ -931,6 +932,10 @@ class MiscTestCase(unittest.TestCase):
support.check__all__(self, gettext,
not_exported={'c2py', 'ENOENT'})
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("gettext", {"re", "warnings", "locale"})
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 1aa8e4e2897..d6623fee9bb 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -13,8 +13,9 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
# _heapq is imported, so check them there
-func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
- '_heappop_max', '_heapreplace_max', '_heapify_max']
+func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace']
+# Add max-heap variants
+func_names += [func + '_max' for func in func_names]
class TestModules(TestCase):
def test_py_functions(self):
@@ -24,7 +25,7 @@ class TestModules(TestCase):
@skipUnless(c_heapq, 'requires _heapq')
def test_c_functions(self):
for fname in func_names:
- self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
+ self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq', fname)
def load_tests(loader, tests, ignore):
@@ -74,6 +75,34 @@ class TestHeap:
except AttributeError:
pass
+ def test_max_push_pop(self):
+ # 1) Push 256 random numbers and pop them off, verifying all's OK.
+ heap = []
+ data = []
+ self.check_max_invariant(heap)
+ for i in range(256):
+ item = random.random()
+ data.append(item)
+ self.module.heappush_max(heap, item)
+ self.check_max_invariant(heap)
+ results = []
+ while heap:
+ item = self.module.heappop_max(heap)
+ self.check_max_invariant(heap)
+ results.append(item)
+ data_sorted = data[:]
+ data_sorted.sort(reverse=True)
+
+ self.assertEqual(data_sorted, results)
+ # 2) Check that the invariant holds for a sorted array
+ self.check_max_invariant(results)
+
+ self.assertRaises(TypeError, self.module.heappush_max, [])
+
+ exc_types = (AttributeError, TypeError)
+ self.assertRaises(exc_types, self.module.heappush_max, None, None)
+ self.assertRaises(exc_types, self.module.heappop_max, None)
+
def check_invariant(self, heap):
# Check the heap invariant.
for pos, item in enumerate(heap):
@@ -81,6 +110,11 @@ class TestHeap:
parentpos = (pos-1) >> 1
self.assertTrue(heap[parentpos] <= item)
+ def check_max_invariant(self, heap):
+ for pos, item in enumerate(heap[1:], start=1):
+ parentpos = (pos - 1) >> 1
+ self.assertGreaterEqual(heap[parentpos], item)
+
def test_heapify(self):
for size in list(range(30)) + [20000]:
heap = [random.random() for dummy in range(size)]
@@ -89,6 +123,14 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapify, None)
+ def test_heapify_max(self):
+ for size in list(range(30)) + [20000]:
+ heap = [random.random() for dummy in range(size)]
+ self.module.heapify_max(heap)
+ self.check_max_invariant(heap)
+
+ self.assertRaises(TypeError, self.module.heapify_max, None)
+
def test_naive_nbest(self):
data = [random.randrange(2000) for i in range(1000)]
heap = []
@@ -109,10 +151,7 @@ class TestHeap:
def test_nbest(self):
# Less-naive "N-best" algorithm, much faster (if len(data) is big
- # enough <wink>) than sorting all of data. However, if we had a max
- # heap instead of a min heap, it could go faster still via
- # heapify'ing all of data (linear time), then doing 10 heappops
- # (10 log-time steps).
+ # enough <wink>) than sorting all of data.
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
@@ -125,6 +164,17 @@ class TestHeap:
self.assertRaises(TypeError, self.module.heapreplace, None, None)
self.assertRaises(IndexError, self.module.heapreplace, [], None)
+ def test_nbest_maxheap(self):
+ # With a max heap instead of a min heap, the "N-best" algorithm can
+ # go even faster still via heapify'ing all of data (linear time), then
+ # doing 10 heappops (10 log-time steps).
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify_max(heap)
+ result = [self.module.heappop_max(heap) for _ in range(10)]
+ result.reverse()
+ self.assertEqual(result, sorted(data)[-10:])
+
def test_nbest_with_pushpop(self):
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
@@ -134,6 +184,62 @@ class TestHeap:
self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
self.assertEqual(self.module.heappushpop([], 'x'), 'x')
+ def test_naive_nworst(self):
+ # Max-heap variant of "test_naive_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ if len(heap) > 10:
+ self.module.heappop_max(heap)
+ heap.sort()
+ expected = sorted(data)[:10]
+ self.assertEqual(heap, expected)
+
+ def heapiter_max(self, heap):
+ # An iterator returning a max-heap's elements, largest-first.
+ try:
+ while 1:
+ yield self.module.heappop_max(heap)
+ except IndexError:
+ pass
+
+ def test_nworst(self):
+ # Max-heap variant of "test_nbest"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ if item < heap[0]: # this gets rarer the longer we run
+ self.module.heapreplace_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+
+ self.assertRaises(TypeError, self.module.heapreplace_max, None)
+ self.assertRaises(TypeError, self.module.heapreplace_max, None, None)
+ self.assertRaises(IndexError, self.module.heapreplace_max, [], None)
+
+ def test_nworst_minheap(self):
+ # Min-heap variant of "test_nbest_maxheap"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:]
+ self.module.heapify(heap)
+ result = [self.module.heappop(heap) for _ in range(10)]
+ result.reverse()
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(result, expected)
+
+ def test_nworst_with_pushpop(self):
+ # Max-heap variant of "test_nbest_with_pushpop"
+ data = [random.randrange(2000) for i in range(1000)]
+ heap = data[:10]
+ self.module.heapify_max(heap)
+ for item in data[10:]:
+ self.module.heappushpop_max(heap, item)
+ expected = sorted(data, reverse=True)[-10:]
+ self.assertEqual(list(self.heapiter_max(heap)), expected)
+ self.assertEqual(self.module.heappushpop_max([], 'x'), 'x')
+
def test_heappushpop(self):
h = []
x = self.module.heappushpop(h, 10)
@@ -153,12 +259,31 @@ class TestHeap:
x = self.module.heappushpop(h, 11)
self.assertEqual((h, x), ([11], 10))
+ def test_heappushpop_max(self):
+ h = []
+ x = self.module.heappushpop_max(h, 10)
+ self.assertTupleEqual((h, x), ([], 10))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 10.0)
+ self.assertTupleEqual((h, x), ([10], 10.0))
+ self.assertIsInstance(h[0], int)
+ self.assertIsInstance(x, float)
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 11)
+ self.assertTupleEqual((h, x), ([10], 11))
+
+ h = [10]
+ x = self.module.heappushpop_max(h, 9)
+ self.assertTupleEqual((h, x), ([9], 10))
+
def test_heappop_max(self):
- # _heapop_max has an optimization for one-item lists which isn't
+ # heapop_max has an optimization for one-item lists which isn't
# covered in other tests, so test that case explicitly here
h = [3, 2]
- self.assertEqual(self.module._heappop_max(h), 3)
- self.assertEqual(self.module._heappop_max(h), 2)
+ self.assertEqual(self.module.heappop_max(h), 3)
+ self.assertEqual(self.module.heappop_max(h), 2)
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
@@ -175,6 +300,20 @@ class TestHeap:
heap_sorted = [self.module.heappop(heap) for i in range(size)]
self.assertEqual(heap_sorted, sorted(data))
+ def test_heapsort_max(self):
+ for trial in range(100):
+ size = random.randrange(50)
+ data = [random.randrange(25) for i in range(size)]
+ if trial & 1: # Half of the time, use heapify_max
+ heap = data[:]
+ self.module.heapify_max(heap)
+ else: # The rest of the time, use heappush_max
+ heap = []
+ for item in data:
+ self.module.heappush_max(heap, item)
+ heap_sorted = [self.module.heappop_max(heap) for i in range(size)]
+ self.assertEqual(heap_sorted, sorted(data, reverse=True))
+
def test_merge(self):
inputs = []
for i in range(random.randrange(25)):
@@ -377,16 +516,20 @@ class SideEffectLT:
class TestErrorHandling:
def test_non_sequence(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, 10)
for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10, 10)
def test_len_only(self):
- for f in (self.module.heapify, self.module.heappop):
+ for f in (self.module.heapify, self.module.heappop,
+ self.module.heapify_max, self.module.heappop_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly())
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(TypeError, f, 2, LenOnly())
@@ -395,7 +538,8 @@ class TestErrorHandling:
seq = [CmpErr(), CmpErr(), CmpErr()]
for f in (self.module.heapify, self.module.heappop):
self.assertRaises(ZeroDivisionError, f, seq)
- for f in (self.module.heappush, self.module.heapreplace):
+ for f in (self.module.heappush, self.module.heapreplace,
+ self.module.heappush_max, self.module.heapreplace_max):
self.assertRaises(ZeroDivisionError, f, seq, 10)
for f in (self.module.nlargest, self.module.nsmallest):
self.assertRaises(ZeroDivisionError, f, 2, seq)
@@ -403,6 +547,8 @@ class TestErrorHandling:
def test_arg_parsing(self):
for f in (self.module.heapify, self.module.heappop,
self.module.heappush, self.module.heapreplace,
+ self.module.heapify_max, self.module.heappop_max,
+ self.module.heappush_max, self.module.heapreplace_max,
self.module.nlargest, self.module.nsmallest):
self.assertRaises((TypeError, AttributeError), f, 10)
@@ -424,6 +570,10 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappush(heap, SideEffectLT(5, heap))
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappush_max(heap, SideEffectLT(5, heap))
def test_heappop_mutating_heap(self):
heap = []
@@ -431,8 +581,12 @@ class TestErrorHandling:
# Python version raises IndexError, C version RuntimeError
with self.assertRaises((IndexError, RuntimeError)):
self.module.heappop(heap)
+ heap = []
+ heap.extend(SideEffectLT(i, heap) for i in range(200))
+ with self.assertRaises((IndexError, RuntimeError)):
+ self.module.heappop_max(heap)
- def test_comparison_operator_modifiying_heap(self):
+ def test_comparison_operator_modifying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
class EvilClass(int):
@@ -444,7 +598,7 @@ class TestErrorHandling:
self.module.heappush(heap, EvilClass(0))
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
- def test_comparison_operator_modifiying_heap_two_heaps(self):
+ def test_comparison_operator_modifying_heap_two_heaps(self):
class h(int):
def __lt__(self, o):
@@ -464,6 +618,17 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
+ list1, list2 = [], []
+
+ self.module.heappush_max(list1, h(0))
+ self.module.heappush_max(list2, g(0))
+ self.module.heappush_max(list1, g(1))
+ self.module.heappush_max(list2, h(1))
+
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list1, g(1))
+ self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list2, h(1))
+
+
class TestErrorHandlingPython(TestErrorHandling, TestCase):
module = py_heapq
diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py
index ba9c42f758e..72cde3f0d6c 100644
--- a/Lib/test/test_json/test_tool.py
+++ b/Lib/test/test_json/test_tool.py
@@ -6,9 +6,11 @@ import unittest
import subprocess
from test import support
-from test.support import force_not_colorized, os_helper
+from test.support import force_colorized, force_not_colorized, os_helper
from test.support.script_helper import assert_python_ok
+from _colorize import get_theme
+
@support.requires_subprocess()
class TestMain(unittest.TestCase):
@@ -246,34 +248,39 @@ class TestMain(unittest.TestCase):
proc.communicate(b'"{}"')
self.assertEqual(proc.returncode, errno.EPIPE)
+ @force_colorized
def test_colors(self):
infile = os_helper.TESTFN
self.addCleanup(os.remove, infile)
+ t = get_theme().syntax
+ ob = "{"
+ cb = "}"
+
cases = (
- ('{}', b'{}'),
- ('[]', b'[]'),
- ('null', b'\x1b[1;36mnull\x1b[0m'),
- ('true', b'\x1b[1;36mtrue\x1b[0m'),
- ('false', b'\x1b[1;36mfalse\x1b[0m'),
- ('NaN', b'NaN'),
- ('Infinity', b'Infinity'),
- ('-Infinity', b'-Infinity'),
- ('"foo"', b'\x1b[1;32m"foo"\x1b[0m'),
- (r'" \"foo\" "', b'\x1b[1;32m" \\"foo\\" "\x1b[0m'),
- ('"α"', b'\x1b[1;32m"\\u03b1"\x1b[0m'),
- ('123', b'123'),
- ('-1.2345e+23', b'-1.2345e+23'),
+ ('{}', '{}'),
+ ('[]', '[]'),
+ ('null', f'{t.keyword}null{t.reset}'),
+ ('true', f'{t.keyword}true{t.reset}'),
+ ('false', f'{t.keyword}false{t.reset}'),
+ ('NaN', f'{t.number}NaN{t.reset}'),
+ ('Infinity', f'{t.number}Infinity{t.reset}'),
+ ('-Infinity', f'{t.number}-Infinity{t.reset}'),
+ ('"foo"', f'{t.string}"foo"{t.reset}'),
+ (r'" \"foo\" "', f'{t.string}" \\"foo\\" "{t.reset}'),
+ ('"α"', f'{t.string}"\\u03b1"{t.reset}'),
+ ('123', f'{t.number}123{t.reset}'),
+ ('-1.2345e+23', f'{t.number}-1.2345e+23{t.reset}'),
(r'{"\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
(r'{"\\\\": ""}',
- b'''\
-{
- \x1b[94m"\\\\\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m
-}'''),
+ f'''\
+{ob}
+ {t.definition}"\\\\\\\\"{t.reset}: {t.string}""{t.reset}
+{cb}'''),
('''\
{
"foo": "bar",
@@ -281,30 +288,32 @@ class TestMain(unittest.TestCase):
"qux": [true, false, null],
"xyz": [NaN, -Infinity, Infinity]
}''',
- b'''\
-{
- \x1b[94m"foo"\x1b[0m: \x1b[1;32m"bar"\x1b[0m,
- \x1b[94m"baz"\x1b[0m: 1234,
- \x1b[94m"qux"\x1b[0m: [
- \x1b[1;36mtrue\x1b[0m,
- \x1b[1;36mfalse\x1b[0m,
- \x1b[1;36mnull\x1b[0m
+ f'''\
+{ob}
+ {t.definition}"foo"{t.reset}: {t.string}"bar"{t.reset},
+ {t.definition}"baz"{t.reset}: {t.number}1234{t.reset},
+ {t.definition}"qux"{t.reset}: [
+ {t.keyword}true{t.reset},
+ {t.keyword}false{t.reset},
+ {t.keyword}null{t.reset}
],
- \x1b[94m"xyz"\x1b[0m: [
- NaN,
- -Infinity,
- Infinity
+ {t.definition}"xyz"{t.reset}: [
+ {t.number}NaN{t.reset},
+ {t.number}-Infinity{t.reset},
+ {t.number}Infinity{t.reset}
]
-}'''),
+{cb}'''),
)
for input_, expected in cases:
with self.subTest(input=input_):
with open(infile, "w", encoding="utf-8") as fp:
fp.write(input_)
- _, stdout, _ = assert_python_ok('-m', self.module, infile,
- PYTHON_COLORS='1')
- stdout = stdout.replace(b'\r\n', b'\n') # normalize line endings
+ _, stdout_b, _ = assert_python_ok(
+ '-m', self.module, infile, FORCE_COLOR='1', __isolated='1'
+ )
+ stdout = stdout_b.decode()
+ stdout = stdout.replace('\r\n', '\n') # normalize line endings
stdout = stdout.strip()
self.assertEqual(stdout, expected)
diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py
index 528ceef5281..455d2af37ef 100644
--- a/Lib/test/test_locale.py
+++ b/Lib/test/test_locale.py
@@ -1,13 +1,18 @@
from decimal import Decimal
-from test.support import verbose, is_android, linked_to_musl, os_helper
+from test.support import cpython_only, verbose, is_android, linked_to_musl, os_helper
from test.support.warnings_helper import check_warnings
-from test.support.import_helper import import_fresh_module
+from test.support.import_helper import ensure_lazy_imports, import_fresh_module
from unittest import mock
import unittest
import locale
import sys
import codecs
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("locale", {"re", "warnings"})
+
class BaseLocalizedTest(unittest.TestCase):
#
diff --git a/Lib/test/test_mimetypes.py b/Lib/test/test_mimetypes.py
index dad5dbde7cd..fb57d5e5544 100644
--- a/Lib/test/test_mimetypes.py
+++ b/Lib/test/test_mimetypes.py
@@ -6,7 +6,8 @@ import sys
import unittest.mock
from platform import win32_edition
from test import support
-from test.support import os_helper
+from test.support import cpython_only, force_not_colorized, os_helper
+from test.support.import_helper import ensure_lazy_imports
try:
import _winapi
@@ -435,8 +436,13 @@ class MiscTestCase(unittest.TestCase):
def test__all__(self):
support.check__all__(self, mimetypes)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("mimetypes", {"os", "posixpath", "urllib.parse", "argparse"})
+
class CommandLineTest(unittest.TestCase):
+ @force_not_colorized
def test_parse_args(self):
args, help_text = mimetypes._parse_args("-h")
self.assertTrue(help_text.startswith("usage: "))
diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py
index 6679c0a4fbe..4f25e9c2a03 100644
--- a/Lib/test/test_minidom.py
+++ b/Lib/test/test_minidom.py
@@ -102,41 +102,38 @@ class MinidomTest(unittest.TestCase):
elem = root.childNodes[0]
nelem = dom.createElement("element")
root.insertBefore(nelem, elem)
- self.confirm(len(root.childNodes) == 2
- and root.childNodes.length == 2
- and root.childNodes[0] is nelem
- and root.childNodes.item(0) is nelem
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.firstChild is nelem
- and root.lastChild is elem
- and root.toxml() == "<doc><element/><foo/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 2)
+ self.assertEqual(root.childNodes.length, 2)
+ self.assertIs(root.childNodes[0], nelem)
+ self.assertIs(root.childNodes.item(0), nelem)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.firstChild, nelem)
+ self.assertIs(root.lastChild, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/></doc>")
nelem = dom.createElement("element")
root.insertBefore(nelem, None)
- self.confirm(len(root.childNodes) == 3
- and root.childNodes.length == 3
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.childNodes[2] is nelem
- and root.childNodes.item(2) is nelem
- and root.lastChild is nelem
- and nelem.previousSibling is elem
- and root.toxml() == "<doc><element/><foo/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 3)
+ self.assertEqual(root.childNodes.length, 3)
+ self.assertIs(root.childNodes[1], elem)
+ self.assertIs(root.childNodes.item(1), elem)
+ self.assertIs(root.childNodes[2], nelem)
+ self.assertIs(root.childNodes.item(2), nelem)
+ self.assertIs(root.lastChild, nelem)
+ self.assertIs(nelem.previousSibling, elem)
+ self.assertEqual(root.toxml(), "<doc><element/><foo/><element/></doc>")
nelem2 = dom.createElement("bar")
root.insertBefore(nelem2, nelem)
- self.confirm(len(root.childNodes) == 4
- and root.childNodes.length == 4
- and root.childNodes[2] is nelem2
- and root.childNodes.item(2) is nelem2
- and root.childNodes[3] is nelem
- and root.childNodes.item(3) is nelem
- and nelem2.nextSibling is nelem
- and nelem.previousSibling is nelem2
- and root.toxml() ==
- "<doc><element/><foo/><bar/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
+ self.assertEqual(len(root.childNodes), 4)
+ self.assertEqual(root.childNodes.length, 4)
+ self.assertIs(root.childNodes[2], nelem2)
+ self.assertIs(root.childNodes.item(2), nelem2)
+ self.assertIs(root.childNodes[3], nelem)
+ self.assertIs(root.childNodes.item(3), nelem)
+ self.assertIs(nelem2.nextSibling, nelem)
+ self.assertIs(nelem.previousSibling, nelem2)
+ self.assertEqual(root.toxml(),
+ "<doc><element/><foo/><bar/><element/></doc>")
dom.unlink()
def _create_fragment_test_nodes(self):
@@ -342,8 +339,8 @@ class MinidomTest(unittest.TestCase):
self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode,
None)
self.assertIs(node, child.removeAttributeNode(node))
- self.confirm(len(child.attributes) == 0
- and child.getAttributeNode("spam") is None)
+ self.assertEqual(len(child.attributes), 0)
+ self.assertIsNone(child.getAttributeNode("spam"))
dom2 = Document()
child2 = dom2.appendChild(dom2.createElement("foo"))
node2 = child2.getAttributeNode("spam")
@@ -366,33 +363,34 @@ class MinidomTest(unittest.TestCase):
# Set this attribute to be an ID and make sure that doesn't change
# when changing the value:
el.setIdAttribute("spam")
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "bam"
- and el.attributes["spam"].nodeValue == "bam"
- and el.getAttribute("spam") == "bam"
- and el.getAttributeNode("spam").isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "bam")
+ self.assertEqual(el.attributes["spam"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam"), "bam")
+ self.assertTrue(el.getAttributeNode("spam").isId)
el.attributes["spam"] = "ham"
- self.confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam"].isId)
+ self.assertEqual(len(el.attributes), 1)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertTrue(el.attributes["spam"].isId)
el.setAttribute("spam2", "bam")
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam"
- and el.attributes["spam2"].nodeValue == "bam"
- and el.getAttribute("spam2") == "bam")
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam")
+ self.assertEqual(el.getAttribute("spam2"), "bam")
el.attributes["spam2"] = "bam2"
- self.confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam2"
- and el.attributes["spam2"].nodeValue == "bam2"
- and el.getAttribute("spam2") == "bam2")
+
+ self.assertEqual(len(el.attributes), 2)
+ self.assertEqual(el.attributes["spam"].value, "ham")
+ self.assertEqual(el.attributes["spam"].nodeValue, "ham")
+ self.assertEqual(el.getAttribute("spam"), "ham")
+ self.assertEqual(el.attributes["spam2"].value, "bam2")
+ self.assertEqual(el.attributes["spam2"].nodeValue, "bam2")
+ self.assertEqual(el.getAttribute("spam2"), "bam2")
dom.unlink()
def testGetAttrList(self):
@@ -448,12 +446,12 @@ class MinidomTest(unittest.TestCase):
dom = parseString(d)
elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom",
"myelem")
- self.confirm(len(elems) == 1
- and elems[0].namespaceURI == "http://pyxml.sf.net/minidom"
- and elems[0].localName == "myelem"
- and elems[0].prefix == "minidom"
- and elems[0].tagName == "minidom:myelem"
- and elems[0].nodeName == "minidom:myelem")
+ self.assertEqual(len(elems), 1)
+ self.assertEqual(elems[0].namespaceURI, "http://pyxml.sf.net/minidom")
+ self.assertEqual(elems[0].localName, "myelem")
+ self.assertEqual(elems[0].prefix, "minidom")
+ self.assertEqual(elems[0].tagName, "minidom:myelem")
+ self.assertEqual(elems[0].nodeName, "minidom:myelem")
dom.unlink()
def get_empty_nodelist_from_elements_by_tagName_ns_helper(self, doc, nsuri,
@@ -602,17 +600,17 @@ class MinidomTest(unittest.TestCase):
def testProcessingInstruction(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
pi = dom.documentElement.firstChild
- self.confirm(pi.target == "mypi"
- and pi.data == "data \t\n "
- and pi.nodeName == "mypi"
- and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
- and pi.attributes is None
- and not pi.hasChildNodes()
- and len(pi.childNodes) == 0
- and pi.firstChild is None
- and pi.lastChild is None
- and pi.localName is None
- and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE)
+ self.assertEqual(pi.target, "mypi")
+ self.assertEqual(pi.data, "data \t\n ")
+ self.assertEqual(pi.nodeName, "mypi")
+ self.assertEqual(pi.nodeType, Node.PROCESSING_INSTRUCTION_NODE)
+ self.assertIsNone(pi.attributes)
+ self.assertFalse(pi.hasChildNodes())
+ self.assertEqual(len(pi.childNodes), 0)
+ self.assertIsNone(pi.firstChild)
+ self.assertIsNone(pi.lastChild)
+ self.assertIsNone(pi.localName)
+ self.assertEqual(pi.namespaceURI, xml.dom.EMPTY_NAMESPACE)
def testProcessingInstructionRepr(self):
dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
@@ -718,19 +716,16 @@ class MinidomTest(unittest.TestCase):
keys2 = list(attrs2.keys())
keys1.sort()
keys2.sort()
- self.assertEqual(keys1, keys2,
- "clone of element has same attribute keys")
+ self.assertEqual(keys1, keys2)
for i in range(len(keys1)):
a1 = attrs1.item(i)
a2 = attrs2.item(i)
- self.confirm(a1 is not a2
- and a1.value == a2.value
- and a1.nodeValue == a2.nodeValue
- and a1.namespaceURI == a2.namespaceURI
- and a1.localName == a2.localName
- , "clone of attribute node has proper attribute values")
- self.assertIs(a2.ownerElement, e2,
- "clone of attribute node correctly owned")
+ self.assertIsNot(a1, a2)
+ self.assertEqual(a1.value, a2.value)
+ self.assertEqual(a1.nodeValue, a2.nodeValue)
+ self.assertEqual(a1.namespaceURI,a2.namespaceURI)
+ self.assertEqual(a1.localName, a2.localName)
+ self.assertIs(a2.ownerElement, e2)
def _setupCloneElement(self, deep):
dom = parseString("<doc attr='value'><foo/></doc>")
@@ -746,20 +741,19 @@ class MinidomTest(unittest.TestCase):
def testCloneElementShallow(self):
dom, clone = self._setupCloneElement(0)
- self.confirm(len(clone.childNodes) == 0
- and clone.childNodes.length == 0
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"/>'
- , "testCloneElementShallow")
+ self.assertEqual(len(clone.childNodes), 0)
+ self.assertEqual(clone.childNodes.length, 0)
+ self.assertIsNone(clone.parentNode)
+ self.assertEqual(clone.toxml(), '<doc attr="value"/>')
+
dom.unlink()
def testCloneElementDeep(self):
dom, clone = self._setupCloneElement(1)
- self.confirm(len(clone.childNodes) == 1
- and clone.childNodes.length == 1
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"><foo/></doc>'
- , "testCloneElementDeep")
+ self.assertEqual(len(clone.childNodes), 1)
+ self.assertEqual(clone.childNodes.length, 1)
+ self.assertIsNone(clone.parentNode)
+ self.assertTrue(clone.toxml(), '<doc attr="value"><foo/></doc>')
dom.unlink()
def testCloneDocumentShallow(self):
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 8655a0537a5..e6ffd2b0ffe 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -14,8 +14,9 @@ import unittest
from io import StringIO
from test import support
-from test.support import os_helper
+from test.support import cpython_only, os_helper
from test.support.i18n_helper import TestTranslationsBase, update_translation_snapshots
+from test.support.import_helper import ensure_lazy_imports
import optparse
from optparse import make_option, Option, \
@@ -1655,6 +1656,10 @@ class MiscTestCase(unittest.TestCase):
not_exported = {'check_builtin', 'AmbiguousOptionError', 'NO_DEFAULT'}
support.check__all__(self, optparse, not_exported=not_exported)
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("optparse", {"textwrap"})
+
class TestTranslations(TestTranslationsBase):
def test_translations(self):
diff --git a/Lib/test/test_pathlib/test_pathlib.py b/Lib/test/test_pathlib/test_pathlib.py
index 41a79d0dceb..8a313cc4292 100644
--- a/Lib/test/test_pathlib/test_pathlib.py
+++ b/Lib/test/test_pathlib/test_pathlib.py
@@ -16,6 +16,7 @@ from unittest import mock
from urllib.request import pathname2url
from test.support import import_helper
+from test.support import cpython_only
from test.support import is_emscripten, is_wasi
from test.support import infinite_recursion
from test.support import os_helper
@@ -80,6 +81,12 @@ class UnsupportedOperationTest(unittest.TestCase):
self.assertTrue(isinstance(pathlib.UnsupportedOperation(), NotImplementedError))
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ import_helper.ensure_lazy_imports("pathlib", {"shutil"})
+
+
#
# Tests for the pure classes.
#
@@ -3290,7 +3297,6 @@ class PathTest(PurePathTest):
self.assertEqual(P.from_uri('file:////foo/bar'), P('//foo/bar'))
self.assertEqual(P.from_uri('file://localhost/foo/bar'), P('/foo/bar'))
if not is_wasi:
- self.assertEqual(P.from_uri('file://127.0.0.1/foo/bar'), P('/foo/bar'))
self.assertEqual(P.from_uri(f'file://{socket.gethostname()}/foo/bar'),
P('/foo/bar'))
self.assertRaises(ValueError, P.from_uri, 'foo/bar')
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index be365a5a3dd..54797d7898f 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -1,7 +1,9 @@
# A test suite for pdb; not very comprehensive at the moment.
+import _colorize
import doctest
import gc
+import io
import os
import pdb
import sys
@@ -18,7 +20,7 @@ from asyncio.events import _set_event_loop_policy
from contextlib import ExitStack, redirect_stdout
from io import StringIO
from test import support
-from test.support import force_not_colorized, has_socket_support, os_helper
+from test.support import has_socket_support, os_helper
from test.support.import_helper import import_module
from test.support.pty_helper import run_pty, FakeInput
from test.support.script_helper import kill_python
@@ -3446,6 +3448,7 @@ def test_pdb_issue_gh_65052():
"""
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestCase(unittest.TestCase):
def tearDown(self):
@@ -3740,7 +3743,6 @@ def bœr():
self.assertNotIn(b'Error', stdout,
"Got an error running test script under PDB")
- @force_not_colorized
def test_issue16180(self):
# A syntax error in the debuggee.
script = "def f: pass\n"
@@ -3754,7 +3756,6 @@ def bœr():
'Fail to handle a syntax error in the debuggee.'
.format(expected, stderr))
- @force_not_colorized
def test_issue84583(self):
# A syntax error from ast.literal_eval should not make pdb exit.
script = "import ast; ast.literal_eval('')\n"
@@ -4688,6 +4689,40 @@ class PdbTestInline(unittest.TestCase):
self.assertIn("42", stdout)
+@support.force_colorized_test_class
+class PdbTestColorize(unittest.TestCase):
+ def setUp(self):
+ self._original_can_colorize = _colorize.can_colorize
+ # Force colorize to be enabled because we are sending data
+ # to a StringIO
+ _colorize.can_colorize = lambda *args, **kwargs: True
+
+ def tearDown(self):
+ _colorize.can_colorize = self._original_can_colorize
+
+ def test_code_display(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=False)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output)
+ p.set_trace(commands=['ll', 'c'])
+ self.assertNotIn("\x1b", output.getvalue())
+
+ def test_stack_entry(self):
+ output = io.StringIO()
+ p = pdb.Pdb(stdout=output, colorize=True)
+ p.set_trace(commands=['w', 'c'])
+ self.assertIn("\x1b", output.getvalue())
+
+
@support.force_not_colorized_test_class
@support.requires_subprocess()
class TestREPLSession(unittest.TestCase):
@@ -4711,6 +4746,7 @@ class TestREPLSession(unittest.TestCase):
self.assertEqual(p.returncode, 0)
+@support.force_not_colorized_test_class
@support.requires_subprocess()
class PdbTestReadline(unittest.TestCase):
def setUpClass():
@@ -4812,14 +4848,35 @@ class PdbTestReadline(unittest.TestCase):
self.assertIn(b'I love Python', output)
+ def test_multiline_auto_indent(self):
+ script = textwrap.dedent("""
+ import pdb; pdb.Pdb().set_trace()
+ """)
+
+ input = b"def f(x):\n"
+ input += b"if x > 0:\n"
+ input += b"x += 1\n"
+ input += b"return x\n"
+ # We need to do backspaces to remove the auto-indentation
+ input += b"\x08\x08\x08\x08else:\n"
+ input += b"return -x\n"
+ input += b"\n"
+ input += b"f(-21-21)\n"
+ input += b"c\n"
+
+ output = run_pty(script, input)
+
+ self.assertIn(b'42', output)
+
def test_multiline_completion(self):
script = textwrap.dedent("""
import pdb; pdb.Pdb().set_trace()
""")
input = b"def func():\n"
- # Complete: \treturn 40 + 2
- input += b"\tret\t 40 + 2\n"
+ # Auto-indent
+ # Complete: return 40 + 2
+ input += b"ret\t 40 + 2\n"
input += b"\n"
# Complete: func()
input += b"fun\t()\n"
@@ -4839,12 +4896,13 @@ class PdbTestReadline(unittest.TestCase):
# if the completion is not working as expected
input = textwrap.dedent("""\
def func():
- \ta = 1
- \ta += 1
- \ta += 1
- \tif a > 0:
- a += 1
- \t\treturn a
+ a = 1
+ \x08\ta += 1
+ \x08\x08\ta += 1
+ \x08\x08\x08\ta += 1
+ \x08\x08\x08\x08\tif a > 0:
+ a += 1
+ \x08\x08\x08\x08return a
func()
c
@@ -4852,7 +4910,7 @@ class PdbTestReadline(unittest.TestCase):
output = run_pty(script, input)
- self.assertIn(b'4', output)
+ self.assertIn(b'5', output)
self.assertNotIn(b'Error', output)
def test_interact_completion(self):
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 565e42b04a6..47f51f1979f 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -1,4 +1,5 @@
import dis
+import gc
from itertools import combinations, product
import opcode
import sys
@@ -2472,6 +2473,13 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.check(insts, insts)
+ insts = [
+ ("LOAD_FAST", 0, 1),
+ ("DELETE_FAST", 0, 2),
+ ("POP_TOP", None, 3),
+ ]
+ self.check(insts, insts)
+
def test_unoptimized_if_aliased(self):
insts = [
("LOAD_FAST", 0, 1),
@@ -2606,6 +2614,22 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests):
]
self.cfg_optimization_test(insts, expected, consts=[None])
+ def test_del_in_finally(self):
+ # This loads `obj` onto the stack, executes `del obj`, then returns the
+ # `obj` from the stack. See gh-133371 for more details.
+ def create_obj():
+ obj = [42]
+ try:
+ return obj
+ finally:
+ del obj
+
+ obj = create_obj()
+ # The crash in the linked issue happens while running GC during
+ # interpreter finalization, so run it here manually.
+ gc.collect()
+ self.assertEqual(obj, [42])
+
if __name__ == "__main__":
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index 296d4b882e1..742ca8de1be 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -15,7 +15,8 @@ from textwrap import dedent
import doctest
import unittest
from test import support
-from test.support import import_helper, os_helper
+from test.support import cpython_only, import_helper, os_helper
+from test.support.import_helper import ensure_lazy_imports
from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests
@@ -36,6 +37,12 @@ except ImportError:
has_c_implementation = False
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pickle", {"re"})
+
+
class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
dump = staticmethod(pickle._dump)
dumps = staticmethod(pickle._dumps)
@@ -745,6 +752,7 @@ class CommandLineTest(unittest.TestCase):
expect = self.text_normalize(expect)
self.assertListEqual(res.splitlines(), expect.splitlines())
+ @support.force_not_colorized
def test_unknown_flag(self):
stderr = io.StringIO()
with self.assertRaises(SystemExit):
diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py
index b90edc05e04..818e807dd3a 100644
--- a/Lib/test/test_platform.py
+++ b/Lib/test/test_platform.py
@@ -794,6 +794,7 @@ class CommandLineTest(unittest.TestCase):
self.invoke_platform(*flags)
obj.assert_called_once_with(aliased, terse)
+ @support.force_not_colorized
def test_help(self):
output = io.StringIO()
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index b6a07f214fa..0817d0a87a3 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -1521,8 +1521,8 @@ class PosixTester(unittest.TestCase):
self.assertEqual(cm.exception.errno, errno.EINVAL)
os.close(os.pidfd_open(os.getpid(), 0))
- @unittest.skipUnless(hasattr(os, "link"), "test needs os.link()")
- @support.skip_android_selinux('hard links to symbolic links')
+ @os_helper.skip_unless_hardlink
+ @os_helper.skip_unless_symlink
def test_link_follow_symlinks(self):
default_follow = sys.platform.startswith(
('darwin', 'freebsd', 'netbsd', 'openbsd', 'dragonfly', 'sunos5'))
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
index dfbc2a06e73..f68996f72b1 100644
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -11,6 +11,9 @@ import re
import types
import unittest
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
# list, tuple and dict subclasses that do or don't overwrite __repr__
class list2(list):
pass
@@ -129,6 +132,10 @@ class QueryTestCase(unittest.TestCase):
self.b = list(range(200))
self.a[-12] = self.b
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pprint", {"dataclasses", "re"})
+
def test_init(self):
pp = pprint.PrettyPrinter()
pp = pprint.PrettyPrinter(indent=4, width=40, depth=5,
diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py
index d5a5a9738c2..a26a8c1d522 100644
--- a/Lib/test/test_pstats.py
+++ b/Lib/test/test_pstats.py
@@ -1,6 +1,7 @@
import unittest
from test import support
+from test.support.import_helper import ensure_lazy_imports
from io import StringIO
from pstats import SortKey
from enum import StrEnum, _test_simple_enum
@@ -10,6 +11,12 @@ import pstats
import tempfile
import cProfile
+class LazyImportTest(unittest.TestCase):
+ @support.cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("pstats", {"typing"})
+
+
class AddCallersTestCase(unittest.TestCase):
"""Tests for pstats.add_callers helper."""
diff --git a/Lib/test/test_pyrepl/support.py b/Lib/test/test_pyrepl/support.py
index 3692e164cb9..4f7f9d77933 100644
--- a/Lib/test/test_pyrepl/support.py
+++ b/Lib/test/test_pyrepl/support.py
@@ -113,9 +113,6 @@ handle_events_narrow_console = partial(
prepare_console=partial(prepare_console, width=10),
)
-reader_no_colors = partial(prepare_reader, can_colorize=False)
-reader_force_colors = partial(prepare_reader, can_colorize=True)
-
class FakeConsole(Console):
def __init__(self, events, encoding="utf-8") -> None:
diff --git a/Lib/test/test_pyrepl/test_eventqueue.py b/Lib/test/test_pyrepl/test_eventqueue.py
index afb55710342..edfe6ac4748 100644
--- a/Lib/test/test_pyrepl/test_eventqueue.py
+++ b/Lib/test/test_pyrepl/test_eventqueue.py
@@ -53,7 +53,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": "b"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "b")
@@ -63,7 +63,7 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"c": "d"}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "a")
@@ -73,13 +73,13 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "c")
- eq.push("d")
+ eq.push(b"d")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "d")
@@ -88,32 +88,32 @@ class EventQueueTestBase:
mock_keymap.compile_keymap.return_value = {"a": "b"}
eq = self.make_eventqueue()
eq.keymap = {b"a": {b"b": "c"}}
- eq.push("a")
+ eq.push(b"a")
mock_keymap.compile_keymap.assert_called()
self.assertTrue(eq.empty())
eq.flush_buf()
- eq.push("\033")
+ eq.push(b"\033")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\033")
- eq.push("b")
+ eq.push(b"b")
self.assertEqual(eq.events[1].evt, "key")
self.assertEqual(eq.events[1].data, "b")
def test_push_special_key(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("A")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"A")
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
def test_push_unrecognized_escape_sequence(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("\x1b")
- eq.push("[")
- eq.push("Z")
+ eq.push(b"\x1b")
+ eq.push(b"[")
+ eq.push(b"Z")
self.assertEqual(len(eq.events), 3)
self.assertEqual(eq.events[0].evt, "key")
self.assertEqual(eq.events[0].data, "\x1b")
@@ -122,12 +122,54 @@ class EventQueueTestBase:
self.assertEqual(eq.events[2].evt, "key")
self.assertEqual(eq.events[2].data, "Z")
- def test_push_unicode_character(self):
+ def test_push_unicode_character_as_str(self):
eq = self.make_eventqueue()
eq.keymap = {}
- eq.push("ч")
- self.assertEqual(eq.events[0].evt, "key")
- self.assertEqual(eq.events[0].data, "ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ч")
+ with self.assertRaises(AssertionError):
+ eq.push("ñ")
+
+ def test_push_unicode_character_two_bytes(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ encoded = "ч".encode(eq.encoding, "replace")
+ self.assertEqual(len(encoded), 2)
+
+ eq.push(encoded[0])
+ e = eq.get()
+ self.assertIsNone(e)
+
+ eq.push(encoded[1])
+ e = eq.get()
+ self.assertEqual(e.evt, "key")
+ self.assertEqual(e.data, "ч")
+
+ def test_push_single_chars_and_unicode_character_as_str(self):
+ eq = self.make_eventqueue()
+ eq.keymap = {}
+
+ def _event(evt, data, raw=None):
+ r = raw if raw is not None else data.encode(eq.encoding)
+ e = Event(evt, data, r)
+ return e
+
+ def _push(keys):
+ for k in keys:
+ eq.push(k)
+
+ self.assertIsInstance("ñ", str)
+
+ # If an exception happens during push, the existing events must be
+ # preserved and we can continue to push.
+ _push(b"b")
+ with self.assertRaises(AssertionError):
+ _push("ñ")
+ _push(b"a")
+
+ self.assertEqual(eq.get(), _event("key", "b"))
+ self.assertEqual(eq.get(), _event("key", "a"))
@unittest.skipIf(support.MS_WINDOWS, "No Unix event queue on Windows")
diff --git a/Lib/test/test_pyrepl/test_reader.py b/Lib/test/test_pyrepl/test_reader.py
index 8d7fcf538d2..4ee320a5a4d 100644
--- a/Lib/test/test_pyrepl/test_reader.py
+++ b/Lib/test/test_pyrepl/test_reader.py
@@ -4,20 +4,21 @@ import rlcompleter
from textwrap import dedent
from unittest import TestCase
from unittest.mock import MagicMock
+from test.support import force_colorized_test_class, force_not_colorized_test_class
from .support import handle_all_events, handle_events_narrow_console
from .support import ScreenEqualMixin, code_to_events
-from .support import prepare_console, reader_force_colors
-from .support import reader_no_colors as prepare_reader
+from .support import prepare_reader, prepare_console
from _pyrepl.console import Event
from _pyrepl.reader import Reader
-from _colorize import theme
+from _colorize import default_theme
-overrides = {"RESET": "z", "SOFT_KEYWORD": "K"}
-colors = {overrides.get(k, k[0].lower()): v for k, v in theme.items()}
+overrides = {"reset": "z", "soft_keyword": "K"}
+colors = {overrides.get(k, k[0].lower()): v for k, v in default_theme.syntax.items()}
+@force_not_colorized_test_class
class TestReader(ScreenEqualMixin, TestCase):
def test_calc_screen_wrap_simple(self):
events = code_to_events(10 * "a")
@@ -127,13 +128,6 @@ class TestReader(ScreenEqualMixin, TestCase):
reader.setpos_from_xy(0, 0)
self.assertEqual(reader.pos, 0)
- def test_control_characters(self):
- code = 'flag = "🏳️‍🌈"'
- events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
- self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
- self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors))
-
def test_setpos_from_xy_multiple_lines(self):
# fmt: off
code = (
@@ -364,6 +358,8 @@ class TestReader(ScreenEqualMixin, TestCase):
reader.setpos_from_xy(8, 0)
self.assertEqual(reader.pos, 7)
+@force_colorized_test_class
+class TestReaderInColor(ScreenEqualMixin, TestCase):
def test_syntax_highlighting_basic(self):
code = dedent(
"""\
@@ -403,7 +399,7 @@ class TestReader(ScreenEqualMixin, TestCase):
)
expected_sync = expected.format(a="", **colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected_sync)
self.assertEqual(reader.pos, 2**7 + 2**8)
@@ -416,7 +412,7 @@ class TestReader(ScreenEqualMixin, TestCase):
[Event(evt="key", data="up", raw=bytearray(b"\x1bOA"))] * 13,
code_to_events("async "),
)
- reader, _ = handle_all_events(more_events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(more_events)
self.assert_screen_equal(reader, expected_async)
self.assertEqual(reader.pos, 21)
self.assertEqual(reader.cxy, (6, 1))
@@ -433,7 +429,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -451,7 +447,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -471,7 +467,7 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(**colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
@@ -497,6 +493,13 @@ class TestReader(ScreenEqualMixin, TestCase):
"""
).format(OB="{", CB="}", **colors)
events = code_to_events(code)
- reader, _ = handle_all_events(events, prepare_reader=reader_force_colors)
+ reader, _ = handle_all_events(events)
self.assert_screen_equal(reader, code, clean=True)
self.assert_screen_equal(reader, expected)
+
+ def test_control_characters(self):
+ code = 'flag = "🏳️‍🌈"'
+ events = code_to_events(code)
+ reader, _ = handle_all_events(events)
+ self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True)
+ self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors))
diff --git a/Lib/test/test_pyrepl/test_unix_console.py b/Lib/test/test_pyrepl/test_unix_console.py
index 7acb84a94f7..c447b310c49 100644
--- a/Lib/test/test_pyrepl/test_unix_console.py
+++ b/Lib/test/test_pyrepl/test_unix_console.py
@@ -3,11 +3,12 @@ import os
import sys
import unittest
from functools import partial
-from test.support import os_helper
+from test.support import os_helper, force_not_colorized_test_class
+
from unittest import TestCase
from unittest.mock import MagicMock, call, patch, ANY
-from .support import handle_all_events, code_to_events, reader_no_colors
+from .support import handle_all_events, code_to_events
try:
from _pyrepl.console import Event
@@ -33,12 +34,10 @@ def unix_console(events, **kwargs):
handle_events_unix_console = partial(
handle_all_events,
- prepare_reader=reader_no_colors,
prepare_console=unix_console,
)
handle_events_narrow_unix_console = partial(
handle_all_events,
- prepare_reader=reader_no_colors,
prepare_console=partial(unix_console, width=5),
)
handle_events_short_unix_console = partial(
@@ -120,6 +119,7 @@ TERM_CAPABILITIES = {
)
@patch("termios.tcsetattr", lambda a, b, c: None)
@patch("os.write")
+@force_not_colorized_test_class
class TestConsole(TestCase):
def test_simple_addition(self, _os_write):
code = "12+34"
@@ -255,9 +255,7 @@ class TestConsole(TestCase):
# fmt: on
events = itertools.chain(code_to_events(code))
- reader, console = handle_events_short_unix_console(
- events, prepare_reader=reader_no_colors
- )
+ reader, console = handle_events_short_unix_console(events)
console.height = 2
console.getheightwidth = MagicMock(lambda _: (2, 80))
diff --git a/Lib/test/test_pyrepl/test_windows_console.py b/Lib/test/test_pyrepl/test_windows_console.py
index e95fec46a85..e7bab226b31 100644
--- a/Lib/test/test_pyrepl/test_windows_console.py
+++ b/Lib/test/test_pyrepl/test_windows_console.py
@@ -7,12 +7,13 @@ if sys.platform != "win32":
import itertools
from functools import partial
+from test.support import force_not_colorized_test_class
from typing import Iterable
from unittest import TestCase
from unittest.mock import MagicMock, call
from .support import handle_all_events, code_to_events
-from .support import reader_no_colors as default_prepare_reader
+from .support import prepare_reader as default_prepare_reader
try:
from _pyrepl.console import Event, Console
@@ -24,10 +25,12 @@ try:
MOVE_DOWN,
ERASE_IN_LINE,
)
+ import _pyrepl.windows_console as wc
except ImportError:
pass
+@force_not_colorized_test_class
class WindowsConsoleTests(TestCase):
def console(self, events, **kwargs) -> Console:
console = WindowsConsole()
@@ -350,8 +353,226 @@ class WindowsConsoleTests(TestCase):
Event(evt="key", data='\x1a', raw=bytearray(b'\x1a')),
],
)
- reader, _ = self.handle_events_narrow(events)
+ reader, con = self.handle_events_narrow(events)
self.assertEqual(reader.cxy, (2, 3))
+ con.restore()
+
+
+class WindowsConsoleGetEventTests(TestCase):
+ # Virtual-Key Codes: https://learn.microsoft.com/en-us/windows/win32/inputdev/virtual-key-codes
+ VK_BACK = 0x08
+ VK_RETURN = 0x0D
+ VK_LEFT = 0x25
+ VK_7 = 0x37
+ VK_M = 0x4D
+ # Used for miscellaneous characters; it can vary by keyboard.
+ # For the US standard keyboard, the '" key.
+ # For the German keyboard, the Ä key.
+ VK_OEM_7 = 0xDE
+
+ # State of control keys: https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ RIGHT_ALT_PRESSED = 0x0001
+ RIGHT_CTRL_PRESSED = 0x0004
+ LEFT_ALT_PRESSED = 0x0002
+ LEFT_CTRL_PRESSED = 0x0008
+ ENHANCED_KEY = 0x0100
+ SHIFT_PRESSED = 0x0010
+
+
+ def get_event(self, input_records, **kwargs) -> Console:
+ self.console = WindowsConsole(encoding='utf-8')
+ self.mock = MagicMock(side_effect=input_records)
+ self.console._read_input = self.mock
+ self.console._WindowsConsole__vt_support = kwargs.get("vt_support",
+ False)
+ event = self.console.get_event(block=False)
+ return event
+
+ def get_input_record(self, unicode_char, vcode=0, control=0):
+ return wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(
+ bKeyDown=True,
+ wRepeatCount=1,
+ wVirtualKeyCode=vcode,
+ wVirtualScanCode=0, # not used
+ uChar=wc.Char(unicode_char),
+ dwControlKeyState=control
+ )))
+
+ def test_EmptyBuffer(self):
+ self.assertEqual(self.get_event([None]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_WINDOW_BUFFER_SIZE_EVENT(self):
+ ir = wc.INPUT_RECORD(
+ wc.WINDOW_BUFFER_SIZE_EVENT,
+ wc.ConsoleEvent(WindowsBufferSizeEvent=
+ wc.WindowsBufferSizeEvent(
+ wc._COORD(0, 0))))
+ self.assertEqual(self.get_event([ir]), Event("resize", ""))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_KEY_EVENT_up_ignored(self):
+ ir = wc.INPUT_RECORD(
+ wc.KEY_EVENT,
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_unhandled_events(self):
+ for event in (wc.FOCUS_EVENT, wc.MENU_EVENT, wc.MOUSE_EVENT):
+ ir = wc.INPUT_RECORD(
+ event,
+ # fake data, nothing is read except bKeyDown
+ wc.ConsoleEvent(KeyEvent=
+ wc.KeyEvent(bKeyDown=False)))
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_enter(self):
+ ir = self.get_input_record("\r", self.VK_RETURN)
+ self.assertEqual(self.get_event([ir]), Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace(self):
+ ir = self.get_input_record("\x08", self.VK_BACK)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "backspace"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m(self):
+ ir = self.get_input_record("m", self.VK_M)
+ self.assertEqual(self.get_event([ir]), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_M(self):
+ ir = self.get_input_record("M", self.VK_M, self.SHIFT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "M"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left(self):
+ # VK_LEFT is sent as ENHANCED_KEY
+ ir = self.get_input_record("\x00", self.VK_LEFT, self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_CTRL_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_CTRL_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(
+ self.get_event([ir]), Event("key", "ctrl left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.RIGHT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ # self.mock is not called again, since the second time we read from the
+ # command queue
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_left_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "\x00", self.VK_LEFT, self.LEFT_ALT_PRESSED | self.ENHANCED_KEY)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(
+ self.console.get_event(), Event("key", "left"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED_and_LEFT_CTRL_PRESSED(self):
+ # For the shift keys, Windows does not send anything when
+ # ALT and CTRL are both pressed, so let's test with VK_M.
+ # get_event() receives this input, but does not
+ # generate an event.
+ # This is for e.g. an English keyboard layout, for a
+ # German layout this returns `µ`, see test_AltGr_m.
+ ir = self.get_input_record(
+ "\x00", self.VK_M, self.LEFT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), None)
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_LEFT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_m_RIGHT_ALT_PRESSED(self):
+ ir = self.get_input_record(
+ "m", vcode=self.VK_M, control=self.RIGHT_ALT_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033"))
+ self.assertEqual(self.console.get_event(), Event("key", "m"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_7(self):
+ # E.g. on a German keyboard layout, '{' is entered via
+ # AltGr + 7, where AltGr is the right Alt key on the keyboard.
+ # In this case, Windows automatically sets
+ # RIGHT_ALT_PRESSED = 0x0001 + LEFT_CTRL_PRESSED = 0x0008
+ # This can also be entered like
+ # LeftAlt + LeftCtrl + 7 or
+ # LeftAlt + RightCtrl + 7
+ # See https://learn.microsoft.com/en-us/windows/console/key-event-record-str
+ # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw
+ ir = self.get_input_record(
+ "{", vcode=self.VK_7,
+ control=self.RIGHT_ALT_PRESSED | self.LEFT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "{"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_AltGr_m(self):
+ # E.g. on a German keyboard layout, this yields 'µ'
+ # Let's use LEFT_ALT_PRESSED and RIGHT_CTRL_PRESSED this
+ # time, to cover that, too. See above in test_AltGr_7.
+ ir = self.get_input_record(
+ "µ", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED | self.RIGHT_CTRL_PRESSED)
+ self.assertEqual(self.get_event([ir]), Event("key", "µ"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_umlaut_a_german(self):
+ ir = self.get_input_record("ä", self.VK_OEM_7)
+ self.assertEqual(self.get_event([ir]), Event("key", "ä"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ # virtual terminal tests
+ # Note: wVirtualKeyCode, wVirtualScanCode and dwControlKeyState
+ # are always zero in this case.
+ # "\r" and backspace are handled specially, everything else
+ # is handled in "elif self.__vt_support:" in WindowsConsole.get_event().
+ # Hence, only one regular key ("m") and a terminal sequence
+ # are sufficient to test here, the real tests happen in test_eventqueue
+ # and test_keymap.
+
+ def test_enter_vt(self):
+ ir = self.get_input_record("\r")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "\n"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_backspace_vt(self):
+ ir = self.get_input_record("\x7f")
+ self.assertEqual(self.get_event([ir], vt_support=True),
+ Event("key", "backspace", b"\x7f"))
+ self.assertEqual(self.mock.call_count, 1)
+
+ def test_up_vt(self):
+ irs = [self.get_input_record(x) for x in "\x1b[A"]
+ self.assertEqual(self.get_event(irs, vt_support=True),
+ Event(evt='key', data='up', raw=bytearray(b'\x1b[A')))
+ self.assertEqual(self.mock.call_count, 3)
if __name__ == "__main__":
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 96f6cc86219..43957f525f1 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -1411,6 +1411,7 @@ class TestModule(unittest.TestCase):
class CommandLineTest(unittest.TestCase):
+ @support.force_not_colorized
def test_parse_args(self):
args, help_text = random._parse_args(shlex.split("--choice a b c"))
self.assertEqual(args.choice, ["a", "b", "c"])
diff --git a/Lib/test/test_remote_pdb.py b/Lib/test/test_remote_pdb.py
index 9fbe94fcdd6..aef8a6b0129 100644
--- a/Lib/test/test_remote_pdb.py
+++ b/Lib/test/test_remote_pdb.py
@@ -3,6 +3,7 @@ import time
import itertools
import json
import os
+import re
import signal
import socket
import subprocess
@@ -12,9 +13,9 @@ import textwrap
import threading
import unittest
import unittest.mock
-from contextlib import contextmanager, redirect_stdout, ExitStack
+from contextlib import closing, contextmanager, redirect_stdout, redirect_stderr, ExitStack
from pathlib import Path
-from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT
+from test.support import is_wasi, cpython_only, force_color, requires_subprocess, SHORT_TIMEOUT
from test.support.os_helper import temp_dir, TESTFN, unlink
from typing import Dict, List, Optional, Tuple, Union, Any
@@ -79,44 +80,6 @@ class MockSocketFile:
return results
-class MockDebuggerSocket:
- """Mock file-like simulating a connection to a _RemotePdb instance"""
-
- def __init__(self, incoming):
- self.incoming = iter(incoming)
- self.outgoing = []
- self.buffered = bytearray()
-
- def write(self, data: bytes) -> None:
- """Simulate write to socket."""
- self.buffered += data
-
- def flush(self) -> None:
- """Ensure each line is valid JSON."""
- lines = self.buffered.splitlines(keepends=True)
- self.buffered.clear()
- for line in lines:
- assert line.endswith(b"\n")
- self.outgoing.append(json.loads(line))
-
- def readline(self) -> bytes:
- """Read a line from the prepared input queue."""
- # Anything written must be flushed before trying to read,
- # since the read will be dependent upon the last write.
- assert not self.buffered
- try:
- item = next(self.incoming)
- if not isinstance(item, bytes):
- item = json.dumps(item).encode()
- return item + b"\n"
- except StopIteration:
- return b""
-
- def close(self) -> None:
- """No-op close implementation."""
- pass
-
-
class PdbClientTestCase(unittest.TestCase):
"""Tests for the _PdbClient class."""
@@ -124,8 +87,11 @@ class PdbClientTestCase(unittest.TestCase):
self,
*,
incoming,
- simulate_failure=None,
+ simulate_send_failure=False,
+ simulate_sigint_during_stdout_write=False,
+ use_interrupt_socket=False,
expected_outgoing=None,
+ expected_outgoing_signals=None,
expected_completions=None,
expected_exception=None,
expected_stdout="",
@@ -134,6 +100,8 @@ class PdbClientTestCase(unittest.TestCase):
):
if expected_outgoing is None:
expected_outgoing = []
+ if expected_outgoing_signals is None:
+ expected_outgoing_signals = []
if expected_completions is None:
expected_completions = []
if expected_state is None:
@@ -142,16 +110,6 @@ class PdbClientTestCase(unittest.TestCase):
expected_state.setdefault("write_failed", False)
messages = [m for source, m in incoming if source == "server"]
prompts = [m["prompt"] for source, m in incoming if source == "user"]
- sockfile = MockDebuggerSocket(messages)
- stdout = io.StringIO()
-
- if simulate_failure:
- sockfile.write = unittest.mock.Mock()
- sockfile.flush = unittest.mock.Mock()
- if simulate_failure == "write":
- sockfile.write.side_effect = OSError("write failed")
- elif simulate_failure == "flush":
- sockfile.flush.side_effect = OSError("flush failed")
input_iter = (m for source, m in incoming if source == "user")
completions = []
@@ -178,18 +136,60 @@ class PdbClientTestCase(unittest.TestCase):
reply = message["input"]
if isinstance(reply, BaseException):
raise reply
- return reply
+ if isinstance(reply, str):
+ return reply
+ return reply()
with ExitStack() as stack:
+ client_sock, server_sock = socket.socketpair()
+ stack.enter_context(closing(client_sock))
+ stack.enter_context(closing(server_sock))
+
+ server_sock = unittest.mock.Mock(wraps=server_sock)
+
+ client_sock.sendall(
+ b"".join(
+ (m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n"
+ for m in messages
+ )
+ )
+ client_sock.shutdown(socket.SHUT_WR)
+
+ if simulate_send_failure:
+ server_sock.sendall = unittest.mock.Mock(
+ side_effect=OSError("sendall failed")
+ )
+ client_sock.shutdown(socket.SHUT_RD)
+
+ stdout = io.StringIO()
+
+ if simulate_sigint_during_stdout_write:
+ orig_stdout_write = stdout.write
+
+ def sigint_stdout_write(s):
+ signal.raise_signal(signal.SIGINT)
+ return orig_stdout_write(s)
+
+ stdout.write = sigint_stdout_write
+
input_mock = stack.enter_context(
unittest.mock.patch("pdb.input", side_effect=mock_input)
)
stack.enter_context(redirect_stdout(stdout))
+ if use_interrupt_socket:
+ interrupt_sock = unittest.mock.Mock(spec=socket.socket)
+ mock_kill = None
+ else:
+ interrupt_sock = None
+ mock_kill = stack.enter_context(
+ unittest.mock.patch("os.kill", spec=os.kill)
+ )
+
client = _PdbClient(
- pid=0,
- sockfile=sockfile,
- interrupt_script="/a/b.py",
+ pid=12345,
+ server_socket=server_sock,
+ interrupt_sock=interrupt_sock,
)
if expected_exception is not None:
@@ -199,13 +199,12 @@ class PdbClientTestCase(unittest.TestCase):
client.cmdloop()
- actual_outgoing = sockfile.outgoing
- if simulate_failure:
- actual_outgoing += [
- json.loads(msg.args[0]) for msg in sockfile.write.mock_calls
- ]
+ sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls]
+ for msg in sent_msgs:
+ assert msg.endswith(b"\n")
+ actual_outgoing = [json.loads(msg) for msg in sent_msgs]
- self.assertEqual(sockfile.outgoing, expected_outgoing)
+ self.assertEqual(actual_outgoing, expected_outgoing)
self.assertEqual(completions, expected_completions)
if expected_stdout_substring and not expected_stdout:
self.assertIn(expected_stdout_substring, stdout.getvalue())
@@ -215,6 +214,20 @@ class PdbClientTestCase(unittest.TestCase):
actual_state = {k: getattr(client, k) for k in expected_state}
self.assertEqual(actual_state, expected_state)
+ if use_interrupt_socket:
+ outgoing_signals = [
+ signal.Signals(int.from_bytes(call.args[0]))
+ for call in interrupt_sock.sendall.call_args_list
+ ]
+ else:
+ assert mock_kill is not None
+ outgoing_signals = []
+ for call in mock_kill.call_args_list:
+ pid, signum = call.args
+ self.assertEqual(pid, 12345)
+ outgoing_signals.append(signal.Signals(signum))
+ self.assertEqual(outgoing_signals, expected_outgoing_signals)
+
def test_remote_immediately_closing_the_connection(self):
"""Test the behavior when the remote closes the connection immediately."""
incoming = []
@@ -409,11 +422,17 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "dumb"},
)
- def test_keyboard_interrupt_at_prompt(self):
- """Test signaling when a prompt gets a KeyboardInterrupt."""
+ def test_sigint_at_prompt(self):
+ """Test signaling when a prompt gets interrupted."""
incoming = [
("server", {"prompt": "(Pdb) ", "state": "pdb"}),
- ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
+ (
+ "user",
+ {
+ "prompt": "(Pdb) ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
]
self.do_test(
incoming=incoming,
@@ -423,6 +442,43 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "pdb"},
)
+ def test_sigint_at_continuation_prompt(self):
+ """Test signaling when a continuation prompt gets interrupted."""
+ incoming = [
+ ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
+ ("user", {"prompt": "(Pdb) ", "input": "if True:"}),
+ (
+ "user",
+ {
+ "prompt": "... ",
+ "input": lambda: signal.raise_signal(signal.SIGINT),
+ },
+ ),
+ ]
+ self.do_test(
+ incoming=incoming,
+ expected_outgoing=[
+ {"signal": "INT"},
+ ],
+ expected_state={"state": "pdb"},
+ )
+
+ def test_sigint_when_writing(self):
+ """Test siginaling when sys.stdout.write() gets interrupted."""
+ incoming = [
+ ("server", {"message": "Some message or other\n", "type": "info"}),
+ ]
+ for use_interrupt_socket in [False, True]:
+ with self.subTest(use_interrupt_socket=use_interrupt_socket):
+ self.do_test(
+ incoming=incoming,
+ simulate_sigint_during_stdout_write=True,
+ use_interrupt_socket=use_interrupt_socket,
+ expected_outgoing=[],
+ expected_outgoing_signals=[signal.SIGINT],
+ expected_stdout="Some message or other\n",
+ )
+
def test_eof_at_prompt(self):
"""Test signaling when a prompt gets an EOFError."""
incoming = [
@@ -478,20 +534,7 @@ class PdbClientTestCase(unittest.TestCase):
self.do_test(
incoming=incoming,
expected_outgoing=[{"signal": "INT"}],
- simulate_failure="write",
- expected_state={"write_failed": True},
- )
-
- def test_flush_failing(self):
- """Test terminating if flush fails due to a half closed socket."""
- incoming = [
- ("server", {"prompt": "(Pdb) ", "state": "pdb"}),
- ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
- ]
- self.do_test(
- incoming=incoming,
- expected_outgoing=[{"signal": "INT"}],
- simulate_failure="flush",
+ simulate_send_failure=True,
expected_state={"write_failed": True},
)
@@ -660,42 +703,7 @@ class PdbClientTestCase(unittest.TestCase):
},
{"reply": "xyz"},
],
- simulate_failure="write",
- expected_completions=[],
- expected_state={"state": "interact", "write_failed": True},
- )
-
- def test_flush_failure_during_completion(self):
- """Test failing to flush to the socket to request tab completions."""
- incoming = [
- ("server", {"prompt": ">>> ", "state": "interact"}),
- (
- "user",
- {
- "prompt": ">>> ",
- "completion_request": {
- "line": "xy",
- "begidx": 0,
- "endidx": 2,
- },
- "input": "xyz",
- },
- ),
- ]
- self.do_test(
- incoming=incoming,
- expected_outgoing=[
- {
- "complete": {
- "text": "xy",
- "line": "xy",
- "begidx": 0,
- "endidx": 2,
- }
- },
- {"reply": "xyz"},
- ],
- simulate_failure="flush",
+ simulate_send_failure=True,
expected_completions=[],
expected_state={"state": "interact", "write_failed": True},
)
@@ -1032,6 +1040,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=False,
+ colorize=False,
)
return x # This line won't be reached in debugging
@@ -1089,23 +1099,6 @@ class PdbConnectTestCase(unittest.TestCase):
client_file.write(json.dumps({"reply": command}).encode() + b"\n")
client_file.flush()
- def _send_interrupt(self, pid):
- """Helper to send an interrupt signal to the debugger."""
- # with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
- interrupt_script = TESTFN + "_interrupt_script.py"
- with open(interrupt_script, 'w') as f:
- f.write(
- 'import pdb, sys\n'
- 'print("Hello, world!")\n'
- 'if inst := pdb.Pdb._last_pdb_instance:\n'
- ' inst.set_trace(sys._getframe(1))\n'
- )
- self.addCleanup(unlink, interrupt_script)
- try:
- sys.remote_exec(pid, interrupt_script)
- except PermissionError:
- self.skipTest("Insufficient permissions to execute code in remote process")
-
def test_connect_and_basic_commands(self):
"""Test connecting to a remote debugger and sending basic commands."""
self._create_script()
@@ -1218,6 +1211,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
+ signal_raising_thread=True,
+ colorize=False,
)
print("Connected to debugger")
iterations = 50
@@ -1233,6 +1228,10 @@ class PdbConnectTestCase(unittest.TestCase):
self._create_script(script=script)
process, client_file = self._connect_and_get_client_file()
+ # Accept a 2nd connection from the subprocess to tell it about signals
+ signal_sock, _ = self.server_sock.accept()
+ self.addCleanup(signal_sock.close)
+
with kill_on_error(process):
# Skip initial messages until we get to the prompt
self._read_until_prompt(client_file)
@@ -1248,7 +1247,7 @@ class PdbConnectTestCase(unittest.TestCase):
break
# Inject a script to interrupt the running process
- self._send_interrupt(process.pid)
+ signal_sock.sendall(signal.SIGINT.to_bytes())
messages = self._read_until_prompt(client_file)
# Verify we got the keyboard interrupt message.
@@ -1304,6 +1303,8 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=fake_version,
+ signal_raising_thread=False,
+ colorize=False,
)
# This should print if the debugger detaches correctly
@@ -1431,5 +1432,152 @@ class PdbConnectTestCase(unittest.TestCase):
self.assertIn("Function returned: 42", stdout)
self.assertEqual(process.returncode, 0)
+
+def _supports_remote_attaching():
+ from contextlib import suppress
+ PROCESS_VM_READV_SUPPORTED = False
+
+ try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ except ImportError:
+ pass
+
+ return PROCESS_VM_READV_SUPPORTED
+
+
+@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled")
+@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32",
+ "Test only runs on Linux, Windows and MacOS")
+@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(),
+ "Testing on Linux requires process_vm_readv support")
+@cpython_only
+@requires_subprocess()
+class PdbAttachTestCase(unittest.TestCase):
+ def setUp(self):
+ # Create a server socket that will wait for the debugger to connect
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock.bind(('127.0.0.1', 0)) # Let OS assign port
+ self.sock.listen(1)
+ self.port = self.sock.getsockname()[1]
+ self._create_script()
+
+ def _create_script(self, script=None):
+ # Create a file for subprocess script
+ script = textwrap.dedent(
+ f"""
+ import socket
+ import time
+
+ def foo():
+ return bar()
+
+ def bar():
+ return baz()
+
+ def baz():
+ x = 1
+ # Trigger attach
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(('127.0.0.1', {self.port}))
+ sock.close()
+ count = 0
+ while x == 1 and count < 100:
+ count += 1
+ time.sleep(0.1)
+ return x
+
+ result = foo()
+ print(f"Function returned: {{result}}")
+ """
+ )
+
+ self.script_path = TESTFN + "_connect_test.py"
+ with open(self.script_path, 'w') as f:
+ f.write(script)
+
+ def tearDown(self):
+ self.sock.close()
+ try:
+ unlink(self.script_path)
+ except OSError:
+ pass
+
+ def do_integration_test(self, client_stdin):
+ process = subprocess.Popen(
+ [sys.executable, self.script_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+ self.addCleanup(process.stdout.close)
+ self.addCleanup(process.stderr.close)
+
+ # Wait for the process to reach our attachment point
+ self.sock.settimeout(10)
+ conn, _ = self.sock.accept()
+ conn.close()
+
+ client_stdin = io.StringIO(client_stdin)
+ client_stdout = io.StringIO()
+ client_stderr = io.StringIO()
+
+ self.addCleanup(client_stdin.close)
+ self.addCleanup(client_stdout.close)
+ self.addCleanup(client_stderr.close)
+ self.addCleanup(process.wait)
+
+ with (
+ unittest.mock.patch("sys.stdin", client_stdin),
+ redirect_stdout(client_stdout),
+ redirect_stderr(client_stderr),
+ unittest.mock.patch("sys.argv", ["pdb", "-p", str(process.pid)]),
+ ):
+ try:
+ pdb.main()
+ except PermissionError:
+ self.skipTest("Insufficient permissions for remote execution")
+
+ process.wait()
+ server_stdout = process.stdout.read()
+ server_stderr = process.stderr.read()
+
+ if process.returncode != 0:
+ print("server failed")
+ print(f"server stdout:\n{server_stdout}")
+ print(f"server stderr:\n{server_stderr}")
+
+ self.assertEqual(process.returncode, 0)
+ return {
+ "client": {
+ "stdout": client_stdout.getvalue(),
+ "stderr": client_stderr.getvalue(),
+ },
+ "server": {
+ "stdout": server_stdout,
+ "stderr": server_stderr,
+ },
+ }
+
+ def test_attach_to_process_without_colors(self):
+ with force_color(False):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("while x == 1", output["client"]["stdout"])
+ self.assertNotIn("\x1b", output["client"]["stdout"])
+
+ def test_attach_to_process_with_colors(self):
+ with force_color(True):
+ output = self.do_integration_test("ll\nx=42\n")
+ self.assertEqual(output["client"]["stderr"], "")
+ self.assertEqual(output["server"]["stderr"], "")
+
+ self.assertEqual(output["server"]["stdout"], "Function returned: 42\n")
+ self.assertIn("\x1b", output["client"]["stdout"])
+ self.assertNotIn("while x == 1", output["client"]["stdout"])
+ self.assertIn("while x == 1", re.sub("\x1b[^m]*m", "", output["client"]["stdout"]))
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py
index ffeb1fba7b8..ffad35092f9 100644
--- a/Lib/test/test_reprlib.py
+++ b/Lib/test/test_reprlib.py
@@ -3,6 +3,7 @@
Nick Mathewson
"""
+import annotationlib
import sys
import os
import shutil
@@ -11,7 +12,7 @@ import importlib.util
import unittest
import textwrap
-from test.support import verbose
+from test.support import verbose, EqualToForwardRef
from test.support.os_helper import create_empty_file
from reprlib import repr as r # Don't shadow builtin repr
from reprlib import Repr
@@ -829,5 +830,19 @@ class TestRecursiveRepr(unittest.TestCase):
self.assertEqual(type_params[0].__name__, 'T')
self.assertEqual(type_params[0].__bound__, str)
+ def test_annotations(self):
+ class My:
+ @recursive_repr()
+ def __repr__(self, default: undefined = ...):
+ return default
+
+ annotations = annotationlib.get_annotations(
+ My.__repr__, format=annotationlib.Format.FORWARDREF
+ )
+ self.assertEqual(
+ annotations,
+ {'default': EqualToForwardRef("undefined", owner=My.__repr__)}
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py
index f35571ea886..a13ddcb76b7 100644
--- a/Lib/test/test_shlex.py
+++ b/Lib/test/test_shlex.py
@@ -3,6 +3,7 @@ import itertools
import shlex
import string
import unittest
+from test.support import cpython_only
from test.support import import_helper
@@ -364,6 +365,7 @@ class ShlexTest(unittest.TestCase):
with self.assertRaises(AttributeError):
shlex_instance.punctuation_chars = False
+ @cpython_only
def test_lazy_imports(self):
import_helper.ensure_lazy_imports('shlex', {'collections', 're', 'os'})
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index ed01163074a..87991fbda4c 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -2153,6 +2153,10 @@ class TestArchives(BaseTest, unittest.TestCase):
def test_unpack_archive_bztar(self):
self.check_unpack_tarball('bztar')
+ @support.requires_zstd()
+ def test_unpack_archive_zstdtar(self):
+ self.check_unpack_tarball('zstdtar')
+
@support.requires_lzma()
@unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger")
def test_unpack_archive_xztar(self):
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index ace97ce0cbe..03c54151a22 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,8 +2,9 @@ import unittest
from unittest import mock
from test import support
from test.support import (
- is_apple, os_helper, refleak_helper, socket_helper, threading_helper
+ cpython_only, is_apple, os_helper, refleak_helper, socket_helper, threading_helper
)
+from test.support.import_helper import ensure_lazy_imports
import _thread as thread
import array
import contextlib
@@ -257,6 +258,12 @@ HAVE_SOCKET_HYPERV = _have_socket_hyperv()
# Size in bytes of the int type
SIZEOF_INT = array.array("i").itemsize
+class TestLazyImport(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("socket", {"array", "selectors"})
+
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
diff --git a/Lib/test/test_sqlite3/test_cli.py b/Lib/test/test_sqlite3/test_cli.py
index dcd90d11d46..ad0dcb3cccb 100644
--- a/Lib/test/test_sqlite3/test_cli.py
+++ b/Lib/test/test_sqlite3/test_cli.py
@@ -4,7 +4,12 @@ import unittest
from sqlite3.__main__ import main as cli
from test.support.os_helper import TESTFN, unlink
-from test.support import captured_stdout, captured_stderr, captured_stdin
+from test.support import (
+ captured_stdout,
+ captured_stderr,
+ captured_stdin,
+ force_not_colorized,
+)
class CommandLineInterface(unittest.TestCase):
@@ -32,6 +37,7 @@ class CommandLineInterface(unittest.TestCase):
self.assertEqual(out, "")
return err
+ @force_not_colorized
def test_cli_help(self):
out = self.expect_success("-h")
self.assertIn("usage: ", out)
diff --git a/Lib/test/test_string/test_string.py b/Lib/test/test_string/test_string.py
index f6d112d8a93..5394fe4e12c 100644
--- a/Lib/test/test_string/test_string.py
+++ b/Lib/test/test_string/test_string.py
@@ -2,6 +2,14 @@ import unittest
import string
from string import Template
import types
+from test.support import cpython_only
+from test.support.import_helper import ensure_lazy_imports
+
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("base64", {"re", "collections"})
class ModuleTest(unittest.TestCase):
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index 3cb755cd56c..ca35804fb36 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -162,6 +162,20 @@ class ProcessTestCase(BaseTestCase):
[sys.executable, "-c", "while True: pass"],
timeout=0.1)
+ def test_timeout_exception(self):
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = -1)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("-1 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+ try:
+ subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = 0)
+ except subprocess.TimeoutExpired as e:
+ self.assertIn("0 seconds", str(e))
+ else:
+ self.fail("Expected TimeoutExpired exception not raised")
+
def test_check_call_zero(self):
# check_call() function with zero return code
rc = subprocess.check_call(ZERO_RETURN_CMD)
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 468bac82924..8446da03e36 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -561,6 +561,7 @@ class TestSupport(unittest.TestCase):
['-Wignore', '-X', 'dev'],
['-X', 'faulthandler'],
['-X', 'importtime'],
+ ['-X', 'importtime=2'],
['-X', 'showrefcount'],
['-X', 'tracemalloc'],
['-X', 'tracemalloc=3'],
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 10c3e0e9a1d..59ef5c99309 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -1960,7 +1960,7 @@ def _supports_remote_attaching():
PROCESS_VM_READV_SUPPORTED = False
try:
- from _remotedebuggingmodule import PROCESS_VM_READV_SUPPORTED
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
except ImportError:
pass
@@ -2101,7 +2101,7 @@ print("Remote script executed successfully!")
prologue = '''\
import sys
def audit_hook(event, arg):
- print(f"Audit event: {event}, arg: {arg}")
+ print(f"Audit event: {event}, arg: {arg}".encode("ascii", errors="replace"))
sys.addaudithook(audit_hook)
'''
script = '''
@@ -2196,6 +2196,64 @@ this is invalid python code
self.assertIn(b"Remote debugging is not enabled", err)
self.assertEqual(out, b"")
+class TestSysJIT(unittest.TestCase):
+
+ def test_jit_is_available(self):
+ available = sys._jit.is_available()
+ script = f"import sys; assert sys._jit.is_available() is {available}"
+ assert_python_ok("-c", script, PYTHON_JIT="0")
+ assert_python_ok("-c", script, PYTHON_JIT="1")
+
+ def test_jit_is_enabled(self):
+ available = sys._jit.is_available()
+ script = "import sys; assert sys._jit.is_enabled() is {enabled}"
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
+
+ def test_jit_is_active(self):
+ available = sys._jit.is_available()
+ script = textwrap.dedent(
+ """
+ import _testcapi
+ import _testinternalcapi
+ import sys
+
+ def frame_0_interpreter() -> None:
+ assert sys._jit.is_active() is False
+
+ def frame_1_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_0_interpreter()
+ assert sys._jit.is_active() is False
+
+ def frame_2_jit(expected: bool) -> None:
+ # Inlined into the last loop of frame_3_jit:
+ assert sys._jit.is_active() is expected
+ # Insert C frame:
+ _testcapi.pyobject_vectorcall(frame_1_interpreter, None, None)
+ assert sys._jit.is_active() is expected
+
+ def frame_3_jit() -> None:
+ # JITs just before the last loop:
+ for i in range(_testinternalcapi.TIER2_THRESHOLD + 1):
+ # Careful, doing this in the reverse order breaks tracing:
+ expected = {enabled} and i == _testinternalcapi.TIER2_THRESHOLD
+ assert sys._jit.is_active() is expected
+ frame_2_jit(expected)
+ assert sys._jit.is_active() is expected
+
+ def frame_4_interpreter() -> None:
+ assert sys._jit.is_active() is False
+ frame_3_jit()
+ assert sys._jit.is_active() is False
+
+ assert sys._jit.is_active() is False
+ frame_4_interpreter()
+ assert sys._jit.is_active() is False
+ """
+ )
+ assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0")
+ assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1")
if __name__ == "__main__":
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index fcbaf854cc2..2d9649237a9 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -38,6 +38,10 @@ try:
import lzma
except ImportError:
lzma = None
+try:
+ from compression import zstd
+except ImportError:
+ zstd = None
def sha256sum(data):
return sha256(data).hexdigest()
@@ -48,6 +52,7 @@ tarname = support.findfile("testtar.tar", subdir="archivetestdata")
gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
xzname = os.path.join(TEMPDIR, "testtar.tar.xz")
+zstname = os.path.join(TEMPDIR, "testtar.tar.zst")
tmpname = os.path.join(TEMPDIR, "tmp.tar")
dotlessname = os.path.join(TEMPDIR, "testtar")
@@ -90,6 +95,12 @@ class LzmaTest:
open = lzma.LZMAFile if lzma else None
taropen = tarfile.TarFile.xzopen
+@support.requires_zstd()
+class ZstdTest:
+ tarname = zstname
+ suffix = 'zst'
+ open = zstd.ZstdFile if zstd else None
+ taropen = tarfile.TarFile.zstopen
class ReadTest(TarTest):
@@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest):
class LzmaUstarReadTest(LzmaTest, UstarReadTest):
pass
+class ZstdUstarReadTest(ZstdTest, UstarReadTest):
+ pass
class ListTest(ReadTest, unittest.TestCase):
@@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest):
class LzmaListTest(LzmaTest, ListTest):
pass
+class ZstdListTest(ZstdTest, ListTest):
+ pass
class CommonReadTest(ReadTest):
@@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase):
class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase):
pass
+class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase):
+ pass
class StreamReadTest(CommonReadTest, unittest.TestCase):
@@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest):
class LzmaStreamReadTest(LzmaTest, StreamReadTest):
pass
+class ZstdStreamReadTest(ZstdTest, StreamReadTest):
+ pass
+
class TarStreamModeReadTest(StreamModeTest, unittest.TestCase):
def test_stream_mode_no_cache(self):
@@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest):
class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest):
pass
+class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest):
+ pass
+
class DetectReadTest(TarTest, unittest.TestCase):
def _testfunc_file(self, name, mode):
try:
@@ -986,6 +1009,8 @@ class Bz2DetectReadTest(Bz2Test, DetectReadTest):
class LzmaDetectReadTest(LzmaTest, DetectReadTest):
pass
+class ZstdDetectReadTest(ZstdTest, DetectReadTest):
+ pass
class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase):
"""
@@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest):
class LzmaWriteTest(LzmaTest, WriteTest):
pass
+class ZstdWriteTest(ZstdTest, WriteTest):
+ pass
class StreamWriteTest(WriteTestBase, unittest.TestCase):
@@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest):
class LzmaStreamWriteTest(LzmaTest, StreamWriteTest):
decompressor = lzma.LZMADecompressor if lzma else None
+class ZstdStreamWriteTest(ZstdTest, StreamWriteTest):
+ decompressor = zstd.ZstdDecompressor if zstd else None
+
class _CompressedWriteTest(TarTest):
# This is not actually a standalone test.
# It does not inherit WriteTest because it only makes sense with gz,bz2
@@ -2042,6 +2072,14 @@ class LzmaCreateTest(LzmaTest, CreateTest):
tobj.add(self.file_path)
+class ZstdCreateTest(ZstdTest, CreateTest):
+
+ # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel.
+ # It does not allow for level to be specified when reading.
+ def test_create_with_level(self):
+ with tarfile.open(tmpname, self.mode, level=1) as tobj:
+ tobj.add(self.file_path)
+
class CreateWithXModeTest(CreateTest):
prefix = "x"
@@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase):
class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase):
pass
+class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase):
+ pass
class LimitsTest(unittest.TestCase):
@@ -2835,7 +2875,7 @@ class CommandLineTest(unittest.TestCase):
support.findfile('tokenize_tests-no-coding-cookie-'
'and-utf8-bom-sig-only.txt',
subdir='tokenizedata')]
- for filetype in (GzipTest, Bz2Test, LzmaTest):
+ for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest):
if not filetype.open:
continue
try:
@@ -4257,7 +4297,7 @@ def setUpModule():
data = fobj.read()
# Create compressed tarfiles.
- for c in GzipTest, Bz2Test, LzmaTest:
+ for c in GzipTest, Bz2Test, LzmaTest, ZstdTest:
if c.open:
os_helper.unlink(c.tarname)
testtarnames.append(c.tarname)
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 814c00ca0fd..4ab38c2598b 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -5,7 +5,7 @@ Tests for the threading module.
import test.support
from test.support import threading_helper, requires_subprocess, requires_gil_enabled
from test.support import verbose, cpython_only, os_helper
-from test.support.import_helper import import_module
+from test.support.import_helper import ensure_lazy_imports, import_module
from test.support.script_helper import assert_python_ok, assert_python_failure
from test.support import force_not_colorized
@@ -121,6 +121,10 @@ class ThreadTests(BaseTestCase):
maxDiff = 9999
@cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("threading", {"functools", "warnings"})
+
+ @cpython_only
def test_name(self):
def func(): pass
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 683486e9aca..b9be87f357f 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -37,6 +37,12 @@ test_code.co_positions = lambda _: iter([(6, 6, 0, 0)])
test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals'])
test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next', 'tb_lasti'])
+color_overrides = {"reset": "z", "filename": "fn", "error_highlight": "E"}
+colors = {
+ color_overrides.get(k, k[0].lower()): v
+ for k, v in _colorize.default_theme.traceback.items()
+}
+
LEVENSHTEIN_DATA_FILE = Path(__file__).parent / 'levenshtein_examples.json'
@@ -4721,6 +4727,8 @@ class MiscTest(unittest.TestCase):
class TestColorizedTraceback(unittest.TestCase):
+ maxDiff = None
+
def test_colorized_traceback(self):
def foo(*args):
x = {'a':{'b': None}}
@@ -4743,9 +4751,9 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
lines = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
+ red = colors["e"]
+ boldr = colors["E"]
+ reset = colors["z"]
self.assertIn("y = " + red + "x['a']['b']" + reset + boldr + "['c']" + reset, lines)
self.assertIn("return " + red + "(lambda *args: foo(*args))" + reset + boldr + "(1,2,3,4)" + reset, lines)
self.assertIn("return (lambda *args: " + red + "foo" + reset + boldr + "(*args)" + reset + ")(1,2,3,4)", lines)
@@ -4761,18 +4769,16 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
actual = "".join(exc.format(colorize=True))
- red = _colorize.ANSIColors.RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- boldr = _colorize.ANSIColors.BOLD_RED
- reset = _colorize.ANSIColors.RESET
- expected = "".join([
- f' File {magenta}"<string>"{reset}, line {magenta}1{reset}\n',
- f' a {boldr}${reset} b\n',
- f' {boldr}^{reset}\n',
- f'{boldm}SyntaxError{reset}: {magenta}invalid syntax{reset}\n']
- )
- self.assertIn(expected, actual)
+ def expected(t, m, fn, l, f, E, e, z):
+ return "".join(
+ [
+ f' File {fn}"<string>"{z}, line {l}1{z}\n',
+ f' a {E}${z} b\n',
+ f' {E}^{z}\n',
+ f'{t}SyntaxError{z}: {m}invalid syntax{z}\n'
+ ]
+ )
+ self.assertIn(expected(**colors), actual)
def test_colorized_traceback_is_the_default(self):
def foo():
@@ -4788,23 +4794,21 @@ class TestColorizedTraceback(unittest.TestCase):
exception_print(e)
actual = tbstderr.getvalue().splitlines()
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
- expected = ['Traceback (most recent call last):',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+5}{reset}, in {magenta}test_colorized_traceback_is_the_default{reset}',
- f' {red}foo{reset+boldr}(){reset}',
- f' {red}~~~{reset+boldr}^^{reset}',
- f' File {magenta}"{__file__}"{reset}, '
- f'line {magenta}{lno_foo+1}{reset}, in {magenta}foo{reset}',
- f' {red}1{reset+boldr}/{reset+red}0{reset}',
- f' {red}~{reset+boldr}^{reset+red}~{reset}',
- f'{boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ 'Traceback (most recent call last):',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+5}{z}, in {f}test_colorized_traceback_is_the_default{z}',
+ f' {e}foo{z}{E}(){z}',
+ f' {e}~~~{z}{E}^^{z}',
+ f' File {fn}"{__file__}"{z}, '
+ f'line {l}{lno_foo+1}{z}, in {f}foo{z}',
+ f' {e}1{z}{E}/{z}{e}0{z}',
+ f' {e}~{z}{E}^{z}{e}~{z}',
+ f'{t}ZeroDivisionError{z}: {m}division by zero{z}',
+ ]
+ self.assertEqual(actual, expected(**colors))
def test_colorized_traceback_from_exception_group(self):
def foo():
@@ -4822,33 +4826,31 @@ class TestColorizedTraceback(unittest.TestCase):
e, capture_locals=True
)
- red = _colorize.ANSIColors.RED
- boldr = _colorize.ANSIColors.BOLD_RED
- magenta = _colorize.ANSIColors.MAGENTA
- boldm = _colorize.ANSIColors.BOLD_MAGENTA
- reset = _colorize.ANSIColors.RESET
lno_foo = foo.__code__.co_firstlineno
actual = "".join(exc.format(colorize=True)).splitlines()
- expected = [f" + Exception Group Traceback (most recent call last):",
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+9}{reset}, in {magenta}test_colorized_traceback_from_exception_group{reset}',
- f' | {red}foo{reset}{boldr}(){reset}',
- f' | {red}~~~{reset}{boldr}^^{reset}',
- f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
- f" | foo = {foo}",
- f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+6}{reset}, in {magenta}foo{reset}',
- f' | raise ExceptionGroup("test", exceptions)',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ExceptionGroup{reset}: {magenta}test (1 sub-exception){reset}',
- f' +-+---------------- 1 ----------------',
- f' | Traceback (most recent call last):',
- f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+3}{reset}, in {magenta}foo{reset}',
- f' | {red}1 {reset}{boldr}/{reset}{red} 0{reset}',
- f' | {red}~~{reset}{boldr}^{reset}{red}~~{reset}',
- f" | exceptions = [ZeroDivisionError('division by zero')]",
- f' | {boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}',
- f' +------------------------------------']
- self.assertEqual(actual, expected)
+ def expected(t, m, fn, l, f, E, e, z):
+ return [
+ f" + Exception Group Traceback (most recent call last):",
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+9}{z}, in {f}test_colorized_traceback_from_exception_group{z}',
+ f' | {e}foo{z}{E}(){z}',
+ f' | {e}~~~{z}{E}^^{z}',
+ f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])",
+ f" | foo = {foo}",
+ f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+6}{z}, in {f}foo{z}',
+ f' | raise ExceptionGroup("test", exceptions)',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ExceptionGroup{z}: {m}test (1 sub-exception){z}',
+ f' +-+---------------- 1 ----------------',
+ f' | Traceback (most recent call last):',
+ f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+3}{z}, in {f}foo{z}',
+ f' | {e}1 {z}{E}/{z}{e} 0{z}',
+ f' | {e}~~{z}{E}^{z}{e}~~{z}',
+ f" | exceptions = [ZeroDivisionError('division by zero')]",
+ f' | {t}ZeroDivisionError{z}: {m}division by zero{z}',
+ f' +------------------------------------',
+ ]
+ self.assertEqual(actual, expected(**colors))
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 90de828cc71..c965860fbb1 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -1551,7 +1551,8 @@ class Pathname_Tests(unittest.TestCase):
urllib.request.url2pathname(url, require_scheme=True),
expected_path)
- error_subtests = [
+ def test_url2pathname_require_scheme_errors(self):
+ subtests = [
'',
':',
'foo',
@@ -1561,13 +1562,20 @@ class Pathname_Tests(unittest.TestCase):
'data:file:foo',
'data:file://foo',
]
- for url in error_subtests:
+ for url in subtests:
with self.subTest(url=url):
self.assertRaises(
urllib.error.URLError,
urllib.request.url2pathname,
url, require_scheme=True)
+ def test_url2pathname_resolve_host(self):
+ fn = urllib.request.url2pathname
+ sep = os.path.sep
+ self.assertEqual(fn('//127.0.0.1/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), f'{sep}foo{sep}bar')
+ self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar', resolve_host=True), f'{sep}foo{sep}bar')
+
@unittest.skipUnless(sys.platform == 'win32',
'test specific to Windows pathnames.')
def test_url2pathname_win(self):
@@ -1598,6 +1606,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertEqual(fn('//server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('////server/path/to/file'), '\\\\server\\path\\to\\file')
self.assertEqual(fn('/////server/path/to/file'), '\\\\server\\path\\to\\file')
+ self.assertEqual(fn('//127.0.0.1/path/to/file'), '\\\\127.0.0.1\\path\\to\\file')
# Localhost paths
self.assertEqual(fn('//localhost/C:/path/to/file'), 'C:\\path\\to\\file')
self.assertEqual(fn('//localhost/C|/path/to/file'), 'C:\\path\\to\\file')
@@ -1622,8 +1631,7 @@ class Pathname_Tests(unittest.TestCase):
self.assertRaises(urllib.error.URLError, fn, '//:80/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//:/foo/bar')
self.assertRaises(urllib.error.URLError, fn, '//c:80/foo/bar')
- self.assertEqual(fn('//127.0.0.1/foo/bar'), '/foo/bar')
- self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), '/foo/bar')
+ self.assertRaises(urllib.error.URLError, fn, '//127.0.0.1/foo/bar')
@unittest.skipUnless(os_helper.FS_NONASCII, 'need os_helper.FS_NONASCII')
def test_url2pathname_nonascii(self):
diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py
index 7c8a82d821a..ae898150658 100644
--- a/Lib/test/test_zipfile/test_core.py
+++ b/Lib/test/test_zipfile/test_core.py
@@ -23,11 +23,13 @@ from test import archiver_tests
from test.support import script_helper, os_helper
from test.support import (
findfile, requires_zlib, requires_bz2, requires_lzma,
- captured_stdout, captured_stderr, requires_subprocess,
+ requires_zstd, captured_stdout, captured_stderr, requires_subprocess,
+ cpython_only
)
from test.support.os_helper import (
TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath
)
+from test.support.import_helper import ensure_lazy_imports
TESTFN2 = TESTFN + "2"
@@ -49,6 +51,13 @@ def get_files(test):
yield f
test.assertFalse(f.closed)
+
+class LazyImportTest(unittest.TestCase):
+ @cpython_only
+ def test_lazy_import(self):
+ ensure_lazy_imports("zipfile", {"typing"})
+
+
class AbstractTestsWithSourceFile:
@classmethod
def setUpClass(cls):
@@ -693,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractTestZip64InSmallFiles:
# These tests test the ZIP64 functionality without using large files,
@@ -1270,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class AbstractWriterTests:
@@ -1339,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase):
class LzmaWriterTests(AbstractWriterTests, unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdWriterTests(AbstractWriterTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
class PyZipFileTests(unittest.TestCase):
def assertCompiledIn(self, name, namelist):
@@ -2669,6 +2689,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
b'\x00>\x00\x00\x00\x00\x00')
+@requires_zstd()
+class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
+ zip_with_bad_crc = (
+ b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00'
+ b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00'
+ b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00'
+ b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK'
+ b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00'
+ b'\x00\x00')
class DecryptionTests(unittest.TestCase):
"""Check that ZIP decryption works. Since the library does not
@@ -2896,6 +2927,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
+@requires_zstd()
+class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
+ unittest.TestCase):
+ compression = zipfile.ZIP_ZSTANDARD
# Provide the tell() method but not seek()
class Tellable:
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
new file mode 100644
index 00000000000..f4a25376e52
--- /dev/null
+++ b/Lib/test/test_zstd.py
@@ -0,0 +1,2507 @@
+import array
+import gc
+import io
+import pathlib
+import random
+import re
+import os
+import unittest
+import tempfile
+import threading
+
+from test.support.import_helper import import_module
+from test.support import threading_helper
+from test.support import _1M
+from test.support import Py_GIL_DISABLED
+
+_zstd = import_module("_zstd")
+zstd = import_module("compression.zstd")
+
+from compression.zstd import (
+ open,
+ compress,
+ decompress,
+ ZstdCompressor,
+ ZstdDecompressor,
+ ZstdDict,
+ ZstdError,
+ zstd_version,
+ zstd_version_info,
+ COMPRESSION_LEVEL_DEFAULT,
+ get_frame_info,
+ get_frame_size,
+ finalize_dict,
+ train_dict,
+ CompressionParameter,
+ DecompressionParameter,
+ Strategy,
+ ZstdFile,
+)
+
+_1K = 1024
+_130_1K = 130 * _1K
+DICT_SIZE1 = 3*_1K
+
+DAT_130K_D = None
+DAT_130K_C = None
+
+DECOMPRESSED_DAT = None
+COMPRESSED_DAT = None
+
+DECOMPRESSED_100_PLUS_32KB = None
+COMPRESSED_100_PLUS_32KB = None
+
+SKIPPABLE_FRAME = None
+
+THIS_FILE_BYTES = None
+THIS_FILE_STR = None
+COMPRESSED_THIS_FILE = None
+
+COMPRESSED_BOGUS = None
+
+SAMPLES = None
+
+TRAINED_DICT = None
+
+SUPPORT_MULTITHREADING = False
+
+def setUpModule():
+ global SUPPORT_MULTITHREADING
+ SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0)
+ # uncompressed size 130KB, more than a zstd block.
+ # with a frame epilogue, 4 bytes checksum.
+ global DAT_130K_D
+ DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)])
+
+ global DAT_130K_C
+ DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1})
+
+ global DECOMPRESSED_DAT
+ DECOMPRESSED_DAT = b'abcdefg123456' * 1000
+
+ global COMPRESSED_DAT
+ COMPRESSED_DAT = compress(DECOMPRESSED_DAT)
+
+ global DECOMPRESSED_100_PLUS_32KB
+ DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K)
+
+ global COMPRESSED_100_PLUS_32KB
+ COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB)
+
+ global SKIPPABLE_FRAME
+ SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \
+ (32*_1K).to_bytes(4, byteorder='little') + \
+ b'a' * (32*_1K)
+
+ global THIS_FILE_BYTES, THIS_FILE_STR
+ with io.open(os.path.abspath(__file__), 'rb') as f:
+ THIS_FILE_BYTES = f.read()
+ THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES)
+ THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8')
+
+ global COMPRESSED_THIS_FILE
+ COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES)
+
+ global COMPRESSED_BOGUS
+ COMPRESSED_BOGUS = DECOMPRESSED_DAT
+
+ # dict data
+ words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue',
+ b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive',
+ b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird']
+ lst = []
+ for i in range(300):
+ sample = [b'%s = %d' % (random.choice(words), random.randrange(100))
+ for j in range(20)]
+ sample = b'\n'.join(sample)
+
+ lst.append(sample)
+ global SAMPLES
+ SAMPLES = lst
+ assert len(SAMPLES) > 10
+
+ global TRAINED_DICT
+ TRAINED_DICT = train_dict(SAMPLES, 3*_1K)
+ assert len(TRAINED_DICT.dict_content) <= 3*_1K
+
+
+class FunctionsTestCase(unittest.TestCase):
+
+ def test_version(self):
+ s = ".".join((str(i) for i in zstd_version_info))
+ self.assertEqual(s, zstd_version)
+
+ def test_compressionLevel_values(self):
+ min, max = CompressionParameter.compression_level.bounds()
+ self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int)
+ self.assertIs(type(min), int)
+ self.assertIs(type(max), int)
+ self.assertLess(min, max)
+
+ def test_roundtrip_default(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ dat1 = compress(raw_dat)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_roundtrip_level(self):
+ raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+
+ for level in range(max(-20, level_min), level_max + 1):
+ dat1 = compress(raw_dat, level)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, raw_dat)
+
+ def test_get_frame_info(self):
+ # no dict
+ info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20])
+ self.assertEqual(info.decompressed_size, 32 * _1K + 100)
+ self.assertEqual(info.dictionary_id, 0)
+
+ # use dict
+ dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT)
+ info = get_frame_info(dat)
+ self.assertEqual(info.decompressed_size, 345)
+ self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id)
+
+ with self.assertRaisesRegex(ZstdError, "not less than the frame header"):
+ get_frame_info(b"aaaaaaaaaaaaaa")
+
+ def test_get_frame_size(self):
+ size = get_frame_size(COMPRESSED_100_PLUS_32KB)
+ self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB))
+
+ with self.assertRaisesRegex(ZstdError, "not less than this complete frame"):
+ get_frame_size(b"aaaaaaaaaaaaaa")
+
+ def test_decompress_2x130_1K(self):
+ decompressed_size = get_frame_info(DAT_130K_C).decompressed_size
+ self.assertEqual(decompressed_size, _130_1K)
+
+ dat = decompress(DAT_130K_C + DAT_130K_C)
+ self.assertEqual(len(dat), 2 * _130_1K)
+
+
+class CompressorTestCase(unittest.TestCase):
+
+ def test_simple_compress_bad_args(self):
+ # ZstdCompressor
+ self.assertRaises(TypeError, ZstdCompressor, [])
+ self.assertRaises(TypeError, ZstdCompressor, level=3.14)
+ self.assertRaises(TypeError, ZstdCompressor, level="abc")
+ self.assertRaises(TypeError, ZstdCompressor, options=b"abc")
+
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
+ self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})
+
+ with self.assertRaises(ValueError):
+ ZstdCompressor(2**31)
+ with self.assertRaises(ValueError):
+ ZstdCompressor(options={2**31: 100})
+
+ with self.assertRaises(ZstdError):
+ ZstdCompressor(options={CompressionParameter.window_log: 100})
+ with self.assertRaises(ZstdError):
+ ZstdCompressor(options={3333: 100})
+
+ # Method bad arguments
+ zc = ZstdCompressor()
+ self.assertRaises(TypeError, zc.compress)
+ self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar")
+ self.assertRaises(TypeError, zc.compress, "str")
+ self.assertRaises((TypeError, ValueError), zc.flush, b"foo")
+ self.assertRaises(TypeError, zc.flush, b"blah", 1)
+
+ self.assertRaises(ValueError, zc.compress, b'', -1)
+ self.assertRaises(ValueError, zc.compress, b'', 3)
+ self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0
+ self.assertRaises(ValueError, zc.flush, 3)
+
+ zc.compress(b'')
+ zc.compress(b'', zc.CONTINUE)
+ zc.compress(b'', zc.FLUSH_BLOCK)
+ zc.compress(b'', zc.FLUSH_FRAME)
+ empty = zc.flush()
+ zc.flush(zc.FLUSH_BLOCK)
+ zc.flush(zc.FLUSH_FRAME)
+
+ def test_compress_parameters(self):
+ d = {CompressionParameter.compression_level : 10,
+
+ CompressionParameter.window_log : 12,
+ CompressionParameter.hash_log : 10,
+ CompressionParameter.chain_log : 12,
+ CompressionParameter.search_log : 12,
+ CompressionParameter.min_match : 4,
+ CompressionParameter.target_length : 12,
+ CompressionParameter.strategy : Strategy.lazy,
+
+ CompressionParameter.enable_long_distance_matching : 1,
+ CompressionParameter.ldm_hash_log : 12,
+ CompressionParameter.ldm_min_match : 11,
+ CompressionParameter.ldm_bucket_size_log : 5,
+ CompressionParameter.ldm_hash_rate_log : 12,
+
+ CompressionParameter.content_size_flag : 1,
+ CompressionParameter.checksum_flag : 1,
+ CompressionParameter.dict_id_flag : 0,
+
+ CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0,
+ CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0,
+ }
+ ZstdCompressor(options=d)
+
+ # larger than signed int, ValueError
+ d1 = d.copy()
+ d1[CompressionParameter.ldm_bucket_size_log] = 2**31
+ self.assertRaises(ValueError, ZstdCompressor, options=d1)
+
+ # clamp compressionLevel
+ level_min, level_max = CompressionParameter.compression_level.bounds()
+ compress(b'', level_max+1)
+ compress(b'', level_min-1)
+
+ compress(b'', options={CompressionParameter.compression_level:level_max+1})
+ compress(b'', options={CompressionParameter.compression_level:level_min-1})
+
+ # zstd lib doesn't support MT compression
+ if not SUPPORT_MULTITHREADING:
+ with self.assertRaises(ZstdError):
+ ZstdCompressor(options={CompressionParameter.nb_workers:4})
+ with self.assertRaises(ZstdError):
+ ZstdCompressor(options={CompressionParameter.job_size:4})
+ with self.assertRaises(ZstdError):
+ ZstdCompressor(options={CompressionParameter.overlap_log:4})
+
+ # out of bounds error msg
+ option = {CompressionParameter.window_log:100}
+ with self.assertRaisesRegex(ZstdError,
+ (r'Error when setting zstd compression parameter "window_log", '
+ r'it should \d+ <= value <= \d+, provided value is 100\. '
+ r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')):
+ compress(b'', options=option)
+
+ def test_unknown_compression_parameter(self):
+ KEY = 100001234
+ option = {CompressionParameter.compression_level: 10,
+ KEY: 200000000}
+ pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \
+ % KEY
+ with self.assertRaisesRegex(ZstdError, pattern):
+ ZstdCompressor(options=option)
+
+ @unittest.skipIf(not SUPPORT_MULTITHREADING,
+ "zstd build doesn't support multi-threaded compression")
+ def test_zstd_multithread_compress(self):
+ size = 40*_1M
+ b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES))
+
+ options = {CompressionParameter.compression_level : 4,
+ CompressionParameter.nb_workers : 2}
+
+ # compress()
+ dat1 = compress(b, options=options)
+ dat2 = decompress(dat1)
+ self.assertEqual(dat2, b)
+
+ # ZstdCompressor
+ c = ZstdCompressor(options=options)
+ dat1 = c.compress(b, c.CONTINUE)
+ dat2 = c.compress(b, c.FLUSH_BLOCK)
+ dat3 = c.compress(b, c.FLUSH_FRAME)
+ dat4 = decompress(dat1+dat2+dat3)
+ self.assertEqual(dat4, b * 3)
+
+ # ZstdFile
+ with ZstdFile(io.BytesIO(), 'w', options=options) as f:
+ f.write(b)
+
+ def test_compress_flushblock(self):
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK)
+ self.assertEqual(c.last_mode, c.FLUSH_BLOCK)
+ dat2 = c.flush()
+ pattern = "Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(dat1)
+
+ dat3 = decompress(dat1 + dat2)
+
+ self.assertEqual(dat3, THIS_FILE_BYTES)
+
+ def test_compress_flushframe(self):
+ # test compress & decompress
+ point = len(THIS_FILE_BYTES) // 2
+
+ c = ZstdCompressor()
+
+ dat1 = c.compress(THIS_FILE_BYTES[:point])
+ self.assertEqual(c.last_mode, c.CONTINUE)
+
+ dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat1)
+ self.assertEqual(nt.decompressed_size, None) # no content size
+
+ dat2 = decompress(dat1)
+
+ self.assertEqual(dat2, THIS_FILE_BYTES)
+
+ # single .FLUSH_FRAME mode has content size
+ c = ZstdCompressor()
+ dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME)
+ self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+ nt = get_frame_info(dat)
+ self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES))
+
+ def test_compress_empty(self):
+ # output empty content frame
+ self.assertNotEqual(compress(b''), b'')
+
+ c = ZstdCompressor()
+ self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
+
+class DecompressorTestCase(unittest.TestCase):
+
+ def test_simple_decompress_bad_args(self):
+ # ZstdDecompressor
+ self.assertRaises(TypeError, ZstdDecompressor, ())
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123)
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc')
+ self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4})
+
+ self.assertRaises(TypeError, ZstdDecompressor, options=123)
+ self.assertRaises(TypeError, ZstdDecompressor, options='abc')
+ self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')
+
+ with self.assertRaises(ValueError):
+ ZstdDecompressor(options={2**31 : 100})
+
+ with self.assertRaises(ZstdError):
+ ZstdDecompressor(options={DecompressionParameter.window_log_max:100})
+ with self.assertRaises(ZstdError):
+ ZstdDecompressor(options={3333 : 100})
+
+ empty = compress(b'')
+ lzd = ZstdDecompressor()
+ self.assertRaises(TypeError, lzd.decompress)
+ self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar")
+ self.assertRaises(TypeError, lzd.decompress, "str")
+ lzd.decompress(empty)
+
+ def test_decompress_parameters(self):
+ d = {DecompressionParameter.window_log_max : 15}
+ ZstdDecompressor(options=d)
+
+ # larger than signed int, ValueError
+ d1 = d.copy()
+ d1[DecompressionParameter.window_log_max] = 2**31
+ self.assertRaises(ValueError, ZstdDecompressor, None, d1)
+
+ # out of bounds error msg
+ options = {DecompressionParameter.window_log_max:100}
+ with self.assertRaisesRegex(ZstdError,
+ (r'Error when setting zstd decompression parameter "window_log_max", '
+ r'it should \d+ <= value <= \d+, provided value is 100\. '
+ r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')):
+ decompress(b'', options=options)
+
+ def test_unknown_decompression_parameter(self):
+ KEY = 100001234
+ options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
+ KEY: 200000000}
+ pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \
+ % KEY
+ with self.assertRaisesRegex(ZstdError, pattern):
+ ZstdDecompressor(options=options)
+
+ def test_decompress_epilogue_flags(self):
+ # DAT_130K_C has a 4 bytes checksum at frame epilogue
+
+ # full unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'')
+
+ # full limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ with self.assertRaises(EOFError):
+ dat = d.decompress(b'', 0)
+
+ # [:-4] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-4] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-3] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-3] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-3], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ # [:-1] unlimited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1])
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.needs_input)
+
+ dat = d.decompress(b'')
+ self.assertEqual(len(dat), 0)
+ self.assertTrue(d.needs_input)
+
+ # [:-1] limited
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-1], _130_1K)
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.needs_input)
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(len(dat), 0)
+ self.assertFalse(d.needs_input)
+
+ def test_decompressor_arg(self):
+ zd = ZstdDict(b'12345678', True)
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(zstd_dict={})
+
+ with self.assertRaises(TypeError):
+ d = ZstdDecompressor(options=zd)
+
+ ZstdDecompressor()
+ ZstdDecompressor(zd, {})
+ ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25})
+
+ def test_decompressor_1(self):
+ # empty
+ d = ZstdDecompressor()
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+
+ # 130_1K full
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K full, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # 130_1K, without 4 bytes checksum
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4])
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+
+ # above, limit output
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+
+ # full, unused_data
+ TRAIL = b'89234893abcd'
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT_130K_C + TRAIL, _130_1K)
+
+ self.assertEqual(len(dat), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, TRAIL)
+
+ def test_decompressor_chunks_read_300(self):
+ TRAIL = b'89234893abcd'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(300)
+ if not dat:
+ break
+ else:
+ raise Exception('should not get here')
+
+ ret = d.decompress(dat)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+ def test_decompressor_chunks_read_3(self):
+ TRAIL = b'89234893'
+ DAT = DAT_130K_C + TRAIL
+ d = ZstdDecompressor()
+
+ bi = io.BytesIO(DAT)
+ lst = []
+ while True:
+ if d.needs_input:
+ dat = bi.read(3)
+ if not dat:
+ break
+ else:
+ dat = b''
+
+ ret = d.decompress(dat, 1)
+ lst.append(ret)
+ if d.eof:
+ break
+
+ ret = b''.join(lst)
+
+ self.assertEqual(len(ret), _130_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+
+ def test_decompress_empty(self):
+ with self.assertRaises(ZstdError):
+ decompress(b'')
+
+ d = ZstdDecompressor()
+ self.assertEqual(d.decompress(b''), b'')
+ self.assertFalse(d.eof)
+
+ def test_decompress_empty_content_frame(self):
+ DAT = compress(b'')
+ # decompress
+ self.assertGreaterEqual(len(DAT), 4)
+ self.assertEqual(decompress(DAT), b'')
+
+ with self.assertRaises(ZstdError):
+ decompress(DAT[:-1])
+
+ # ZstdDecompressor
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT)
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ d = ZstdDecompressor()
+ dat = d.decompress(DAT[:-1])
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+class DecompressorFlagsTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ options = {CompressionParameter.checksum_flag:1}
+ c = ZstdCompressor(options=options)
+
+ cls.DECOMPRESSED_42 = b'a'*42
+ cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME)
+
+ cls.DECOMPRESSED_60 = b'a'*60
+ cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME)
+
+ cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60
+ cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60
+
+ cls._130_1K = 130*_1K
+
+ c = ZstdCompressor()
+ cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush()
+ cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush()
+ cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60
+
+ cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|'
+
+ def test_function_decompress(self):
+
+ self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K)
+
+ # 1 frame
+ self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42)
+
+ pattern = r"Compressed data ended before the end-of-stream marker"
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:1])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42[:-1])
+
+ # 2 frames
+ self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.FRAME_42_60[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(self.UNKNOWN_FRAME_42_60[:-1])
+
+ # 130_1K
+ self.assertEqual(decompress(DAT_130K_C), DAT_130K_D)
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-4])
+
+ with self.assertRaisesRegex(ZstdError, pattern):
+ decompress(DAT_130K_C[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.FRAME_42 + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa')
+
+ # doesn't match checksum
+ checksum = DAT_130K_C[-4:]
+ if checksum[0] == 255:
+ wrong_checksum = bytes([254]) + checksum[1:]
+ else:
+ wrong_checksum = bytes([checksum[0]+1]) + checksum[1:]
+
+ dat = DAT_130K_C[:-4] + wrong_checksum
+
+ with self.assertRaisesRegex(ZstdError, "doesn't match checksum"):
+ decompress(dat)
+
+ def test_function_skippable(self):
+ self.assertEqual(decompress(SKIPPABLE_FRAME), b'')
+ self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'')
+
+ # 1 frame + 2 skippable
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)),
+ self._130_1K)
+
+ # unknown size
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60),
+ self.DECOMPRESSED_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_60)
+
+ # 2 frames + 1 skippable
+ self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60),
+ self.DECOMPRESSED_42_60)
+
+ self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME),
+ self.DECOMPRESSED_42_60)
+
+ # incomplete
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:1])
+
+ with self.assertRaises(ZstdError):
+ decompress(SKIPPABLE_FRAME[:-1])
+
+ with self.assertRaises(ZstdError):
+ decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1])
+
+ # Unknown frame descriptor
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME)
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+ def test_decompressor_1(self):
+ # empty 1
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # empty 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(b'', 0)
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+ self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'a')
+ self.assertEqual(d.unused_data, b'a') # twice
+
+ # 1 frame
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_42 + self.TRAIL)
+
+ self.assertEqual(dat, self.DECOMPRESSED_42)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # 1 frame, 32_1K
+ temp = compress(b'a'*(32*_1K))
+ d = ZstdDecompressor()
+ dat = d.decompress(temp, 32*_1K)
+
+ self.assertEqual(dat, b'a'*(32*_1K))
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # 1 frame, 32_1K+100, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes
+
+ self.assertEqual(len(dat), 100)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ dat = d.decompress(b'') # 32_1K
+
+ self.assertEqual(len(dat), 32*_1K)
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ with self.assertRaises(EOFError):
+ d.decompress(b'')
+
+ # incomplete 1
+ d = ZstdDecompressor()
+ dat = d.decompress(self.FRAME_60[:1])
+
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 2
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete 3
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-1])
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+
+ # incomplete 4
+ d = ZstdDecompressor()
+
+ dat = d.decompress(self.FRAME_60[:-4], 60)
+ self.assertEqual(dat, self.DECOMPRESSED_60)
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # Unknown frame descriptor
+ d = ZstdDecompressor()
+ with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+ d.decompress(b'aaaaaaaaa')
+
+ def test_decompressor_skippable(self):
+ # 1 skippable
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, max_length=0
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME, 0)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # 1 skippable, trail
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL)
+
+ self.assertEqual(dat, b'')
+ self.assertTrue(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, self.TRAIL)
+ self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1])
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ # incomplete
+ d = ZstdDecompressor()
+ dat = d.decompress(SKIPPABLE_FRAME[:-1], 0)
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertFalse(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+ dat = d.decompress(b'')
+
+ self.assertEqual(dat, b'')
+ self.assertFalse(d.eof)
+ self.assertTrue(d.needs_input)
+ self.assertEqual(d.unused_data, b'')
+ self.assertEqual(d.unused_data, b'') # twice
+
+
+
+class ZstdDictTestCase(unittest.TestCase):
+
+ def test_is_raw(self):
+ # content < 8
+ b = b'1234567'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # content == 8
+ b = b'12345678'
+ zd = ZstdDict(b, is_raw=True)
+ self.assertEqual(zd.dict_id, 0)
+
+ temp = compress(b'aaa12345678', level=3, zstd_dict=zd)
+ self.assertEqual(b'aaa12345678', decompress(temp, zd))
+
+ # is_raw == False
+ b = b'12345678abcd'
+ with self.assertRaises(ValueError):
+ ZstdDict(b)
+
+ # read only attributes
+ with self.assertRaises(AttributeError):
+ zd.dict_content = b
+
+ with self.assertRaises(AttributeError):
+ zd.dict_id = 10000
+
+ # ZstdDict arguments
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+ self.assertNotEqual(zd.dict_id, 0)
+
+ zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True)
+ self.assertNotEqual(zd.dict_id, 0) # note this assertion
+
+ with self.assertRaises(TypeError):
+ ZstdDict("12345678abcdef", is_raw=True)
+ with self.assertRaises(TypeError):
+ ZstdDict(TRAINED_DICT)
+
+ # invalid parameter
+ with self.assertRaises(TypeError):
+ ZstdDict(desk333=345)
+
+ def test_invalid_dict(self):
+ DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little')
+ dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz'
+
+ # corrupted
+ zd = ZstdDict(dict_content, is_raw=False)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?corrupted'):
+ ZstdCompressor(zstd_dict=zd.as_digested_dict)
+ with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?corrupted'):
+ ZstdDecompressor(zd)
+
+ # wrong type
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, b'123'))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdCompressor(zstd_dict=(zd, 3))
+
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdDecompressor(zstd_dict=(zd, b'123'))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdDecompressor((zd, 1, 2))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdDecompressor((zd, -1))
+ with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+ ZstdDecompressor((zd, 3))
+
+ def test_train_dict(self):
+
+
+ TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
+ ZstdDict(TRAINED_DICT.dict_content, False)
+
+ self.assertNotEqual(TRAINED_DICT.dict_id, 0)
+ self.assertGreater(len(TRAINED_DICT.dict_content), 0)
+ self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1)
+ self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT)))
+
+ # compress/decompress
+ c = ZstdCompressor(zstd_dict=TRAINED_DICT)
+ for sample in SAMPLES:
+ dat1 = compress(sample, zstd_dict=TRAINED_DICT)
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, TRAINED_DICT)
+ self.assertEqual(sample, dat2)
+
+ def test_finalize_dict(self):
+ DICT_SIZE2 = 200*_1K
+ C_LEVEL = 6
+
+ try:
+ dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL)
+ except NotImplementedError:
+ # < v1.4.5 at compile-time, >= v.1.4.5 at run-time
+ return
+
+ self.assertNotEqual(dic2.dict_id, 0)
+ self.assertGreater(len(dic2.dict_content), 0)
+ self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2)
+
+ # compress/decompress
+ c = ZstdCompressor(C_LEVEL, zstd_dict=dic2)
+ for sample in SAMPLES:
+ dat1 = compress(sample, C_LEVEL, zstd_dict=dic2)
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ dat1 = c.compress(sample)
+ dat1 += c.flush()
+ dat2 = decompress(dat1, dic2)
+ self.assertEqual(sample, dat2)
+
+ # dict mismatch
+ self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id)
+
+ dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT)
+ with self.assertRaises(ZstdError):
+ decompress(dat1, dic2)
+
+ def test_train_dict_arguments(self):
+ with self.assertRaises(ValueError):
+ train_dict([], 100*_1K)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, -100)
+
+ with self.assertRaises(ValueError):
+ train_dict(SAMPLES, 0)
+
+ def test_finalize_dict_arguments(self):
+ with self.assertRaises(TypeError):
+ finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, [], 100*_1K, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, -100, 2)
+
+ with self.assertRaises(ValueError):
+ finalize_dict(TRAINED_DICT, SAMPLES, 0, 2)
+
+ def test_train_dict_c(self):
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd._train_dict({}, (), 100)
+ with self.assertRaises(TypeError):
+ _zstd._train_dict(b'', 99, 100)
+ with self.assertRaises(TypeError):
+ _zstd._train_dict(b'', (), 100.1)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd._train_dict(b'', (2**64+1,), 100)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd._train_dict(b'', (), 0)
+
+ def test_finalize_dict_c(self):
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict(1, 2, 3, 4, 5)
+
+ # argument wrong type
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict({}, b'', (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
+ with self.assertRaises(TypeError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
+
+ # size > size_t
+ with self.assertRaises(ValueError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
+
+ # dict_size <= 0
+ with self.assertRaises(ValueError):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
+
+ def test_train_buffer_protocol_samples(self):
+ def _nbytes(dat):
+ if isinstance(dat, (bytes, bytearray)):
+ return len(dat)
+ return memoryview(dat).nbytes
+
+ # prepare samples
+ chunk_lst = []
+ wrong_size_lst = []
+ correct_size_lst = []
+ for _ in range(300):
+ arr = array.array('Q', [random.randint(0, 20) for i in range(20)])
+ chunk_lst.append(arr)
+ correct_size_lst.append(_nbytes(arr))
+ wrong_size_lst.append(len(arr))
+ concatenation = b''.join(chunk_lst)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*_1K)
+
+ # correct size list
+ _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*_1K)
+
+ # wrong size list
+ with self.assertRaisesRegex(ValueError,
+ "The samples size tuple doesn't match the concatenation's size"):
+ _zstd._finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(wrong_size_lst), 300*_1K, 5)
+
+ # correct size list
+ _zstd._finalize_dict(TRAINED_DICT.dict_content,
+ concatenation, tuple(correct_size_lst), 300*_1K, 5)
+
+ def test_as_prefix(self):
+ # V1
+ V1 = THIS_FILE_BYTES
+ zd = ZstdDict(V1, True)
+
+ # V2
+ mid = len(V1) // 2
+ V2 = V1[:mid] + \
+ (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \
+ V1[mid+1:]
+
+ # compress
+ dat = compress(V2, zstd_dict=zd.as_prefix)
+ self.assertEqual(get_frame_info(dat).dictionary_id, 0)
+
+ # decompress
+ self.assertEqual(decompress(dat, zd.as_prefix), V2)
+
+ # use wrong prefix
+ zd2 = ZstdDict(SAMPLES[0], True)
+ try:
+ decompressed = decompress(dat, zd2.as_prefix)
+ except ZstdError: # expected
+ pass
+ else:
+ self.assertNotEqual(decompressed, V2)
+
+ # read only attribute
+ with self.assertRaises(AttributeError):
+ zd.as_prefix = b'1234'
+
+ def test_as_digested_dict(self):
+ zd = TRAINED_DICT
+
+ # test .as_digested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict)
+ self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_digested_dict = b'1234'
+
+ # test .as_undigested_dict
+ dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict)
+ self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0])
+ with self.assertRaises(AttributeError):
+ zd.as_undigested_dict = b'1234'
+
+ def test_advanced_compression_parameters(self):
+ options = {CompressionParameter.compression_level: 6,
+ CompressionParameter.window_log: 20,
+ CompressionParameter.enable_long_distance_matching: 1}
+
+ # automatically select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ # explicitly select
+ dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict)
+ self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+ def test_len(self):
+ self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content))
+ self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT))
+
+class FileTestCase(unittest.TestCase):
+ def setUp(self):
+ self.DECOMPRESSED_42 = b'a'*42
+ self.FRAME_42 = compress(self.DECOMPRESSED_42)
+
+ def test_init(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "x") as f:
+ pass
+ with ZstdFile(io.BytesIO(), "a") as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "w", level=12) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", options={}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f:
+ pass
+ with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f:
+ pass
+
+ def test_init_with_PathLike_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(filename, "a") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(filename) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2)
+
+ os.remove(filename)
+
+ def test_init_with_filename(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ with ZstdFile(filename) as f:
+ pass
+ with ZstdFile(filename, "w") as f:
+ pass
+ with ZstdFile(filename, "a") as f:
+ pass
+
+ os.remove(filename)
+
+ def test_init_mode(self):
+ bi = io.BytesIO()
+
+ with ZstdFile(bi, "r"):
+ pass
+ with ZstdFile(bi, "rb"):
+ pass
+ with ZstdFile(bi, "w"):
+ pass
+ with ZstdFile(bi, "wb"):
+ pass
+ with ZstdFile(bi, "a"):
+ pass
+ with ZstdFile(bi, "ab"):
+ pass
+
+ def test_init_with_x_mode(self):
+ with tempfile.NamedTemporaryFile() as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb"):
+ with ZstdFile(filename, mode):
+ pass
+ with self.assertRaises(FileExistsError):
+ with ZstdFile(filename, mode):
+ pass
+ os.remove(filename)
+
+ def test_init_bad_mode(self):
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x"))
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+")
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")
+
+ with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"):
+ ZstdFile(io.BytesIO(), 'rb',
+ options={CompressionParameter.compression_level:5})
+ with self.assertRaisesRegex(TypeError,
+ r"NOT be DecompressionParameter"):
+ ZstdFile(io.BytesIO(), 'wb',
+ options={DecompressionParameter.window_log_max:21})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12)
+
+ def test_init_bad_check(self):
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(), "w", level='asd')
+ # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
+ with self.assertRaises(ZstdError):
+ ZstdFile(io.BytesIO(), "w", options={999:9999})
+ with self.assertRaises(ZstdError):
+ ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)
+
+ with self.assertRaises(ValueError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:2**31})
+
+ with self.assertRaises(ZstdError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={444:333})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2})
+
+ with self.assertRaises(TypeError):
+ ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456')
+
+ def test_init_close_fp(self):
+ # get a temp file name
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ tmp_f.write(DAT_130K_C)
+ filename = tmp_f.name
+
+ with self.assertRaises(ValueError):
+ ZstdFile(filename, options={'a':'b'})
+
+ # for PyPy
+ gc.collect()
+
+ os.remove(filename)
+
+ def test_close(self):
+ with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src:
+ f = ZstdFile(src)
+ f.close()
+ # ZstdFile.close() should not close the underlying file object.
+ self.assertFalse(src.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+ self.assertFalse(src.closed)
+
+ # Test with a real file on disk, opened directly by ZstdFile.
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ fp = f._fp
+ f.close()
+ # Here, ZstdFile.close() *should* close the underlying file object.
+ self.assertTrue(fp.closed)
+ # Try closing an already-closed ZstdFile.
+ f.close()
+
+ os.remove(filename)
+
+ def test_closed(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.closed)
+ f.read()
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.closed)
+ finally:
+ f.close()
+ self.assertTrue(f.closed)
+
+ def test_fileno(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertRaises(io.UnsupportedOperation, f.fileno)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.fileno(), f._fp.fileno())
+ self.assertIsInstance(f.fileno(), int)
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.fileno)
+
+ os.remove(filename)
+
+ # 3, no .fileno() method
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'fileno'):
+ f.fileno()
+
+ def test_name(self):
+ # 1
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ with self.assertRaises(AttributeError):
+ f.name
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ # 2
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ filename = pathlib.Path(tmp_f.name)
+
+ f = ZstdFile(filename)
+ try:
+ self.assertEqual(f.name, f._fp.name)
+ self.assertIsInstance(f.name, str)
+ finally:
+ f.close()
+ with self.assertRaises(ValueError):
+ f.name
+
+ os.remove(filename)
+
+ # 3, no .filename property
+ class C:
+ def read(self, size=-1):
+ return b'123'
+ with ZstdFile(C(), 'rb') as f:
+ with self.assertRaisesRegex(AttributeError, r'name'):
+ f.name
+
+ def test_seekable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.seekable())
+ f.read()
+ self.assertTrue(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ src = io.BytesIO(COMPRESSED_100_PLUS_32KB)
+ src.seekable = lambda: False
+ f = ZstdFile(src)
+ try:
+ self.assertFalse(f.seekable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.seekable)
+
+ def test_readable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertTrue(f.readable())
+ f.read()
+ self.assertTrue(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertFalse(f.readable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.readable)
+
+ def test_writable(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ try:
+ self.assertFalse(f.writable())
+ f.read()
+ self.assertFalse(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ f = ZstdFile(io.BytesIO(), "w")
+ try:
+ self.assertTrue(f.writable())
+ finally:
+ f.close()
+ self.assertRaises(ValueError, f.writable)
+
+ def test_read_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertEqual(f.read(0), b"")
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+ options={DecompressionParameter.window_log_max:20}) as f:
+ self.assertEqual(f.read(0), b"")
+
+ # empty file
+ with ZstdFile(io.BytesIO(b'')) as f:
+ self.assertEqual(f.read(0), b"")
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ with ZstdFile(io.BytesIO(b'')) as f:
+ with self.assertRaises(EOFError):
+ f.read(10)
+
+ def test_read_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ chunks = []
+ while True:
+ result = f.read(10)
+ if not result:
+ break
+ self.assertLessEqual(len(result), 10)
+ chunks.append(result)
+ self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_read_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT)
+
+ def test_read_incomplete(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f:
+ self.assertRaises(EOFError, f.read)
+
+ # Trailing data isn't a valid compressed stream
+ with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f:
+ self.assertEqual(f.read(), self.DECOMPRESSED_42)
+
+ with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f:
+ self.assertEqual(f.read(), b'')
+
+ def test_read_truncated(self):
+ # Drop stream epilogue: 4 bytes checksum
+ truncated = DAT_130K_C[:-4]
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+
+ with ZstdFile(io.BytesIO(truncated)) as f:
+ # this is an important test, make sure it doesn't raise EOFError.
+ self.assertEqual(f.read(130*_1K), DAT_130K_D)
+ with self.assertRaises(EOFError):
+ f.read(1)
+
+ # Incomplete header
+ for i in range(1, 20):
+ with ZstdFile(io.BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
+ def test_read_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_DAT))
+ f.close()
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read)
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertRaises(TypeError, f.read, float())
+
+ def test_read_bad_data(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f:
+ self.assertRaises(ZstdError, f.read)
+
+ def test_read_exception(self):
+ class C:
+ def read(self, size=-1):
+ raise OSError
+ with ZstdFile(C()) as f:
+ with self.assertRaises(OSError):
+ f.read(10)
+
+ def test_read1(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DAT_130K_D)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_0(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ self.assertEqual(f.read1(0), b"")
+
+ def test_read1_10(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+ blocks = []
+ while True:
+ result = f.read1(10)
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_multistream(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+ blocks = []
+ while True:
+ result = f.read1()
+ if not result:
+ break
+ blocks.append(result)
+ self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5)
+ self.assertEqual(f.read1(), b"")
+
+ def test_read1_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.read1)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(TypeError, f.read1, None)
+
+ def test_readinto(self):
+ arr = array.array("I", range(100))
+ self.assertEqual(len(arr), 100)
+ self.assertEqual(len(arr) * arr.itemsize, 400)
+ ba = bytearray(300)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ # 0 length output buffer
+ self.assertEqual(f.readinto(ba[0:0]), 0)
+
+ # use correct length for buffer protocol object
+ self.assertEqual(f.readinto(arr), 400)
+ self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400])
+
+ # normal readinto
+ self.assertEqual(f.readinto(ba), 300)
+ self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700])
+
+ def test_peek(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek()
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ result = f.peek(10)
+ self.assertGreater(len(result), 0)
+ self.assertTrue(DAT_130K_D.startswith(result))
+ self.assertEqual(f.read(), DAT_130K_D)
+
+ def test_peek_bad_args(self):
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.peek)
+
+ def test_iterator(self):
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ compressed = compress(THIS_FILE_BYTES)
+
+ # iter
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(list(iter(f)), lines)
+
+ # readline
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ for line in lines:
+ self.assertEqual(f.readline(), line)
+ self.assertEqual(f.readline(), b'')
+ self.assertEqual(f.readline(), b'')
+
+ # readlines
+ with ZstdFile(io.BytesIO(compressed)) as f:
+ self.assertListEqual(f.readlines(), lines)
+
+ def test_decompress_limited(self):
+ _ZSTD_DStreamInSize = 128*_1K + 3
+
+ bomb = compress(b'\0' * int(2e6), level=10)
+ self.assertLess(len(bomb), _ZSTD_DStreamInSize)
+
+ decomp = ZstdFile(io.BytesIO(bomb))
+ self.assertEqual(decomp.read(1), b'\0')
+
+ # BufferedReader uses 128 KiB buffer in __init__.py
+ max_decomp = 128*_1K
+ self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp,
+ "Excessive amount of data was decompressed")
+
+ def test_write(self):
+ raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor()
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", level=12) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(12)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1})
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ with io.BytesIO() as dst:
+ options = {CompressionParameter.compression_level:-5,
+ CompressionParameter.checksum_flag:1}
+ with ZstdFile(dst, "w",
+ options=options) as f:
+ f.write(raw_data)
+
+ comp = ZstdCompressor(options=options)
+ expected = comp.compress(raw_data) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_empty_frame(self):
+ # .FLUSH_FRAME generates an empty content frame
+ c = ZstdCompressor()
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+ self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+
+ # don't generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ pass
+ self.assertEqual(bo.getvalue(), b'')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_FRAME)
+ self.assertEqual(bo.getvalue(), b'')
+
+ # if .write(b''), generate empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'')
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ # has an empty content frame
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertNotEqual(bo.getvalue(), b'')
+
+ def test_write_empty_block(self):
+ # If no internal data, .FLUSH_BLOCK return b''.
+ c = ZstdCompressor()
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK),
+ b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.compress(b''), b'')
+ self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+
+ # mode = .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ f.flush(f.FLUSH_BLOCK)
+ fp_pos = f._fp.tell()
+ self.assertNotEqual(fp_pos, 0)
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), fp_pos)
+
+ # mode != .last_mode
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+ f.write(b'')
+ f.flush(f.FLUSH_BLOCK)
+ self.assertEqual(f._fp.tell(), 0)
+
+ def test_write_101(self):
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ for start in range(0, len(THIS_FILE_BYTES), 101):
+ f.write(THIS_FILE_BYTES[start:start+101])
+
+ comp = ZstdCompressor()
+ expected = comp.compress(THIS_FILE_BYTES) + comp.flush()
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_append(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ part1 = THIS_FILE_BYTES[:_1K]
+ part2 = THIS_FILE_BYTES[_1K:1536]
+ part3 = THIS_FILE_BYTES[1536:]
+ expected = b"".join(comp(x) for x in (part1, part2, part3))
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.write(part1)
+ with ZstdFile(dst, "a") as f:
+ f.write(part2)
+ with ZstdFile(dst, "a") as f:
+ f.write(part3)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_write_bad_args(self):
+ f = ZstdFile(io.BytesIO(), "w")
+ f.close()
+ self.assertRaises(ValueError, f.write, b"foo")
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f:
+ self.assertRaises(ValueError, f.write, b"bar")
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(TypeError, f.write, None)
+ self.assertRaises(TypeError, f.write, "text")
+ self.assertRaises(TypeError, f.write, 789)
+
+ def test_writelines(self):
+ def comp(data):
+ comp = ZstdCompressor()
+ return comp.compress(data) + comp.flush()
+
+ with io.BytesIO(THIS_FILE_BYTES) as f:
+ lines = f.readlines()
+ with io.BytesIO() as dst:
+ with ZstdFile(dst, "w") as f:
+ f.writelines(lines)
+ expected = comp(THIS_FILE_BYTES)
+ self.assertEqual(dst.getvalue(), expected)
+
+ def test_seek_forward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(555)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:])
+
+ def test_seek_forward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:])
+
+ def test_seek_forward_relative_to_current(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(100)
+ f.seek(1236, 1)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:])
+
+ def test_seek_forward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-555, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:])
+
+ def test_seek_backward(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.read(1001)
+ f.seek(211)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:])
+
+ def test_seek_backward_across_streams(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+ f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333)
+ f.seek(737)
+ self.assertEqual(f.read(),
+ DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_backward_relative_to_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-150, 2)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:])
+
+ def test_seek_past_end(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001)
+ self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB))
+ self.assertEqual(f.read(), b"")
+
+ def test_seek_past_start(self):
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ f.seek(-88)
+ self.assertEqual(f.tell(), 0)
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ def test_seek_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(), "w") as f:
+ self.assertRaises(ValueError, f.seek, 0)
+ with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+ self.assertRaises(ValueError, f.seek, 0, 3)
+ # io.BufferedReader raises TypeError instead of ValueError
+ self.assertRaises((TypeError, ValueError), f.seek, 9, ())
+ self.assertRaises(TypeError, f.seek, None)
+ self.assertRaises(TypeError, f.seek, b"derp")
+
+ def test_seek_not_seekable(self):
+ class C(io.BytesIO):
+ def seekable(self):
+ return False
+ obj = C(COMPRESSED_100_PLUS_32KB)
+ with ZstdFile(obj, 'r') as f:
+ d = f.read(1)
+ self.assertFalse(f.seekable())
+ with self.assertRaisesRegex(io.UnsupportedOperation,
+ 'File or stream is not seekable'):
+ f.seek(0)
+ d += f.read()
+ self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_tell(self):
+ with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+ pos = 0
+ while True:
+ self.assertEqual(f.tell(), pos)
+ result = f.read(random.randint(171, 189))
+ if not result:
+ break
+ pos += len(result)
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+ with ZstdFile(io.BytesIO(), "w") as f:
+ for pos in range(0, len(DAT_130K_D), 143):
+ self.assertEqual(f.tell(), pos)
+ f.write(DAT_130K_D[pos:pos+143])
+ self.assertEqual(f.tell(), len(DAT_130K_D))
+
+ def test_tell_bad_args(self):
+ f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+ f.close()
+ self.assertRaises(ValueError, f.tell)
+
+ def test_file_dict(self):
+ # default
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_file_prefix(self):
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_UnsupportedOperation(self):
+ # 1
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.write(b'1234')
+
+ # 2
+ class T:
+ def read(self, size):
+ return b'a' * size
+
+ with self.assertRaises(TypeError): # on creation
+ with ZstdFile(T(), 'w') as f:
+ pass
+
+ # 3
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.read(100)
+ with self.assertRaises(io.UnsupportedOperation):
+ f.seek(100)
+ self.assertEqual(f.closed, True)
+ with self.assertRaises(ValueError):
+ f.readable()
+ with self.assertRaises(ValueError):
+ f.tell()
+ with self.assertRaises(ValueError):
+ f.read(100)
+
+ def test_read_readinto_readinto1(self):
+ lst = []
+ with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f:
+ while True:
+ method = random.randint(0, 2)
+ size = random.randint(0, 300)
+
+ if method == 0:
+ dat = f.read(size)
+ if not dat and size:
+ break
+ lst.append(dat)
+ elif method == 1:
+ ba = bytearray(size)
+ read_size = f.readinto(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ elif method == 2:
+ ba = bytearray(size)
+ read_size = f.readinto1(ba)
+ if read_size == 0 and size:
+ break
+ lst.append(bytes(ba[:read_size]))
+ self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5)
+
+ def test_zstdfile_flush(self):
+ # closed
+ f = ZstdFile(io.BytesIO(), 'w')
+ f.close()
+ with self.assertRaises(ValueError):
+ f.flush()
+
+ # read
+ with ZstdFile(io.BytesIO(), 'r') as f:
+ # does nothing for read-only stream
+ f.flush()
+
+ # write
+ DAT = b'abcd'
+ bi = io.BytesIO()
+ with ZstdFile(bi, 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertEqual(bi.tell(), 0) # not enough for a block
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+ self.assertGreater(bi.tell(), 0) # flushed
+
+ # write, no .flush() method
+ class C:
+ def write(self, b):
+ return len(b)
+ with ZstdFile(C(), 'w') as f:
+ self.assertEqual(f.write(DAT), len(DAT))
+ self.assertEqual(f.tell(), len(DAT))
+
+ self.assertEqual(f.flush(), None)
+ self.assertEqual(f.tell(), len(DAT))
+
+ def test_zstdfile_flush_mode(self):
+ self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK)
+ self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME)
+ with self.assertRaises(AttributeError):
+ ZstdFile.CONTINUE
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ # flush block
+ self.assertEqual(f.write(b'123'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_BLOCK))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush())
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ # flush frame
+ self.assertEqual(f.write(b'456'), 3)
+ self.assertIsNone(f.flush(mode=f.FLUSH_FRAME))
+ # flush frame
+ self.assertEqual(f.write(b'789'), 3)
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p1 = bo.tell()
+ # mode == .last_mode, should return
+ self.assertIsNone(f.flush(f.FLUSH_FRAME))
+ p2 = bo.tell()
+ self.assertEqual(p1, p2)
+ self.assertEqual(decompress(bo.getvalue()), b'123456789')
+
+ bo = io.BytesIO()
+ with ZstdFile(bo, 'w') as f:
+ f.write(b'123')
+ with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'):
+ f.flush(ZstdCompressor.CONTINUE)
+ with self.assertRaises(ValueError):
+ f.flush(-1)
+ with self.assertRaises(ValueError):
+ f.flush(123456)
+ with self.assertRaises(TypeError):
+ f.flush(node=ZstdCompressor.CONTINUE)
+ with self.assertRaises((TypeError, ValueError)):
+ f.flush('FLUSH_FRAME')
+ with self.assertRaises(TypeError):
+ f.flush(b'456', f.FLUSH_BLOCK)
+
+ def test_zstdfile_truncate(self):
+ with ZstdFile(io.BytesIO(), 'w') as f:
+ with self.assertRaises(io.UnsupportedOperation):
+ f.truncate(200)
+
+ def test_zstdfile_iter_issue45475(self):
+ lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))]
+ self.assertGreater(len(lines), 0)
+
+ def test_append_new_file(self):
+ with tempfile.NamedTemporaryFile(delete=True) as tmp_f:
+ filename = tmp_f.name
+
+ with ZstdFile(filename, 'a') as f:
+ pass
+ self.assertTrue(os.path.isfile(filename))
+
+ os.remove(filename)
+
+class OpenTestCase(unittest.TestCase):
+
+ def test_binary_modes(self):
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+ with io.BytesIO() as bio:
+ with open(bio, "wb") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+ with open(bio, "ab") as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2)
+
+ def test_text_modes(self):
+ # empty input
+ with self.assertRaises(EOFError):
+ with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader:
+ for _ in reader:
+ pass
+
+ # read
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+ with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f:
+ self.assertEqual(f.read(), uncompressed)
+
+ with io.BytesIO() as bio:
+ # write
+ with open(bio, "wt", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ # append
+ with open(bio, "at", encoding="utf-8") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-8")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2)
+
+ def test_bad_params(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ with self.assertRaises(ValueError):
+ open(TESTFN, "")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rbt")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", encoding="utf-8")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", errors="ignore")
+ with self.assertRaises(ValueError):
+ open(TESTFN, "rb", newline="\n")
+
+ os.remove(TESTFN)
+
+ def test_option(self):
+ options = {DecompressionParameter.window_log_max:25}
+ with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f:
+ self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+ options = {CompressionParameter.compression_level:12}
+ with io.BytesIO() as bio:
+ with open(bio, "wb", options=options) as f:
+ f.write(DECOMPRESSED_100_PLUS_32KB)
+ file_data = decompress(bio.getvalue())
+ self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+
+ def test_encoding(self):
+ uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-16-le") as f:
+ f.write(uncompressed)
+ file_data = decompress(bio.getvalue()).decode("utf-16-le")
+ self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-16-le") as f:
+ self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed)
+
+ def test_encoding_error_handler(self):
+ with io.BytesIO(compress(b"foo\xffbar")) as bio:
+ with open(bio, "rt", encoding="ascii", errors="ignore") as f:
+ self.assertEqual(f.read(), "foobar")
+
+ def test_newline(self):
+ # Test with explicit newline (universal newline mode disabled).
+ text = THIS_FILE_STR.replace(os.linesep, "\n")
+ with io.BytesIO() as bio:
+ with open(bio, "wt", encoding="utf-8", newline="\n") as f:
+ f.write(text)
+ bio.seek(0)
+ with open(bio, "rt", encoding="utf-8", newline="\r") as f:
+ self.assertEqual(f.readlines(), [text])
+
+ def test_x_mode(self):
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+ TESTFN = pathlib.Path(tmp_f.name)
+
+ for mode in ("x", "xb", "xt"):
+ os.remove(TESTFN)
+
+ if mode == "xt":
+ encoding = "utf-8"
+ else:
+ encoding = None
+ with open(TESTFN, mode, encoding=encoding):
+ pass
+ with self.assertRaises(FileExistsError):
+ with open(TESTFN, mode):
+ pass
+
+ os.remove(TESTFN)
+
+ def test_open_dict(self):
+ # default
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # .as_(un)digested_dict
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ # invalid dictionary
+ bi = io.BytesIO()
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict={1:2, 2:3})
+
+ with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+ open(bi, 'w', zstd_dict=b'1234567890')
+
+ def test_open_prefix(self):
+ bi = io.BytesIO()
+ with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+ f.write(SAMPLES[0])
+ bi.seek(0)
+ with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+ dat = f.read()
+ self.assertEqual(dat, SAMPLES[0])
+
+ def test_buffer_protocol(self):
+ # don't use len() for buffer protocol objects
+ arr = array.array("i", range(1000))
+ LENGTH = len(arr) * arr.itemsize
+
+ with open(io.BytesIO(), "wb") as f:
+ self.assertEqual(f.write(arr), LENGTH)
+ self.assertEqual(f.tell(), LENGTH)
+
+class FreeThreadingMethodTests(unittest.TestCase):
+
+ @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_compress_locking(self):
+ input = b'a'* (16*_1K)
+ num_threads = 8
+
+ comp = ZstdCompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ parts.append(res)
+ rest1 = comp.flush()
+ expected = b''.join(parts) + rest1
+
+ comp = ZstdCompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ rest2 = comp.flush()
+ self.assertEqual(rest1, rest2)
+ actual = b''.join(output) + rest2
+ self.assertEqual(expected, actual)
+
+ @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
+ @threading_helper.reap_threads
+ @threading_helper.requires_working_threading()
+ def test_decompress_locking(self):
+ input = compress(b'a'* (16*_1K))
+ num_threads = 8
+ # to ensure we decompress over multiple calls, set maxsize
+ window_size = _1K * 16//num_threads
+
+ decomp = ZstdDecompressor()
+ parts = []
+ for _ in range(num_threads):
+ res = decomp.decompress(input, window_size)
+ if res:
+ parts.append(res)
+ expected = b''.join(parts)
+
+ comp = ZstdDecompressor()
+ output = []
+ def run_method(method, input_data, output_data):
+ res = method(input_data, window_size)
+ if res:
+ output_data.append(res)
+ threads = []
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
+
+ threads.append(thread)
+
+ with threading_helper.start_threads(threads):
+ pass
+
+ actual = b''.join(output)
+ self.assertEqual(expected, actual)
+
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/tokenize.py b/Lib/tokenize.py
index 117b485b934..8d01fd7bce4 100644
--- a/Lib/tokenize.py
+++ b/Lib/tokenize.py
@@ -518,7 +518,7 @@ def _main(args=None):
sys.exit(1)
# Parse the arguments and options
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument(dest='filename', nargs='?',
metavar='filename.py',
help='the file to tokenize; defaults to stdin')
diff --git a/Lib/trace.py b/Lib/trace.py
index a87bc6d61a8..cf8817f4383 100644
--- a/Lib/trace.py
+++ b/Lib/trace.py
@@ -604,7 +604,7 @@ class Trace:
def main():
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('--version', action='version', version='trace 2.0')
grp = parser.add_argument_group('Main options',
diff --git a/Lib/traceback.py b/Lib/traceback.py
index 16ba7fc2ee8..17b082eced6 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -10,9 +10,9 @@ import codeop
import keyword
import tokenize
import io
-from contextlib import suppress
import _colorize
-from _colorize import ANSIColors
+
+from contextlib import suppress
__all__ = ['extract_stack', 'extract_tb', 'format_exception',
'format_exception_only', 'format_list', 'format_stack',
@@ -187,15 +187,13 @@ def _format_final_exc_line(etype, value, *, insert_final_newline=True, colorize=
valuestr = _safe_string(value, 'exception')
end_char = "\n" if insert_final_newline else ""
if colorize:
- if value is None or not valuestr:
- line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}{end_char}"
- else:
- line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}: {ANSIColors.MAGENTA}{valuestr}{ANSIColors.RESET}{end_char}"
+ theme = _colorize.get_theme(force_color=True).traceback
else:
- if value is None or not valuestr:
- line = f"{etype}{end_char}"
- else:
- line = f"{etype}: {valuestr}{end_char}"
+ theme = _colorize.get_theme(force_no_color=True).traceback
+ if value is None or not valuestr:
+ line = f"{theme.type}{etype}{theme.reset}{end_char}"
+ else:
+ line = f"{theme.type}{etype}{theme.reset}: {theme.message}{valuestr}{theme.reset}{end_char}"
return line
@@ -539,21 +537,22 @@ class StackSummary(list):
if frame_summary.filename.startswith("<stdin>-"):
filename = "<stdin>"
if colorize:
- row.append(' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format(
- ANSIColors.MAGENTA,
- filename,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- frame_summary.lineno,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- frame_summary.name,
- ANSIColors.RESET,
- )
- )
+ theme = _colorize.get_theme(force_color=True).traceback
else:
- row.append(' File "{}", line {}, in {}\n'.format(
- filename, frame_summary.lineno, frame_summary.name))
+ theme = _colorize.get_theme(force_no_color=True).traceback
+ row.append(
+ ' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format(
+ theme.filename,
+ filename,
+ theme.reset,
+ theme.line_no,
+ frame_summary.lineno,
+ theme.reset,
+ theme.frame,
+ frame_summary.name,
+ theme.reset,
+ )
+ )
if frame_summary._dedented_lines and frame_summary._dedented_lines.strip():
if (
frame_summary.colno is None or
@@ -672,11 +671,11 @@ class StackSummary(list):
for color, group in itertools.groupby(itertools.zip_longest(line, carets, fillvalue=""), key=lambda x: x[1]):
caret_group = list(group)
if color == "^":
- colorized_line_parts.append(ANSIColors.BOLD_RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET)
- colorized_carets_parts.append(ANSIColors.BOLD_RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET)
+ colorized_line_parts.append(theme.error_highlight + "".join(char for char, _ in caret_group) + theme.reset)
+ colorized_carets_parts.append(theme.error_highlight + "".join(caret for _, caret in caret_group) + theme.reset)
elif color == "~":
- colorized_line_parts.append(ANSIColors.RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET)
- colorized_carets_parts.append(ANSIColors.RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET)
+ colorized_line_parts.append(theme.error_range + "".join(char for char, _ in caret_group) + theme.reset)
+ colorized_carets_parts.append(theme.error_range + "".join(caret for _, caret in caret_group) + theme.reset)
else:
colorized_line_parts.append("".join(char for char, _ in caret_group))
colorized_carets_parts.append("".join(caret for _, caret in caret_group))
@@ -1378,20 +1377,20 @@ class TracebackException:
"""Format SyntaxError exceptions (internal helper)."""
# Show exactly where the problem was found.
colorize = kwargs.get("colorize", False)
+ if colorize:
+ theme = _colorize.get_theme(force_color=True).traceback
+ else:
+ theme = _colorize.get_theme(force_no_color=True).traceback
filename_suffix = ''
if self.lineno is not None:
- if colorize:
- yield ' File {}"{}"{}, line {}{}{}\n'.format(
- ANSIColors.MAGENTA,
- self.filename or "<string>",
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- self.lineno,
- ANSIColors.RESET,
- )
- else:
- yield ' File "{}", line {}\n'.format(
- self.filename or "<string>", self.lineno)
+ yield ' File {}"{}"{}, line {}{}{}\n'.format(
+ theme.filename,
+ self.filename or "<string>",
+ theme.reset,
+ theme.line_no,
+ self.lineno,
+ theme.reset,
+ )
elif self.filename is not None:
filename_suffix = ' ({})'.format(self.filename)
@@ -1441,11 +1440,11 @@ class TracebackException:
# colorize from colno to end_colno
ltext = (
ltext[:colno] +
- ANSIColors.BOLD_RED + ltext[colno:end_colno] + ANSIColors.RESET +
+ theme.error_highlight + ltext[colno:end_colno] + theme.reset +
ltext[end_colno:]
)
- start_color = ANSIColors.BOLD_RED
- end_color = ANSIColors.RESET
+ start_color = theme.error_highlight
+ end_color = theme.reset
yield ' {}\n'.format(ltext)
yield ' {}{}{}{}\n'.format(
"".join(caretspace),
@@ -1456,17 +1455,15 @@ class TracebackException:
else:
yield ' {}\n'.format(ltext)
msg = self.msg or "<no detail available>"
- if colorize:
- yield "{}{}{}: {}{}{}{}\n".format(
- ANSIColors.BOLD_MAGENTA,
- stype,
- ANSIColors.RESET,
- ANSIColors.MAGENTA,
- msg,
- ANSIColors.RESET,
- filename_suffix)
- else:
- yield "{}: {}{}\n".format(stype, msg, filename_suffix)
+ yield "{}{}{}: {}{}{}{}\n".format(
+ theme.type,
+ stype,
+ theme.reset,
+ theme.message,
+ msg,
+ theme.reset,
+ filename_suffix,
+ )
def format(self, *, chain=True, _ctx=None, **kwargs):
"""Format the exception.
diff --git a/Lib/typing.py b/Lib/typing.py
index e019c597580..2baf655256d 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -3477,7 +3477,7 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def readlines(self, hint: int = -1) -> List[AnyStr]:
+ def readlines(self, hint: int = -1) -> list[AnyStr]:
pass
@abstractmethod
@@ -3493,7 +3493,7 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def truncate(self, size: int = None) -> int:
+ def truncate(self, size: int | None = None) -> int:
pass
@abstractmethod
@@ -3505,11 +3505,11 @@ class IO(Generic[AnyStr]):
pass
@abstractmethod
- def writelines(self, lines: List[AnyStr]) -> None:
+ def writelines(self, lines: list[AnyStr]) -> None:
pass
@abstractmethod
- def __enter__(self) -> 'IO[AnyStr]':
+ def __enter__(self) -> IO[AnyStr]:
pass
@abstractmethod
@@ -3523,11 +3523,11 @@ class BinaryIO(IO[bytes]):
__slots__ = ()
@abstractmethod
- def write(self, s: Union[bytes, bytearray]) -> int:
+ def write(self, s: bytes | bytearray) -> int:
pass
@abstractmethod
- def __enter__(self) -> 'BinaryIO':
+ def __enter__(self) -> BinaryIO:
pass
@@ -3548,7 +3548,7 @@ class TextIO(IO[str]):
@property
@abstractmethod
- def errors(self) -> Optional[str]:
+ def errors(self) -> str | None:
pass
@property
@@ -3562,7 +3562,7 @@ class TextIO(IO[str]):
pass
@abstractmethod
- def __enter__(self) -> 'TextIO':
+ def __enter__(self) -> TextIO:
pass
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index c3869de3f6f..6fd949581f3 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -197,7 +197,7 @@ class TestProgram(object):
return parser
def _getMainArgParser(self, parent):
- parser = argparse.ArgumentParser(parents=[parent])
+ parser = argparse.ArgumentParser(parents=[parent], color=True)
parser.prog = self.progName
parser.print_help = self._print_help
@@ -208,7 +208,7 @@ class TestProgram(object):
return parser
def _getDiscoveryArgParser(self, parent):
- parser = argparse.ArgumentParser(parents=[parent])
+ parser = argparse.ArgumentParser(parents=[parent], color=True)
parser.prog = '%s discover' % self.progName
parser.epilog = ('For test discovery all test modules must be '
'importable from the top level directory of the '
diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py
index eb0234a2617..5f22d91aebd 100644
--- a/Lib/unittest/runner.py
+++ b/Lib/unittest/runner.py
@@ -4,7 +4,7 @@ import sys
import time
import warnings
-from _colorize import get_colors
+from _colorize import get_theme
from . import result
from .case import _SubTest
@@ -45,7 +45,7 @@ class TextTestResult(result.TestResult):
self.showAll = verbosity > 1
self.dots = verbosity == 1
self.descriptions = descriptions
- self._ansi = get_colors(file=stream)
+ self._theme = get_theme(tty_file=stream).unittest
self._newline = True
self.durations = durations
@@ -79,101 +79,99 @@ class TextTestResult(result.TestResult):
def addSubTest(self, test, subtest, err):
if err is not None:
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
if issubclass(err[0], subtest.failureException):
- self._write_status(subtest, f"{red}FAIL{reset}")
+ self._write_status(subtest, f"{t.fail}FAIL{t.reset}")
else:
- self._write_status(subtest, f"{red}ERROR{reset}")
+ self._write_status(subtest, f"{t.fail}ERROR{t.reset}")
elif self.dots:
if issubclass(err[0], subtest.failureException):
- self.stream.write(f"{red}F{reset}")
+ self.stream.write(f"{t.fail}F{t.reset}")
else:
- self.stream.write(f"{red}E{reset}")
+ self.stream.write(f"{t.fail}E{t.reset}")
self.stream.flush()
super(TextTestResult, self).addSubTest(test, subtest, err)
def addSuccess(self, test):
super(TextTestResult, self).addSuccess(test)
- green, reset = self._ansi.GREEN, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{green}ok{reset}")
+ self._write_status(test, f"{t.passed}ok{t.reset}")
elif self.dots:
- self.stream.write(f"{green}.{reset}")
+ self.stream.write(f"{t.passed}.{t.reset}")
self.stream.flush()
def addError(self, test, err):
super(TextTestResult, self).addError(test, err)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{red}ERROR{reset}")
+ self._write_status(test, f"{t.fail}ERROR{t.reset}")
elif self.dots:
- self.stream.write(f"{red}E{reset}")
+ self.stream.write(f"{t.fail}E{t.reset}")
self.stream.flush()
def addFailure(self, test, err):
super(TextTestResult, self).addFailure(test, err)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{red}FAIL{reset}")
+ self._write_status(test, f"{t.fail}FAIL{t.reset}")
elif self.dots:
- self.stream.write(f"{red}F{reset}")
+ self.stream.write(f"{t.fail}F{t.reset}")
self.stream.flush()
def addSkip(self, test, reason):
super(TextTestResult, self).addSkip(test, reason)
- yellow, reset = self._ansi.YELLOW, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self._write_status(test, f"{yellow}skipped{reset} {reason!r}")
+ self._write_status(test, f"{t.warn}skipped{t.reset} {reason!r}")
elif self.dots:
- self.stream.write(f"{yellow}s{reset}")
+ self.stream.write(f"{t.warn}s{t.reset}")
self.stream.flush()
def addExpectedFailure(self, test, err):
super(TextTestResult, self).addExpectedFailure(test, err)
- yellow, reset = self._ansi.YELLOW, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self.stream.writeln(f"{yellow}expected failure{reset}")
+ self.stream.writeln(f"{t.warn}expected failure{t.reset}")
self.stream.flush()
elif self.dots:
- self.stream.write(f"{yellow}x{reset}")
+ self.stream.write(f"{t.warn}x{t.reset}")
self.stream.flush()
def addUnexpectedSuccess(self, test):
super(TextTestResult, self).addUnexpectedSuccess(test)
- red, reset = self._ansi.RED, self._ansi.RESET
+ t = self._theme
if self.showAll:
- self.stream.writeln(f"{red}unexpected success{reset}")
+ self.stream.writeln(f"{t.fail}unexpected success{t.reset}")
self.stream.flush()
elif self.dots:
- self.stream.write(f"{red}u{reset}")
+ self.stream.write(f"{t.fail}u{t.reset}")
self.stream.flush()
def printErrors(self):
- bold_red = self._ansi.BOLD_RED
- red = self._ansi.RED
- reset = self._ansi.RESET
+ t = self._theme
if self.dots or self.showAll:
self.stream.writeln()
self.stream.flush()
- self.printErrorList(f"{red}ERROR{reset}", self.errors)
- self.printErrorList(f"{red}FAIL{reset}", self.failures)
+ self.printErrorList(f"{t.fail}ERROR{t.reset}", self.errors)
+ self.printErrorList(f"{t.fail}FAIL{t.reset}", self.failures)
unexpectedSuccesses = getattr(self, "unexpectedSuccesses", ())
if unexpectedSuccesses:
self.stream.writeln(self.separator1)
for test in unexpectedSuccesses:
self.stream.writeln(
- f"{red}UNEXPECTED SUCCESS{bold_red}: "
- f"{self.getDescription(test)}{reset}"
+ f"{t.fail}UNEXPECTED SUCCESS{t.fail_info}: "
+ f"{self.getDescription(test)}{t.reset}"
)
self.stream.flush()
def printErrorList(self, flavour, errors):
- bold_red, reset = self._ansi.BOLD_RED, self._ansi.RESET
+ t = self._theme
for test, err in errors:
self.stream.writeln(self.separator1)
self.stream.writeln(
- f"{flavour}{bold_red}: {self.getDescription(test)}{reset}"
+ f"{flavour}{t.fail_info}: {self.getDescription(test)}{t.reset}"
)
self.stream.writeln(self.separator2)
self.stream.writeln("%s" % err)
@@ -286,31 +284,26 @@ class TextTestRunner(object):
expected_fails, unexpected_successes, skipped = results
infos = []
- ansi = get_colors(file=self.stream)
- bold_red = ansi.BOLD_RED
- green = ansi.GREEN
- red = ansi.RED
- reset = ansi.RESET
- yellow = ansi.YELLOW
+ t = get_theme(tty_file=self.stream).unittest
if not result.wasSuccessful():
- self.stream.write(f"{bold_red}FAILED{reset}")
+ self.stream.write(f"{t.fail_info}FAILED{t.reset}")
failed, errored = len(result.failures), len(result.errors)
if failed:
- infos.append(f"{bold_red}failures={failed}{reset}")
+ infos.append(f"{t.fail_info}failures={failed}{t.reset}")
if errored:
- infos.append(f"{bold_red}errors={errored}{reset}")
+ infos.append(f"{t.fail_info}errors={errored}{t.reset}")
elif run == 0 and not skipped:
- self.stream.write(f"{yellow}NO TESTS RAN{reset}")
+ self.stream.write(f"{t.warn}NO TESTS RAN{t.reset}")
else:
- self.stream.write(f"{green}OK{reset}")
+ self.stream.write(f"{t.passed}OK{t.reset}")
if skipped:
- infos.append(f"{yellow}skipped={skipped}{reset}")
+ infos.append(f"{t.warn}skipped={skipped}{t.reset}")
if expected_fails:
- infos.append(f"{yellow}expected failures={expected_fails}{reset}")
+ infos.append(f"{t.warn}expected failures={expected_fails}{t.reset}")
if unexpected_successes:
infos.append(
- f"{red}unexpected successes={unexpected_successes}{reset}"
+ f"{t.fail}unexpected successes={unexpected_successes}{t.reset}"
)
if infos:
self.stream.writeln(" (%s)" % (", ".join(infos),))
diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py
index 9a6b29a90a2..41dc5d7b35d 100644
--- a/Lib/urllib/request.py
+++ b/Lib/urllib/request.py
@@ -1466,7 +1466,7 @@ class FileHandler(BaseHandler):
def open_local_file(self, req):
import email.utils
import mimetypes
- localfile = url2pathname(req.full_url, require_scheme=True)
+ localfile = url2pathname(req.full_url, require_scheme=True, resolve_host=True)
try:
stats = os.stat(localfile)
size = stats.st_size
@@ -1482,7 +1482,7 @@ class FileHandler(BaseHandler):
file_open = open_local_file
-def _is_local_authority(authority):
+def _is_local_authority(authority, resolve):
# Compare hostnames
if not authority or authority == 'localhost':
return True
@@ -1494,9 +1494,11 @@ def _is_local_authority(authority):
if authority == hostname:
return True
# Compare IP addresses
+ if not resolve:
+ return False
try:
address = socket.gethostbyname(authority)
- except (socket.gaierror, AttributeError):
+ except (socket.gaierror, AttributeError, UnicodeEncodeError):
return False
return address in FileHandler().get_names()
@@ -1641,13 +1643,16 @@ class DataHandler(BaseHandler):
return addinfourl(io.BytesIO(data), headers, url)
-# Code move from the old urllib module
+# Code moved from the old urllib module
-def url2pathname(url, *, require_scheme=False):
+def url2pathname(url, *, require_scheme=False, resolve_host=False):
"""Convert the given file URL to a local file system path.
The 'file:' scheme prefix must be omitted unless *require_scheme*
is set to true.
+
+ The URL authority may be resolved with gethostbyname() if
+ *resolve_host* is set to true.
"""
if require_scheme:
scheme, url = _splittype(url)
@@ -1655,7 +1660,7 @@ def url2pathname(url, *, require_scheme=False):
raise URLError("URL is missing a 'file:' scheme")
authority, url = _splithost(url)
if os.name == 'nt':
- if not _is_local_authority(authority):
+ if not _is_local_authority(authority, resolve_host):
# e.g. file://server/share/file.txt
url = '//' + authority + url
elif url[:3] == '///':
@@ -1669,7 +1674,7 @@ def url2pathname(url, *, require_scheme=False):
# Older URLs use a pipe after a drive letter
url = url[:1] + ':' + url[2:]
url = url.replace('/', '\\')
- elif not _is_local_authority(authority):
+ elif not _is_local_authority(authority, resolve_host):
raise URLError("file:// scheme is supported only on localhost")
encoding = sys.getfilesystemencoding()
errors = sys.getfilesystemencodeerrors()
diff --git a/Lib/uuid.py b/Lib/uuid.py
index 2c16c3f0f5a..036ffebf67a 100644
--- a/Lib/uuid.py
+++ b/Lib/uuid.py
@@ -949,7 +949,9 @@ def main():
import argparse
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- description="Generate a UUID using the selected UUID function.")
+ description="Generate a UUID using the selected UUID function.",
+ color=True,
+ )
parser.add_argument("-u", "--uuid",
choices=uuid_funcs.keys(),
default="uuid4",
diff --git a/Lib/venv/__init__.py b/Lib/venv/__init__.py
index dc4c9ef3531..15e15b7a518 100644
--- a/Lib/venv/__init__.py
+++ b/Lib/venv/__init__.py
@@ -624,7 +624,9 @@ def main(args=None):
'created, you may wish to '
'activate it, e.g. by '
'sourcing an activate script '
- 'in its bin directory.')
+ 'in its bin directory.',
+ color=True,
+ )
parser.add_argument('dirs', metavar='ENV_DIR', nargs='+',
help='A directory to create the environment in.')
parser.add_argument('--system-site-packages', default=False,
diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py
index ab50ec1ee95..f2e2394089d 100644
--- a/Lib/webbrowser.py
+++ b/Lib/webbrowser.py
@@ -719,7 +719,9 @@ if sys.platform == "ios":
def parse_args(arg_list: list[str] | None):
import argparse
- parser = argparse.ArgumentParser(description="Open URL in a web browser.")
+ parser = argparse.ArgumentParser(
+ description="Open URL in a web browser.", color=True,
+ )
parser.add_argument("url", help="URL to open")
group = parser.add_mutually_exclusive_group()
diff --git a/Lib/zipapp.py b/Lib/zipapp.py
index 59b444075a6..7a4ef96ea0f 100644
--- a/Lib/zipapp.py
+++ b/Lib/zipapp.py
@@ -187,7 +187,7 @@ def main(args=None):
"""
import argparse
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(color=True)
parser.add_argument('--output', '-o', default=None,
help="The name of the output archive. "
"Required if SOURCE is an archive.")
diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py
index b7840d0f945..88356abe8cb 100644
--- a/Lib/zipfile/__init__.py
+++ b/Lib/zipfile/__init__.py
@@ -31,6 +31,11 @@ try:
except ImportError:
lzma = None
+try:
+ from compression import zstd # We may need its compression method
+except ImportError:
+ zstd = None
+
__all__ = ["BadZipFile", "BadZipfile", "error",
"ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA",
"is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile",
@@ -58,12 +63,14 @@ ZIP_STORED = 0
ZIP_DEFLATED = 8
ZIP_BZIP2 = 12
ZIP_LZMA = 14
+ZIP_ZSTANDARD = 93
# Other ZIP compression methods not supported
DEFAULT_VERSION = 20
ZIP64_VERSION = 45
BZIP2_VERSION = 46
LZMA_VERSION = 63
+ZSTANDARD_VERSION = 63
# we recognize (but not necessarily support) all features up to that version
MAX_EXTRACT_VERSION = 63
@@ -505,6 +512,8 @@ class ZipInfo:
min_version = max(BZIP2_VERSION, min_version)
elif self.compress_type == ZIP_LZMA:
min_version = max(LZMA_VERSION, min_version)
+ elif self.compress_type == ZIP_ZSTANDARD:
+ min_version = max(ZSTANDARD_VERSION, min_version)
self.extract_version = max(min_version, self.extract_version)
self.create_version = max(min_version, self.create_version)
@@ -766,6 +775,7 @@ compressor_names = {
14: 'lzma',
18: 'terse',
19: 'lz77',
+ 93: 'zstd',
97: 'wavpack',
98: 'ppmd',
}
@@ -785,6 +795,10 @@ def _check_compression(compression):
if not lzma:
raise RuntimeError(
"Compression requires the (missing) lzma module")
+ elif compression == ZIP_ZSTANDARD:
+ if not zstd:
+ raise RuntimeError(
+ "Compression requires the (missing) compression.zstd module")
else:
raise NotImplementedError("That compression method is not supported")
@@ -798,9 +812,11 @@ def _get_compressor(compress_type, compresslevel=None):
if compresslevel is not None:
return bz2.BZ2Compressor(compresslevel)
return bz2.BZ2Compressor()
- # compresslevel is ignored for ZIP_LZMA
+ # compresslevel is ignored for ZIP_LZMA and ZIP_ZSTANDARD
elif compress_type == ZIP_LZMA:
return LZMACompressor()
+ elif compress_type == ZIP_ZSTANDARD:
+ return zstd.ZstdCompressor()
else:
return None
@@ -815,6 +831,8 @@ def _get_decompressor(compress_type):
return bz2.BZ2Decompressor()
elif compress_type == ZIP_LZMA:
return LZMADecompressor()
+ elif compress_type == ZIP_ZSTANDARD:
+ return zstd.ZstdDecompressor()
else:
descr = compressor_names.get(compress_type)
if descr:
@@ -2317,7 +2335,7 @@ def main(args=None):
import argparse
description = 'A simple command-line interface for zipfile module.'
- parser = argparse.ArgumentParser(description=description)
+ parser = argparse.ArgumentParser(description=description, color=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-l', '--list', metavar='<zipfile>',
help='Show listing of a zipfile')