diff options
Diffstat (limited to 'Lib')
132 files changed, 6468 insertions, 1131 deletions
diff --git a/Lib/_colorize.py b/Lib/_colorize.py index 54895488e74..4a310a40235 100644 --- a/Lib/_colorize.py +++ b/Lib/_colorize.py @@ -1,28 +1,17 @@ -from __future__ import annotations import io import os import sys +from collections.abc import Callable, Iterator, Mapping +from dataclasses import dataclass, field, Field + COLORIZE = True + # types if False: - from typing import IO, Literal - - type ColorTag = Literal[ - "PROMPT", - "KEYWORD", - "BUILTIN", - "COMMENT", - "STRING", - "NUMBER", - "OP", - "DEFINITION", - "SOFT_KEYWORD", - "RESET", - ] - - theme: dict[ColorTag, str] + from typing import IO, Self, ClassVar + _theme: Theme class ANSIColors: @@ -86,6 +75,186 @@ for attr, code in ANSIColors.__dict__.items(): setattr(NoColors, attr, "") +# +# Experimental theming support (see gh-133346) +# + +# - Create a theme by copying an existing `Theme` with one or more sections +# replaced, using `default_theme.copy_with()`; +# - create a theme section by copying an existing `ThemeSection` with one or +# more colors replaced, using for example `default_theme.syntax.copy_with()`; +# - create a theme from scratch by instantiating a `Theme` data class with +# the required sections (which are also dataclass instances). +# +# Then call `_colorize.set_theme(your_theme)` to set it. +# +# Put your theme configuration in $PYTHONSTARTUP for the interactive shell, +# or sitecustomize.py in your virtual environment or Python installation for +# other uses. Your applications can call `_colorize.set_theme()` too. +# +# Note that thanks to the dataclasses providing default values for all fields, +# creating a new theme or theme section from scratch is possible without +# specifying all keys. +# +# For example, here's a theme that makes punctuation and operators less prominent: +# +# try: +# from _colorize import set_theme, default_theme, Syntax, ANSIColors +# except ImportError: +# pass +# else: +# theme_with_dim_operators = default_theme.copy_with( +# syntax=Syntax(op=ANSIColors.INTENSE_BLACK), +# ) +# set_theme(theme_with_dim_operators) +# del set_theme, default_theme, Syntax, ANSIColors, theme_with_dim_operators +# +# Guarding the import ensures that your .pythonstartup file will still work in +# Python 3.13 and older. Deleting the variables ensures they don't remain in your +# interactive shell's global scope. + +class ThemeSection(Mapping[str, str]): + """A mixin/base class for theme sections. + + It enables dictionary access to a section, as well as implements convenience + methods. + """ + + # The two types below are just that: types to inform the type checker that the + # mixin will work in context of those fields existing + __dataclass_fields__: ClassVar[dict[str, Field[str]]] + _name_to_value: Callable[[str], str] + + def __post_init__(self) -> None: + name_to_value = {} + for color_name in self.__dataclass_fields__: + name_to_value[color_name] = getattr(self, color_name) + super().__setattr__('_name_to_value', name_to_value.__getitem__) + + def copy_with(self, **kwargs: str) -> Self: + color_state: dict[str, str] = {} + for color_name in self.__dataclass_fields__: + color_state[color_name] = getattr(self, color_name) + color_state.update(kwargs) + return type(self)(**color_state) + + @classmethod + def no_colors(cls) -> Self: + color_state: dict[str, str] = {} + for color_name in cls.__dataclass_fields__: + color_state[color_name] = "" + return cls(**color_state) + + def __getitem__(self, key: str) -> str: + return self._name_to_value(key) + + def __len__(self) -> int: + return len(self.__dataclass_fields__) + + def __iter__(self) -> Iterator[str]: + return iter(self.__dataclass_fields__) + + +@dataclass(frozen=True) +class Argparse(ThemeSection): + usage: str = ANSIColors.BOLD_BLUE + prog: str = ANSIColors.BOLD_MAGENTA + prog_extra: str = ANSIColors.MAGENTA + heading: str = ANSIColors.BOLD_BLUE + summary_long_option: str = ANSIColors.CYAN + summary_short_option: str = ANSIColors.GREEN + summary_label: str = ANSIColors.YELLOW + summary_action: str = ANSIColors.GREEN + long_option: str = ANSIColors.BOLD_CYAN + short_option: str = ANSIColors.BOLD_GREEN + label: str = ANSIColors.BOLD_YELLOW + action: str = ANSIColors.BOLD_GREEN + reset: str = ANSIColors.RESET + + +@dataclass(frozen=True) +class Syntax(ThemeSection): + prompt: str = ANSIColors.BOLD_MAGENTA + keyword: str = ANSIColors.BOLD_BLUE + builtin: str = ANSIColors.CYAN + comment: str = ANSIColors.RED + string: str = ANSIColors.GREEN + number: str = ANSIColors.YELLOW + op: str = ANSIColors.RESET + definition: str = ANSIColors.BOLD + soft_keyword: str = ANSIColors.BOLD_BLUE + reset: str = ANSIColors.RESET + + +@dataclass(frozen=True) +class Traceback(ThemeSection): + type: str = ANSIColors.BOLD_MAGENTA + message: str = ANSIColors.MAGENTA + filename: str = ANSIColors.MAGENTA + line_no: str = ANSIColors.MAGENTA + frame: str = ANSIColors.MAGENTA + error_highlight: str = ANSIColors.BOLD_RED + error_range: str = ANSIColors.RED + reset: str = ANSIColors.RESET + + +@dataclass(frozen=True) +class Unittest(ThemeSection): + passed: str = ANSIColors.GREEN + warn: str = ANSIColors.YELLOW + fail: str = ANSIColors.RED + fail_info: str = ANSIColors.BOLD_RED + reset: str = ANSIColors.RESET + + +@dataclass(frozen=True) +class Theme: + """A suite of themes for all sections of Python. + + When adding a new one, remember to also modify `copy_with` and `no_colors` + below. + """ + argparse: Argparse = field(default_factory=Argparse) + syntax: Syntax = field(default_factory=Syntax) + traceback: Traceback = field(default_factory=Traceback) + unittest: Unittest = field(default_factory=Unittest) + + def copy_with( + self, + *, + argparse: Argparse | None = None, + syntax: Syntax | None = None, + traceback: Traceback | None = None, + unittest: Unittest | None = None, + ) -> Self: + """Return a new Theme based on this instance with some sections replaced. + + Themes are immutable to protect against accidental modifications that + could lead to invalid terminal states. + """ + return type(self)( + argparse=argparse or self.argparse, + syntax=syntax or self.syntax, + traceback=traceback or self.traceback, + unittest=unittest or self.unittest, + ) + + @classmethod + def no_colors(cls) -> Self: + """Return a new Theme where colors in all sections are empty strings. + + This allows writing user code as if colors are always used. The color + fields will be ANSI color code strings when colorization is desired + and possible, and empty strings otherwise. + """ + return cls( + argparse=Argparse.no_colors(), + syntax=Syntax.no_colors(), + traceback=Traceback.no_colors(), + unittest=Unittest.no_colors(), + ) + + def get_colors( colorize: bool = False, *, file: IO[str] | IO[bytes] | None = None ) -> ANSIColors: @@ -138,26 +307,40 @@ def can_colorize(*, file: IO[str] | IO[bytes] | None = None) -> bool: return hasattr(file, "isatty") and file.isatty() -def set_theme(t: dict[ColorTag, str] | None = None) -> None: - global theme +default_theme = Theme() +theme_no_color = default_theme.no_colors() + + +def get_theme( + *, + tty_file: IO[str] | IO[bytes] | None = None, + force_color: bool = False, + force_no_color: bool = False, +) -> Theme: + """Returns the currently set theme, potentially in a zero-color variant. + + In cases where colorizing is not possible (see `can_colorize`), the returned + theme contains all empty strings in all color definitions. + See `Theme.no_colors()` for more information. + + It is recommended not to cache the result of this function for extended + periods of time because the user might influence theme selection by + the interactive shell, a debugger, or application-specific code. The + environment (including environment variable state and console configuration + on Windows) can also change in the course of the application life cycle. + """ + if force_color or (not force_no_color and can_colorize(file=tty_file)): + return _theme + return theme_no_color + + +def set_theme(t: Theme) -> None: + global _theme - if t: - theme = t - return + if not isinstance(t, Theme): + raise ValueError(f"Expected Theme object, found {t}") - colors = get_colors() - theme = { - "PROMPT": colors.BOLD_MAGENTA, - "KEYWORD": colors.BOLD_BLUE, - "BUILTIN": colors.CYAN, - "COMMENT": colors.RED, - "STRING": colors.GREEN, - "NUMBER": colors.YELLOW, - "OP": colors.RESET, - "DEFINITION": colors.BOLD, - "SOFT_KEYWORD": colors.BOLD_BLUE, - "RESET": colors.RESET, - } + _theme = t -set_theme() +set_theme(default_theme) diff --git a/Lib/_pyrepl/base_eventqueue.py b/Lib/_pyrepl/base_eventqueue.py index e018c4fc183..842599bd187 100644 --- a/Lib/_pyrepl/base_eventqueue.py +++ b/Lib/_pyrepl/base_eventqueue.py @@ -69,18 +69,14 @@ class BaseEventQueue: trace('added event {event}', event=event) self.events.append(event) - def push(self, char: int | bytes | str) -> None: + def push(self, char: int | bytes) -> None: """ Processes a character by updating the buffer and handling special key mappings. """ + assert isinstance(char, (int, bytes)) ord_char = char if isinstance(char, int) else ord(char) - if ord_char > 255: - assert isinstance(char, str) - char = bytes(char.encode(self.encoding, "replace")) - self.buf.extend(char) - else: - char = bytes(bytearray((ord_char,))) - self.buf.append(ord_char) + char = ord_char.to_bytes() + self.buf.append(ord_char) if char in self.keymap: if self.keymap is self.compiled_keymap: diff --git a/Lib/_pyrepl/commands.py b/Lib/_pyrepl/commands.py index 2054a8e400f..2354fbb2ec2 100644 --- a/Lib/_pyrepl/commands.py +++ b/Lib/_pyrepl/commands.py @@ -439,7 +439,7 @@ class help(Command): import _sitebuiltins with self.reader.suspend(): - self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment, call-arg] + self.reader.msg = _sitebuiltins._Helper()() # type: ignore[assignment] class invalid_key(Command): diff --git a/Lib/_pyrepl/reader.py b/Lib/_pyrepl/reader.py index 65c2230dfd6..0ebd9162eca 100644 --- a/Lib/_pyrepl/reader.py +++ b/Lib/_pyrepl/reader.py @@ -28,7 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field, fields from . import commands, console, input -from .utils import wlen, unbracket, disp_str, gen_colors +from .utils import wlen, unbracket, disp_str, gen_colors, THEME from .trace import trace @@ -491,11 +491,8 @@ class Reader: prompt = self.ps1 if self.can_colorize: - prompt = ( - f"{_colorize.theme["PROMPT"]}" - f"{prompt}" - f"{_colorize.theme["RESET"]}" - ) + t = THEME() + prompt = f"{t.prompt}{prompt}{t.reset}" return prompt def push_input_trans(self, itrans: input.KeymapTranslator) -> None: diff --git a/Lib/_pyrepl/simple_interact.py b/Lib/_pyrepl/simple_interact.py index e2274629b65..b3848833e14 100644 --- a/Lib/_pyrepl/simple_interact.py +++ b/Lib/_pyrepl/simple_interact.py @@ -162,3 +162,8 @@ def run_multiline_interactive_console( except MemoryError: console.write("\nMemoryError\n") console.resetbuffer() + except SystemExit: + raise + except: + console.showtraceback() + console.resetbuffer() diff --git a/Lib/_pyrepl/unix_console.py b/Lib/_pyrepl/unix_console.py index 07b160d2324..d21cdd9b076 100644 --- a/Lib/_pyrepl/unix_console.py +++ b/Lib/_pyrepl/unix_console.py @@ -29,6 +29,7 @@ import signal import struct import termios import time +import types import platform from fcntl import ioctl @@ -39,6 +40,12 @@ from .trace import trace from .unix_eventqueue import EventQueue from .utils import wlen +# declare posix optional to allow None assignment on other platforms +posix: types.ModuleType | None +try: + import posix +except ImportError: + posix = None TYPE_CHECKING = False @@ -197,6 +204,12 @@ class UnixConsole(Console): self.event_queue = EventQueue(self.input_fd, self.encoding) self.cursor_visible = 1 + signal.signal(signal.SIGCONT, self._sigcont_handler) + + def _sigcont_handler(self, signum, frame): + self.restore() + self.prepare() + def __read(self, n: int) -> bytes: return os.read(self.input_fd, n) @@ -550,11 +563,9 @@ class UnixConsole(Console): @property def input_hook(self): - try: - import posix - except ImportError: - return None - if posix._is_inputhook_installed(): + # avoid inline imports here so the repl doesn't get flooded + # with import logging from -X importtime=2 + if posix is not None and posix._is_inputhook_installed(): return posix._inputhook def __enable_bracketed_paste(self) -> None: diff --git a/Lib/_pyrepl/utils.py b/Lib/_pyrepl/utils.py index fe154aa59a0..38cf6b5a08e 100644 --- a/Lib/_pyrepl/utils.py +++ b/Lib/_pyrepl/utils.py @@ -23,6 +23,11 @@ IDENTIFIERS_AFTER = {"def", "class"} BUILTINS = {str(name) for name in dir(builtins) if not name.startswith('_')} +def THEME(**kwargs): + # Not cached: the user can modify the theme inside the interactive session. + return _colorize.get_theme(**kwargs).syntax + + class Span(NamedTuple): """Span indexing that's inclusive on both ends.""" @@ -44,7 +49,7 @@ class Span(NamedTuple): class ColorSpan(NamedTuple): span: Span - tag: _colorize.ColorTag + tag: str @functools.cache @@ -135,7 +140,7 @@ def recover_unterminated_string( span = Span(start, end) trace("yielding span {a} -> {b}", a=span.start, b=span.end) - yield ColorSpan(span, "STRING") + yield ColorSpan(span, "string") else: trace( "unhandled token error({buffer}) = {te}", @@ -164,28 +169,28 @@ def gen_colors_from_token_stream( | T.TSTRING_START | T.TSTRING_MIDDLE | T.TSTRING_END ): span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "STRING") + yield ColorSpan(span, "string") case T.COMMENT: span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "COMMENT") + yield ColorSpan(span, "comment") case T.NUMBER: span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "NUMBER") + yield ColorSpan(span, "number") case T.OP: if token.string in "([{": bracket_level += 1 elif token.string in ")]}": bracket_level -= 1 span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "OP") + yield ColorSpan(span, "op") case T.NAME: if is_def_name: is_def_name = False span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "DEFINITION") + yield ColorSpan(span, "definition") elif keyword.iskeyword(token.string): span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "KEYWORD") + yield ColorSpan(span, "keyword") if token.string in IDENTIFIERS_AFTER: is_def_name = True elif ( @@ -194,10 +199,10 @@ def gen_colors_from_token_stream( and is_soft_keyword_used(prev_token, token, next_token) ): span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "SOFT_KEYWORD") + yield ColorSpan(span, "soft_keyword") elif token.string in BUILTINS: span = Span.from_token(token, line_lengths) - yield ColorSpan(span, "BUILTIN") + yield ColorSpan(span, "builtin") keyword_first_sets_match = {"False", "None", "True", "await", "lambda", "not"} @@ -249,7 +254,10 @@ def is_soft_keyword_used(*tokens: TI | None) -> bool: def disp_str( - buffer: str, colors: list[ColorSpan] | None = None, start_index: int = 0 + buffer: str, + colors: list[ColorSpan] | None = None, + start_index: int = 0, + force_color: bool = False, ) -> tuple[CharBuffer, CharWidths]: r"""Decompose the input buffer into a printable variant with applied colors. @@ -290,15 +298,16 @@ def disp_str( # move past irrelevant spans colors.pop(0) + theme = THEME(force_color=force_color) pre_color = "" post_color = "" if colors and colors[0].span.start < start_index: # looks like we're continuing a previous color (e.g. a multiline str) - pre_color = _colorize.theme[colors[0].tag] + pre_color = theme[colors[0].tag] for i, c in enumerate(buffer, start_index): if colors and colors[0].span.start == i: # new color starts now - pre_color = _colorize.theme[colors[0].tag] + pre_color = theme[colors[0].tag] if c == "\x1a": # CTRL-Z on Windows chars.append(c) @@ -315,7 +324,7 @@ def disp_str( char_widths.append(str_width(c)) if colors and colors[0].span.end == i: # current color ends now - post_color = _colorize.theme["RESET"] + post_color = theme.reset colors.pop(0) chars[-1] = pre_color + chars[-1] + post_color @@ -325,7 +334,7 @@ def disp_str( if colors and colors[0].span.start < i and colors[0].span.end > i: # even though the current color should be continued, reset it for now. # the next call to `disp_str()` will revive it. - chars[-1] += _colorize.theme["RESET"] + chars[-1] += theme.reset return chars, char_widths diff --git a/Lib/_pyrepl/windows_console.py b/Lib/_pyrepl/windows_console.py index 77985e59a93..95749198b3b 100644 --- a/Lib/_pyrepl/windows_console.py +++ b/Lib/_pyrepl/windows_console.py @@ -24,6 +24,7 @@ import os import sys import ctypes +import types from ctypes.wintypes import ( _COORD, WORD, @@ -58,6 +59,12 @@ except: self.err = err self.descr = descr +# declare nt optional to allow None assignment on other platforms +nt: types.ModuleType | None +try: + import nt +except ImportError: + nt = None TYPE_CHECKING = False @@ -121,9 +128,8 @@ class _error(Exception): def _supports_vt(): try: - import nt return nt._supports_virtual_terminal() - except (ImportError, AttributeError): + except AttributeError: return False class WindowsConsole(Console): @@ -235,11 +241,9 @@ class WindowsConsole(Console): @property def input_hook(self): - try: - import nt - except ImportError: - return None - if nt._is_inputhook_installed(): + # avoid inline imports here so the repl doesn't get flooded + # with import logging from -X importtime=2 + if nt is not None and nt._is_inputhook_installed(): return nt._inputhook def __write_changed_line( @@ -464,7 +468,7 @@ class WindowsConsole(Console): if key == "\r": # Make enter unix-like - return Event(evt="key", data="\n", raw=b"\n") + return Event(evt="key", data="\n") elif key_event.wVirtualKeyCode == 8: # Turn backspace directly into the command key = "backspace" @@ -476,24 +480,29 @@ class WindowsConsole(Console): key = f"ctrl {key}" elif key_event.dwControlKeyState & ALT_ACTIVE: # queue the key, return the meta command - self.event_queue.insert(Event(evt="key", data=key, raw=key)) + self.event_queue.insert(Event(evt="key", data=key)) return Event(evt="key", data="\033") # keymap.py uses this for meta - return Event(evt="key", data=key, raw=key) + return Event(evt="key", data=key) if block: continue return None elif self.__vt_support: # If virtual terminal is enabled, scanning VT sequences - self.event_queue.push(rec.Event.KeyEvent.uChar.UnicodeChar) + for char in raw_key.encode(self.event_queue.encoding, "replace"): + self.event_queue.push(char) continue if key_event.dwControlKeyState & ALT_ACTIVE: - # queue the key, return the meta command - self.event_queue.insert(Event(evt="key", data=key, raw=raw_key)) - return Event(evt="key", data="\033") # keymap.py uses this for meta - - return Event(evt="key", data=key, raw=raw_key) + # Do not swallow characters that have been entered via AltGr: + # Windows internally converts AltGr to CTRL+ALT, see + # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw + if not key_event.dwControlKeyState & CTRL_ACTIVE: + # queue the key, return the meta command + self.event_queue.insert(Event(evt="key", data=key)) + return Event(evt="key", data="\033") # keymap.py uses this for meta + + return Event(evt="key", data=key) return self.event_queue.get() def push_char(self, char: int | bytes) -> None: diff --git a/Lib/_threading_local.py b/Lib/_threading_local.py index b006d76c4e2..0b9e5d3bbf6 100644 --- a/Lib/_threading_local.py +++ b/Lib/_threading_local.py @@ -4,128 +4,6 @@ class. Depending on the version of Python you're using, there may be a faster one available. You should always import the `local` class from `threading`.) - -Thread-local objects support the management of thread-local data. -If you have data that you want to be local to a thread, simply create -a thread-local object and use its attributes: - - >>> mydata = local() - >>> mydata.number = 42 - >>> mydata.number - 42 - -You can also access the local-object's dictionary: - - >>> mydata.__dict__ - {'number': 42} - >>> mydata.__dict__.setdefault('widgets', []) - [] - >>> mydata.widgets - [] - -What's important about thread-local objects is that their data are -local to a thread. If we access the data in a different thread: - - >>> log = [] - >>> def f(): - ... items = sorted(mydata.__dict__.items()) - ... log.append(items) - ... mydata.number = 11 - ... log.append(mydata.number) - - >>> import threading - >>> thread = threading.Thread(target=f) - >>> thread.start() - >>> thread.join() - >>> log - [[], 11] - -we get different data. Furthermore, changes made in the other thread -don't affect data seen in this thread: - - >>> mydata.number - 42 - -Of course, values you get from a local object, including a __dict__ -attribute, are for whatever thread was current at the time the -attribute was read. For that reason, you generally don't want to save -these values across threads, as they apply only to the thread they -came from. - -You can create custom local objects by subclassing the local class: - - >>> class MyLocal(local): - ... number = 2 - ... def __init__(self, /, **kw): - ... self.__dict__.update(kw) - ... def squared(self): - ... return self.number ** 2 - -This can be useful to support default values, methods and -initialization. Note that if you define an __init__ method, it will be -called each time the local object is used in a separate thread. This -is necessary to initialize each thread's dictionary. - -Now if we create a local object: - - >>> mydata = MyLocal(color='red') - -Now we have a default number: - - >>> mydata.number - 2 - -an initial color: - - >>> mydata.color - 'red' - >>> del mydata.color - -And a method that operates on the data: - - >>> mydata.squared() - 4 - -As before, we can access the data in a separate thread: - - >>> log = [] - >>> thread = threading.Thread(target=f) - >>> thread.start() - >>> thread.join() - >>> log - [[('color', 'red')], 11] - -without affecting this thread's data: - - >>> mydata.number - 2 - >>> mydata.color - Traceback (most recent call last): - ... - AttributeError: 'MyLocal' object has no attribute 'color' - -Note that subclasses can define slots, but they are not thread -local. They are shared across threads: - - >>> class MyLocal(local): - ... __slots__ = 'number' - - >>> mydata = MyLocal() - >>> mydata.number = 42 - >>> mydata.color = 'red' - -So, the separate thread: - - >>> thread = threading.Thread(target=f) - >>> thread.start() - >>> thread.join() - -affects what we see: - - >>> mydata.number - 11 - ->>> del mydata """ from weakref import ref diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 5ad0893106a..c0b1d4395d1 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -868,7 +868,7 @@ def get_annotations( # For FORWARDREF, we use __annotations__ if it exists try: ann = _get_dunder_annotations(obj) - except NameError: + except Exception: pass else: if ann is not None: diff --git a/Lib/argparse.py b/Lib/argparse.py index c0dcd0bbff0..f13ac82dbc5 100644 --- a/Lib/argparse.py +++ b/Lib/argparse.py @@ -176,13 +176,13 @@ class HelpFormatter(object): width = shutil.get_terminal_size().columns width -= 2 - from _colorize import ANSIColors, NoColors, can_colorize, decolor + from _colorize import can_colorize, decolor, get_theme if color and can_colorize(): - self._ansi = ANSIColors() + self._theme = get_theme(force_color=True).argparse self._decolor = decolor else: - self._ansi = NoColors + self._theme = get_theme(force_no_color=True).argparse self._decolor = lambda text: text self._prefix_chars = prefix_chars @@ -237,14 +237,12 @@ class HelpFormatter(object): # add the heading if the section was non-empty if self.heading is not SUPPRESS and self.heading is not None: - bold_blue = self.formatter._ansi.BOLD_BLUE - reset = self.formatter._ansi.RESET - current_indent = self.formatter._current_indent heading_text = _('%(heading)s:') % dict(heading=self.heading) + t = self.formatter._theme heading = ( f'{" " * current_indent}' - f'{bold_blue}{heading_text}{reset}\n' + f'{t.heading}{heading_text}{t.reset}\n' ) else: heading = '' @@ -314,10 +312,7 @@ class HelpFormatter(object): if part and part is not SUPPRESS]) def _format_usage(self, usage, actions, groups, prefix): - bold_blue = self._ansi.BOLD_BLUE - bold_magenta = self._ansi.BOLD_MAGENTA - magenta = self._ansi.MAGENTA - reset = self._ansi.RESET + t = self._theme if prefix is None: prefix = _('usage: ') @@ -325,15 +320,15 @@ class HelpFormatter(object): # if usage is specified, use that if usage is not None: usage = ( - magenta + t.prog_extra + usage - % {"prog": f"{bold_magenta}{self._prog}{reset}{magenta}"} - + reset + % {"prog": f"{t.prog}{self._prog}{t.reset}{t.prog_extra}"} + + t.reset ) # if no optionals or positionals are available, usage is just prog elif usage is None and not actions: - usage = f"{bold_magenta}{self._prog}{reset}" + usage = f"{t.prog}{self._prog}{t.reset}" # if optionals and positionals are available, calculate usage elif usage is None: @@ -411,10 +406,10 @@ class HelpFormatter(object): usage = '\n'.join(lines) usage = usage.removeprefix(prog) - usage = f"{bold_magenta}{prog}{reset}{usage}" + usage = f"{t.prog}{prog}{t.reset}{usage}" # prefix with 'usage:' - return f'{bold_blue}{prefix}{reset}{usage}\n\n' + return f'{t.usage}{prefix}{t.reset}{usage}\n\n' def _format_actions_usage(self, actions, groups): return ' '.join(self._get_actions_usage_parts(actions, groups)) @@ -452,10 +447,7 @@ class HelpFormatter(object): # collect all actions format strings parts = [] - cyan = self._ansi.CYAN - green = self._ansi.GREEN - yellow = self._ansi.YELLOW - reset = self._ansi.RESET + t = self._theme for action in actions: # suppressed arguments are marked with None @@ -465,7 +457,11 @@ class HelpFormatter(object): # produce all arg strings elif not action.option_strings: default = self._get_default_metavar_for_positional(action) - part = green + self._format_args(action, default) + reset + part = ( + t.summary_action + + self._format_args(action, default) + + t.reset + ) # if it's in a group, strip the outer [] if action in group_actions: @@ -481,9 +477,9 @@ class HelpFormatter(object): if action.nargs == 0: part = action.format_usage() if self._is_long_option(part): - part = f"{cyan}{part}{reset}" + part = f"{t.summary_long_option}{part}{t.reset}" elif self._is_short_option(part): - part = f"{green}{part}{reset}" + part = f"{t.summary_short_option}{part}{t.reset}" # if the Optional takes a value, format is: # -s ARGS or --long ARGS @@ -491,10 +487,13 @@ class HelpFormatter(object): default = self._get_default_metavar_for_optional(action) args_string = self._format_args(action, default) if self._is_long_option(option_string): - option_string = f"{cyan}{option_string}" + option_color = t.summary_long_option elif self._is_short_option(option_string): - option_string = f"{green}{option_string}" - part = f"{option_string} {yellow}{args_string}{reset}" + option_color = t.summary_short_option + part = ( + f"{option_color}{option_string} " + f"{t.summary_label}{args_string}{t.reset}" + ) # make it look optional if it's not required or in a group if not action.required and action not in group_actions: @@ -590,17 +589,14 @@ class HelpFormatter(object): return self._join_parts(parts) def _format_action_invocation(self, action): - bold_green = self._ansi.BOLD_GREEN - bold_cyan = self._ansi.BOLD_CYAN - bold_yellow = self._ansi.BOLD_YELLOW - reset = self._ansi.RESET + t = self._theme if not action.option_strings: default = self._get_default_metavar_for_positional(action) return ( - bold_green + t.action + ' '.join(self._metavar_formatter(action, default)(1)) - + reset + + t.reset ) else: @@ -609,9 +605,9 @@ class HelpFormatter(object): parts = [] for s in strings: if self._is_long_option(s): - parts.append(f"{bold_cyan}{s}{reset}") + parts.append(f"{t.long_option}{s}{t.reset}") elif self._is_short_option(s): - parts.append(f"{bold_green}{s}{reset}") + parts.append(f"{t.short_option}{s}{t.reset}") else: parts.append(s) return parts @@ -628,7 +624,7 @@ class HelpFormatter(object): default = self._get_default_metavar_for_optional(action) option_strings = color_option_strings(action.option_strings) args_string = ( - f"{bold_yellow}{self._format_args(action, default)}{reset}" + f"{t.label}{self._format_args(action, default)}{t.reset}" ) return ', '.join(option_strings) + ' ' + args_string diff --git a/Lib/ast.py b/Lib/ast.py index af4fe8ff5a8..b9791bf52d3 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -630,7 +630,7 @@ def main(args=None): import argparse import sys - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('infile', nargs='?', default='-', help='the file to parse; defaults to stdin') parser.add_argument('-m', '--mode', default='exec', @@ -643,6 +643,15 @@ def main(args=None): 'column offsets') parser.add_argument('-i', '--indent', type=int, default=3, help='indentation of nodes (number of spaces)') + parser.add_argument('--feature-version', + type=str, default=None, metavar='VERSION', + help='Python version in the format 3.x ' + '(for example, 3.10)') + parser.add_argument('-O', '--optimize', + type=int, default=-1, metavar='LEVEL', + help='optimization level for parser (default -1)') + parser.add_argument('--show-empty', default=False, action='store_true', + help='show empty lists and fields in dump output') args = parser.parse_args(args) if args.infile == '-': @@ -652,8 +661,22 @@ def main(args=None): name = args.infile with open(args.infile, 'rb') as infile: source = infile.read() - tree = parse(source, name, args.mode, type_comments=args.no_type_comments) - print(dump(tree, include_attributes=args.include_attributes, indent=args.indent)) + + # Process feature_version + feature_version = None + if args.feature_version: + try: + major, minor = map(int, args.feature_version.split('.', 1)) + except ValueError: + parser.error('Invalid format for --feature-version; ' + 'expected format 3.x (for example, 3.10)') + + feature_version = (major, minor) + + tree = parse(source, name, args.mode, type_comments=args.no_type_comments, + feature_version=feature_version, optimize=args.optimize) + print(dump(tree, include_attributes=args.include_attributes, + indent=args.indent, show_empty=args.show_empty)) if __name__ == '__main__': main() diff --git a/Lib/asyncio/__main__.py b/Lib/asyncio/__main__.py index 7d980bc401a..21ca5c5f62a 100644 --- a/Lib/asyncio/__main__.py +++ b/Lib/asyncio/__main__.py @@ -12,7 +12,7 @@ import threading import types import warnings -from _colorize import can_colorize, ANSIColors # type: ignore[import-not-found] +from _colorize import get_theme from _pyrepl.console import InteractiveColoredConsole from . import futures @@ -103,8 +103,9 @@ class REPLThread(threading.Thread): exec(startup_code, console.locals) ps1 = getattr(sys, "ps1", ">>> ") - if can_colorize() and CAN_USE_PYREPL: - ps1 = f"{ANSIColors.BOLD_MAGENTA}{ps1}{ANSIColors.RESET}" + if CAN_USE_PYREPL: + theme = get_theme().syntax + ps1 = f"{theme.prompt}{ps1}{theme.reset}" console.write(f"{ps1}import asyncio\n") if CAN_USE_PYREPL: @@ -145,6 +146,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser( prog="python3 -m asyncio", description="Interactive asyncio shell and CLI tools", + color=True, ) subparsers = parser.add_subparsers(help="sub-commands", dest="command") ps = subparsers.add_parser( diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 29b872ce00e..04fb961e998 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -459,7 +459,7 @@ class BaseEventLoop(events.AbstractEventLoop): return futures.Future(loop=self) def create_task(self, coro, **kwargs): - """Schedule a coroutine object. + """Schedule or begin executing a coroutine object. Return a task object. """ diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 1633478d1c8..00e8f6d5d1a 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -179,7 +179,7 @@ class TaskGroup: exc = None - def create_task(self, coro, *, name=None, context=None): + def create_task(self, coro, **kwargs): """Create a new task in this group and return it. Similar to `asyncio.create_task`. @@ -193,10 +193,7 @@ class TaskGroup: if self._aborting: coro.close() raise RuntimeError(f"TaskGroup {self!r} is shutting down") - if context is None: - task = self._loop.create_task(coro, name=name) - else: - task = self._loop.create_task(coro, name=name, context=context) + task = self._loop.create_task(coro, **kwargs) futures.future_add_to_awaited_by(task, self._parent_task) diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 825e91f5594..888615f8e5e 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -386,19 +386,13 @@ else: Task = _CTask = _asyncio.Task -def create_task(coro, *, name=None, context=None): +def create_task(coro, **kwargs): """Schedule the execution of a coroutine object in a spawn task. Return a Task object. """ loop = events.get_running_loop() - if context is None: - # Use legacy API if context is not needed - task = loop.create_task(coro, name=name) - else: - task = loop.create_task(coro, name=name, context=context) - - return task + return loop.create_task(coro, **kwargs) # wait() and as_completed() similar to those in PEP 3148. @@ -1030,9 +1024,9 @@ def create_eager_task_factory(custom_task_constructor): used. E.g. `loop.set_task_factory(asyncio.eager_task_factory)`. """ - def factory(loop, coro, *, name=None, context=None): + def factory(loop, coro, *, eager_start=True, **kwargs): return custom_task_constructor( - coro, loop=loop, name=name, context=context, eager_start=True) + coro, loop=loop, eager_start=eager_start, **kwargs) return factory diff --git a/Lib/asyncio/tools.py b/Lib/asyncio/tools.py index 6c1f725e777..bf1cb5e64cb 100644 --- a/Lib/asyncio/tools.py +++ b/Lib/asyncio/tools.py @@ -5,7 +5,7 @@ from collections import defaultdict from itertools import count from enum import Enum import sys -from _remotedebugging import get_all_awaited_by +from _remote_debugging import get_all_awaited_by class NodeType(Enum): @@ -21,13 +21,21 @@ class CycleFoundException(Exception): # ─── indexing helpers ─────────────────────────────────────────── +def _format_stack_entry(elem: tuple[str, str, int] | str) -> str: + if isinstance(elem, tuple): + fqname, path, line_no = elem + return f"{fqname} {path}:{line_no}" + + return elem + + def _index(result): id2name, awaits = {}, [] for _thr_id, tasks in result: for tid, tname, awaited in tasks: id2name[tid] = tname for stack, parent_id in awaited: - stack = [elem[0] if isinstance(elem, tuple) else elem for elem in stack] + stack = [_format_stack_entry(elem) for elem in stack] awaits.append((parent_id, stack, tid)) return id2name, awaits @@ -106,7 +114,7 @@ def _find_cycles(graph): # ─── PRINT TREE FUNCTION ─────────────────────────────────────── def build_async_tree(result, task_emoji="(T)", cor_emoji=""): """ - Build a list of strings for pretty-print a async call tree. + Build a list of strings for pretty-print an async call tree. The call tree is produced by `get_all_async_stacks()`, prefixing tasks with `task_emoji` and coroutine frames with `cor_emoji`. @@ -169,7 +177,7 @@ def build_task_table(result): return table def _print_cycle_exception(exception: CycleFoundException): - print("ERROR: await-graph contains cycles – cannot print a tree!", file=sys.stderr) + print("ERROR: await-graph contains cycles - cannot print a tree!", file=sys.stderr) print("", file=sys.stderr) for c in exception.cycles: inames = " → ".join(exception.id2name.get(tid, hex(tid)) for tid in c) diff --git a/Lib/calendar.py b/Lib/calendar.py index 01a76ff8e78..18f76d52ff8 100644 --- a/Lib/calendar.py +++ b/Lib/calendar.py @@ -810,7 +810,7 @@ def timegm(tuple): def main(args=None): import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) textgroup = parser.add_argument_group('text only arguments') htmlgroup = parser.add_argument_group('html only arguments') textgroup.add_argument( diff --git a/Lib/code.py b/Lib/code.py index 41331dfd071..b134886dc26 100644 --- a/Lib/code.py +++ b/Lib/code.py @@ -385,7 +385,7 @@ def interact(banner=None, readfunc=None, local=None, exitmsg=None, local_exit=Fa if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('-q', action='store_true', help="don't print version and copyright messages") args = parser.parse_args() diff --git a/Lib/compileall.py b/Lib/compileall.py index 47e2446356e..67fe370451e 100644 --- a/Lib/compileall.py +++ b/Lib/compileall.py @@ -317,7 +317,9 @@ def main(): import argparse parser = argparse.ArgumentParser( - description='Utilities to support installing Python libraries.') + description='Utilities to support installing Python libraries.', + color=True, + ) parser.add_argument('-l', action='store_const', const=0, default=None, dest='maxlevels', help="don't recurse into subdirectories") diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py new file mode 100644 index 00000000000..4f734eb07b0 --- /dev/null +++ b/Lib/compression/zstd/__init__.py @@ -0,0 +1,234 @@ +"""Python bindings to the Zstandard (zstd) compression library (RFC-8878).""" + +__all__ = ( + # compression.zstd + "COMPRESSION_LEVEL_DEFAULT", + "compress", + "CompressionParameter", + "decompress", + "DecompressionParameter", + "finalize_dict", + "get_frame_info", + "Strategy", + "train_dict", + + # compression.zstd._zstdfile + "open", + "ZstdFile", + + # _zstd + "get_frame_size", + "zstd_version", + "zstd_version_info", + "ZstdCompressor", + "ZstdDecompressor", + "ZstdDict", + "ZstdError", +) + +import _zstd +import enum +from _zstd import * +from compression.zstd._zstdfile import ZstdFile, open, _nbytes + +COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0] +"""The default compression level for Zstandard, currently '3'.""" + + +class FrameInfo: + """Information about a Zstandard frame.""" + __slots__ = 'decompressed_size', 'dictionary_id' + + def __init__(self, decompressed_size, dictionary_id): + super().__setattr__('decompressed_size', decompressed_size) + super().__setattr__('dictionary_id', dictionary_id) + + def __repr__(self): + return (f'FrameInfo(decompressed_size={self.decompressed_size}, ' + f'dictionary_id={self.dictionary_id})') + + def __setattr__(self, name, _): + raise AttributeError(f"can't set attribute {name!r}") + + +def get_frame_info(frame_buffer): + """Get Zstandard frame information from a frame header. + + *frame_buffer* is a bytes-like object. It should start from the beginning + of a frame, and needs to include at least the frame header (6 to 18 bytes). + + The returned FrameInfo object has two attributes. + 'decompressed_size' is the size in bytes of the data in the frame when + decompressed, or None when the decompressed size is unknown. + 'dictionary_id' is an int in the range (0, 2**32). The special value 0 + means that the dictionary ID was not recorded in the frame header, + the frame may or may not need a dictionary to be decoded, + and the ID of such a dictionary is not specified. + """ + return FrameInfo(*_zstd._get_frame_info(frame_buffer)) + + +def train_dict(samples, dict_size): + """Return a ZstdDict representing a trained Zstandard dictionary. + + *samples* is an iterable of samples, where a sample is a bytes-like + object representing a file. + + *dict_size* is the dictionary's maximum size, in bytes. + """ + if not isinstance(dict_size, int): + ds_cls = type(dict_size).__qualname__ + raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.') + + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) + if not chunks: + raise ValueError("samples contained no data; can't train dictionary.") + dict_content = _zstd._train_dict(chunks, chunk_sizes, dict_size) + return ZstdDict(dict_content) + + +def finalize_dict(zstd_dict, /, samples, dict_size, level): + """Return a ZstdDict representing a finalized Zstandard dictionary. + + Given a custom content as a basis for dictionary, and a set of samples, + finalize *zstd_dict* by adding headers and statistics according to the + Zstandard dictionary format. + + You may compose an effective dictionary content by hand, which is used as + basis dictionary, and use some samples to finalize a dictionary. The basis + dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict. + + *samples* is an iterable of samples, where a sample is a bytes-like object + representing a file. + *dict_size* is the dictionary's maximum size, in bytes. + *level* is the expected compression level. The statistics for each + compression level differ, so tuning the dictionary to the compression level + can provide improvements. + """ + + if not isinstance(zstd_dict, ZstdDict): + raise TypeError('zstd_dict argument should be a ZstdDict object.') + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + if not isinstance(level, int): + raise TypeError('level argument should be an int object.') + + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) + if not chunks: + raise ValueError("The samples are empty content, can't finalize the" + "dictionary.") + dict_content = _zstd._finalize_dict(zstd_dict.dict_content, + chunks, chunk_sizes, + dict_size, level) + return ZstdDict(dict_content) + +def compress(data, level=None, options=None, zstd_dict=None): + """Return Zstandard compressed *data* as bytes. + + *level* is an int specifying the compression level to use, defaulting to + COMPRESSION_LEVEL_DEFAULT ('3'). + *options* is a dict object that contains advanced compression + parameters. See CompressionParameter for more on options. + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + + For incremental compression, use a ZstdCompressor instead. + """ + comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) + return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME) + +def decompress(data, zstd_dict=None, options=None): + """Decompress one or more frames of Zstandard compressed *data*. + + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + *options* is a dict object that contains advanced compression + parameters. See DecompressionParameter for more on options. + + For incremental decompression, use a ZstdDecompressor instead. + """ + results = [] + while True: + decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict) + results.append(decomp.decompress(data)) + if not decomp.eof: + raise ZstdError("Compressed data ended before the " + "end-of-stream marker was reached") + data = decomp.unused_data + if not data: + break + return b"".join(results) + + +class CompressionParameter(enum.IntEnum): + """Compression parameters.""" + + compression_level = _zstd._ZSTD_c_compressionLevel + window_log = _zstd._ZSTD_c_windowLog + hash_log = _zstd._ZSTD_c_hashLog + chain_log = _zstd._ZSTD_c_chainLog + search_log = _zstd._ZSTD_c_searchLog + min_match = _zstd._ZSTD_c_minMatch + target_length = _zstd._ZSTD_c_targetLength + strategy = _zstd._ZSTD_c_strategy + + enable_long_distance_matching = _zstd._ZSTD_c_enableLongDistanceMatching + ldm_hash_log = _zstd._ZSTD_c_ldmHashLog + ldm_min_match = _zstd._ZSTD_c_ldmMinMatch + ldm_bucket_size_log = _zstd._ZSTD_c_ldmBucketSizeLog + ldm_hash_rate_log = _zstd._ZSTD_c_ldmHashRateLog + + content_size_flag = _zstd._ZSTD_c_contentSizeFlag + checksum_flag = _zstd._ZSTD_c_checksumFlag + dict_id_flag = _zstd._ZSTD_c_dictIDFlag + + nb_workers = _zstd._ZSTD_c_nbWorkers + job_size = _zstd._ZSTD_c_jobSize + overlap_log = _zstd._ZSTD_c_overlapLog + + def bounds(self): + """Return the (lower, upper) int bounds of a compression parameter. + + Both the lower and upper bounds are inclusive. + """ + return _zstd._get_param_bounds(self.value, is_compress=True) + + +class DecompressionParameter(enum.IntEnum): + """Decompression parameters.""" + + window_log_max = _zstd._ZSTD_d_windowLogMax + + def bounds(self): + """Return the (lower, upper) int bounds of a decompression parameter. + + Both the lower and upper bounds are inclusive. + """ + return _zstd._get_param_bounds(self.value, is_compress=False) + + +class Strategy(enum.IntEnum): + """Compression strategies, listed from fastest to strongest. + + Note that new strategies might be added in the future. + Only the order (from fast to strong) is guaranteed, + the numeric value might change. + """ + + fast = _zstd._ZSTD_fast + dfast = _zstd._ZSTD_dfast + greedy = _zstd._ZSTD_greedy + lazy = _zstd._ZSTD_lazy + lazy2 = _zstd._ZSTD_lazy2 + btlazy2 = _zstd._ZSTD_btlazy2 + btopt = _zstd._ZSTD_btopt + btultra = _zstd._ZSTD_btultra + btultra2 = _zstd._ZSTD_btultra2 + + +# Check validity of the CompressionParameter & DecompressionParameter types +_zstd._set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py new file mode 100644 index 00000000000..fbc9e02a733 --- /dev/null +++ b/Lib/compression/zstd/_zstdfile.py @@ -0,0 +1,349 @@ +import io +from os import PathLike +from _zstd import (ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, + ZstdError) +from compression._common import _streams + +__all__ = ("ZstdFile", "open") + +_ZSTD_DStreamOutSize = _ZSTD_DStreamSizes[1] + +_MODE_CLOSED = 0 +_MODE_READ = 1 +_MODE_WRITE = 2 + + +def _nbytes(dat, /): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + with memoryview(dat) as mv: + return mv.nbytes + + +class ZstdFile(_streams.BaseStream): + """A file-like object providing transparent Zstandard (de)compression. + + A ZstdFile can act as a wrapper for an existing file object, or refer + directly to a named file on disk. + + ZstdFile provides a *binary* file interface. Data is read and returned as + bytes, and may only be written to objects that support the Buffer Protocol. + """ + + FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK + FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME + + def __init__(self, file, /, mode="r", *, + level=None, options=None, zstd_dict=None): + """Open a Zstandard compressed file in binary mode. + + *file* can be either an file-like object, or a file name to open. + + *mode* can be "r" for reading (default), "w" for (over)writing, "x" for + creating exclusively, or "a" for appending. These can equivalently be + given as "rb", "wb", "xb" and "ab" respectively. + + *level* is an optional int specifying the compression level to use, + or COMPRESSION_LEVEL_DEFAULT if not given. + + *options* is an optional dict for advanced compression parameters. + See CompressionParameter and DecompressionParameter for the possible + options. + + *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard + dictionary. See train_dict() to train ZstdDict on sample data. + """ + self._fp = None + self._close_fp = False + self._mode = _MODE_CLOSED + self._buffer = None + + if not isinstance(mode, str): + raise ValueError("mode must be a str") + if options is not None and not isinstance(options, dict): + raise TypeError("options must be a dict or None") + mode = mode.removesuffix("b") # handle rb, wb, xb, ab + if mode == "r": + if level is not None: + raise TypeError("level is illegal in read mode") + self._mode = _MODE_READ + elif mode in {"w", "a", "x"}: + if level is not None and not isinstance(level, int): + raise TypeError("level must be int or None") + self._mode = _MODE_WRITE + self._compressor = ZstdCompressor(level=level, options=options, + zstd_dict=zstd_dict) + self._pos = 0 + else: + raise ValueError(f"Invalid mode: {mode!r}") + + if isinstance(file, (str, bytes, PathLike)): + self._fp = io.open(file, f'{mode}b') + self._close_fp = True + elif ((mode == 'r' and hasattr(file, "read")) + or (mode != 'r' and hasattr(file, "write"))): + self._fp = file + else: + raise TypeError("file must be a file-like object " + "or a str, bytes, or PathLike object") + + if self._mode == _MODE_READ: + raw = _streams.DecompressReader( + self._fp, + ZstdDecompressor, + trailing_error=ZstdError, + zstd_dict=zstd_dict, + options=options, + ) + self._buffer = io.BufferedReader(raw) + + def close(self): + """Flush and close the file. + + May be called multiple times. Once the file has been closed, + any other operation on it will raise ValueError. + """ + if self._fp is None: + return + try: + if self._mode == _MODE_READ: + if getattr(self, '_buffer', None): + self._buffer.close() + self._buffer = None + elif self._mode == _MODE_WRITE: + self.flush(self.FLUSH_FRAME) + self._compressor = None + finally: + self._mode = _MODE_CLOSED + try: + if self._close_fp: + self._fp.close() + finally: + self._fp = None + self._close_fp = False + + def write(self, data, /): + """Write a bytes-like object *data* to the file. + + Returns the number of uncompressed bytes written, which is + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until .flush() + or .close() is called. + """ + self._check_can_write() + + length = _nbytes(data) + + compressed = self._compressor.compress(data) + self._fp.write(compressed) + self._pos += length + return length + + def flush(self, mode=FLUSH_BLOCK): + """Flush remaining data to the underlying stream. + + The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this + method will reduce compression ratio, use it only when necessary. + + If the program is interrupted afterwards, all data can be recovered. + To ensure saving to disk, also need to use os.fsync(fd). + + This method does nothing in reading mode. + """ + if self._mode == _MODE_READ: + return + self._check_not_closed() + if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}: + raise ValueError("Invalid mode argument, expected either " + "ZstdFile.FLUSH_FRAME or " + "ZstdFile.FLUSH_BLOCK") + if self._compressor.last_mode == mode: + return + # Flush zstd block/frame, and write. + data = self._compressor.flush(mode) + self._fp.write(data) + if hasattr(self._fp, "flush"): + self._fp.flush() + + def read(self, size=-1): + """Read up to size uncompressed bytes from the file. + + If size is negative or omitted, read until EOF is reached. + Returns b"" if the file is already at EOF. + """ + if size is None: + size = -1 + self._check_can_read() + return self._buffer.read(size) + + def read1(self, size=-1): + """Read up to size uncompressed bytes, while trying to avoid + making multiple reads from the underlying stream. Reads up to a + buffer's worth of data if size is negative. + + Returns b"" if the file is at EOF. + """ + self._check_can_read() + if size < 0: + # Note this should *not* be io.DEFAULT_BUFFER_SIZE. + # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing + # a full block is read. + size = _ZSTD_DStreamOutSize + return self._buffer.read1(size) + + def readinto(self, b): + """Read bytes into b. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto(b) + + def readinto1(self, b): + """Read bytes into b, while trying to avoid making multiple reads + from the underlying stream. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto1(b) + + def readline(self, size=-1): + """Read a line of uncompressed bytes from the file. + + The terminating newline (if present) is retained. If size is + non-negative, no more than size bytes will be read (in which + case the line may be incomplete). Returns b'' if already at EOF. + """ + self._check_can_read() + return self._buffer.readline(size) + + def seek(self, offset, whence=io.SEEK_SET): + """Change the file position. + + The new position is specified by offset, relative to the + position indicated by whence. Possible values for whence are: + + 0: start of stream (default): offset must not be negative + 1: current stream position + 2: end of stream; offset must not be positive + + Returns the new file position. + + Note that seeking is emulated, so depending on the arguments, + this operation may be extremely slow. + """ + self._check_can_read() + + # BufferedReader.seek() checks seekable + return self._buffer.seek(offset, whence) + + def peek(self, size=-1): + """Return buffered data without advancing the file position. + + Always returns at least one byte of data, unless at EOF. + The exact number of bytes returned is unspecified. + """ + # Relies on the undocumented fact that BufferedReader.peek() always + # returns at least one byte (except at EOF) + self._check_can_read() + return self._buffer.peek(size) + + def __next__(self): + if ret := self._buffer.readline(): + return ret + raise StopIteration + + def tell(self): + """Return the current file position.""" + self._check_not_closed() + if self._mode == _MODE_READ: + return self._buffer.tell() + elif self._mode == _MODE_WRITE: + return self._pos + + def fileno(self): + """Return the file descriptor for the underlying file.""" + self._check_not_closed() + return self._fp.fileno() + + @property + def name(self): + self._check_not_closed() + return self._fp.name + + @property + def mode(self): + return 'wb' if self._mode == _MODE_WRITE else 'rb' + + @property + def closed(self): + """True if this file is closed.""" + return self._mode == _MODE_CLOSED + + def seekable(self): + """Return whether the file supports seeking.""" + return self.readable() and self._buffer.seekable() + + def readable(self): + """Return whether the file was opened for reading.""" + self._check_not_closed() + return self._mode == _MODE_READ + + def writable(self): + """Return whether the file was opened for writing.""" + self._check_not_closed() + return self._mode == _MODE_WRITE + + +def open(file, /, mode="rb", *, level=None, options=None, zstd_dict=None, + encoding=None, errors=None, newline=None): + """Open a Zstandard compressed file in binary or text mode. + + file can be either a file name (given as a str, bytes, or PathLike object), + in which case the named file is opened, or it can be an existing file object + to read from or write to. + + The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a", + "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode. + + The level, options, and zstd_dict parameters specify the settings the same + as ZstdFile. + + When using read mode (decompression), the options parameter is a dict + representing advanced decompression options. The level parameter is not + supported in this case. When using write mode (compression), only one of + level, an int representing the compression level, or options, a dict + representing advanced compression options, may be passed. In both modes, + zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary. + + For binary mode, this function is equivalent to the ZstdFile constructor: + ZstdFile(filename, mode, ...). In this case, the encoding, errors and + newline parameters must not be provided. + + For text mode, an ZstdFile object is created, and wrapped in an + io.TextIOWrapper instance with the specified encoding, error handling + behavior, and line ending(s). + """ + + text_mode = "t" in mode + mode = mode.replace("t", "") + + if text_mode: + if "b" in mode: + raise ValueError(f"Invalid mode: {mode!r}") + else: + if encoding is not None: + raise ValueError("Argument 'encoding' not supported in binary mode") + if errors is not None: + raise ValueError("Argument 'errors' not supported in binary mode") + if newline is not None: + raise ValueError("Argument 'newline' not supported in binary mode") + + binary_file = ZstdFile(file, mode, level=level, options=options, + zstd_dict=zstd_dict) + + if text_mode: + return io.TextIOWrapper(binary_file, encoding, errors, newline) + else: + return binary_file diff --git a/Lib/ctypes/_layout.py b/Lib/ctypes/_layout.py index 0719e72cfed..2048ccb6a1c 100644 --- a/Lib/ctypes/_layout.py +++ b/Lib/ctypes/_layout.py @@ -5,6 +5,7 @@ may change at any time. """ import sys +import warnings from _ctypes import CField, buffer_info import ctypes @@ -66,9 +67,26 @@ def get_layout(cls, input_fields, is_struct, base): # For clarity, variables that count bits have `bit` in their names. + pack = getattr(cls, '_pack_', None) + layout = getattr(cls, '_layout_', None) if layout is None: - if sys.platform == 'win32' or getattr(cls, '_pack_', None): + if sys.platform == 'win32': + gcc_layout = False + elif pack: + if is_struct: + base_type_name = 'Structure' + else: + base_type_name = 'Union' + warnings._deprecated( + '_pack_ without _layout_', + f"Due to '_pack_', the '{cls.__name__}' {base_type_name} will " + + "use memory layout compatible with MSVC (Windows). " + + "If this is intended, set _layout_ to 'ms'. " + + "The implicit default is deprecated and slated to become " + + "an error in Python {remove}.", + remove=(3, 19), + ) gcc_layout = False else: gcc_layout = True @@ -95,7 +113,6 @@ def get_layout(cls, input_fields, is_struct, base): else: big_endian = sys.byteorder == 'big' - pack = getattr(cls, '_pack_', None) if pack is not None: try: pack = int(pack) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 0f7dc9ae6b8..86d29df0639 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -244,6 +244,10 @@ _ATOMIC_TYPES = frozenset({ property, }) +# Any marker is used in `make_dataclass` to mark unannotated fields as `Any` +# without importing `typing` module. +_ANY_MARKER = object() + class InitVar: __slots__ = ('type', ) @@ -1591,7 +1595,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, for item in fields: if isinstance(item, str): name = item - tp = 'typing.Any' + tp = _ANY_MARKER elif len(item) == 2: name, tp, = item elif len(item) == 3: @@ -1610,15 +1614,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, seen.add(name) annotations[name] = tp + # We initially block the VALUE format, because inside dataclass() we'll + # call get_annotations(), which will try the VALUE format first. If we don't + # block, that means we'd always end up eagerly importing typing here, which + # is what we're trying to avoid. + value_blocked = True + + def annotate_method(format): + def get_any(): + match format: + case annotationlib.Format.STRING: + return 'typing.Any' + case annotationlib.Format.FORWARDREF: + typing = sys.modules.get("typing") + if typing is None: + return annotationlib.ForwardRef("Any", module="typing") + else: + return typing.Any + case annotationlib.Format.VALUE: + if value_blocked: + raise NotImplementedError + from typing import Any + return Any + case _: + raise NotImplementedError + annos = { + ann: get_any() if t is _ANY_MARKER else t + for ann, t in annotations.items() + } + if format == annotationlib.Format.STRING: + return annotationlib.annotations_to_string(annos) + else: + return annos + # Update 'ns' with the user-supplied namespace plus our calculated values. def exec_body_callback(ns): ns.update(namespace) ns.update(defaults) - ns['__annotations__'] = annotations # We use `types.new_class()` instead of simply `type()` to allow dynamic creation # of generic dataclasses. cls = types.new_class(cls_name, bases, {}, exec_body_callback) + # For now, set annotations including the _ANY_MARKER. + cls.__annotate__ = annotate_method # For pickling to work, the __module__ variable needs to be set to the frame # where the dataclass is created. @@ -1634,10 +1672,13 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, cls.__module__ = module # Apply the normal provided decorator. - return decorator(cls, init=init, repr=repr, eq=eq, order=order, - unsafe_hash=unsafe_hash, frozen=frozen, - match_args=match_args, kw_only=kw_only, slots=slots, - weakref_slot=weakref_slot) + cls = decorator(cls, init=init, repr=repr, eq=eq, order=order, + unsafe_hash=unsafe_hash, frozen=frozen, + match_args=match_args, kw_only=kw_only, slots=slots, + weakref_slot=weakref_slot) + # Now that the class is ready, allow the VALUE format. + value_blocked = False + return cls def replace(obj, /, **changes): diff --git a/Lib/dis.py b/Lib/dis.py index cb6d077a391..d6d2c1386dd 100644 --- a/Lib/dis.py +++ b/Lib/dis.py @@ -1131,7 +1131,7 @@ class Bytecode: def main(args=None): import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('-C', '--show-caches', action='store_true', help='show inline caches') parser.add_argument('-O', '--show-offsets', action='store_true', diff --git a/Lib/doctest.py b/Lib/doctest.py index e02e73ed722..2acb6cb79f3 100644 --- a/Lib/doctest.py +++ b/Lib/doctest.py @@ -2870,7 +2870,7 @@ __test__ = {"_TestClass": _TestClass, def _test(): import argparse - parser = argparse.ArgumentParser(description="doctest runner") + parser = argparse.ArgumentParser(description="doctest runner", color=True) parser.add_argument('-v', '--verbose', action='store_true', default=False, help='print very verbose output for all tests') parser.add_argument('-o', '--option', action='append', diff --git a/Lib/ensurepip/__init__.py b/Lib/ensurepip/__init__.py index 6fc9f39b24c..aa641e94a8b 100644 --- a/Lib/ensurepip/__init__.py +++ b/Lib/ensurepip/__init__.py @@ -205,7 +205,7 @@ def _uninstall_helper(*, verbosity=0): def _main(argv=None): import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument( "--version", action="version", diff --git a/Lib/getpass.py b/Lib/getpass.py index bd0097ced94..f571425e541 100644 --- a/Lib/getpass.py +++ b/Lib/getpass.py @@ -1,6 +1,7 @@ """Utilities to get a password and/or the current user name. -getpass(prompt[, stream]) - Prompt for a password, with echo turned off. +getpass(prompt[, stream[, echo_char]]) - Prompt for a password, with echo +turned off and optional keyboard feedback. getuser() - Get the user name from the environment or password database. GetPassWarning - This UserWarning is issued when getpass() cannot prevent @@ -25,13 +26,15 @@ __all__ = ["getpass","getuser","GetPassWarning"] class GetPassWarning(UserWarning): pass -def unix_getpass(prompt='Password: ', stream=None): +def unix_getpass(prompt='Password: ', stream=None, *, echo_char=None): """Prompt for a password, with echo turned off. Args: prompt: Written on stream to ask for the input. Default: 'Password: ' stream: A writable file object to display the prompt. Defaults to the tty. If no tty is available defaults to sys.stderr. + echo_char: A string used to mask input (e.g., '*'). If None, input is + hidden. Returns: The seKr3t input. Raises: @@ -40,6 +43,8 @@ def unix_getpass(prompt='Password: ', stream=None): Always restores terminal settings before returning. """ + _check_echo_char(echo_char) + passwd = None with contextlib.ExitStack() as stack: try: @@ -68,12 +73,16 @@ def unix_getpass(prompt='Password: ', stream=None): old = termios.tcgetattr(fd) # a copy to save new = old[:] new[3] &= ~termios.ECHO # 3 == 'lflags' + if echo_char: + new[3] &= ~termios.ICANON tcsetattr_flags = termios.TCSAFLUSH if hasattr(termios, 'TCSASOFT'): tcsetattr_flags |= termios.TCSASOFT try: termios.tcsetattr(fd, tcsetattr_flags, new) - passwd = _raw_input(prompt, stream, input=input) + passwd = _raw_input(prompt, stream, input=input, + echo_char=echo_char) + finally: termios.tcsetattr(fd, tcsetattr_flags, old) stream.flush() # issue7208 @@ -93,10 +102,11 @@ def unix_getpass(prompt='Password: ', stream=None): return passwd -def win_getpass(prompt='Password: ', stream=None): +def win_getpass(prompt='Password: ', stream=None, *, echo_char=None): """Prompt for password with echo off, using Windows getwch().""" if sys.stdin is not sys.__stdin__: return fallback_getpass(prompt, stream) + _check_echo_char(echo_char) for c in prompt: msvcrt.putwch(c) @@ -108,9 +118,15 @@ def win_getpass(prompt='Password: ', stream=None): if c == '\003': raise KeyboardInterrupt if c == '\b': + if echo_char and pw: + msvcrt.putch('\b') + msvcrt.putch(' ') + msvcrt.putch('\b') pw = pw[:-1] else: pw = pw + c + if echo_char: + msvcrt.putwch(echo_char) msvcrt.putwch('\r') msvcrt.putwch('\n') return pw @@ -126,7 +142,14 @@ def fallback_getpass(prompt='Password: ', stream=None): return _raw_input(prompt, stream) -def _raw_input(prompt="", stream=None, input=None): +def _check_echo_char(echo_char): + # ASCII excluding control characters + if echo_char and not (echo_char.isprintable() and echo_char.isascii()): + raise ValueError("'echo_char' must be a printable ASCII string, " + f"got: {echo_char!r}") + + +def _raw_input(prompt="", stream=None, input=None, echo_char=None): # This doesn't save the string in the GNU readline history. if not stream: stream = sys.stderr @@ -143,6 +166,8 @@ def _raw_input(prompt="", stream=None, input=None): stream.write(prompt) stream.flush() # NOTE: The Python C API calls flockfile() (and unlock) during readline. + if echo_char: + return _readline_with_echo_char(stream, input, echo_char) line = input.readline() if not line: raise EOFError @@ -151,6 +176,35 @@ def _raw_input(prompt="", stream=None, input=None): return line +def _readline_with_echo_char(stream, input, echo_char): + passwd = "" + eof_pressed = False + while True: + char = input.read(1) + if char == '\n' or char == '\r': + break + elif char == '\x03': + raise KeyboardInterrupt + elif char == '\x7f' or char == '\b': + if passwd: + stream.write("\b \b") + stream.flush() + passwd = passwd[:-1] + elif char == '\x04': + if eof_pressed: + break + else: + eof_pressed = True + elif char == '\x00': + continue + else: + passwd += char + stream.write(echo_char) + stream.flush() + eof_pressed = False + return passwd + + def getuser(): """Get the username from the environment or password database. diff --git a/Lib/gzip.py b/Lib/gzip.py index b7375b25473..c00f51858de 100644 --- a/Lib/gzip.py +++ b/Lib/gzip.py @@ -667,7 +667,9 @@ def main(): from argparse import ArgumentParser parser = ArgumentParser(description= "A simple command line interface for the gzip module: act like gzip, " - "but do not delete the input file.") + "but do not delete the input file.", + color=True, + ) group = parser.add_mutually_exclusive_group() group.add_argument('--fast', action='store_true', help='compress faster') group.add_argument('--best', action='store_true', help='compress better') diff --git a/Lib/heapq.py b/Lib/heapq.py index 9649da251f2..6ceb211f1ca 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -178,7 +178,7 @@ def heapify(x): for i in reversed(range(n//2)): _siftup(x, i) -def _heappop_max(heap): +def heappop_max(heap): """Maxheap version of a heappop.""" lastelt = heap.pop() # raises appropriate IndexError if heap is empty if heap: @@ -188,19 +188,32 @@ def _heappop_max(heap): return returnitem return lastelt -def _heapreplace_max(heap, item): +def heapreplace_max(heap, item): """Maxheap version of a heappop followed by a heappush.""" returnitem = heap[0] # raises appropriate IndexError if heap is empty heap[0] = item _siftup_max(heap, 0) return returnitem -def _heapify_max(x): +def heappush_max(heap, item): + """Maxheap version of a heappush.""" + heap.append(item) + _siftdown_max(heap, 0, len(heap)-1) + +def heappushpop_max(heap, item): + """Maxheap fast version of a heappush followed by a heappop.""" + if heap and item < heap[0]: + item, heap[0] = heap[0], item + _siftup_max(heap, 0) + return item + +def heapify_max(x): """Transform list into a maxheap, in-place, in O(len(x)) time.""" n = len(x) for i in reversed(range(n//2)): _siftup_max(x, i) + # 'heap' is a heap at all indices >= startpos, except possibly for pos. pos # is the index of a leaf with a possibly out-of-order value. Restore the # heap invariant. @@ -335,9 +348,9 @@ def merge(*iterables, key=None, reverse=False): h_append = h.append if reverse: - _heapify = _heapify_max - _heappop = _heappop_max - _heapreplace = _heapreplace_max + _heapify = heapify_max + _heappop = heappop_max + _heapreplace = heapreplace_max direction = -1 else: _heapify = heapify @@ -490,10 +503,10 @@ def nsmallest(n, iterable, key=None): result = [(elem, i) for i, elem in zip(range(n), it)] if not result: return result - _heapify_max(result) + heapify_max(result) top = result[0][0] order = n - _heapreplace = _heapreplace_max + _heapreplace = heapreplace_max for elem in it: if elem < top: _heapreplace(result, (elem, order)) @@ -507,10 +520,10 @@ def nsmallest(n, iterable, key=None): result = [(key(elem), i, elem) for i, elem in zip(range(n), it)] if not result: return result - _heapify_max(result) + heapify_max(result) top = result[0][0] order = n - _heapreplace = _heapreplace_max + _heapreplace = heapreplace_max for elem in it: k = key(elem) if k < top: @@ -583,19 +596,13 @@ try: from _heapq import * except ImportError: pass -try: - from _heapq import _heapreplace_max -except ImportError: - pass -try: - from _heapq import _heapify_max -except ImportError: - pass -try: - from _heapq import _heappop_max -except ImportError: - pass +# For backwards compatibility +_heappop_max = heappop_max +_heapreplace_max = heapreplace_max +_heappush_max = heappush_max +_heappushpop_max = heappushpop_max +_heapify_max = heapify_max if __name__ == "__main__": diff --git a/Lib/http/server.py b/Lib/http/server.py index a2aad4c9be3..64f766f9bc2 100644 --- a/Lib/http/server.py +++ b/Lib/http/server.py @@ -1340,7 +1340,7 @@ if __name__ == '__main__': import argparse import contextlib - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('--cgi', action='store_true', help='run as CGI server') parser.add_argument('-b', '--bind', metavar='ADDRESS', diff --git a/Lib/inspect.py b/Lib/inspect.py index 9592559ba6d..52c9bb05b31 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -3343,7 +3343,7 @@ def _main(): import argparse import importlib - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument( 'object', help="The object to be analysed. " diff --git a/Lib/json/tool.py b/Lib/json/tool.py index 585583da860..1967817add8 100644 --- a/Lib/json/tool.py +++ b/Lib/json/tool.py @@ -7,7 +7,7 @@ import argparse import json import re import sys -from _colorize import ANSIColors, can_colorize +from _colorize import get_theme, can_colorize # The string we are colorizing is valid JSON, @@ -17,34 +17,34 @@ from _colorize import ANSIColors, can_colorize _color_pattern = re.compile(r''' (?P<key>"(\\.|[^"\\])*")(?=:) | (?P<string>"(\\.|[^"\\])*") | + (?P<number>NaN|-?Infinity|[0-9\-+.Ee]+) | (?P<boolean>true|false) | (?P<null>null) ''', re.VERBOSE) - -_colors = { - 'key': ANSIColors.INTENSE_BLUE, - 'string': ANSIColors.BOLD_GREEN, - 'boolean': ANSIColors.BOLD_CYAN, - 'null': ANSIColors.BOLD_CYAN, +_group_to_theme_color = { + "key": "definition", + "string": "string", + "number": "number", + "boolean": "keyword", + "null": "keyword", } -def _replace_match_callback(match): - for key, color in _colors.items(): - if m := match.group(key): - return f"{color}{m}{ANSIColors.RESET}" - return match.group() - +def _colorize_json(json_str, theme): + def _replace_match_callback(match): + for group, color in _group_to_theme_color.items(): + if m := match.group(group): + return f"{theme[color]}{m}{theme.reset}" + return match.group() -def _colorize_json(json_str): return re.sub(_color_pattern, _replace_match_callback, json_str) def main(): description = ('A simple command line interface for json module ' 'to validate and pretty-print JSON objects.') - parser = argparse.ArgumentParser(description=description) + parser = argparse.ArgumentParser(description=description, color=True) parser.add_argument('infile', nargs='?', help='a JSON file to be validated or pretty-printed', default='-') @@ -100,13 +100,16 @@ def main(): else: outfile = open(options.outfile, 'w', encoding='utf-8') with outfile: - for obj in objs: - if can_colorize(file=outfile): + if can_colorize(file=outfile): + t = get_theme(tty_file=outfile).syntax + for obj in objs: json_str = json.dumps(obj, **dump_args) - outfile.write(_colorize_json(json_str)) - else: + outfile.write(_colorize_json(json_str, t)) + outfile.write('\n') + else: + for obj in objs: json.dump(obj, outfile, **dump_args) - outfile.write('\n') + outfile.write('\n') except ValueError as e: raise SystemExit(e) diff --git a/Lib/mimetypes.py b/Lib/mimetypes.py index b5a1b8da263..33e86d51a0f 100644 --- a/Lib/mimetypes.py +++ b/Lib/mimetypes.py @@ -698,7 +698,9 @@ _default_mime_types() def _parse_args(args): from argparse import ArgumentParser - parser = ArgumentParser(description='map filename extensions to MIME types') + parser = ArgumentParser( + description='map filename extensions to MIME types', color=True + ) parser.add_argument( '-e', '--extension', action='store_true', diff --git a/Lib/pdb.py b/Lib/pdb.py index 2aa60c75396..f89d104fcdd 100644 --- a/Lib/pdb.py +++ b/Lib/pdb.py @@ -77,6 +77,7 @@ import glob import json import token import types +import atexit import codeop import pprint import signal @@ -92,10 +93,12 @@ import tokenize import itertools import traceback import linecache +import selectors +import threading import _colorize +import _pyrepl.utils -from contextlib import closing -from contextlib import contextmanager +from contextlib import ExitStack, closing, contextmanager from rlcompleter import Completer from types import CodeType from warnings import deprecated @@ -339,7 +342,7 @@ class Pdb(bdb.Bdb, cmd.Cmd): _last_pdb_instance = None def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None, - nosigint=False, readrc=True, mode=None, backend=None): + nosigint=False, readrc=True, mode=None, backend=None, colorize=False): bdb.Bdb.__init__(self, skip=skip, backend=backend if backend else get_default_backend()) cmd.Cmd.__init__(self, completekey, stdin, stdout) sys.audit("pdb.Pdb") @@ -352,6 +355,7 @@ class Pdb(bdb.Bdb, cmd.Cmd): self._wait_for_mainpyfile = False self.tb_lineno = {} self.mode = mode + self.colorize = colorize and _colorize.can_colorize(file=stdout or sys.stdout) # Try to load readline if it exists try: import readline @@ -743,12 +747,34 @@ class Pdb(bdb.Bdb, cmd.Cmd): self.message(repr(obj)) @contextmanager - def _enable_multiline_completion(self): + def _enable_multiline_input(self): + try: + import readline + except ImportError: + yield + return + + def input_auto_indent(): + last_index = readline.get_current_history_length() + last_line = readline.get_history_item(last_index) + if last_line: + if last_line.isspace(): + # If the last line is empty, we don't need to indent + return + + last_line = last_line.rstrip('\r\n') + indent = len(last_line) - len(last_line.lstrip()) + if last_line.endswith(":"): + indent += 4 + readline.insert_text(' ' * indent) + completenames = self.completenames try: self.completenames = self.complete_multiline_names + readline.set_startup_hook(input_auto_indent) yield finally: + readline.set_startup_hook() self.completenames = completenames return @@ -857,7 +883,7 @@ class Pdb(bdb.Bdb, cmd.Cmd): try: if (code := codeop.compile_command(line + '\n', '<stdin>', 'single')) is None: # Multi-line mode - with self._enable_multiline_completion(): + with self._enable_multiline_input(): buffer = line continue_prompt = "... " while (code := codeop.compile_command(buffer, '<stdin>', 'single')) is None: @@ -879,7 +905,11 @@ class Pdb(bdb.Bdb, cmd.Cmd): return None, None, False else: line = line.rstrip('\r\n') - buffer += '\n' + line + if line.isspace(): + # empty line, just continue + buffer += '\n' + else: + buffer += '\n' + line self.lastcmd = buffer except SyntaxError as e: # Maybe it's an await expression/statement @@ -1036,6 +1066,13 @@ class Pdb(bdb.Bdb, cmd.Cmd): return True return False + def _colorize_code(self, code): + if self.colorize: + colors = list(_pyrepl.utils.gen_colors(code)) + chars, _ = _pyrepl.utils.disp_str(code, colors=colors, force_color=True) + code = "".join(chars) + return code + # interface abstraction functions def message(self, msg, end='\n'): @@ -2166,6 +2203,8 @@ class Pdb(bdb.Bdb, cmd.Cmd): s += '->' elif lineno == exc_lineno: s += '>>' + if self.colorize: + line = self._colorize_code(line) self.message(s + '\t' + line.rstrip()) def do_whatis(self, arg): @@ -2365,8 +2404,14 @@ class Pdb(bdb.Bdb, cmd.Cmd): prefix = '> ' else: prefix = ' ' - self.message(prefix + - self.format_stack_entry(frame_lineno, prompt_prefix)) + stack_entry = self.format_stack_entry(frame_lineno, prompt_prefix) + if self.colorize: + lines = stack_entry.split(prompt_prefix, 1) + if len(lines) > 1: + # We have some code to display + lines[1] = self._colorize_code(lines[1]) + stack_entry = prompt_prefix.join(lines) + self.message(prefix + stack_entry) # Provide help @@ -2604,7 +2649,7 @@ def set_trace(*, header=None, commands=None): if Pdb._last_pdb_instance is not None: pdb = Pdb._last_pdb_instance else: - pdb = Pdb(mode='inline', backend='monitoring') + pdb = Pdb(mode='inline', backend='monitoring', colorize=True) if header is not None: pdb.message(header) pdb.set_trace(sys._getframe().f_back, commands=commands) @@ -2619,7 +2664,7 @@ async def set_trace_async(*, header=None, commands=None): if Pdb._last_pdb_instance is not None: pdb = Pdb._last_pdb_instance else: - pdb = Pdb(mode='inline', backend='monitoring') + pdb = Pdb(mode='inline', backend='monitoring', colorize=True) if header is not None: pdb.message(header) await pdb.set_trace_async(sys._getframe().f_back, commands=commands) @@ -2627,13 +2672,26 @@ async def set_trace_async(*, header=None, commands=None): # Remote PDB class _PdbServer(Pdb): - def __init__(self, sockfile, owns_sockfile=True, **kwargs): + def __init__( + self, + sockfile, + signal_server=None, + owns_sockfile=True, + colorize=False, + **kwargs, + ): self._owns_sockfile = owns_sockfile self._interact_state = None self._sockfile = sockfile self._command_name_cache = [] self._write_failed = False - super().__init__(**kwargs) + if signal_server: + # Only started by the top level _PdbServer, not recursive ones. + self._start_signal_listener(signal_server) + # Override the `colorize` attribute set by the parent constructor, + # because it checks the server's stdout, rather than the client's. + super().__init__(colorize=False, **kwargs) + self.colorize = colorize @staticmethod def protocol_version(): @@ -2688,15 +2746,49 @@ class _PdbServer(Pdb): f"PDB message doesn't follow the schema! {msg}" ) + @classmethod + def _start_signal_listener(cls, address): + def listener(sock): + with closing(sock): + # Check if the interpreter is finalizing every quarter of a second. + # Clean up and exit if so. + sock.settimeout(0.25) + sock.shutdown(socket.SHUT_WR) + while not shut_down.is_set(): + try: + data = sock.recv(1024) + except socket.timeout: + continue + if data == b"": + return # EOF + signal.raise_signal(signal.SIGINT) + + def stop_thread(): + shut_down.set() + thread.join() + + # Use a daemon thread so that we don't detach until after all non-daemon + # threads are done. Use an atexit handler to stop gracefully at that point, + # so that our thread is stopped before the interpreter is torn down. + shut_down = threading.Event() + thread = threading.Thread( + target=listener, + args=[socket.create_connection(address, timeout=5)], + daemon=True, + ) + atexit.register(stop_thread) + thread.start() + def _send(self, **kwargs): self._ensure_valid_message(kwargs) json_payload = json.dumps(kwargs) try: self._sockfile.write(json_payload.encode() + b"\n") self._sockfile.flush() - except OSError: - # This means that the client has abruptly disconnected, but we'll - # handle that the next time we try to read from the client instead + except (OSError, ValueError): + # We get an OSError if the network connection has dropped, and a + # ValueError if detach() if the sockfile has been closed. We'll + # handle this the next time we try to read from the client instead # of trying to handle it from everywhere _send() may be called. # Track this with a flag rather than assuming readline() will ever # return an empty string because the socket may be half-closed. @@ -2887,7 +2979,11 @@ class _PdbServer(Pdb): @typing.override def _create_recursive_debugger(self): - return _PdbServer(self._sockfile, owns_sockfile=False) + return _PdbServer( + self._sockfile, + owns_sockfile=False, + colorize=self.colorize, + ) @typing.override def _prompt_for_confirmation(self, prompt, default): @@ -2924,10 +3020,15 @@ class _PdbServer(Pdb): class _PdbClient: - def __init__(self, pid, sockfile, interrupt_script): + def __init__(self, pid, server_socket, interrupt_sock): self.pid = pid - self.sockfile = sockfile - self.interrupt_script = interrupt_script + self.read_buf = b"" + self.signal_read = None + self.signal_write = None + self.sigint_received = False + self.raise_on_sigint = False + self.server_socket = server_socket + self.interrupt_sock = interrupt_sock self.pdb_instance = Pdb() self.pdb_commands = set() self.completion_matches = [] @@ -2969,8 +3070,7 @@ class _PdbClient: self._ensure_valid_message(kwargs) json_payload = json.dumps(kwargs) try: - self.sockfile.write(json_payload.encode() + b"\n") - self.sockfile.flush() + self.server_socket.sendall(json_payload.encode() + b"\n") except OSError: # This means that the client has abruptly disconnected, but we'll # handle that the next time we try to read from the client instead @@ -2979,10 +3079,44 @@ class _PdbClient: # return an empty string because the socket may be half-closed. self.write_failed = True - def read_command(self, prompt): - self.multiline_block = False - reply = input(prompt) + def _readline(self): + if self.sigint_received: + # There's a pending unhandled SIGINT. Handle it now. + self.sigint_received = False + raise KeyboardInterrupt + + # Wait for either a SIGINT or a line or EOF from the PDB server. + selector = selectors.DefaultSelector() + selector.register(self.signal_read, selectors.EVENT_READ) + selector.register(self.server_socket, selectors.EVENT_READ) + + while b"\n" not in self.read_buf: + for key, _ in selector.select(): + if key.fileobj == self.signal_read: + self.signal_read.recv(1024) + if self.sigint_received: + # If not, we're reading wakeup events for sigints that + # we've previously handled, and can ignore them. + self.sigint_received = False + raise KeyboardInterrupt + elif key.fileobj == self.server_socket: + data = self.server_socket.recv(16 * 1024) + self.read_buf += data + if not data and b"\n" not in self.read_buf: + # EOF without a full final line. Drop the partial line. + self.read_buf = b"" + return b"" + + ret, sep, self.read_buf = self.read_buf.partition(b"\n") + return ret + sep + + def read_input(self, prompt, multiline_block): + self.multiline_block = multiline_block + with self._sigint_raises_keyboard_interrupt(): + return input(prompt) + def read_command(self, prompt): + reply = self.read_input(prompt, multiline_block=False) if self.state == "dumb": # No logic applied whatsoever, just pass the raw reply back. return reply @@ -3005,10 +3139,9 @@ class _PdbClient: return prefix + reply # Otherwise, valid first line of a multi-line statement - self.multiline_block = True - continue_prompt = "...".ljust(len(prompt)) + more_prompt = "...".ljust(len(prompt)) while codeop.compile_command(reply, "<stdin>", "single") is None: - reply += "\n" + input(continue_prompt) + reply += "\n" + self.read_input(more_prompt, multiline_block=True) return prefix + reply @@ -3033,11 +3166,70 @@ class _PdbClient: finally: readline.set_completer(old_completer) + @contextmanager + def _sigint_handler(self): + # Signal handling strategy: + # - When we call input() we want a SIGINT to raise KeyboardInterrupt + # - Otherwise we want to write to the wakeup FD and set a flag. + # We'll break out of select() when the wakeup FD is written to, + # and we'll check the flag whenever we're about to accept input. + def handler(signum, frame): + self.sigint_received = True + if self.raise_on_sigint: + # One-shot; don't raise again until the flag is set again. + self.raise_on_sigint = False + self.sigint_received = False + raise KeyboardInterrupt + + sentinel = object() + old_handler = sentinel + old_wakeup_fd = sentinel + + self.signal_read, self.signal_write = socket.socketpair() + with (closing(self.signal_read), closing(self.signal_write)): + self.signal_read.setblocking(False) + self.signal_write.setblocking(False) + + try: + old_handler = signal.signal(signal.SIGINT, handler) + + try: + old_wakeup_fd = signal.set_wakeup_fd( + self.signal_write.fileno(), + warn_on_full_buffer=False, + ) + yield + finally: + # Restore the old wakeup fd if we installed a new one + if old_wakeup_fd is not sentinel: + signal.set_wakeup_fd(old_wakeup_fd) + finally: + self.signal_read = self.signal_write = None + if old_handler is not sentinel: + # Restore the old handler if we installed a new one + signal.signal(signal.SIGINT, old_handler) + + @contextmanager + def _sigint_raises_keyboard_interrupt(self): + if self.sigint_received: + # There's a pending unhandled SIGINT. Handle it now. + self.sigint_received = False + raise KeyboardInterrupt + + try: + self.raise_on_sigint = True + yield + finally: + self.raise_on_sigint = False + def cmdloop(self): - with self.readline_completion(self.complete): + with ( + self._sigint_handler(), + self.readline_completion(self.complete), + ): while not self.write_failed: try: - if not (payload_bytes := self.sockfile.readline()): + if not (payload_bytes := self._readline()): break except KeyboardInterrupt: self.send_interrupt() @@ -3055,11 +3247,17 @@ class _PdbClient: self.process_payload(payload) def send_interrupt(self): - print( - "\n*** Program will stop at the next bytecode instruction." - " (Use 'cont' to resume)." - ) - sys.remote_exec(self.pid, self.interrupt_script) + if self.interrupt_sock is not None: + # Write to a socket that the PDB server listens on. This triggers + # the remote to raise a SIGINT for itself. We do this because + # Windows doesn't allow triggering SIGINT remotely. + # See https://stackoverflow.com/a/35792192 for many more details. + self.interrupt_sock.sendall(signal.SIGINT.to_bytes()) + else: + # On Unix we can just send a SIGINT to the remote process. + # This is preferable to using the signal thread approach that we + # use on Windows because it can interrupt IO in the main thread. + os.kill(self.pid, signal.SIGINT) def process_payload(self, payload): match payload: @@ -3129,7 +3327,7 @@ class _PdbClient: if self.write_failed: return None - payload = self.sockfile.readline() + payload = self._readline() if not payload: return None @@ -3146,11 +3344,31 @@ class _PdbClient: return None -def _connect(host, port, frame, commands, version): +def _connect( + *, + host, + port, + frame, + commands, + version, + signal_raising_thread, + colorize, +): with closing(socket.create_connection((host, port))) as conn: sockfile = conn.makefile("rwb") - remote_pdb = _PdbServer(sockfile) + # The client requests this thread on Windows but not on Unix. + # Most tests don't request this thread, to keep them simpler. + if signal_raising_thread: + signal_server = (host, port) + else: + signal_server = None + + remote_pdb = _PdbServer( + sockfile, + signal_server=signal_server, + colorize=colorize, + ) weakref.finalize(remote_pdb, sockfile.close) if Pdb._last_pdb_instance is not None: @@ -3171,43 +3389,50 @@ def _connect(host, port, frame, commands, version): def attach(pid, commands=()): """Attach to a running process with the given PID.""" - with closing(socket.create_server(("localhost", 0))) as server: + with ExitStack() as stack: + server = stack.enter_context( + closing(socket.create_server(("localhost", 0))) + ) port = server.getsockname()[1] - with tempfile.NamedTemporaryFile("w", delete_on_close=False) as connect_script: - connect_script.write( - textwrap.dedent( - f""" - import pdb, sys - pdb._connect( - host="localhost", - port={port}, - frame=sys._getframe(1), - commands={json.dumps("\n".join(commands))}, - version={_PdbServer.protocol_version()}, - ) - """ + connect_script = stack.enter_context( + tempfile.NamedTemporaryFile("w", delete_on_close=False) + ) + + use_signal_thread = sys.platform == "win32" + colorize = _colorize.can_colorize() + + connect_script.write( + textwrap.dedent( + f""" + import pdb, sys + pdb._connect( + host="localhost", + port={port}, + frame=sys._getframe(1), + commands={json.dumps("\n".join(commands))}, + version={_PdbServer.protocol_version()}, + signal_raising_thread={use_signal_thread!r}, + colorize={colorize!r}, ) + """ ) - connect_script.close() - sys.remote_exec(pid, connect_script.name) - - # TODO Add a timeout? Or don't bother since the user can ^C? - client_sock, _ = server.accept() + ) + connect_script.close() + sys.remote_exec(pid, connect_script.name) - with closing(client_sock): - sockfile = client_sock.makefile("rwb") + # TODO Add a timeout? Or don't bother since the user can ^C? + client_sock, _ = server.accept() + stack.enter_context(closing(client_sock)) - with closing(sockfile): - with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script: - interrupt_script.write( - 'import pdb, sys\n' - 'if inst := pdb.Pdb._last_pdb_instance:\n' - ' inst.set_trace(sys._getframe(1))\n' - ) - interrupt_script.close() + if use_signal_thread: + interrupt_sock, _ = server.accept() + stack.enter_context(closing(interrupt_sock)) + interrupt_sock.setblocking(False) + else: + interrupt_sock = None - _PdbClient(pid, sockfile, interrupt_script.name).cmdloop() + _PdbClient(pid, client_sock, interrupt_sock).cmdloop() # Post-Mortem interface @@ -3279,10 +3504,13 @@ To let the script run up to a given line X in the debugged file, use def main(): import argparse - parser = argparse.ArgumentParser(usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]", - description=_usage, - formatter_class=argparse.RawDescriptionHelpFormatter, - allow_abbrev=False) + parser = argparse.ArgumentParser( + usage="%(prog)s [-h] [-c command] (-m module | -p pid | pyfile) [args ...]", + description=_usage, + formatter_class=argparse.RawDescriptionHelpFormatter, + allow_abbrev=False, + color=True, + ) # We need to maunally get the script from args, because the first positional # arguments could be either the script we need to debug, or the argument @@ -3345,7 +3573,7 @@ def main(): # modified by the script being debugged. It's a bad idea when it was # changed by the user from the command line. There is a "restart" command # which allows explicit specification of command line arguments. - pdb = Pdb(mode='cli', backend='monitoring') + pdb = Pdb(mode='cli', backend='monitoring', colorize=True) pdb.rcLines.extend(opts.commands) while True: try: diff --git a/Lib/pickle.py b/Lib/pickle.py index 4fa3632d1a7..beaefae0479 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -1911,7 +1911,9 @@ def _main(args=None): import argparse import pprint parser = argparse.ArgumentParser( - description='display contents of the pickle files') + description='display contents of the pickle files', + color=True, + ) parser.add_argument( 'pickle_file', nargs='+', help='the pickle file') diff --git a/Lib/pickletools.py b/Lib/pickletools.py index 53f25ea4e46..bcddfb722bd 100644 --- a/Lib/pickletools.py +++ b/Lib/pickletools.py @@ -2842,7 +2842,9 @@ __test__ = {'disassembler_test': _dis_test, if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( - description='disassemble one or more pickle files') + description='disassemble one or more pickle files', + color=True, + ) parser.add_argument( 'pickle_file', nargs='+', help='the pickle file') diff --git a/Lib/platform.py b/Lib/platform.py index 507552f360b..55e211212d4 100644 --- a/Lib/platform.py +++ b/Lib/platform.py @@ -1467,7 +1467,7 @@ def invalidate_caches(): def _parse_args(args: list[str] | None): import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument("args", nargs="*", choices=["nonaliased", "terse"]) parser.add_argument( "--terse", diff --git a/Lib/py_compile.py b/Lib/py_compile.py index 388614e51b1..43d8ec90ffb 100644 --- a/Lib/py_compile.py +++ b/Lib/py_compile.py @@ -177,7 +177,7 @@ def main(): import argparse description = 'A simple command-line interface for py_compile module.' - parser = argparse.ArgumentParser(description=description) + parser = argparse.ArgumentParser(description=description, color=True) parser.add_argument( '-q', '--quiet', action='store_true', diff --git a/Lib/random.py b/Lib/random.py index 5e5d0c4c694..86d562f0b8a 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -1011,7 +1011,7 @@ if hasattr(_os, "fork"): def _parse_args(arg_list: list[str] | None): import argparse parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, color=True) group = parser.add_mutually_exclusive_group() group.add_argument( "-c", "--choice", nargs="+", diff --git a/Lib/reprlib.py b/Lib/reprlib.py index 19dbe3a07eb..441d1be4bde 100644 --- a/Lib/reprlib.py +++ b/Lib/reprlib.py @@ -28,7 +28,7 @@ def recursive_repr(fillvalue='...'): wrapper.__doc__ = getattr(user_function, '__doc__') wrapper.__name__ = getattr(user_function, '__name__') wrapper.__qualname__ = getattr(user_function, '__qualname__') - wrapper.__annotations__ = getattr(user_function, '__annotations__', {}) + wrapper.__annotate__ = getattr(user_function, '__annotate__', None) wrapper.__type_params__ = getattr(user_function, '__type_params__', ()) wrapper.__wrapped__ = user_function return wrapper diff --git a/Lib/shutil.py b/Lib/shutil.py index 510ae8c6f22..ca0a2ea2f7f 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -32,6 +32,13 @@ try: except ImportError: _LZMA_SUPPORTED = False +try: + from compression import zstd + del zstd + _ZSTD_SUPPORTED = True +except ImportError: + _ZSTD_SUPPORTED = False + _WINDOWS = os.name == 'nt' posix = nt = None if os.name == 'posix': @@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0, tar_compression = 'bz2' elif _LZMA_SUPPORTED and compress == 'xz': tar_compression = 'xz' + elif _ZSTD_SUPPORTED and compress == 'zst': + tar_compression = 'zst' else: raise ValueError("bad value for 'compress', or compression format not " "supported : {0}".format(compress)) @@ -1134,6 +1143,10 @@ if _LZMA_SUPPORTED: _ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')], + "zstd'ed tar-file") + def get_archive_formats(): """Returns a list of supported formats for archiving and unarchiving. @@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0, 'base_name' is the name of the file to create, minus any format-specific extension; 'format' is the archive format: one of "zip", "tar", "gztar", - "bztar", or "xztar". Or any other registered format. + "bztar", "zstdtar", or "xztar". Or any other registered format. 'root_dir' is a directory that will be the root directory of the archive; ie. we typically chdir into 'root_dir' before creating the @@ -1359,6 +1372,10 @@ if _LZMA_SUPPORTED: _UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [], + "zstd'ed tar-file") + def _find_unpack_format(filename): for name, info in _UNPACK_FORMATS.items(): for extension in info[0]: diff --git a/Lib/sqlite3/__main__.py b/Lib/sqlite3/__main__.py index 79a6209468d..002f1986cdd 100644 --- a/Lib/sqlite3/__main__.py +++ b/Lib/sqlite3/__main__.py @@ -65,6 +65,7 @@ class SqliteInteractiveConsole(InteractiveConsole): def main(*args): parser = ArgumentParser( description="Python sqlite3 CLI", + color=True, ) parser.add_argument( "filename", type=str, default=":memory:", nargs="?", diff --git a/Lib/subprocess.py b/Lib/subprocess.py index da5f5729e09..54c2eb515b6 100644 --- a/Lib/subprocess.py +++ b/Lib/subprocess.py @@ -1235,8 +1235,11 @@ class Popen: finally: self._communication_started = True - - sts = self.wait(timeout=self._remaining_time(endtime)) + try: + sts = self.wait(timeout=self._remaining_time(endtime)) + except TimeoutExpired as exc: + exc.timeout = timeout + raise return (stdout, stderr) @@ -2145,8 +2148,11 @@ class Popen: selector.unregister(key.fileobj) key.fileobj.close() self._fileobj2output[key.fileobj].append(data) - - self.wait(timeout=self._remaining_time(endtime)) + try: + self.wait(timeout=self._remaining_time(endtime)) + except TimeoutExpired as exc: + exc.timeout = orig_timeout + raise # All data exchanged. Translate lists into strings. if stdout is not None: diff --git a/Lib/tarfile.py b/Lib/tarfile.py index 82c5f6704cb..c0f5a609b9f 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -399,7 +399,17 @@ class _Stream: self.exception = lzma.LZMAError else: self.cmp = lzma.LZMACompressor(preset=preset) - + elif comptype == "zst": + try: + from compression import zstd + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = zstd.ZstdDecompressor() + self.exception = zstd.ZstdError + else: + self.cmp = zstd.ZstdCompressor() elif comptype != "tar": raise CompressionError("unknown compression type %r" % comptype) @@ -591,6 +601,8 @@ class _StreamProxy(object): return "bz2" elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): return "xz" + elif self.buf.startswith(b"\x28\xb5\x2f\xfd"): + return "zst" else: return "tar" @@ -1817,11 +1829,13 @@ class TarFile(object): 'r:gz' open for reading with gzip compression 'r:bz2' open for reading with bzip2 compression 'r:xz' open for reading with lzma compression + 'r:zst' open for reading with zstd compression 'a' or 'a:' open for appending, creating the file if necessary 'w' or 'w:' open for writing without compression 'w:gz' open for writing with gzip compression 'w:bz2' open for writing with bzip2 compression 'w:xz' open for writing with lzma compression + 'w:zst' open for writing with zstd compression 'x' or 'x:' create a tarfile exclusively without compression, raise an exception if the file is already created @@ -1831,16 +1845,20 @@ class TarFile(object): if the file is already created 'x:xz' create an lzma compressed tarfile, raise an exception if the file is already created + 'x:zst' create a zstd compressed tarfile, raise an exception + if the file is already created 'r|*' open a stream of tar blocks with transparent compression 'r|' open an uncompressed stream of tar blocks for reading 'r|gz' open a gzip compressed stream of tar blocks 'r|bz2' open a bzip2 compressed stream of tar blocks 'r|xz' open an lzma compressed stream of tar blocks + 'r|zst' open a zstd compressed stream of tar blocks 'w|' open an uncompressed stream for writing 'w|gz' open a gzip compressed stream for writing 'w|bz2' open a bzip2 compressed stream for writing 'w|xz' open an lzma compressed stream for writing + 'w|zst' open a zstd compressed stream for writing """ if not name and not fileobj: @@ -2006,12 +2024,48 @@ class TarFile(object): t._extfileobj = False return t + @classmethod + def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, + zstd_dict=None, **kwargs): + """Open zstd compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from compression.zstd import ZstdFile, ZstdError + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + + fileobj = ZstdFile( + fileobj or name, + mode, + level=level, + options=options, + zstd_dict=zstd_dict + ) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (ZstdError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a zstd file") from e + raise + except Exception: + fileobj.close() + raise + t._extfileobj = False + return t + # All *open() methods are registered here. OPEN_METH = { "tar": "taropen", # uncompressed tar "gz": "gzopen", # gzip compressed tar "bz2": "bz2open", # bzip2 compressed tar - "xz": "xzopen" # lzma compressed tar + "xz": "xzopen", # lzma compressed tar + "zst": "zstopen" # zstd compressed tar } #-------------------------------------------------------------------------- @@ -2883,7 +2937,7 @@ def main(): import argparse description = 'A simple command-line interface for tarfile module.' - parser = argparse.ArgumentParser(description=description) + parser = argparse.ArgumentParser(description=description, color=True) parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Verbose output') parser.add_argument('--filter', metavar='<filtername>', @@ -2963,6 +3017,9 @@ def main(): '.tbz': 'bz2', '.tbz2': 'bz2', '.tb2': 'bz2', + # zstd + '.zst': 'zst', + '.tzst': 'zst', } tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' tar_files = args.create diff --git a/Lib/test/_code_definitions.py b/Lib/test/_code_definitions.py index 06cf6a10231..c3daa0dccf5 100644 --- a/Lib/test/_code_definitions.py +++ b/Lib/test/_code_definitions.py @@ -12,6 +12,50 @@ def spam_minimal(): return +def spam_with_builtins(): + x = 42 + values = (42,) + checks = tuple(callable(v) for v in values) + res = callable(values), tuple(values), list(values), checks + print(res) + + +def spam_with_globals_and_builtins(): + func1 = spam + func2 = spam_minimal + funcs = (func1, func2) + checks = tuple(callable(f) for f in funcs) + res = callable(funcs), tuple(funcs), list(funcs), checks + print(res) + + +def spam_args_attrs_and_builtins(a, b, /, c, d, *args, e, f, **kwargs): + if args.__len__() > 2: + return None + return a, b, c, d, e, f, args, kwargs + + +def spam_returns_arg(x): + return x + + +def spam_with_inner_not_closure(): + def eggs(): + pass + eggs() + + +def spam_with_inner_closure(): + x = 42 + def eggs(): + print(x) + eggs() + + +def spam_annotated(a: int, b: str, c: object) -> tuple: + return a, b, c + + def spam_full(a, b, /, c, d:int=1, *args, e, f:object=None, **kwargs) -> tuple: # arg defaults, kwarg defaults # annotations @@ -98,6 +142,13 @@ ham_C_closure, *_ = eggs_closure_C(2) TOP_FUNCTIONS = [ # shallow spam_minimal, + spam_with_builtins, + spam_with_globals_and_builtins, + spam_args_attrs_and_builtins, + spam_returns_arg, + spam_with_inner_not_closure, + spam_with_inner_closure, + spam_annotated, spam_full, spam, # outer func diff --git a/Lib/test/libregrtest/setup.py b/Lib/test/libregrtest/setup.py index c0346aa934d..c3d1f60a400 100644 --- a/Lib/test/libregrtest/setup.py +++ b/Lib/test/libregrtest/setup.py @@ -40,7 +40,7 @@ def setup_process() -> None: faulthandler.enable(all_threads=True, file=stderr_fd) # Display the Python traceback on SIGALRM or SIGUSR1 signal - signals = [] + signals: list[signal.Signals] = [] if hasattr(signal, 'SIGALRM'): signals.append(signal.SIGALRM) if hasattr(signal, 'SIGUSR1'): diff --git a/Lib/test/libregrtest/utils.py b/Lib/test/libregrtest/utils.py index c4a1506c9a7..63a2e427d18 100644 --- a/Lib/test/libregrtest/utils.py +++ b/Lib/test/libregrtest/utils.py @@ -335,43 +335,11 @@ def get_build_info(): build.append('with_assert') # --enable-experimental-jit - tier2 = re.search('-D_Py_TIER2=([0-9]+)', cflags) - if tier2: - tier2 = int(tier2.group(1)) - - if not sys.flags.ignore_environment: - PYTHON_JIT = os.environ.get('PYTHON_JIT', None) - if PYTHON_JIT: - PYTHON_JIT = (PYTHON_JIT != '0') - else: - PYTHON_JIT = None - - if tier2 == 1: # =yes - if PYTHON_JIT == False: - jit = 'JIT=off' - else: - jit = 'JIT' - elif tier2 == 3: # =yes-off - if PYTHON_JIT: - jit = 'JIT' + if sys._jit.is_available(): + if sys._jit.is_enabled(): + build.append("JIT") else: - jit = 'JIT=off' - elif tier2 == 4: # =interpreter - if PYTHON_JIT == False: - jit = 'JIT-interpreter=off' - else: - jit = 'JIT-interpreter' - elif tier2 == 6: # =interpreter-off (Secret option!) - if PYTHON_JIT: - jit = 'JIT-interpreter' - else: - jit = 'JIT-interpreter=off' - elif '-D_Py_JIT' in cflags: - jit = 'JIT' - else: - jit = None - if jit: - build.append(jit) + build.append("JIT (disabled)") # --enable-framework=name framework = sysconfig.get_config_var('PYTHONFRAMEWORK') diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 24984ad81ff..c74c3a31909 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -33,7 +33,7 @@ __all__ = [ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_gil_enabled", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "requires_gzip", "requires_bz2", "requires_lzma", + "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", "has_fork_support", "requires_fork", @@ -527,6 +527,13 @@ def requires_lzma(reason='requires lzma'): lzma = None return unittest.skipUnless(lzma, reason) +def requires_zstd(reason='requires zstd'): + try: + from compression import zstd + except ImportError: + zstd = None + return unittest.skipUnless(zstd, reason) + def has_no_debug_ranges(): try: import _testcapi @@ -2648,13 +2655,9 @@ skip_on_s390x = unittest.skipIf(is_s390x, 'skipped on s390x') Py_TRACE_REFS = hasattr(sys, 'getobjects') -try: - from _testinternalcapi import jit_enabled -except ImportError: - requires_jit_enabled = requires_jit_disabled = unittest.skip("requires _testinternalcapi") -else: - requires_jit_enabled = unittest.skipUnless(jit_enabled(), "requires JIT enabled") - requires_jit_disabled = unittest.skipIf(jit_enabled(), "requires JIT disabled") +_JIT_ENABLED = sys._jit.is_enabled() +requires_jit_enabled = unittest.skipUnless(_JIT_ENABLED, "requires JIT enabled") +requires_jit_disabled = unittest.skipIf(_JIT_ENABLED, "requires JIT disabled") _BASE_COPY_SRC_DIR_IGNORED_NAMES = frozenset({ @@ -2855,36 +2858,59 @@ def iter_slot_wrappers(cls): @contextlib.contextmanager -def no_color(): +def force_color(color: bool): import _colorize from .os_helper import EnvironmentVarGuard with ( - swap_attr(_colorize, "can_colorize", lambda file=None: False), + swap_attr(_colorize, "can_colorize", lambda file=None: color), EnvironmentVarGuard() as env, ): env.unset("FORCE_COLOR", "NO_COLOR", "PYTHON_COLORS") - env.set("NO_COLOR", "1") + env.set("FORCE_COLOR" if color else "NO_COLOR", "1") yield +def force_colorized(func): + """Force the terminal to be colorized.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + with force_color(True): + return func(*args, **kwargs) + return wrapper + + def force_not_colorized(func): - """Force the terminal not to be colorized.""" + """Force the terminal NOT to be colorized.""" @functools.wraps(func) def wrapper(*args, **kwargs): - with no_color(): + with force_color(False): return func(*args, **kwargs) return wrapper +def force_colorized_test_class(cls): + """Force the terminal to be colorized for the entire test class.""" + original_setUpClass = cls.setUpClass + + @classmethod + @functools.wraps(cls.setUpClass) + def new_setUpClass(cls): + cls.enterClassContext(force_color(True)) + original_setUpClass() + + cls.setUpClass = new_setUpClass + return cls + + def force_not_colorized_test_class(cls): - """Force the terminal not to be colorized for the entire test class.""" + """Force the terminal NOT to be colorized for the entire test class.""" original_setUpClass = cls.setUpClass @classmethod @functools.wraps(cls.setUpClass) def new_setUpClass(cls): - cls.enterClassContext(no_color()) + cls.enterClassContext(force_color(False)) original_setUpClass() cls.setUpClass = new_setUpClass diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 13c6a2a584b..c3c245ddaf8 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -1053,6 +1053,21 @@ class TestGetAnnotations(unittest.TestCase): }, ) + def test_partial_evaluation_error(self): + def f(x: range[1]): + pass + with self.assertRaisesRegex( + TypeError, "type 'range' is not subscriptable" + ): + f.__annotations__ + + self.assertEqual( + get_annotations(f, format=Format.FORWARDREF), + { + "x": support.EqualToForwardRef("range[1]", owner=f), + }, + ) + def test_partial_evaluation_cell(self): obj = object() diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index c5a1f31aa52..5a6be1180c1 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -7058,7 +7058,7 @@ class TestColorized(TestCase): super().setUp() # Ensure color even if ran with NO_COLOR=1 _colorize.can_colorize = lambda *args, **kwargs: True - self.ansi = _colorize.ANSIColors() + self.theme = _colorize.get_theme(force_color=True).argparse def test_argparse_color(self): # Arrange: create a parser with a bit of everything @@ -7120,13 +7120,17 @@ class TestColorized(TestCase): sub2 = subparsers.add_parser("sub2", deprecated=True, help="sub2 help") sub2.add_argument("--baz", choices=("X", "Y", "Z"), help="baz help") - heading = self.ansi.BOLD_BLUE - label, label_b = self.ansi.YELLOW, self.ansi.BOLD_YELLOW - long, long_b = self.ansi.CYAN, self.ansi.BOLD_CYAN - pos, pos_b = short, short_b = self.ansi.GREEN, self.ansi.BOLD_GREEN - sub = self.ansi.BOLD_GREEN - prog = self.ansi.BOLD_MAGENTA - reset = self.ansi.RESET + prog = self.theme.prog + heading = self.theme.heading + long = self.theme.summary_long_option + short = self.theme.summary_short_option + label = self.theme.summary_label + pos = self.theme.summary_action + long_b = self.theme.long_option + short_b = self.theme.short_option + label_b = self.theme.label + pos_b = self.theme.action + reset = self.theme.reset # Act help_text = parser.format_help() @@ -7171,9 +7175,9 @@ class TestColorized(TestCase): {heading}subcommands:{reset} valid subcommands - {sub}{{sub1,sub2}}{reset} additional help - {sub}sub1{reset} sub1 help - {sub}sub2{reset} sub2 help + {pos_b}{{sub1,sub2}}{reset} additional help + {pos_b}sub1{reset} sub1 help + {pos_b}sub2{reset} sub2 help """ ), ) @@ -7187,10 +7191,10 @@ class TestColorized(TestCase): prog="PROG", usage="[prefix] %(prog)s [suffix]", ) - heading = self.ansi.BOLD_BLUE - prog = self.ansi.BOLD_MAGENTA - reset = self.ansi.RESET - usage = self.ansi.MAGENTA + heading = self.theme.heading + prog = self.theme.prog + reset = self.theme.reset + usage = self.theme.prog_extra # Act help_text = parser.format_help() diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py index 2c198a6b8b2..b9df6568123 100644 --- a/Lib/test/test_asdl_parser.py +++ b/Lib/test/test_asdl_parser.py @@ -62,17 +62,17 @@ class TestAsdlParser(unittest.TestCase): alias = self.types['alias'] self.assertEqual( str(alias), - 'Product([Field(identifier, name), Field(identifier, asname, opt=True)], ' + 'Product([Field(identifier, name), Field(identifier, asname, quantifiers=[OPTIONAL])], ' '[Field(int, lineno), Field(int, col_offset), ' - 'Field(int, end_lineno, opt=True), Field(int, end_col_offset, opt=True)])') + 'Field(int, end_lineno, quantifiers=[OPTIONAL]), Field(int, end_col_offset, quantifiers=[OPTIONAL])])') def test_attributes(self): stmt = self.types['stmt'] self.assertEqual(len(stmt.attributes), 4) self.assertEqual(repr(stmt.attributes[0]), 'Field(int, lineno)') self.assertEqual(repr(stmt.attributes[1]), 'Field(int, col_offset)') - self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, opt=True)') - self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, opt=True)') + self.assertEqual(repr(stmt.attributes[2]), 'Field(int, end_lineno, quantifiers=[OPTIONAL])') + self.assertEqual(repr(stmt.attributes[3]), 'Field(int, end_col_offset, quantifiers=[OPTIONAL])') def test_constructor_fields(self): ehandler = self.types['excepthandler'] diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 6a9b7812ef6..09cf3186e05 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -26,6 +26,7 @@ from test import support from test.support import os_helper from test.support import skip_emscripten_stack_overflow, skip_wasi_stack_overflow from test.support.ast_helper import ASTTestMixin +from test.support.import_helper import ensure_lazy_imports from test.test_ast.utils import to_tuple from test.test_ast.snippets import ( eval_tests, eval_results, exec_tests, exec_results, single_tests, single_results @@ -47,6 +48,12 @@ def ast_repr_update_snapshots() -> None: AST_REPR_DATA_FILE.write_text("\n".join(data)) +class LazyImportTest(unittest.TestCase): + @support.cpython_only + def test_lazy_import(self): + ensure_lazy_imports("ast", {"contextlib", "enum", "inspect", "re", "collections", "argparse"}) + + class AST_Tests(unittest.TestCase): maxDiff = None @@ -3272,6 +3279,9 @@ class CommandLineTests(unittest.TestCase): ('--no-type-comments', '--no-type-comments'), ('-a', '--include-attributes'), ('-i=4', '--indent=4'), + ('--feature-version=3.13', '--feature-version=3.13'), + ('-O=-1', '--optimize=-1'), + ('--show-empty', '--show-empty'), ) self.set_source(''' print(1, 2, 3) @@ -3286,6 +3296,7 @@ class CommandLineTests(unittest.TestCase): with self.subTest(flags=args): self.invoke_ast(*args) + @support.force_not_colorized def test_help_message(self): for flag in ('-h', '--help', '--unknown'): with self.subTest(flag=flag): @@ -3389,7 +3400,7 @@ class CommandLineTests(unittest.TestCase): self.check_output(source, expect, flag) def test_indent_flag(self): - # test 'python -m ast -i/--indent' + # test 'python -m ast -i/--indent 0' source = 'pass' expect = ''' Module( @@ -3400,6 +3411,96 @@ class CommandLineTests(unittest.TestCase): with self.subTest(flag=flag): self.check_output(source, expect, flag) + def test_feature_version_flag(self): + # test 'python -m ast --feature-version 3.9/3.10' + source = ''' + match x: + case 1: + pass + ''' + expect = ''' + Module( + body=[ + Match( + subject=Name(id='x', ctx=Load()), + cases=[ + match_case( + pattern=MatchValue( + value=Constant(value=1)), + body=[ + Pass()])])]) + ''' + self.check_output(source, expect, '--feature-version=3.10') + with self.assertRaises(SyntaxError): + self.invoke_ast('--feature-version=3.9') + + def test_no_optimize_flag(self): + # test 'python -m ast -O/--optimize -1/0' + source = ''' + match a: + case 1+2j: + pass + ''' + expect = ''' + Module( + body=[ + Match( + subject=Name(id='a', ctx=Load()), + cases=[ + match_case( + pattern=MatchValue( + value=BinOp( + left=Constant(value=1), + op=Add(), + right=Constant(value=2j))), + body=[ + Pass()])])]) + ''' + for flag in ('-O=-1', '--optimize=-1', '-O=0', '--optimize=0'): + with self.subTest(flag=flag): + self.check_output(source, expect, flag) + + def test_optimize_flag(self): + # test 'python -m ast -O/--optimize 1/2' + source = ''' + match a: + case 1+2j: + pass + ''' + expect = ''' + Module( + body=[ + Match( + subject=Name(id='a', ctx=Load()), + cases=[ + match_case( + pattern=MatchValue( + value=Constant(value=(1+2j))), + body=[ + Pass()])])]) + ''' + for flag in ('-O=1', '--optimize=1', '-O=2', '--optimize=2'): + with self.subTest(flag=flag): + self.check_output(source, expect, flag) + + def test_show_empty_flag(self): + # test 'python -m ast --show-empty' + source = 'print(1, 2, 3)' + expect = ''' + Module( + body=[ + Expr( + value=Call( + func=Name(id='print', ctx=Load()), + args=[ + Constant(value=1), + Constant(value=2), + Constant(value=3)], + keywords=[]))], + type_ignores=[]) + ''' + self.check_output(source, expect, '--show-empty') + class ASTOptimiziationTests(unittest.TestCase): def wrap_expr(self, expr): diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index a2fb1022ae4..9f3b6f9acef 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -263,6 +263,24 @@ class EagerTaskFactoryLoopTests: self.run_coro(run()) + def test_eager_start_false(self): + name = None + + async def asyncfn(): + nonlocal name + name = asyncio.current_task().get_name() + + async def main(): + t = asyncio.get_running_loop().create_task( + asyncfn(), eager_start=False, name="example" + ) + self.assertFalse(t.done()) + self.assertIsNone(name) + await t + self.assertEqual(name, "example") + + self.run_coro(main()) + class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask @@ -505,5 +523,24 @@ class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): asyncio.current_task = asyncio.tasks.current_task = self._current_task return super().tearDown() + +class DefaultTaskFactoryEagerStart(test_utils.TestCase): + def test_eager_start_true_with_default_factory(self): + name = None + + async def asyncfn(): + nonlocal name + name = asyncio.current_task().get_name() + + async def main(): + t = asyncio.get_running_loop().create_task( + asyncfn(), eager_start=True, name="example" + ) + self.assertTrue(t.done()) + self.assertEqual(name, "example") + await t + + asyncio.run(main(), loop_factory=asyncio.EventLoop) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 8d7f1733454..44498ef790e 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -89,8 +89,8 @@ class BaseTaskTests: Future = None all_tasks = None - def new_task(self, loop, coro, name='TestTask', context=None): - return self.__class__.Task(coro, loop=loop, name=name, context=context) + def new_task(self, loop, coro, name='TestTask', context=None, eager_start=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context, eager_start=eager_start) def new_future(self, loop): return self.__class__.Future(loop=loop) @@ -2686,6 +2686,35 @@ class BaseTaskTests: self.assertEqual([None, 1, 2], ret) + def test_eager_start_true(self): + name = None + + async def asyncfn(): + nonlocal name + name = self.current_task().get_name() + + async def main(): + t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=True, name="example") + self.assertTrue(t.done()) + self.assertEqual(name, "example") + await t + + def test_eager_start_false(self): + name = None + + async def asyncfn(): + nonlocal name + name = self.current_task().get_name() + + async def main(): + t = self.new_task(coro=asyncfn(), loop=asyncio.get_running_loop(), eager_start=False, name="example") + self.assertFalse(t.done()) + self.assertIsNone(name) + await t + self.assertEqual(name, "example") + + asyncio.run(main(), loop_factory=asyncio.EventLoop) + def test_get_coro(self): loop = asyncio.new_event_loop() coro = coroutine_function() diff --git a/Lib/test/test_asyncio/test_tools.py b/Lib/test/test_asyncio/test_tools.py index 2caf56172c9..0413e236c27 100644 --- a/Lib/test/test_asyncio/test_tools.py +++ b/Lib/test/test_asyncio/test_tools.py @@ -18,10 +18,18 @@ TEST_INPUTS_TREE = [ 3, "timer", [ - [["awaiter3", "awaiter2", "awaiter"], 4], - [["awaiter1_3", "awaiter1_2", "awaiter1"], 5], - [["awaiter1_3", "awaiter1_2", "awaiter1"], 6], - [["awaiter3", "awaiter2", "awaiter"], 7], + [[("awaiter3", "/path/to/app.py", 130), + ("awaiter2", "/path/to/app.py", 120), + ("awaiter", "/path/to/app.py", 110)], 4], + [[("awaiterB3", "/path/to/app.py", 190), + ("awaiterB2", "/path/to/app.py", 180), + ("awaiterB", "/path/to/app.py", 170)], 5], + [[("awaiterB3", "/path/to/app.py", 190), + ("awaiterB2", "/path/to/app.py", 180), + ("awaiterB", "/path/to/app.py", 170)], 6], + [[("awaiter3", "/path/to/app.py", 130), + ("awaiter2", "/path/to/app.py", 120), + ("awaiter", "/path/to/app.py", 110)], 7], ], ), ( @@ -91,14 +99,14 @@ TEST_INPUTS_TREE = [ " │ └── __aexit__", " │ └── _aexit", " │ ├── (T) child1_1", - " │ │ └── awaiter", - " │ │ └── awaiter2", - " │ │ └── awaiter3", + " │ │ └── awaiter /path/to/app.py:110", + " │ │ └── awaiter2 /path/to/app.py:120", + " │ │ └── awaiter3 /path/to/app.py:130", " │ │ └── (T) timer", " │ └── (T) child2_1", - " │ └── awaiter1", - " │ └── awaiter1_2", - " │ └── awaiter1_3", + " │ └── awaiterB /path/to/app.py:170", + " │ └── awaiterB2 /path/to/app.py:180", + " │ └── awaiterB3 /path/to/app.py:190", " │ └── (T) timer", " └── (T) root2", " └── bloch", @@ -106,14 +114,14 @@ TEST_INPUTS_TREE = [ " └── __aexit__", " └── _aexit", " ├── (T) child1_2", - " │ └── awaiter", - " │ └── awaiter2", - " │ └── awaiter3", + " │ └── awaiter /path/to/app.py:110", + " │ └── awaiter2 /path/to/app.py:120", + " │ └── awaiter3 /path/to/app.py:130", " │ └── (T) timer", " └── (T) child2_2", - " └── awaiter1", - " └── awaiter1_2", - " └── awaiter1_3", + " └── awaiterB /path/to/app.py:170", + " └── awaiterB2 /path/to/app.py:180", + " └── awaiterB3 /path/to/app.py:190", " └── (T) timer", ] ] @@ -589,7 +597,6 @@ TEST_INPUTS_TABLE = [ class TestAsyncioToolsTree(unittest.TestCase): - def test_asyncio_utils(self): for input_, tree in TEST_INPUTS_TREE: with self.subTest(input_): diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index 409c8c109e8..9efebc43d91 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -3,8 +3,16 @@ import base64 import binascii import os from array import array +from test.support import cpython_only from test.support import os_helper from test.support import script_helper +from test.support.import_helper import ensure_lazy_imports + + +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("base64", {"re", "getopt"}) class LegacyBase64TestCase(unittest.TestCase): diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py index 073df310bb4..cbfee604b7a 100644 --- a/Lib/test/test_calendar.py +++ b/Lib/test/test_calendar.py @@ -987,6 +987,7 @@ class CommandLineTestCase(unittest.TestCase): self.assertCLIFails(*args) self.assertCmdFails(*args) + @support.force_not_colorized def test_help(self): stdout = self.run_cmd_ok('-h') self.assertIn(b'usage:', stdout) diff --git a/Lib/test/test_capi/test_config.py b/Lib/test/test_capi/test_config.py index bf351c4defa..a2d70dd3af4 100644 --- a/Lib/test/test_capi/test_config.py +++ b/Lib/test/test_capi/test_config.py @@ -57,7 +57,7 @@ class CAPITests(unittest.TestCase): ("home", str | None, None), ("thread_inherit_context", int, None), ("context_aware_warnings", int, None), - ("import_time", bool, None), + ("import_time", int, None), ("inspect", bool, None), ("install_signal_handlers", bool, None), ("int_max_str_digits", int, None), diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py index 98dc3b42ef0..a597f23a992 100644 --- a/Lib/test/test_capi/test_misc.py +++ b/Lib/test/test_capi/test_misc.py @@ -306,7 +306,7 @@ class CAPITest(unittest.TestCase): CURRENT_THREAD_REGEX + r' File .*, line 6 in <module>\n' r'\n' - r'Extension modules: _testcapi, _testinternalcapi \(total: 2\)\n') + r'Extension modules: _testcapi \(total: 1\)\n') else: # Python built with NDEBUG macro defined: # test _Py_CheckFunctionResult() instead. diff --git a/Lib/test/test_capi/test_object.py b/Lib/test/test_capi/test_object.py index 54a01ac7c4a..127862546b1 100644 --- a/Lib/test/test_capi/test_object.py +++ b/Lib/test/test_capi/test_object.py @@ -174,6 +174,16 @@ class EnableDeferredRefcountingTest(unittest.TestCase): self.assertTrue(_testinternalcapi.has_deferred_refcount(silly_list)) +class IsUniquelyReferencedTest(unittest.TestCase): + """Test PyUnstable_Object_IsUniquelyReferenced""" + def test_is_uniquely_referenced(self): + self.assertTrue(_testcapi.is_uniquely_referenced(object())) + self.assertTrue(_testcapi.is_uniquely_referenced([])) + # Immortals + self.assertFalse(_testcapi.is_uniquely_referenced("spanish inquisition")) + self.assertFalse(_testcapi.is_uniquely_referenced(42)) + # CRASHES is_uniquely_referenced(NULL) + class CAPITest(unittest.TestCase): def check_negative_refcount(self, code): # bpo-35059: Check that Py_DECREF() reports the correct filename diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py index 7e0c60d5522..ba7bcb4540a 100644 --- a/Lib/test/test_capi/test_opt.py +++ b/Lib/test/test_capi/test_opt.py @@ -1919,9 +1919,11 @@ class TestUopsOptimization(unittest.TestCase): _, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD) uops = get_opnames(ex) + self.assertNotIn("_GUARD_NOS_NULL", uops) + self.assertNotIn("_GUARD_CALLABLE_LEN", uops) + self.assertIn("_CALL_LEN", uops) self.assertNotIn("_GUARD_NOS_INT", uops) self.assertNotIn("_GUARD_TOS_INT", uops) - self.assertIn("_CALL_LEN", uops) def test_binary_op_subscr_tuple_int(self): def testfunc(n): diff --git a/Lib/test/test_cmd.py b/Lib/test/test_cmd.py index 0ae44f3987d..dbfec42fc21 100644 --- a/Lib/test/test_cmd.py +++ b/Lib/test/test_cmd.py @@ -11,9 +11,15 @@ import unittest import io import textwrap from test import support -from test.support.import_helper import import_module +from test.support.import_helper import ensure_lazy_imports, import_module from test.support.pty_helper import run_pty +class LazyImportTest(unittest.TestCase): + @support.cpython_only + def test_lazy_import(self): + ensure_lazy_imports("cmd", {"inspect", "string"}) + + class samplecmdclass(cmd.Cmd): """ Instance the sampleclass: diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 36f87e259e7..1b40e0d05fe 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -1158,6 +1158,24 @@ class CmdLineTest(unittest.TestCase): res = assert_python_ok('-c', code, PYTHON_CPU_COUNT='default') self.assertEqual(self.res2int(res), (os.cpu_count(), os.process_cpu_count())) + def test_import_time(self): + # os is not imported at startup + code = 'import os; import os' + + for case in 'importtime', 'importtime=1', 'importtime=true': + res = assert_python_ok('-X', case, '-c', code) + res_err = res.err.decode('utf-8') + self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os') + self.assertNotRegex(res_err, r'import time: cached\s* \| cached\s* \| os') + + res = assert_python_ok('-X', 'importtime=2', '-c', code) + res_err = res.err.decode('utf-8') + self.assertRegex(res_err, r'import time: \s*\d+ \| \s*\d+ \| \s*os') + self.assertRegex(res_err, r'import time: cached\s* \| cached\s* \| os') + + assert_python_failure('-X', 'importtime=-1', '-c', code) + assert_python_failure('-X', 'importtime=3', '-c', code) + def res2int(self, res): out = res.out.strip().decode("utf-8") return tuple(int(i) for i in out.split()) diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index 7cf09ee7847..b646042a3b8 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -674,6 +674,44 @@ class CodeTest(unittest.TestCase): import test._code_definitions as defs funcs = { defs.spam_minimal: {}, + defs.spam_with_builtins: { + 'x': CO_FAST_LOCAL, + 'values': CO_FAST_LOCAL, + 'checks': CO_FAST_LOCAL, + 'res': CO_FAST_LOCAL, + }, + defs.spam_with_globals_and_builtins: { + 'func1': CO_FAST_LOCAL, + 'func2': CO_FAST_LOCAL, + 'funcs': CO_FAST_LOCAL, + 'checks': CO_FAST_LOCAL, + 'res': CO_FAST_LOCAL, + }, + defs.spam_args_attrs_and_builtins: { + 'a': POSONLY, + 'b': POSONLY, + 'c': POSORKW, + 'd': POSORKW, + 'e': KWONLY, + 'f': KWONLY, + 'args': VARARGS, + 'kwargs': VARKWARGS, + }, + defs.spam_returns_arg: { + 'x': POSORKW, + }, + defs.spam_with_inner_not_closure: { + 'eggs': CO_FAST_LOCAL, + }, + defs.spam_with_inner_closure: { + 'x': CO_FAST_CELL, + 'eggs': CO_FAST_LOCAL, + }, + defs.spam_annotated: { + 'a': POSORKW, + 'b': POSORKW, + 'c': POSORKW, + }, defs.spam_full: { 'a': POSONLY, 'b': POSONLY, @@ -777,6 +815,265 @@ class CodeTest(unittest.TestCase): kinds = _testinternalcapi.get_co_localskinds(func.__code__) self.assertEqual(kinds, expected) + @unittest.skipIf(_testinternalcapi is None, "missing _testinternalcapi") + def test_var_counts(self): + self.maxDiff = None + def new_var_counts(*, + posonly=0, + posorkw=0, + kwonly=0, + varargs=0, + varkwargs=0, + purelocals=0, + argcells=0, + othercells=0, + freevars=0, + globalvars=0, + attrs=0, + unknown=0, + ): + nargvars = posonly + posorkw + kwonly + varargs + varkwargs + nlocals = nargvars + purelocals + othercells + if isinstance(globalvars, int): + globalvars = { + 'total': globalvars, + 'numglobal': 0, + 'numbuiltin': 0, + 'numunknown': globalvars, + } + else: + g_numunknown = 0 + if isinstance(globalvars, dict): + numglobal = globalvars['numglobal'] + numbuiltin = globalvars['numbuiltin'] + size = 2 + if 'numunknown' in globalvars: + g_numunknown = globalvars['numunknown'] + size += 1 + assert len(globalvars) == size, globalvars + else: + assert not isinstance(globalvars, str), repr(globalvars) + try: + numglobal, numbuiltin = globalvars + except ValueError: + numglobal, numbuiltin, g_numunknown = globalvars + globalvars = { + 'total': numglobal + numbuiltin + g_numunknown, + 'numglobal': numglobal, + 'numbuiltin': numbuiltin, + 'numunknown': g_numunknown, + } + unbound = globalvars['total'] + attrs + unknown + return { + 'total': nlocals + freevars + unbound, + 'locals': { + 'total': nlocals, + 'args': { + 'total': nargvars, + 'numposonly': posonly, + 'numposorkw': posorkw, + 'numkwonly': kwonly, + 'varargs': varargs, + 'varkwargs': varkwargs, + }, + 'numpure': purelocals, + 'cells': { + 'total': argcells + othercells, + 'numargs': argcells, + 'numothers': othercells, + }, + 'hidden': { + 'total': 0, + 'numpure': 0, + 'numcells': 0, + }, + }, + 'numfree': freevars, + 'unbound': { + 'total': unbound, + 'globals': globalvars, + 'numattrs': attrs, + 'numunknown': unknown, + }, + } + + import test._code_definitions as defs + funcs = { + defs.spam_minimal: new_var_counts(), + defs.spam_with_builtins: new_var_counts( + purelocals=4, + globalvars=4, + ), + defs.spam_with_globals_and_builtins: new_var_counts( + purelocals=5, + globalvars=6, + ), + defs.spam_args_attrs_and_builtins: new_var_counts( + posonly=2, + posorkw=2, + kwonly=2, + varargs=1, + varkwargs=1, + attrs=1, + ), + defs.spam_returns_arg: new_var_counts( + posorkw=1, + ), + defs.spam_with_inner_not_closure: new_var_counts( + purelocals=1, + ), + defs.spam_with_inner_closure: new_var_counts( + othercells=1, + purelocals=1, + ), + defs.spam_annotated: new_var_counts( + posorkw=3, + ), + defs.spam_full: new_var_counts( + posonly=2, + posorkw=2, + kwonly=2, + varargs=1, + varkwargs=1, + purelocals=4, + globalvars=3, + attrs=1, + ), + defs.spam: new_var_counts( + posorkw=1, + ), + defs.spam_N: new_var_counts( + posorkw=1, + purelocals=1, + ), + defs.spam_C: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + othercells=1, + ), + defs.spam_NN: new_var_counts( + posorkw=1, + purelocals=1, + ), + defs.spam_NC: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + othercells=1, + ), + defs.spam_CN: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + othercells=1, + ), + defs.spam_CC: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + othercells=1, + ), + defs.eggs_nested: new_var_counts( + posorkw=1, + ), + defs.eggs_closure: new_var_counts( + posorkw=1, + freevars=2, + ), + defs.eggs_nested_N: new_var_counts( + posorkw=1, + purelocals=1, + ), + defs.eggs_nested_C: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + freevars=2, + ), + defs.eggs_closure_N: new_var_counts( + posorkw=1, + purelocals=1, + freevars=2, + ), + defs.eggs_closure_C: new_var_counts( + posorkw=1, + purelocals=1, + argcells=1, + othercells=1, + freevars=2, + ), + defs.ham_nested: new_var_counts( + posorkw=1, + ), + defs.ham_closure: new_var_counts( + posorkw=1, + freevars=3, + ), + defs.ham_C_nested: new_var_counts( + posorkw=1, + ), + defs.ham_C_closure: new_var_counts( + posorkw=1, + freevars=4, + ), + } + assert len(funcs) == len(defs.FUNCTIONS), (len(funcs), len(defs.FUNCTIONS)) + for func in defs.FUNCTIONS: + with self.subTest(func): + expected = funcs[func] + counts = _testinternalcapi.get_code_var_counts(func.__code__) + self.assertEqual(counts, expected) + + def func_with_globals_and_builtins(): + mod1 = _testinternalcapi + mod2 = dis + mods = (mod1, mod2) + checks = tuple(callable(m) for m in mods) + return callable(mod2), tuple(mods), list(mods), checks + + func = func_with_globals_and_builtins + with self.subTest(f'{func} code'): + expected = new_var_counts( + purelocals=4, + globalvars=5, + ) + counts = _testinternalcapi.get_code_var_counts(func.__code__) + self.assertEqual(counts, expected) + + with self.subTest(f'{func} with own globals and builtins'): + expected = new_var_counts( + purelocals=4, + globalvars=(2, 3), + ) + counts = _testinternalcapi.get_code_var_counts(func) + self.assertEqual(counts, expected) + + with self.subTest(f'{func} without globals'): + expected = new_var_counts( + purelocals=4, + globalvars=(0, 3, 2), + ) + counts = _testinternalcapi.get_code_var_counts(func, globalsns={}) + self.assertEqual(counts, expected) + + with self.subTest(f'{func} without both'): + expected = new_var_counts( + purelocals=4, + globalvars=5, + ) + counts = _testinternalcapi.get_code_var_counts(func, globalsns={}, + builtinsns={}) + self.assertEqual(counts, expected) + + with self.subTest(f'{func} without builtins'): + expected = new_var_counts( + purelocals=4, + globalvars=(2, 0, 3), + ) + counts = _testinternalcapi.get_code_var_counts(func, builtinsns={}) + self.assertEqual(counts, expected) + def isinterned(s): return s is sys.intern(('_' + s + '_')[1:-1]) diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py index 32d6fd4e94b..5ac0080db43 100644 --- a/Lib/test/test_crossinterp.py +++ b/Lib/test/test_crossinterp.py @@ -725,6 +725,39 @@ class MarshalTests(_GetXIDataTests): ]) +class CodeTests(_GetXIDataTests): + + MODE = 'code' + + def test_function_code(self): + self.assert_roundtrip_equal_not_identical([ + *(f.__code__ for f in defs.FUNCTIONS), + *(f.__code__ for f in defs.FUNCTION_LIKE), + ]) + + def test_functions(self): + self.assert_not_shareable([ + *defs.FUNCTIONS, + *defs.FUNCTION_LIKE, + ]) + + def test_other_objects(self): + self.assert_not_shareable([ + None, + True, + False, + Ellipsis, + NotImplemented, + 9999, + 'spam', + b'spam', + (), + [], + {}, + object(), + ]) + + class ShareableTypeTests(_GetXIDataTests): MODE = 'xidata' @@ -817,6 +850,13 @@ class ShareableTypeTests(_GetXIDataTests): object(), ]) + def test_code(self): + # types.CodeType + self.assert_not_shareable([ + *(f.__code__ for f in defs.FUNCTIONS), + *(f.__code__ for f in defs.FUNCTION_LIKE), + ]) + def test_function_object(self): for func in defs.FUNCTIONS: assert type(func) is types.FunctionType, func @@ -935,12 +975,6 @@ class ShareableTypeTests(_GetXIDataTests): self.assert_not_shareable([ types.MappingProxyType({}), types.SimpleNamespace(), - # types.CodeType - defs.spam_minimal.__code__, - defs.spam_full.__code__, - defs.spam_CC.__code__, - defs.eggs_closure_C.__code__, - defs.ham_C_closure.__code__, # types.CellType types.CellType(), # types.FrameType diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 4af8f7f480e..9aace57633b 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -10,7 +10,8 @@ import csv import gc import pickle from test import support -from test.support import import_helper, check_disallow_instantiation +from test.support import cpython_only, import_helper, check_disallow_instantiation +from test.support.import_helper import ensure_lazy_imports from itertools import permutations from textwrap import dedent from collections import OrderedDict @@ -1565,6 +1566,10 @@ class MiscTestCase(unittest.TestCase): def test__all__(self): support.check__all__(self, csv, ('csv', '_csv')) + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("csv", {"re"}) + def test_subclassable(self): # issue 44089 class Foo(csv.Error): ... diff --git a/Lib/test/test_ctypes/test_aligned_structures.py b/Lib/test/test_ctypes/test_aligned_structures.py index 0c563ab8055..50b4d729b9d 100644 --- a/Lib/test/test_ctypes/test_aligned_structures.py +++ b/Lib/test/test_ctypes/test_aligned_structures.py @@ -316,6 +316,7 @@ class TestAlignedStructures(unittest.TestCase, StructCheckMixin): class Main(sbase): _pack_ = 1 + _layout_ = "ms" _fields_ = [ ("a", c_ubyte), ("b", Inner), diff --git a/Lib/test/test_ctypes/test_bitfields.py b/Lib/test/test_ctypes/test_bitfields.py index dc81e752567..518f838219e 100644 --- a/Lib/test/test_ctypes/test_bitfields.py +++ b/Lib/test/test_ctypes/test_bitfields.py @@ -430,6 +430,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin): def test_gh_84039(self): class Bad(Structure): _pack_ = 1 + _layout_ = "ms" _fields_ = [ ("a0", c_uint8, 1), ("a1", c_uint8, 1), @@ -443,9 +444,9 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin): ("b1", c_uint16, 12), ] - class GoodA(Structure): _pack_ = 1 + _layout_ = "ms" _fields_ = [ ("a0", c_uint8, 1), ("a1", c_uint8, 1), @@ -460,6 +461,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin): class Good(Structure): _pack_ = 1 + _layout_ = "ms" _fields_ = [ ("a", GoodA), ("b0", c_uint16, 4), @@ -475,6 +477,7 @@ class BitFieldTest(unittest.TestCase, StructCheckMixin): def test_gh_73939(self): class MyStructure(Structure): _pack_ = 1 + _layout_ = "ms" _fields_ = [ ("P", c_uint16), ("L", c_uint16, 9), diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py index 9f9904282e4..ea5951603f9 100644 --- a/Lib/test/test_ctypes/test_byteswap.py +++ b/Lib/test/test_ctypes/test_byteswap.py @@ -269,6 +269,7 @@ class Test(unittest.TestCase, StructCheckMixin): class S(base): _pack_ = 1 + _layout_ = "ms" _fields_ = [("b", c_byte), ("h", c_short), @@ -296,6 +297,7 @@ class Test(unittest.TestCase, StructCheckMixin): class S(Structure): _pack_ = 1 + _layout_ = "ms" _fields_ = [("b", c_byte), ("h", c_short), diff --git a/Lib/test/test_ctypes/test_generated_structs.py b/Lib/test/test_ctypes/test_generated_structs.py index 9a8102219d8..aa448fad5bb 100644 --- a/Lib/test/test_ctypes/test_generated_structs.py +++ b/Lib/test/test_ctypes/test_generated_structs.py @@ -125,18 +125,21 @@ class Nested(Structure): class Packed1(Structure): _fields_ = [('a', c_int8), ('b', c_int64)] _pack_ = 1 + _layout_ = 'ms' @register() class Packed2(Structure): _fields_ = [('a', c_int8), ('b', c_int64)] _pack_ = 2 + _layout_ = 'ms' @register() class Packed3(Structure): _fields_ = [('a', c_int8), ('b', c_int64)] _pack_ = 4 + _layout_ = 'ms' @register() @@ -155,6 +158,7 @@ class Packed4(Structure): _fields_ = [('a', c_int8), ('b', c_int64)] _pack_ = 8 + _layout_ = 'ms' @register() class X86_32EdgeCase(Structure): @@ -366,6 +370,7 @@ class Example_gh_95496(Structure): @register() class Example_gh_84039_bad(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("a0", c_uint8, 1), ("a1", c_uint8, 1), ("a2", c_uint8, 1), @@ -380,6 +385,7 @@ class Example_gh_84039_bad(Structure): @register() class Example_gh_84039_good_a(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("a0", c_uint8, 1), ("a1", c_uint8, 1), ("a2", c_uint8, 1), @@ -392,6 +398,7 @@ class Example_gh_84039_good_a(Structure): @register() class Example_gh_84039_good(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("a", Example_gh_84039_good_a), ("b0", c_uint16, 4), ("b1", c_uint16, 12)] @@ -399,6 +406,7 @@ class Example_gh_84039_good(Structure): @register() class Example_gh_73939(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("P", c_uint16), ("L", c_uint16, 9), ("Pro", c_uint16, 1), @@ -419,6 +427,7 @@ class Example_gh_86098(Structure): @register() class Example_gh_86098_pack(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("a", c_uint8, 8), ("b", c_uint8, 8), ("c", c_uint32, 16)] @@ -528,7 +537,7 @@ def dump_ctype(tp, struct_or_union_tag='', variable_name='', semi=''): pushes.append(f'#pragma pack(push, {pack})') pops.append(f'#pragma pack(pop)') layout = getattr(tp, '_layout_', None) - if layout == 'ms' or pack: + if layout == 'ms': # The 'ms_struct' attribute only works on x86 and PowerPC requires.add( 'defined(MS_WIN32) || (' diff --git a/Lib/test/test_ctypes/test_pep3118.py b/Lib/test/test_ctypes/test_pep3118.py index 06b2ccecade..11a0744f5a8 100644 --- a/Lib/test/test_ctypes/test_pep3118.py +++ b/Lib/test/test_ctypes/test_pep3118.py @@ -81,6 +81,7 @@ class Point(Structure): class PackedPoint(Structure): _pack_ = 2 + _layout_ = 'ms' _fields_ = [("x", c_long), ("y", c_long)] class PointMidPad(Structure): @@ -88,6 +89,7 @@ class PointMidPad(Structure): class PackedPointMidPad(Structure): _pack_ = 2 + _layout_ = 'ms' _fields_ = [("x", c_byte), ("y", c_uint64)] class PointEndPad(Structure): @@ -95,6 +97,7 @@ class PointEndPad(Structure): class PackedPointEndPad(Structure): _pack_ = 2 + _layout_ = 'ms' _fields_ = [("x", c_uint64), ("y", c_byte)] class Point2(Structure): diff --git a/Lib/test/test_ctypes/test_structunion.py b/Lib/test/test_ctypes/test_structunion.py index 8d8b7e5e995..5b21d48d99c 100644 --- a/Lib/test/test_ctypes/test_structunion.py +++ b/Lib/test/test_ctypes/test_structunion.py @@ -11,6 +11,8 @@ from ._support import (_CData, PyCStructType, UnionType, Py_TPFLAGS_DISALLOW_INSTANTIATION, Py_TPFLAGS_IMMUTABLETYPE) from struct import calcsize +import contextlib +from test.support import MS_WINDOWS class StructUnionTestBase: @@ -335,6 +337,22 @@ class StructUnionTestBase: self.assertIn("from_address", dir(type(self.cls))) self.assertIn("in_dll", dir(type(self.cls))) + def test_pack_layout_switch(self): + # Setting _pack_ implicitly sets default layout to MSVC; + # this is deprecated on non-Windows platforms. + if MS_WINDOWS: + warn_context = contextlib.nullcontext() + else: + warn_context = self.assertWarns(DeprecationWarning) + with warn_context: + class X(self.cls): + _pack_ = 1 + # _layout_ missing + _fields_ = [('a', c_int8, 1), ('b', c_int16, 2)] + + # Check MSVC layout (bitfields of different types aren't combined) + self.check_sizeof(X, struct_size=3, union_size=2) + class StructureTestCase(unittest.TestCase, StructUnionTestBase): cls = Structure diff --git a/Lib/test/test_ctypes/test_structures.py b/Lib/test/test_ctypes/test_structures.py index 221319642e8..92d4851d739 100644 --- a/Lib/test/test_ctypes/test_structures.py +++ b/Lib/test/test_ctypes/test_structures.py @@ -25,6 +25,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin): _fields_ = [("a", c_byte), ("b", c_longlong)] _pack_ = 1 + _layout_ = 'ms' self.check_struct(X) self.assertEqual(sizeof(X), 9) @@ -34,6 +35,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin): _fields_ = [("a", c_byte), ("b", c_longlong)] _pack_ = 2 + _layout_ = 'ms' self.check_struct(X) self.assertEqual(sizeof(X), 10) self.assertEqual(X.b.offset, 2) @@ -45,6 +47,7 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin): _fields_ = [("a", c_byte), ("b", c_longlong)] _pack_ = 4 + _layout_ = 'ms' self.check_struct(X) self.assertEqual(sizeof(X), min(4, longlong_align) + longlong_size) self.assertEqual(X.b.offset, min(4, longlong_align)) @@ -53,27 +56,33 @@ class StructureTestCase(unittest.TestCase, StructCheckMixin): _fields_ = [("a", c_byte), ("b", c_longlong)] _pack_ = 8 + _layout_ = 'ms' self.check_struct(X) self.assertEqual(sizeof(X), min(8, longlong_align) + longlong_size) self.assertEqual(X.b.offset, min(8, longlong_align)) - - d = {"_fields_": [("a", "b"), - ("b", "q")], - "_pack_": -1} - self.assertRaises(ValueError, type(Structure), "X", (Structure,), d) + with self.assertRaises(ValueError): + class X(Structure): + _fields_ = [("a", "b"), ("b", "q")] + _pack_ = -1 + _layout_ = "ms" @support.cpython_only def test_packed_c_limits(self): # Issue 15989 import _testcapi - d = {"_fields_": [("a", c_byte)], - "_pack_": _testcapi.INT_MAX + 1} - self.assertRaises(ValueError, type(Structure), "X", (Structure,), d) - d = {"_fields_": [("a", c_byte)], - "_pack_": _testcapi.UINT_MAX + 2} - self.assertRaises(ValueError, type(Structure), "X", (Structure,), d) + with self.assertRaises(ValueError): + class X(Structure): + _fields_ = [("a", c_byte)] + _pack_ = _testcapi.INT_MAX + 1 + _layout_ = "ms" + + with self.assertRaises(ValueError): + class X(Structure): + _fields_ = [("a", c_byte)] + _pack_ = _testcapi.UINT_MAX + 2 + _layout_ = "ms" def test_initializers(self): class Person(Structure): diff --git a/Lib/test/test_ctypes/test_unaligned_structures.py b/Lib/test/test_ctypes/test_unaligned_structures.py index 58a00597ef5..b5fb4c0df77 100644 --- a/Lib/test/test_ctypes/test_unaligned_structures.py +++ b/Lib/test/test_ctypes/test_unaligned_structures.py @@ -19,10 +19,12 @@ for typ in [c_short, c_int, c_long, c_longlong, c_ushort, c_uint, c_ulong, c_ulonglong]: class X(Structure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("pad", c_byte), ("value", typ)] class Y(SwappedStructure): _pack_ = 1 + _layout_ = 'ms' _fields_ = [("pad", c_byte), ("value", typ)] structures.append(X) diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 99fefb57fd0..ac78f8327b8 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5,6 +5,7 @@ from dataclasses import * import abc +import annotationlib import io import pickle import inspect @@ -12,6 +13,7 @@ import builtins import types import weakref import traceback +import sys import textwrap import unittest from unittest.mock import Mock @@ -25,6 +27,7 @@ import typing # Needed for the string "typing.ClassVar[int]" to work as an import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. from test import support +from test.support import import_helper # Just any custom exception we can catch. class CustomError(Exception): pass @@ -3754,7 +3757,6 @@ class TestSlots(unittest.TestCase): @support.cpython_only def test_dataclass_slot_dict_ctype(self): # https://github.com/python/cpython/issues/123935 - from test.support import import_helper # Skips test if `_testcapi` is not present: _testcapi = import_helper.import_module('_testcapi') @@ -4246,16 +4248,56 @@ class TestMakeDataclass(unittest.TestCase): C = make_dataclass('Point', ['x', 'y', 'z']) c = C(1, 2, 3) self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) - self.assertEqual(C.__annotations__, {'x': 'typing.Any', - 'y': 'typing.Any', - 'z': 'typing.Any'}) + self.assertEqual(C.__annotations__, {'x': typing.Any, + 'y': typing.Any, + 'z': typing.Any}) C = make_dataclass('Point', ['x', ('y', int), 'z']) c = C(1, 2, 3) self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) - self.assertEqual(C.__annotations__, {'x': 'typing.Any', + self.assertEqual(C.__annotations__, {'x': typing.Any, 'y': int, - 'z': 'typing.Any'}) + 'z': typing.Any}) + + def test_no_types_get_annotations(self): + C = make_dataclass('C', ['x', ('y', int), 'z']) + + self.assertEqual( + annotationlib.get_annotations(C, format=annotationlib.Format.VALUE), + {'x': typing.Any, 'y': int, 'z': typing.Any}, + ) + self.assertEqual( + annotationlib.get_annotations( + C, format=annotationlib.Format.FORWARDREF), + {'x': typing.Any, 'y': int, 'z': typing.Any}, + ) + self.assertEqual( + annotationlib.get_annotations( + C, format=annotationlib.Format.STRING), + {'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'}, + ) + + def test_no_types_no_typing_import(self): + with import_helper.CleanImport('typing'): + self.assertNotIn('typing', sys.modules) + C = make_dataclass('C', ['x', ('y', int)]) + + self.assertNotIn('typing', sys.modules) + self.assertEqual( + C.__annotate__(annotationlib.Format.FORWARDREF), + { + 'x': annotationlib.ForwardRef('Any', module='typing'), + 'y': int, + }, + ) + self.assertNotIn('typing', sys.modules) + + for field in fields(C): + if field.name == "x": + self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing')) + else: + self.assertEqual(field.name, "y") + self.assertIs(field.type, int) def test_module_attr(self): self.assertEqual(ByMakeDataClass.__module__, __name__) diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index f2586fcee57..ae68c1dd75c 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1336,7 +1336,7 @@ class DisTests(DisTestBase): # Loop can trigger a quicken where the loop is located self.code_quicken(loop_test) got = self.get_disassembly(loop_test, adaptive=True) - jit = import_helper.import_module("_testinternalcapi").jit_enabled() + jit = sys._jit.is_enabled() expected = dis_loop_test_quickened_code.format("JIT" if jit else "NO_JIT") self.do_disassembly_compare(got, expected) diff --git a/Lib/test/test_email/test_utils.py b/Lib/test/test_email/test_utils.py index 4e6201e13c8..c9d09098b50 100644 --- a/Lib/test/test_email/test_utils.py +++ b/Lib/test/test_email/test_utils.py @@ -4,6 +4,16 @@ import test.support import time import unittest +from test.support import cpython_only +from test.support.import_helper import ensure_lazy_imports + + +class TestImportTime(unittest.TestCase): + + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("email.utils", {"random", "socket"}) + class DateTimeTests(unittest.TestCase): diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py index e06e684408c..95b2d80464c 100644 --- a/Lib/test/test_embed.py +++ b/Lib/test/test_embed.py @@ -585,7 +585,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase): 'faulthandler': False, 'tracemalloc': 0, 'perf_profiling': 0, - 'import_time': False, + 'import_time': 0, 'thread_inherit_context': DEFAULT_THREAD_INHERIT_CONTEXT, 'context_aware_warnings': DEFAULT_CONTEXT_AWARE_WARNINGS, 'code_debug_ranges': True, @@ -998,7 +998,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase): 'hash_seed': 123, 'tracemalloc': 2, 'perf_profiling': 0, - 'import_time': True, + 'import_time': 2, 'code_debug_ranges': False, 'show_ref_count': True, 'malloc_stats': True, @@ -1064,7 +1064,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase): 'use_hash_seed': True, 'hash_seed': 42, 'tracemalloc': 2, - 'import_time': True, + 'import_time': 1, 'code_debug_ranges': False, 'malloc_stats': True, 'inspect': True, @@ -1100,7 +1100,7 @@ class InitConfigTests(EmbeddingTestsMixin, unittest.TestCase): 'use_hash_seed': True, 'hash_seed': 42, 'tracemalloc': 2, - 'import_time': True, + 'import_time': 1, 'code_debug_ranges': False, 'malloc_stats': True, 'inspect': True, diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 68cedc666a5..d8cb5261244 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -19,7 +19,8 @@ from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support from test.support import ALWAYS_EQ, REPO_ROOT -from test.support import threading_helper +from test.support import threading_helper, cpython_only +from test.support.import_helper import ensure_lazy_imports from datetime import timedelta python_version = sys.version_info[:2] @@ -5288,6 +5289,10 @@ class MiscTestCase(unittest.TestCase): def test__all__(self): support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'}) + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("enum", {"functools", "warnings", "inspect", "re"}) + def test_doc_1(self): class Single(Enum): ONE = 1 diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py index f787190b1ae..ad3f669a030 100644 --- a/Lib/test/test_external_inspection.py +++ b/Lib/test/test_external_inspection.py @@ -15,13 +15,12 @@ import subprocess PROCESS_VM_READV_SUPPORTED = False try: - from _remotedebugging import PROCESS_VM_READV_SUPPORTED - from _remotedebugging import get_stack_trace - from _remotedebugging import get_async_stack_trace - from _remotedebugging import get_all_awaited_by + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + from _remote_debugging import get_stack_trace + from _remote_debugging import get_async_stack_trace + from _remote_debugging import get_all_awaited_by except ImportError: - raise unittest.SkipTest("Test only runs when _remotedebuggingmodule is available") - + raise unittest.SkipTest("Test only runs when _remote_debugging is available") def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) @@ -60,8 +59,7 @@ class TestGetStackTrace(unittest.TestCase): foo() def foo(): - sock.sendall(b"ready") - time.sleep(1000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number bar() """ @@ -97,10 +95,10 @@ class TestGetStackTrace(unittest.TestCase): p.wait(timeout=SHORT_TIMEOUT) expected_stack_trace = [ - ("foo", script_name, 15), + ("foo", script_name, 14), ("baz", script_name, 11), ("bar", script_name, 9), - ("<module>", script_name, 17), + ("<module>", script_name, 16), ] self.assertEqual(stack_trace, expected_stack_trace) @@ -123,8 +121,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) def c5(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c4(): await asyncio.sleep(0) @@ -196,10 +193,10 @@ class TestGetStackTrace(unittest.TestCase): root_task = "Task-1" expected_stack_trace = [ [ - ("c5", script_name, 11), - ("c4", script_name, 15), - ("c3", script_name, 18), - ("c2", script_name, 21), + ("c5", script_name, 10), + ("c4", script_name, 14), + ("c3", script_name, 17), + ("c2", script_name, 20), ], "c2_root", [ @@ -215,13 +212,13 @@ class TestGetStackTrace(unittest.TestCase): taskgroups.__file__, ANY, ), - ("main", script_name, 27), + ("main", script_name, 26), ], "Task-1", [], ], [ - [("c1", script_name, 24)], + [("c1", script_name, 23)], "sub_main_1", [ [ @@ -236,7 +233,7 @@ class TestGetStackTrace(unittest.TestCase): taskgroups.__file__, ANY, ), - ("main", script_name, 27), + ("main", script_name, 26), ], "Task-1", [], @@ -244,7 +241,7 @@ class TestGetStackTrace(unittest.TestCase): ], ], [ - [("c1", script_name, 24)], + [("c1", script_name, 23)], "sub_main_2", [ [ @@ -259,7 +256,7 @@ class TestGetStackTrace(unittest.TestCase): taskgroups.__file__, ANY, ), - ("main", script_name, 27), + ("main", script_name, 26), ], "Task-1", [], @@ -289,8 +286,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) async def gen_nested_call(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def gen(): for num in range(2): @@ -338,9 +334,9 @@ class TestGetStackTrace(unittest.TestCase): expected_stack_trace = [ [ - ("gen_nested_call", script_name, 11), - ("gen", script_name, 17), - ("main", script_name, 20), + ("gen_nested_call", script_name, 10), + ("gen", script_name, 16), + ("main", script_name, 19), ], "Task-1", [], @@ -367,8 +363,7 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) @@ -415,9 +410,9 @@ class TestGetStackTrace(unittest.TestCase): stack_trace[2].sort(key=lambda x: x[1]) expected_stack_trace = [ - [("deep", script_name, ANY), ("c1", script_name, 16)], + [("deep", script_name, 11), ("c1", script_name, 15)], "Task-2", - [[[("main", script_name, 22)], "Task-1", []]], + [[[("main", script_name, 21)], "Task-1", []]], ] self.assertEqual(stack_trace, expected_stack_trace) @@ -441,15 +436,14 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) await deep() async def c2(): - await asyncio.sleep(10000) + await asyncio.sleep(10_000) async def main(): await asyncio.staggered.staggered_race( @@ -492,8 +486,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace[2].sort(key=lambda x: x[1]) expected_stack_trace = [ [ - ("deep", script_name, ANY), - ("c1", script_name, 16), + ("deep", script_name, 11), + ("c1", script_name, 15), ("staggered_race.<locals>.run_one_coro", staggered.__file__, ANY), ], "Task-2", @@ -501,7 +495,7 @@ class TestGetStackTrace(unittest.TestCase): [ [ ("staggered_race", staggered.__file__, ANY), - ("main", script_name, 22), + ("main", script_name, 21), ], "Task-1", [], diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index e3b449f2d24..2e794b0fc95 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -23,6 +23,7 @@ from inspect import Signature from test.support import import_helper from test.support import threading_helper +from test.support import cpython_only from test.support import EqualToForwardRef import functools @@ -63,6 +64,14 @@ class BadTuple(tuple): class MyDict(dict): pass +class TestImportTime(unittest.TestCase): + + @cpython_only + def test_lazy_import(self): + import_helper.ensure_lazy_imports( + "functools", {"os", "weakref", "typing", "annotationlib", "warnings"} + ) + class TestPartial: diff --git a/Lib/test/test_getpass.py b/Lib/test/test_getpass.py index 80dda2caaa3..ab36535a1cf 100644 --- a/Lib/test/test_getpass.py +++ b/Lib/test/test_getpass.py @@ -161,6 +161,45 @@ class UnixGetpassTest(unittest.TestCase): self.assertIn('Warning', stderr.getvalue()) self.assertIn('Password:', stderr.getvalue()) + def test_echo_char_replaces_input_with_asterisks(self): + mock_result = '*************' + with mock.patch('os.open') as os_open, \ + mock.patch('io.FileIO'), \ + mock.patch('io.TextIOWrapper') as textio, \ + mock.patch('termios.tcgetattr'), \ + mock.patch('termios.tcsetattr'), \ + mock.patch('getpass._raw_input') as mock_input: + os_open.return_value = 3 + mock_input.return_value = mock_result + + result = getpass.unix_getpass(echo_char='*') + mock_input.assert_called_once_with('Password: ', textio(), + input=textio(), echo_char='*') + self.assertEqual(result, mock_result) + + def test_raw_input_with_echo_char(self): + passwd = 'my1pa$$word!' + mock_input = StringIO(f'{passwd}\n') + mock_output = StringIO() + with mock.patch('sys.stdin', mock_input), \ + mock.patch('sys.stdout', mock_output): + result = getpass._raw_input('Password: ', mock_output, mock_input, + '*') + self.assertEqual(result, passwd) + self.assertEqual('Password: ************', mock_output.getvalue()) + + def test_control_chars_with_echo_char(self): + passwd = 'pass\twd\b' + expect_result = 'pass\tw' + mock_input = StringIO(f'{passwd}\n') + mock_output = StringIO() + with mock.patch('sys.stdin', mock_input), \ + mock.patch('sys.stdout', mock_output): + result = getpass._raw_input('Password: ', mock_output, mock_input, + '*') + self.assertEqual(result, expect_result) + self.assertEqual('Password: *******\x08 \x08', mock_output.getvalue()) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py index 585ed08ea14..33b7d75e3ff 100644 --- a/Lib/test/test_gettext.py +++ b/Lib/test/test_gettext.py @@ -6,7 +6,8 @@ import unittest.mock from functools import partial from test import support -from test.support import os_helper +from test.support import cpython_only, os_helper +from test.support.import_helper import ensure_lazy_imports # TODO: @@ -931,6 +932,10 @@ class MiscTestCase(unittest.TestCase): support.check__all__(self, gettext, not_exported={'c2py', 'ENOENT'}) + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("gettext", {"re", "warnings", "locale"}) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index 1aa8e4e2897..d6623fee9bb 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -13,8 +13,9 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq']) # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when # _heapq is imported, so check them there -func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace', - '_heappop_max', '_heapreplace_max', '_heapify_max'] +func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace'] +# Add max-heap variants +func_names += [func + '_max' for func in func_names] class TestModules(TestCase): def test_py_functions(self): @@ -24,7 +25,7 @@ class TestModules(TestCase): @skipUnless(c_heapq, 'requires _heapq') def test_c_functions(self): for fname in func_names: - self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq') + self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq', fname) def load_tests(loader, tests, ignore): @@ -74,6 +75,34 @@ class TestHeap: except AttributeError: pass + def test_max_push_pop(self): + # 1) Push 256 random numbers and pop them off, verifying all's OK. + heap = [] + data = [] + self.check_max_invariant(heap) + for i in range(256): + item = random.random() + data.append(item) + self.module.heappush_max(heap, item) + self.check_max_invariant(heap) + results = [] + while heap: + item = self.module.heappop_max(heap) + self.check_max_invariant(heap) + results.append(item) + data_sorted = data[:] + data_sorted.sort(reverse=True) + + self.assertEqual(data_sorted, results) + # 2) Check that the invariant holds for a sorted array + self.check_max_invariant(results) + + self.assertRaises(TypeError, self.module.heappush_max, []) + + exc_types = (AttributeError, TypeError) + self.assertRaises(exc_types, self.module.heappush_max, None, None) + self.assertRaises(exc_types, self.module.heappop_max, None) + def check_invariant(self, heap): # Check the heap invariant. for pos, item in enumerate(heap): @@ -81,6 +110,11 @@ class TestHeap: parentpos = (pos-1) >> 1 self.assertTrue(heap[parentpos] <= item) + def check_max_invariant(self, heap): + for pos, item in enumerate(heap[1:], start=1): + parentpos = (pos - 1) >> 1 + self.assertGreaterEqual(heap[parentpos], item) + def test_heapify(self): for size in list(range(30)) + [20000]: heap = [random.random() for dummy in range(size)] @@ -89,6 +123,14 @@ class TestHeap: self.assertRaises(TypeError, self.module.heapify, None) + def test_heapify_max(self): + for size in list(range(30)) + [20000]: + heap = [random.random() for dummy in range(size)] + self.module.heapify_max(heap) + self.check_max_invariant(heap) + + self.assertRaises(TypeError, self.module.heapify_max, None) + def test_naive_nbest(self): data = [random.randrange(2000) for i in range(1000)] heap = [] @@ -109,10 +151,7 @@ class TestHeap: def test_nbest(self): # Less-naive "N-best" algorithm, much faster (if len(data) is big - # enough <wink>) than sorting all of data. However, if we had a max - # heap instead of a min heap, it could go faster still via - # heapify'ing all of data (linear time), then doing 10 heappops - # (10 log-time steps). + # enough <wink>) than sorting all of data. data = [random.randrange(2000) for i in range(1000)] heap = data[:10] self.module.heapify(heap) @@ -125,6 +164,17 @@ class TestHeap: self.assertRaises(TypeError, self.module.heapreplace, None, None) self.assertRaises(IndexError, self.module.heapreplace, [], None) + def test_nbest_maxheap(self): + # With a max heap instead of a min heap, the "N-best" algorithm can + # go even faster still via heapify'ing all of data (linear time), then + # doing 10 heappops (10 log-time steps). + data = [random.randrange(2000) for i in range(1000)] + heap = data[:] + self.module.heapify_max(heap) + result = [self.module.heappop_max(heap) for _ in range(10)] + result.reverse() + self.assertEqual(result, sorted(data)[-10:]) + def test_nbest_with_pushpop(self): data = [random.randrange(2000) for i in range(1000)] heap = data[:10] @@ -134,6 +184,62 @@ class TestHeap: self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:]) self.assertEqual(self.module.heappushpop([], 'x'), 'x') + def test_naive_nworst(self): + # Max-heap variant of "test_naive_nbest" + data = [random.randrange(2000) for i in range(1000)] + heap = [] + for item in data: + self.module.heappush_max(heap, item) + if len(heap) > 10: + self.module.heappop_max(heap) + heap.sort() + expected = sorted(data)[:10] + self.assertEqual(heap, expected) + + def heapiter_max(self, heap): + # An iterator returning a max-heap's elements, largest-first. + try: + while 1: + yield self.module.heappop_max(heap) + except IndexError: + pass + + def test_nworst(self): + # Max-heap variant of "test_nbest" + data = [random.randrange(2000) for i in range(1000)] + heap = data[:10] + self.module.heapify_max(heap) + for item in data[10:]: + if item < heap[0]: # this gets rarer the longer we run + self.module.heapreplace_max(heap, item) + expected = sorted(data, reverse=True)[-10:] + self.assertEqual(list(self.heapiter_max(heap)), expected) + + self.assertRaises(TypeError, self.module.heapreplace_max, None) + self.assertRaises(TypeError, self.module.heapreplace_max, None, None) + self.assertRaises(IndexError, self.module.heapreplace_max, [], None) + + def test_nworst_minheap(self): + # Min-heap variant of "test_nbest_maxheap" + data = [random.randrange(2000) for i in range(1000)] + heap = data[:] + self.module.heapify(heap) + result = [self.module.heappop(heap) for _ in range(10)] + result.reverse() + expected = sorted(data, reverse=True)[-10:] + self.assertEqual(result, expected) + + def test_nworst_with_pushpop(self): + # Max-heap variant of "test_nbest_with_pushpop" + data = [random.randrange(2000) for i in range(1000)] + heap = data[:10] + self.module.heapify_max(heap) + for item in data[10:]: + self.module.heappushpop_max(heap, item) + expected = sorted(data, reverse=True)[-10:] + self.assertEqual(list(self.heapiter_max(heap)), expected) + self.assertEqual(self.module.heappushpop_max([], 'x'), 'x') + def test_heappushpop(self): h = [] x = self.module.heappushpop(h, 10) @@ -153,12 +259,31 @@ class TestHeap: x = self.module.heappushpop(h, 11) self.assertEqual((h, x), ([11], 10)) + def test_heappushpop_max(self): + h = [] + x = self.module.heappushpop_max(h, 10) + self.assertTupleEqual((h, x), ([], 10)) + + h = [10] + x = self.module.heappushpop_max(h, 10.0) + self.assertTupleEqual((h, x), ([10], 10.0)) + self.assertIsInstance(h[0], int) + self.assertIsInstance(x, float) + + h = [10] + x = self.module.heappushpop_max(h, 11) + self.assertTupleEqual((h, x), ([10], 11)) + + h = [10] + x = self.module.heappushpop_max(h, 9) + self.assertTupleEqual((h, x), ([9], 10)) + def test_heappop_max(self): - # _heapop_max has an optimization for one-item lists which isn't + # heapop_max has an optimization for one-item lists which isn't # covered in other tests, so test that case explicitly here h = [3, 2] - self.assertEqual(self.module._heappop_max(h), 3) - self.assertEqual(self.module._heappop_max(h), 2) + self.assertEqual(self.module.heappop_max(h), 3) + self.assertEqual(self.module.heappop_max(h), 2) def test_heapsort(self): # Exercise everything with repeated heapsort checks @@ -175,6 +300,20 @@ class TestHeap: heap_sorted = [self.module.heappop(heap) for i in range(size)] self.assertEqual(heap_sorted, sorted(data)) + def test_heapsort_max(self): + for trial in range(100): + size = random.randrange(50) + data = [random.randrange(25) for i in range(size)] + if trial & 1: # Half of the time, use heapify_max + heap = data[:] + self.module.heapify_max(heap) + else: # The rest of the time, use heappush_max + heap = [] + for item in data: + self.module.heappush_max(heap, item) + heap_sorted = [self.module.heappop_max(heap) for i in range(size)] + self.assertEqual(heap_sorted, sorted(data, reverse=True)) + def test_merge(self): inputs = [] for i in range(random.randrange(25)): @@ -377,16 +516,20 @@ class SideEffectLT: class TestErrorHandling: def test_non_sequence(self): - for f in (self.module.heapify, self.module.heappop): + for f in (self.module.heapify, self.module.heappop, + self.module.heapify_max, self.module.heappop_max): self.assertRaises((TypeError, AttributeError), f, 10) for f in (self.module.heappush, self.module.heapreplace, + self.module.heappush_max, self.module.heapreplace_max, self.module.nlargest, self.module.nsmallest): self.assertRaises((TypeError, AttributeError), f, 10, 10) def test_len_only(self): - for f in (self.module.heapify, self.module.heappop): + for f in (self.module.heapify, self.module.heappop, + self.module.heapify_max, self.module.heappop_max): self.assertRaises((TypeError, AttributeError), f, LenOnly()) - for f in (self.module.heappush, self.module.heapreplace): + for f in (self.module.heappush, self.module.heapreplace, + self.module.heappush_max, self.module.heapreplace_max): self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) for f in (self.module.nlargest, self.module.nsmallest): self.assertRaises(TypeError, f, 2, LenOnly()) @@ -395,7 +538,8 @@ class TestErrorHandling: seq = [CmpErr(), CmpErr(), CmpErr()] for f in (self.module.heapify, self.module.heappop): self.assertRaises(ZeroDivisionError, f, seq) - for f in (self.module.heappush, self.module.heapreplace): + for f in (self.module.heappush, self.module.heapreplace, + self.module.heappush_max, self.module.heapreplace_max): self.assertRaises(ZeroDivisionError, f, seq, 10) for f in (self.module.nlargest, self.module.nsmallest): self.assertRaises(ZeroDivisionError, f, 2, seq) @@ -403,6 +547,8 @@ class TestErrorHandling: def test_arg_parsing(self): for f in (self.module.heapify, self.module.heappop, self.module.heappush, self.module.heapreplace, + self.module.heapify_max, self.module.heappop_max, + self.module.heappush_max, self.module.heapreplace_max, self.module.nlargest, self.module.nsmallest): self.assertRaises((TypeError, AttributeError), f, 10) @@ -424,6 +570,10 @@ class TestErrorHandling: # Python version raises IndexError, C version RuntimeError with self.assertRaises((IndexError, RuntimeError)): self.module.heappush(heap, SideEffectLT(5, heap)) + heap = [] + heap.extend(SideEffectLT(i, heap) for i in range(200)) + with self.assertRaises((IndexError, RuntimeError)): + self.module.heappush_max(heap, SideEffectLT(5, heap)) def test_heappop_mutating_heap(self): heap = [] @@ -431,8 +581,12 @@ class TestErrorHandling: # Python version raises IndexError, C version RuntimeError with self.assertRaises((IndexError, RuntimeError)): self.module.heappop(heap) + heap = [] + heap.extend(SideEffectLT(i, heap) for i in range(200)) + with self.assertRaises((IndexError, RuntimeError)): + self.module.heappop_max(heap) - def test_comparison_operator_modifiying_heap(self): + def test_comparison_operator_modifying_heap(self): # See bpo-39421: Strong references need to be taken # when comparing objects as they can alter the heap class EvilClass(int): @@ -444,7 +598,7 @@ class TestErrorHandling: self.module.heappush(heap, EvilClass(0)) self.assertRaises(IndexError, self.module.heappushpop, heap, 1) - def test_comparison_operator_modifiying_heap_two_heaps(self): + def test_comparison_operator_modifying_heap_two_heaps(self): class h(int): def __lt__(self, o): @@ -464,6 +618,17 @@ class TestErrorHandling: self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1)) self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1)) + list1, list2 = [], [] + + self.module.heappush_max(list1, h(0)) + self.module.heappush_max(list2, g(0)) + self.module.heappush_max(list1, g(1)) + self.module.heappush_max(list2, h(1)) + + self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list1, g(1)) + self.assertRaises((IndexError, RuntimeError), self.module.heappush_max, list2, h(1)) + + class TestErrorHandlingPython(TestErrorHandling, TestCase): module = py_heapq diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py index ba9c42f758e..72cde3f0d6c 100644 --- a/Lib/test/test_json/test_tool.py +++ b/Lib/test/test_json/test_tool.py @@ -6,9 +6,11 @@ import unittest import subprocess from test import support -from test.support import force_not_colorized, os_helper +from test.support import force_colorized, force_not_colorized, os_helper from test.support.script_helper import assert_python_ok +from _colorize import get_theme + @support.requires_subprocess() class TestMain(unittest.TestCase): @@ -246,34 +248,39 @@ class TestMain(unittest.TestCase): proc.communicate(b'"{}"') self.assertEqual(proc.returncode, errno.EPIPE) + @force_colorized def test_colors(self): infile = os_helper.TESTFN self.addCleanup(os.remove, infile) + t = get_theme().syntax + ob = "{" + cb = "}" + cases = ( - ('{}', b'{}'), - ('[]', b'[]'), - ('null', b'\x1b[1;36mnull\x1b[0m'), - ('true', b'\x1b[1;36mtrue\x1b[0m'), - ('false', b'\x1b[1;36mfalse\x1b[0m'), - ('NaN', b'NaN'), - ('Infinity', b'Infinity'), - ('-Infinity', b'-Infinity'), - ('"foo"', b'\x1b[1;32m"foo"\x1b[0m'), - (r'" \"foo\" "', b'\x1b[1;32m" \\"foo\\" "\x1b[0m'), - ('"α"', b'\x1b[1;32m"\\u03b1"\x1b[0m'), - ('123', b'123'), - ('-1.2345e+23', b'-1.2345e+23'), + ('{}', '{}'), + ('[]', '[]'), + ('null', f'{t.keyword}null{t.reset}'), + ('true', f'{t.keyword}true{t.reset}'), + ('false', f'{t.keyword}false{t.reset}'), + ('NaN', f'{t.number}NaN{t.reset}'), + ('Infinity', f'{t.number}Infinity{t.reset}'), + ('-Infinity', f'{t.number}-Infinity{t.reset}'), + ('"foo"', f'{t.string}"foo"{t.reset}'), + (r'" \"foo\" "', f'{t.string}" \\"foo\\" "{t.reset}'), + ('"α"', f'{t.string}"\\u03b1"{t.reset}'), + ('123', f'{t.number}123{t.reset}'), + ('-1.2345e+23', f'{t.number}-1.2345e+23{t.reset}'), (r'{"\\": ""}', - b'''\ -{ - \x1b[94m"\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m -}'''), + f'''\ +{ob} + {t.definition}"\\\\"{t.reset}: {t.string}""{t.reset} +{cb}'''), (r'{"\\\\": ""}', - b'''\ -{ - \x1b[94m"\\\\\\\\"\x1b[0m: \x1b[1;32m""\x1b[0m -}'''), + f'''\ +{ob} + {t.definition}"\\\\\\\\"{t.reset}: {t.string}""{t.reset} +{cb}'''), ('''\ { "foo": "bar", @@ -281,30 +288,32 @@ class TestMain(unittest.TestCase): "qux": [true, false, null], "xyz": [NaN, -Infinity, Infinity] }''', - b'''\ -{ - \x1b[94m"foo"\x1b[0m: \x1b[1;32m"bar"\x1b[0m, - \x1b[94m"baz"\x1b[0m: 1234, - \x1b[94m"qux"\x1b[0m: [ - \x1b[1;36mtrue\x1b[0m, - \x1b[1;36mfalse\x1b[0m, - \x1b[1;36mnull\x1b[0m + f'''\ +{ob} + {t.definition}"foo"{t.reset}: {t.string}"bar"{t.reset}, + {t.definition}"baz"{t.reset}: {t.number}1234{t.reset}, + {t.definition}"qux"{t.reset}: [ + {t.keyword}true{t.reset}, + {t.keyword}false{t.reset}, + {t.keyword}null{t.reset} ], - \x1b[94m"xyz"\x1b[0m: [ - NaN, - -Infinity, - Infinity + {t.definition}"xyz"{t.reset}: [ + {t.number}NaN{t.reset}, + {t.number}-Infinity{t.reset}, + {t.number}Infinity{t.reset} ] -}'''), +{cb}'''), ) for input_, expected in cases: with self.subTest(input=input_): with open(infile, "w", encoding="utf-8") as fp: fp.write(input_) - _, stdout, _ = assert_python_ok('-m', self.module, infile, - PYTHON_COLORS='1') - stdout = stdout.replace(b'\r\n', b'\n') # normalize line endings + _, stdout_b, _ = assert_python_ok( + '-m', self.module, infile, FORCE_COLOR='1', __isolated='1' + ) + stdout = stdout_b.decode() + stdout = stdout.replace('\r\n', '\n') # normalize line endings stdout = stdout.strip() self.assertEqual(stdout, expected) diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index 528ceef5281..455d2af37ef 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -1,13 +1,18 @@ from decimal import Decimal -from test.support import verbose, is_android, linked_to_musl, os_helper +from test.support import cpython_only, verbose, is_android, linked_to_musl, os_helper from test.support.warnings_helper import check_warnings -from test.support.import_helper import import_fresh_module +from test.support.import_helper import ensure_lazy_imports, import_fresh_module from unittest import mock import unittest import locale import sys import codecs +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("locale", {"re", "warnings"}) + class BaseLocalizedTest(unittest.TestCase): # diff --git a/Lib/test/test_mimetypes.py b/Lib/test/test_mimetypes.py index dad5dbde7cd..fb57d5e5544 100644 --- a/Lib/test/test_mimetypes.py +++ b/Lib/test/test_mimetypes.py @@ -6,7 +6,8 @@ import sys import unittest.mock from platform import win32_edition from test import support -from test.support import os_helper +from test.support import cpython_only, force_not_colorized, os_helper +from test.support.import_helper import ensure_lazy_imports try: import _winapi @@ -435,8 +436,13 @@ class MiscTestCase(unittest.TestCase): def test__all__(self): support.check__all__(self, mimetypes) + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("mimetypes", {"os", "posixpath", "urllib.parse", "argparse"}) + class CommandLineTest(unittest.TestCase): + @force_not_colorized def test_parse_args(self): args, help_text = mimetypes._parse_args("-h") self.assertTrue(help_text.startswith("usage: ")) diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py index 6679c0a4fbe..4f25e9c2a03 100644 --- a/Lib/test/test_minidom.py +++ b/Lib/test/test_minidom.py @@ -102,41 +102,38 @@ class MinidomTest(unittest.TestCase): elem = root.childNodes[0] nelem = dom.createElement("element") root.insertBefore(nelem, elem) - self.confirm(len(root.childNodes) == 2 - and root.childNodes.length == 2 - and root.childNodes[0] is nelem - and root.childNodes.item(0) is nelem - and root.childNodes[1] is elem - and root.childNodes.item(1) is elem - and root.firstChild is nelem - and root.lastChild is elem - and root.toxml() == "<doc><element/><foo/></doc>" - , "testInsertBefore -- node properly placed in tree") + self.assertEqual(len(root.childNodes), 2) + self.assertEqual(root.childNodes.length, 2) + self.assertIs(root.childNodes[0], nelem) + self.assertIs(root.childNodes.item(0), nelem) + self.assertIs(root.childNodes[1], elem) + self.assertIs(root.childNodes.item(1), elem) + self.assertIs(root.firstChild, nelem) + self.assertIs(root.lastChild, elem) + self.assertEqual(root.toxml(), "<doc><element/><foo/></doc>") nelem = dom.createElement("element") root.insertBefore(nelem, None) - self.confirm(len(root.childNodes) == 3 - and root.childNodes.length == 3 - and root.childNodes[1] is elem - and root.childNodes.item(1) is elem - and root.childNodes[2] is nelem - and root.childNodes.item(2) is nelem - and root.lastChild is nelem - and nelem.previousSibling is elem - and root.toxml() == "<doc><element/><foo/><element/></doc>" - , "testInsertBefore -- node properly placed in tree") + self.assertEqual(len(root.childNodes), 3) + self.assertEqual(root.childNodes.length, 3) + self.assertIs(root.childNodes[1], elem) + self.assertIs(root.childNodes.item(1), elem) + self.assertIs(root.childNodes[2], nelem) + self.assertIs(root.childNodes.item(2), nelem) + self.assertIs(root.lastChild, nelem) + self.assertIs(nelem.previousSibling, elem) + self.assertEqual(root.toxml(), "<doc><element/><foo/><element/></doc>") nelem2 = dom.createElement("bar") root.insertBefore(nelem2, nelem) - self.confirm(len(root.childNodes) == 4 - and root.childNodes.length == 4 - and root.childNodes[2] is nelem2 - and root.childNodes.item(2) is nelem2 - and root.childNodes[3] is nelem - and root.childNodes.item(3) is nelem - and nelem2.nextSibling is nelem - and nelem.previousSibling is nelem2 - and root.toxml() == - "<doc><element/><foo/><bar/><element/></doc>" - , "testInsertBefore -- node properly placed in tree") + self.assertEqual(len(root.childNodes), 4) + self.assertEqual(root.childNodes.length, 4) + self.assertIs(root.childNodes[2], nelem2) + self.assertIs(root.childNodes.item(2), nelem2) + self.assertIs(root.childNodes[3], nelem) + self.assertIs(root.childNodes.item(3), nelem) + self.assertIs(nelem2.nextSibling, nelem) + self.assertIs(nelem.previousSibling, nelem2) + self.assertEqual(root.toxml(), + "<doc><element/><foo/><bar/><element/></doc>") dom.unlink() def _create_fragment_test_nodes(self): @@ -342,8 +339,8 @@ class MinidomTest(unittest.TestCase): self.assertRaises(xml.dom.NotFoundErr, child.removeAttributeNode, None) self.assertIs(node, child.removeAttributeNode(node)) - self.confirm(len(child.attributes) == 0 - and child.getAttributeNode("spam") is None) + self.assertEqual(len(child.attributes), 0) + self.assertIsNone(child.getAttributeNode("spam")) dom2 = Document() child2 = dom2.appendChild(dom2.createElement("foo")) node2 = child2.getAttributeNode("spam") @@ -366,33 +363,34 @@ class MinidomTest(unittest.TestCase): # Set this attribute to be an ID and make sure that doesn't change # when changing the value: el.setIdAttribute("spam") - self.confirm(len(el.attributes) == 1 - and el.attributes["spam"].value == "bam" - and el.attributes["spam"].nodeValue == "bam" - and el.getAttribute("spam") == "bam" - and el.getAttributeNode("spam").isId) + self.assertEqual(len(el.attributes), 1) + self.assertEqual(el.attributes["spam"].value, "bam") + self.assertEqual(el.attributes["spam"].nodeValue, "bam") + self.assertEqual(el.getAttribute("spam"), "bam") + self.assertTrue(el.getAttributeNode("spam").isId) el.attributes["spam"] = "ham" - self.confirm(len(el.attributes) == 1 - and el.attributes["spam"].value == "ham" - and el.attributes["spam"].nodeValue == "ham" - and el.getAttribute("spam") == "ham" - and el.attributes["spam"].isId) + self.assertEqual(len(el.attributes), 1) + self.assertEqual(el.attributes["spam"].value, "ham") + self.assertEqual(el.attributes["spam"].nodeValue, "ham") + self.assertEqual(el.getAttribute("spam"), "ham") + self.assertTrue(el.attributes["spam"].isId) el.setAttribute("spam2", "bam") - self.confirm(len(el.attributes) == 2 - and el.attributes["spam"].value == "ham" - and el.attributes["spam"].nodeValue == "ham" - and el.getAttribute("spam") == "ham" - and el.attributes["spam2"].value == "bam" - and el.attributes["spam2"].nodeValue == "bam" - and el.getAttribute("spam2") == "bam") + self.assertEqual(len(el.attributes), 2) + self.assertEqual(el.attributes["spam"].value, "ham") + self.assertEqual(el.attributes["spam"].nodeValue, "ham") + self.assertEqual(el.getAttribute("spam"), "ham") + self.assertEqual(el.attributes["spam2"].value, "bam") + self.assertEqual(el.attributes["spam2"].nodeValue, "bam") + self.assertEqual(el.getAttribute("spam2"), "bam") el.attributes["spam2"] = "bam2" - self.confirm(len(el.attributes) == 2 - and el.attributes["spam"].value == "ham" - and el.attributes["spam"].nodeValue == "ham" - and el.getAttribute("spam") == "ham" - and el.attributes["spam2"].value == "bam2" - and el.attributes["spam2"].nodeValue == "bam2" - and el.getAttribute("spam2") == "bam2") + + self.assertEqual(len(el.attributes), 2) + self.assertEqual(el.attributes["spam"].value, "ham") + self.assertEqual(el.attributes["spam"].nodeValue, "ham") + self.assertEqual(el.getAttribute("spam"), "ham") + self.assertEqual(el.attributes["spam2"].value, "bam2") + self.assertEqual(el.attributes["spam2"].nodeValue, "bam2") + self.assertEqual(el.getAttribute("spam2"), "bam2") dom.unlink() def testGetAttrList(self): @@ -448,12 +446,12 @@ class MinidomTest(unittest.TestCase): dom = parseString(d) elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom", "myelem") - self.confirm(len(elems) == 1 - and elems[0].namespaceURI == "http://pyxml.sf.net/minidom" - and elems[0].localName == "myelem" - and elems[0].prefix == "minidom" - and elems[0].tagName == "minidom:myelem" - and elems[0].nodeName == "minidom:myelem") + self.assertEqual(len(elems), 1) + self.assertEqual(elems[0].namespaceURI, "http://pyxml.sf.net/minidom") + self.assertEqual(elems[0].localName, "myelem") + self.assertEqual(elems[0].prefix, "minidom") + self.assertEqual(elems[0].tagName, "minidom:myelem") + self.assertEqual(elems[0].nodeName, "minidom:myelem") dom.unlink() def get_empty_nodelist_from_elements_by_tagName_ns_helper(self, doc, nsuri, @@ -602,17 +600,17 @@ class MinidomTest(unittest.TestCase): def testProcessingInstruction(self): dom = parseString('<e><?mypi \t\n data \t\n ?></e>') pi = dom.documentElement.firstChild - self.confirm(pi.target == "mypi" - and pi.data == "data \t\n " - and pi.nodeName == "mypi" - and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE - and pi.attributes is None - and not pi.hasChildNodes() - and len(pi.childNodes) == 0 - and pi.firstChild is None - and pi.lastChild is None - and pi.localName is None - and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE) + self.assertEqual(pi.target, "mypi") + self.assertEqual(pi.data, "data \t\n ") + self.assertEqual(pi.nodeName, "mypi") + self.assertEqual(pi.nodeType, Node.PROCESSING_INSTRUCTION_NODE) + self.assertIsNone(pi.attributes) + self.assertFalse(pi.hasChildNodes()) + self.assertEqual(len(pi.childNodes), 0) + self.assertIsNone(pi.firstChild) + self.assertIsNone(pi.lastChild) + self.assertIsNone(pi.localName) + self.assertEqual(pi.namespaceURI, xml.dom.EMPTY_NAMESPACE) def testProcessingInstructionRepr(self): dom = parseString('<e><?mypi \t\n data \t\n ?></e>') @@ -718,19 +716,16 @@ class MinidomTest(unittest.TestCase): keys2 = list(attrs2.keys()) keys1.sort() keys2.sort() - self.assertEqual(keys1, keys2, - "clone of element has same attribute keys") + self.assertEqual(keys1, keys2) for i in range(len(keys1)): a1 = attrs1.item(i) a2 = attrs2.item(i) - self.confirm(a1 is not a2 - and a1.value == a2.value - and a1.nodeValue == a2.nodeValue - and a1.namespaceURI == a2.namespaceURI - and a1.localName == a2.localName - , "clone of attribute node has proper attribute values") - self.assertIs(a2.ownerElement, e2, - "clone of attribute node correctly owned") + self.assertIsNot(a1, a2) + self.assertEqual(a1.value, a2.value) + self.assertEqual(a1.nodeValue, a2.nodeValue) + self.assertEqual(a1.namespaceURI,a2.namespaceURI) + self.assertEqual(a1.localName, a2.localName) + self.assertIs(a2.ownerElement, e2) def _setupCloneElement(self, deep): dom = parseString("<doc attr='value'><foo/></doc>") @@ -746,20 +741,19 @@ class MinidomTest(unittest.TestCase): def testCloneElementShallow(self): dom, clone = self._setupCloneElement(0) - self.confirm(len(clone.childNodes) == 0 - and clone.childNodes.length == 0 - and clone.parentNode is None - and clone.toxml() == '<doc attr="value"/>' - , "testCloneElementShallow") + self.assertEqual(len(clone.childNodes), 0) + self.assertEqual(clone.childNodes.length, 0) + self.assertIsNone(clone.parentNode) + self.assertEqual(clone.toxml(), '<doc attr="value"/>') + dom.unlink() def testCloneElementDeep(self): dom, clone = self._setupCloneElement(1) - self.confirm(len(clone.childNodes) == 1 - and clone.childNodes.length == 1 - and clone.parentNode is None - and clone.toxml() == '<doc attr="value"><foo/></doc>' - , "testCloneElementDeep") + self.assertEqual(len(clone.childNodes), 1) + self.assertEqual(clone.childNodes.length, 1) + self.assertIsNone(clone.parentNode) + self.assertTrue(clone.toxml(), '<doc attr="value"><foo/></doc>') dom.unlink() def testCloneDocumentShallow(self): diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py index 8655a0537a5..e6ffd2b0ffe 100644 --- a/Lib/test/test_optparse.py +++ b/Lib/test/test_optparse.py @@ -14,8 +14,9 @@ import unittest from io import StringIO from test import support -from test.support import os_helper +from test.support import cpython_only, os_helper from test.support.i18n_helper import TestTranslationsBase, update_translation_snapshots +from test.support.import_helper import ensure_lazy_imports import optparse from optparse import make_option, Option, \ @@ -1655,6 +1656,10 @@ class MiscTestCase(unittest.TestCase): not_exported = {'check_builtin', 'AmbiguousOptionError', 'NO_DEFAULT'} support.check__all__(self, optparse, not_exported=not_exported) + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("optparse", {"textwrap"}) + class TestTranslations(TestTranslationsBase): def test_translations(self): diff --git a/Lib/test/test_pathlib/test_pathlib.py b/Lib/test/test_pathlib/test_pathlib.py index 41a79d0dceb..8a313cc4292 100644 --- a/Lib/test/test_pathlib/test_pathlib.py +++ b/Lib/test/test_pathlib/test_pathlib.py @@ -16,6 +16,7 @@ from unittest import mock from urllib.request import pathname2url from test.support import import_helper +from test.support import cpython_only from test.support import is_emscripten, is_wasi from test.support import infinite_recursion from test.support import os_helper @@ -80,6 +81,12 @@ class UnsupportedOperationTest(unittest.TestCase): self.assertTrue(isinstance(pathlib.UnsupportedOperation(), NotImplementedError)) +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + import_helper.ensure_lazy_imports("pathlib", {"shutil"}) + + # # Tests for the pure classes. # @@ -3290,7 +3297,6 @@ class PathTest(PurePathTest): self.assertEqual(P.from_uri('file:////foo/bar'), P('//foo/bar')) self.assertEqual(P.from_uri('file://localhost/foo/bar'), P('/foo/bar')) if not is_wasi: - self.assertEqual(P.from_uri('file://127.0.0.1/foo/bar'), P('/foo/bar')) self.assertEqual(P.from_uri(f'file://{socket.gethostname()}/foo/bar'), P('/foo/bar')) self.assertRaises(ValueError, P.from_uri, 'foo/bar') diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py index be365a5a3dd..54797d7898f 100644 --- a/Lib/test/test_pdb.py +++ b/Lib/test/test_pdb.py @@ -1,7 +1,9 @@ # A test suite for pdb; not very comprehensive at the moment. +import _colorize import doctest import gc +import io import os import pdb import sys @@ -18,7 +20,7 @@ from asyncio.events import _set_event_loop_policy from contextlib import ExitStack, redirect_stdout from io import StringIO from test import support -from test.support import force_not_colorized, has_socket_support, os_helper +from test.support import has_socket_support, os_helper from test.support.import_helper import import_module from test.support.pty_helper import run_pty, FakeInput from test.support.script_helper import kill_python @@ -3446,6 +3448,7 @@ def test_pdb_issue_gh_65052(): """ +@support.force_not_colorized_test_class @support.requires_subprocess() class PdbTestCase(unittest.TestCase): def tearDown(self): @@ -3740,7 +3743,6 @@ def bœr(): self.assertNotIn(b'Error', stdout, "Got an error running test script under PDB") - @force_not_colorized def test_issue16180(self): # A syntax error in the debuggee. script = "def f: pass\n" @@ -3754,7 +3756,6 @@ def bœr(): 'Fail to handle a syntax error in the debuggee.' .format(expected, stderr)) - @force_not_colorized def test_issue84583(self): # A syntax error from ast.literal_eval should not make pdb exit. script = "import ast; ast.literal_eval('')\n" @@ -4688,6 +4689,40 @@ class PdbTestInline(unittest.TestCase): self.assertIn("42", stdout) +@support.force_colorized_test_class +class PdbTestColorize(unittest.TestCase): + def setUp(self): + self._original_can_colorize = _colorize.can_colorize + # Force colorize to be enabled because we are sending data + # to a StringIO + _colorize.can_colorize = lambda *args, **kwargs: True + + def tearDown(self): + _colorize.can_colorize = self._original_can_colorize + + def test_code_display(self): + output = io.StringIO() + p = pdb.Pdb(stdout=output, colorize=True) + p.set_trace(commands=['ll', 'c']) + self.assertIn("\x1b", output.getvalue()) + + output = io.StringIO() + p = pdb.Pdb(stdout=output, colorize=False) + p.set_trace(commands=['ll', 'c']) + self.assertNotIn("\x1b", output.getvalue()) + + output = io.StringIO() + p = pdb.Pdb(stdout=output) + p.set_trace(commands=['ll', 'c']) + self.assertNotIn("\x1b", output.getvalue()) + + def test_stack_entry(self): + output = io.StringIO() + p = pdb.Pdb(stdout=output, colorize=True) + p.set_trace(commands=['w', 'c']) + self.assertIn("\x1b", output.getvalue()) + + @support.force_not_colorized_test_class @support.requires_subprocess() class TestREPLSession(unittest.TestCase): @@ -4711,6 +4746,7 @@ class TestREPLSession(unittest.TestCase): self.assertEqual(p.returncode, 0) +@support.force_not_colorized_test_class @support.requires_subprocess() class PdbTestReadline(unittest.TestCase): def setUpClass(): @@ -4812,14 +4848,35 @@ class PdbTestReadline(unittest.TestCase): self.assertIn(b'I love Python', output) + def test_multiline_auto_indent(self): + script = textwrap.dedent(""" + import pdb; pdb.Pdb().set_trace() + """) + + input = b"def f(x):\n" + input += b"if x > 0:\n" + input += b"x += 1\n" + input += b"return x\n" + # We need to do backspaces to remove the auto-indentation + input += b"\x08\x08\x08\x08else:\n" + input += b"return -x\n" + input += b"\n" + input += b"f(-21-21)\n" + input += b"c\n" + + output = run_pty(script, input) + + self.assertIn(b'42', output) + def test_multiline_completion(self): script = textwrap.dedent(""" import pdb; pdb.Pdb().set_trace() """) input = b"def func():\n" - # Complete: \treturn 40 + 2 - input += b"\tret\t 40 + 2\n" + # Auto-indent + # Complete: return 40 + 2 + input += b"ret\t 40 + 2\n" input += b"\n" # Complete: func() input += b"fun\t()\n" @@ -4839,12 +4896,13 @@ class PdbTestReadline(unittest.TestCase): # if the completion is not working as expected input = textwrap.dedent("""\ def func(): - \ta = 1 - \ta += 1 - \ta += 1 - \tif a > 0: - a += 1 - \t\treturn a + a = 1 + \x08\ta += 1 + \x08\x08\ta += 1 + \x08\x08\x08\ta += 1 + \x08\x08\x08\x08\tif a > 0: + a += 1 + \x08\x08\x08\x08return a func() c @@ -4852,7 +4910,7 @@ class PdbTestReadline(unittest.TestCase): output = run_pty(script, input) - self.assertIn(b'4', output) + self.assertIn(b'5', output) self.assertNotIn(b'Error', output) def test_interact_completion(self): diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py index 565e42b04a6..47f51f1979f 100644 --- a/Lib/test/test_peepholer.py +++ b/Lib/test/test_peepholer.py @@ -1,4 +1,5 @@ import dis +import gc from itertools import combinations, product import opcode import sys @@ -2472,6 +2473,13 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests): ] self.check(insts, insts) + insts = [ + ("LOAD_FAST", 0, 1), + ("DELETE_FAST", 0, 2), + ("POP_TOP", None, 3), + ] + self.check(insts, insts) + def test_unoptimized_if_aliased(self): insts = [ ("LOAD_FAST", 0, 1), @@ -2606,6 +2614,22 @@ class OptimizeLoadFastTestCase(DirectCfgOptimizerTests): ] self.cfg_optimization_test(insts, expected, consts=[None]) + def test_del_in_finally(self): + # This loads `obj` onto the stack, executes `del obj`, then returns the + # `obj` from the stack. See gh-133371 for more details. + def create_obj(): + obj = [42] + try: + return obj + finally: + del obj + + obj = create_obj() + # The crash in the linked issue happens while running GC during + # interpreter finalization, so run it here manually. + gc.collect() + self.assertEqual(obj, [42]) + if __name__ == "__main__": diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index 296d4b882e1..742ca8de1be 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -15,7 +15,8 @@ from textwrap import dedent import doctest import unittest from test import support -from test.support import import_helper, os_helper +from test.support import cpython_only, import_helper, os_helper +from test.support.import_helper import ensure_lazy_imports from test.pickletester import AbstractHookTests from test.pickletester import AbstractUnpickleTests @@ -36,6 +37,12 @@ except ImportError: has_c_implementation = False +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("pickle", {"re"}) + + class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase): dump = staticmethod(pickle._dump) dumps = staticmethod(pickle._dumps) @@ -745,6 +752,7 @@ class CommandLineTest(unittest.TestCase): expect = self.text_normalize(expect) self.assertListEqual(res.splitlines(), expect.splitlines()) + @support.force_not_colorized def test_unknown_flag(self): stderr = io.StringIO() with self.assertRaises(SystemExit): diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py index b90edc05e04..818e807dd3a 100644 --- a/Lib/test/test_platform.py +++ b/Lib/test/test_platform.py @@ -794,6 +794,7 @@ class CommandLineTest(unittest.TestCase): self.invoke_platform(*flags) obj.assert_called_once_with(aliased, terse) + @support.force_not_colorized def test_help(self): output = io.StringIO() diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py index b6a07f214fa..0817d0a87a3 100644 --- a/Lib/test/test_posix.py +++ b/Lib/test/test_posix.py @@ -1521,8 +1521,8 @@ class PosixTester(unittest.TestCase): self.assertEqual(cm.exception.errno, errno.EINVAL) os.close(os.pidfd_open(os.getpid(), 0)) - @unittest.skipUnless(hasattr(os, "link"), "test needs os.link()") - @support.skip_android_selinux('hard links to symbolic links') + @os_helper.skip_unless_hardlink + @os_helper.skip_unless_symlink def test_link_follow_symlinks(self): default_follow = sys.platform.startswith( ('darwin', 'freebsd', 'netbsd', 'openbsd', 'dragonfly', 'sunos5')) diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py index dfbc2a06e73..f68996f72b1 100644 --- a/Lib/test/test_pprint.py +++ b/Lib/test/test_pprint.py @@ -11,6 +11,9 @@ import re import types import unittest +from test.support import cpython_only +from test.support.import_helper import ensure_lazy_imports + # list, tuple and dict subclasses that do or don't overwrite __repr__ class list2(list): pass @@ -129,6 +132,10 @@ class QueryTestCase(unittest.TestCase): self.b = list(range(200)) self.a[-12] = self.b + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("pprint", {"dataclasses", "re"}) + def test_init(self): pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter(indent=4, width=40, depth=5, diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py index d5a5a9738c2..a26a8c1d522 100644 --- a/Lib/test/test_pstats.py +++ b/Lib/test/test_pstats.py @@ -1,6 +1,7 @@ import unittest from test import support +from test.support.import_helper import ensure_lazy_imports from io import StringIO from pstats import SortKey from enum import StrEnum, _test_simple_enum @@ -10,6 +11,12 @@ import pstats import tempfile import cProfile +class LazyImportTest(unittest.TestCase): + @support.cpython_only + def test_lazy_import(self): + ensure_lazy_imports("pstats", {"typing"}) + + class AddCallersTestCase(unittest.TestCase): """Tests for pstats.add_callers helper.""" diff --git a/Lib/test/test_pyrepl/support.py b/Lib/test/test_pyrepl/support.py index 3692e164cb9..4f7f9d77933 100644 --- a/Lib/test/test_pyrepl/support.py +++ b/Lib/test/test_pyrepl/support.py @@ -113,9 +113,6 @@ handle_events_narrow_console = partial( prepare_console=partial(prepare_console, width=10), ) -reader_no_colors = partial(prepare_reader, can_colorize=False) -reader_force_colors = partial(prepare_reader, can_colorize=True) - class FakeConsole(Console): def __init__(self, events, encoding="utf-8") -> None: diff --git a/Lib/test/test_pyrepl/test_eventqueue.py b/Lib/test/test_pyrepl/test_eventqueue.py index afb55710342..edfe6ac4748 100644 --- a/Lib/test/test_pyrepl/test_eventqueue.py +++ b/Lib/test/test_pyrepl/test_eventqueue.py @@ -53,7 +53,7 @@ class EventQueueTestBase: mock_keymap.compile_keymap.return_value = {"a": "b"} eq = self.make_eventqueue() eq.keymap = {b"a": "b"} - eq.push("a") + eq.push(b"a") mock_keymap.compile_keymap.assert_called() self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "b") @@ -63,7 +63,7 @@ class EventQueueTestBase: mock_keymap.compile_keymap.return_value = {"a": "b"} eq = self.make_eventqueue() eq.keymap = {b"c": "d"} - eq.push("a") + eq.push(b"a") mock_keymap.compile_keymap.assert_called() self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "a") @@ -73,13 +73,13 @@ class EventQueueTestBase: mock_keymap.compile_keymap.return_value = {"a": "b"} eq = self.make_eventqueue() eq.keymap = {b"a": {b"b": "c"}} - eq.push("a") + eq.push(b"a") mock_keymap.compile_keymap.assert_called() self.assertTrue(eq.empty()) - eq.push("b") + eq.push(b"b") self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "c") - eq.push("d") + eq.push(b"d") self.assertEqual(eq.events[1].evt, "key") self.assertEqual(eq.events[1].data, "d") @@ -88,32 +88,32 @@ class EventQueueTestBase: mock_keymap.compile_keymap.return_value = {"a": "b"} eq = self.make_eventqueue() eq.keymap = {b"a": {b"b": "c"}} - eq.push("a") + eq.push(b"a") mock_keymap.compile_keymap.assert_called() self.assertTrue(eq.empty()) eq.flush_buf() - eq.push("\033") + eq.push(b"\033") self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "\033") - eq.push("b") + eq.push(b"b") self.assertEqual(eq.events[1].evt, "key") self.assertEqual(eq.events[1].data, "b") def test_push_special_key(self): eq = self.make_eventqueue() eq.keymap = {} - eq.push("\x1b") - eq.push("[") - eq.push("A") + eq.push(b"\x1b") + eq.push(b"[") + eq.push(b"A") self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "\x1b") def test_push_unrecognized_escape_sequence(self): eq = self.make_eventqueue() eq.keymap = {} - eq.push("\x1b") - eq.push("[") - eq.push("Z") + eq.push(b"\x1b") + eq.push(b"[") + eq.push(b"Z") self.assertEqual(len(eq.events), 3) self.assertEqual(eq.events[0].evt, "key") self.assertEqual(eq.events[0].data, "\x1b") @@ -122,12 +122,54 @@ class EventQueueTestBase: self.assertEqual(eq.events[2].evt, "key") self.assertEqual(eq.events[2].data, "Z") - def test_push_unicode_character(self): + def test_push_unicode_character_as_str(self): eq = self.make_eventqueue() eq.keymap = {} - eq.push("ч") - self.assertEqual(eq.events[0].evt, "key") - self.assertEqual(eq.events[0].data, "ч") + with self.assertRaises(AssertionError): + eq.push("ч") + with self.assertRaises(AssertionError): + eq.push("ñ") + + def test_push_unicode_character_two_bytes(self): + eq = self.make_eventqueue() + eq.keymap = {} + + encoded = "ч".encode(eq.encoding, "replace") + self.assertEqual(len(encoded), 2) + + eq.push(encoded[0]) + e = eq.get() + self.assertIsNone(e) + + eq.push(encoded[1]) + e = eq.get() + self.assertEqual(e.evt, "key") + self.assertEqual(e.data, "ч") + + def test_push_single_chars_and_unicode_character_as_str(self): + eq = self.make_eventqueue() + eq.keymap = {} + + def _event(evt, data, raw=None): + r = raw if raw is not None else data.encode(eq.encoding) + e = Event(evt, data, r) + return e + + def _push(keys): + for k in keys: + eq.push(k) + + self.assertIsInstance("ñ", str) + + # If an exception happens during push, the existing events must be + # preserved and we can continue to push. + _push(b"b") + with self.assertRaises(AssertionError): + _push("ñ") + _push(b"a") + + self.assertEqual(eq.get(), _event("key", "b")) + self.assertEqual(eq.get(), _event("key", "a")) @unittest.skipIf(support.MS_WINDOWS, "No Unix event queue on Windows") diff --git a/Lib/test/test_pyrepl/test_reader.py b/Lib/test/test_pyrepl/test_reader.py index 8d7fcf538d2..4ee320a5a4d 100644 --- a/Lib/test/test_pyrepl/test_reader.py +++ b/Lib/test/test_pyrepl/test_reader.py @@ -4,20 +4,21 @@ import rlcompleter from textwrap import dedent from unittest import TestCase from unittest.mock import MagicMock +from test.support import force_colorized_test_class, force_not_colorized_test_class from .support import handle_all_events, handle_events_narrow_console from .support import ScreenEqualMixin, code_to_events -from .support import prepare_console, reader_force_colors -from .support import reader_no_colors as prepare_reader +from .support import prepare_reader, prepare_console from _pyrepl.console import Event from _pyrepl.reader import Reader -from _colorize import theme +from _colorize import default_theme -overrides = {"RESET": "z", "SOFT_KEYWORD": "K"} -colors = {overrides.get(k, k[0].lower()): v for k, v in theme.items()} +overrides = {"reset": "z", "soft_keyword": "K"} +colors = {overrides.get(k, k[0].lower()): v for k, v in default_theme.syntax.items()} +@force_not_colorized_test_class class TestReader(ScreenEqualMixin, TestCase): def test_calc_screen_wrap_simple(self): events = code_to_events(10 * "a") @@ -127,13 +128,6 @@ class TestReader(ScreenEqualMixin, TestCase): reader.setpos_from_xy(0, 0) self.assertEqual(reader.pos, 0) - def test_control_characters(self): - code = 'flag = "🏳️🌈"' - events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) - self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True) - self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors)) - def test_setpos_from_xy_multiple_lines(self): # fmt: off code = ( @@ -364,6 +358,8 @@ class TestReader(ScreenEqualMixin, TestCase): reader.setpos_from_xy(8, 0) self.assertEqual(reader.pos, 7) +@force_colorized_test_class +class TestReaderInColor(ScreenEqualMixin, TestCase): def test_syntax_highlighting_basic(self): code = dedent( """\ @@ -403,7 +399,7 @@ class TestReader(ScreenEqualMixin, TestCase): ) expected_sync = expected.format(a="", **colors) events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(events) self.assert_screen_equal(reader, code, clean=True) self.assert_screen_equal(reader, expected_sync) self.assertEqual(reader.pos, 2**7 + 2**8) @@ -416,7 +412,7 @@ class TestReader(ScreenEqualMixin, TestCase): [Event(evt="key", data="up", raw=bytearray(b"\x1bOA"))] * 13, code_to_events("async "), ) - reader, _ = handle_all_events(more_events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(more_events) self.assert_screen_equal(reader, expected_async) self.assertEqual(reader.pos, 21) self.assertEqual(reader.cxy, (6, 1)) @@ -433,7 +429,7 @@ class TestReader(ScreenEqualMixin, TestCase): """ ).format(**colors) events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(events) self.assert_screen_equal(reader, code, clean=True) self.assert_screen_equal(reader, expected) @@ -451,7 +447,7 @@ class TestReader(ScreenEqualMixin, TestCase): """ ).format(**colors) events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(events) self.assert_screen_equal(reader, code, clean=True) self.assert_screen_equal(reader, expected) @@ -471,7 +467,7 @@ class TestReader(ScreenEqualMixin, TestCase): """ ).format(**colors) events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(events) self.assert_screen_equal(reader, code, clean=True) self.assert_screen_equal(reader, expected) @@ -497,6 +493,13 @@ class TestReader(ScreenEqualMixin, TestCase): """ ).format(OB="{", CB="}", **colors) events = code_to_events(code) - reader, _ = handle_all_events(events, prepare_reader=reader_force_colors) + reader, _ = handle_all_events(events) self.assert_screen_equal(reader, code, clean=True) self.assert_screen_equal(reader, expected) + + def test_control_characters(self): + code = 'flag = "🏳️🌈"' + events = code_to_events(code) + reader, _ = handle_all_events(events) + self.assert_screen_equal(reader, 'flag = "🏳️\\u200d🌈"', clean=True) + self.assert_screen_equal(reader, 'flag {o}={z} {s}"🏳️\\u200d🌈"{z}'.format(**colors)) diff --git a/Lib/test/test_pyrepl/test_unix_console.py b/Lib/test/test_pyrepl/test_unix_console.py index 7acb84a94f7..c447b310c49 100644 --- a/Lib/test/test_pyrepl/test_unix_console.py +++ b/Lib/test/test_pyrepl/test_unix_console.py @@ -3,11 +3,12 @@ import os import sys import unittest from functools import partial -from test.support import os_helper +from test.support import os_helper, force_not_colorized_test_class + from unittest import TestCase from unittest.mock import MagicMock, call, patch, ANY -from .support import handle_all_events, code_to_events, reader_no_colors +from .support import handle_all_events, code_to_events try: from _pyrepl.console import Event @@ -33,12 +34,10 @@ def unix_console(events, **kwargs): handle_events_unix_console = partial( handle_all_events, - prepare_reader=reader_no_colors, prepare_console=unix_console, ) handle_events_narrow_unix_console = partial( handle_all_events, - prepare_reader=reader_no_colors, prepare_console=partial(unix_console, width=5), ) handle_events_short_unix_console = partial( @@ -120,6 +119,7 @@ TERM_CAPABILITIES = { ) @patch("termios.tcsetattr", lambda a, b, c: None) @patch("os.write") +@force_not_colorized_test_class class TestConsole(TestCase): def test_simple_addition(self, _os_write): code = "12+34" @@ -255,9 +255,7 @@ class TestConsole(TestCase): # fmt: on events = itertools.chain(code_to_events(code)) - reader, console = handle_events_short_unix_console( - events, prepare_reader=reader_no_colors - ) + reader, console = handle_events_short_unix_console(events) console.height = 2 console.getheightwidth = MagicMock(lambda _: (2, 80)) diff --git a/Lib/test/test_pyrepl/test_windows_console.py b/Lib/test/test_pyrepl/test_windows_console.py index e95fec46a85..e7bab226b31 100644 --- a/Lib/test/test_pyrepl/test_windows_console.py +++ b/Lib/test/test_pyrepl/test_windows_console.py @@ -7,12 +7,13 @@ if sys.platform != "win32": import itertools from functools import partial +from test.support import force_not_colorized_test_class from typing import Iterable from unittest import TestCase from unittest.mock import MagicMock, call from .support import handle_all_events, code_to_events -from .support import reader_no_colors as default_prepare_reader +from .support import prepare_reader as default_prepare_reader try: from _pyrepl.console import Event, Console @@ -24,10 +25,12 @@ try: MOVE_DOWN, ERASE_IN_LINE, ) + import _pyrepl.windows_console as wc except ImportError: pass +@force_not_colorized_test_class class WindowsConsoleTests(TestCase): def console(self, events, **kwargs) -> Console: console = WindowsConsole() @@ -350,8 +353,226 @@ class WindowsConsoleTests(TestCase): Event(evt="key", data='\x1a', raw=bytearray(b'\x1a')), ], ) - reader, _ = self.handle_events_narrow(events) + reader, con = self.handle_events_narrow(events) self.assertEqual(reader.cxy, (2, 3)) + con.restore() + + +class WindowsConsoleGetEventTests(TestCase): + # Virtual-Key Codes: https://learn.microsoft.com/en-us/windows/win32/inputdev/virtual-key-codes + VK_BACK = 0x08 + VK_RETURN = 0x0D + VK_LEFT = 0x25 + VK_7 = 0x37 + VK_M = 0x4D + # Used for miscellaneous characters; it can vary by keyboard. + # For the US standard keyboard, the '" key. + # For the German keyboard, the Ä key. + VK_OEM_7 = 0xDE + + # State of control keys: https://learn.microsoft.com/en-us/windows/console/key-event-record-str + RIGHT_ALT_PRESSED = 0x0001 + RIGHT_CTRL_PRESSED = 0x0004 + LEFT_ALT_PRESSED = 0x0002 + LEFT_CTRL_PRESSED = 0x0008 + ENHANCED_KEY = 0x0100 + SHIFT_PRESSED = 0x0010 + + + def get_event(self, input_records, **kwargs) -> Console: + self.console = WindowsConsole(encoding='utf-8') + self.mock = MagicMock(side_effect=input_records) + self.console._read_input = self.mock + self.console._WindowsConsole__vt_support = kwargs.get("vt_support", + False) + event = self.console.get_event(block=False) + return event + + def get_input_record(self, unicode_char, vcode=0, control=0): + return wc.INPUT_RECORD( + wc.KEY_EVENT, + wc.ConsoleEvent(KeyEvent= + wc.KeyEvent( + bKeyDown=True, + wRepeatCount=1, + wVirtualKeyCode=vcode, + wVirtualScanCode=0, # not used + uChar=wc.Char(unicode_char), + dwControlKeyState=control + ))) + + def test_EmptyBuffer(self): + self.assertEqual(self.get_event([None]), None) + self.assertEqual(self.mock.call_count, 1) + + def test_WINDOW_BUFFER_SIZE_EVENT(self): + ir = wc.INPUT_RECORD( + wc.WINDOW_BUFFER_SIZE_EVENT, + wc.ConsoleEvent(WindowsBufferSizeEvent= + wc.WindowsBufferSizeEvent( + wc._COORD(0, 0)))) + self.assertEqual(self.get_event([ir]), Event("resize", "")) + self.assertEqual(self.mock.call_count, 1) + + def test_KEY_EVENT_up_ignored(self): + ir = wc.INPUT_RECORD( + wc.KEY_EVENT, + wc.ConsoleEvent(KeyEvent= + wc.KeyEvent(bKeyDown=False))) + self.assertEqual(self.get_event([ir]), None) + self.assertEqual(self.mock.call_count, 1) + + def test_unhandled_events(self): + for event in (wc.FOCUS_EVENT, wc.MENU_EVENT, wc.MOUSE_EVENT): + ir = wc.INPUT_RECORD( + event, + # fake data, nothing is read except bKeyDown + wc.ConsoleEvent(KeyEvent= + wc.KeyEvent(bKeyDown=False))) + self.assertEqual(self.get_event([ir]), None) + self.assertEqual(self.mock.call_count, 1) + + def test_enter(self): + ir = self.get_input_record("\r", self.VK_RETURN) + self.assertEqual(self.get_event([ir]), Event("key", "\n")) + self.assertEqual(self.mock.call_count, 1) + + def test_backspace(self): + ir = self.get_input_record("\x08", self.VK_BACK) + self.assertEqual( + self.get_event([ir]), Event("key", "backspace")) + self.assertEqual(self.mock.call_count, 1) + + def test_m(self): + ir = self.get_input_record("m", self.VK_M) + self.assertEqual(self.get_event([ir]), Event("key", "m")) + self.assertEqual(self.mock.call_count, 1) + + def test_M(self): + ir = self.get_input_record("M", self.VK_M, self.SHIFT_PRESSED) + self.assertEqual(self.get_event([ir]), Event("key", "M")) + self.assertEqual(self.mock.call_count, 1) + + def test_left(self): + # VK_LEFT is sent as ENHANCED_KEY + ir = self.get_input_record("\x00", self.VK_LEFT, self.ENHANCED_KEY) + self.assertEqual(self.get_event([ir]), Event("key", "left")) + self.assertEqual(self.mock.call_count, 1) + + def test_left_RIGHT_CTRL_PRESSED(self): + ir = self.get_input_record( + "\x00", self.VK_LEFT, self.RIGHT_CTRL_PRESSED | self.ENHANCED_KEY) + self.assertEqual( + self.get_event([ir]), Event("key", "ctrl left")) + self.assertEqual(self.mock.call_count, 1) + + def test_left_LEFT_CTRL_PRESSED(self): + ir = self.get_input_record( + "\x00", self.VK_LEFT, self.LEFT_CTRL_PRESSED | self.ENHANCED_KEY) + self.assertEqual( + self.get_event([ir]), Event("key", "ctrl left")) + self.assertEqual(self.mock.call_count, 1) + + def test_left_RIGHT_ALT_PRESSED(self): + ir = self.get_input_record( + "\x00", self.VK_LEFT, self.RIGHT_ALT_PRESSED | self.ENHANCED_KEY) + self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033")) + self.assertEqual( + self.console.get_event(), Event("key", "left")) + # self.mock is not called again, since the second time we read from the + # command queue + self.assertEqual(self.mock.call_count, 1) + + def test_left_LEFT_ALT_PRESSED(self): + ir = self.get_input_record( + "\x00", self.VK_LEFT, self.LEFT_ALT_PRESSED | self.ENHANCED_KEY) + self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033")) + self.assertEqual( + self.console.get_event(), Event("key", "left")) + self.assertEqual(self.mock.call_count, 1) + + def test_m_LEFT_ALT_PRESSED_and_LEFT_CTRL_PRESSED(self): + # For the shift keys, Windows does not send anything when + # ALT and CTRL are both pressed, so let's test with VK_M. + # get_event() receives this input, but does not + # generate an event. + # This is for e.g. an English keyboard layout, for a + # German layout this returns `µ`, see test_AltGr_m. + ir = self.get_input_record( + "\x00", self.VK_M, self.LEFT_ALT_PRESSED | self.LEFT_CTRL_PRESSED) + self.assertEqual(self.get_event([ir]), None) + self.assertEqual(self.mock.call_count, 1) + + def test_m_LEFT_ALT_PRESSED(self): + ir = self.get_input_record( + "m", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED) + self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033")) + self.assertEqual(self.console.get_event(), Event("key", "m")) + self.assertEqual(self.mock.call_count, 1) + + def test_m_RIGHT_ALT_PRESSED(self): + ir = self.get_input_record( + "m", vcode=self.VK_M, control=self.RIGHT_ALT_PRESSED) + self.assertEqual(self.get_event([ir]), Event(evt="key", data="\033")) + self.assertEqual(self.console.get_event(), Event("key", "m")) + self.assertEqual(self.mock.call_count, 1) + + def test_AltGr_7(self): + # E.g. on a German keyboard layout, '{' is entered via + # AltGr + 7, where AltGr is the right Alt key on the keyboard. + # In this case, Windows automatically sets + # RIGHT_ALT_PRESSED = 0x0001 + LEFT_CTRL_PRESSED = 0x0008 + # This can also be entered like + # LeftAlt + LeftCtrl + 7 or + # LeftAlt + RightCtrl + 7 + # See https://learn.microsoft.com/en-us/windows/console/key-event-record-str + # https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-vkkeyscanw + ir = self.get_input_record( + "{", vcode=self.VK_7, + control=self.RIGHT_ALT_PRESSED | self.LEFT_CTRL_PRESSED) + self.assertEqual(self.get_event([ir]), Event("key", "{")) + self.assertEqual(self.mock.call_count, 1) + + def test_AltGr_m(self): + # E.g. on a German keyboard layout, this yields 'µ' + # Let's use LEFT_ALT_PRESSED and RIGHT_CTRL_PRESSED this + # time, to cover that, too. See above in test_AltGr_7. + ir = self.get_input_record( + "µ", vcode=self.VK_M, control=self.LEFT_ALT_PRESSED | self.RIGHT_CTRL_PRESSED) + self.assertEqual(self.get_event([ir]), Event("key", "µ")) + self.assertEqual(self.mock.call_count, 1) + + def test_umlaut_a_german(self): + ir = self.get_input_record("ä", self.VK_OEM_7) + self.assertEqual(self.get_event([ir]), Event("key", "ä")) + self.assertEqual(self.mock.call_count, 1) + + # virtual terminal tests + # Note: wVirtualKeyCode, wVirtualScanCode and dwControlKeyState + # are always zero in this case. + # "\r" and backspace are handled specially, everything else + # is handled in "elif self.__vt_support:" in WindowsConsole.get_event(). + # Hence, only one regular key ("m") and a terminal sequence + # are sufficient to test here, the real tests happen in test_eventqueue + # and test_keymap. + + def test_enter_vt(self): + ir = self.get_input_record("\r") + self.assertEqual(self.get_event([ir], vt_support=True), + Event("key", "\n")) + self.assertEqual(self.mock.call_count, 1) + + def test_backspace_vt(self): + ir = self.get_input_record("\x7f") + self.assertEqual(self.get_event([ir], vt_support=True), + Event("key", "backspace", b"\x7f")) + self.assertEqual(self.mock.call_count, 1) + + def test_up_vt(self): + irs = [self.get_input_record(x) for x in "\x1b[A"] + self.assertEqual(self.get_event(irs, vt_support=True), + Event(evt='key', data='up', raw=bytearray(b'\x1b[A'))) + self.assertEqual(self.mock.call_count, 3) if __name__ == "__main__": diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index 96f6cc86219..43957f525f1 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -1411,6 +1411,7 @@ class TestModule(unittest.TestCase): class CommandLineTest(unittest.TestCase): + @support.force_not_colorized def test_parse_args(self): args, help_text = random._parse_args(shlex.split("--choice a b c")) self.assertEqual(args.choice, ["a", "b", "c"]) diff --git a/Lib/test/test_remote_pdb.py b/Lib/test/test_remote_pdb.py index 9fbe94fcdd6..aef8a6b0129 100644 --- a/Lib/test/test_remote_pdb.py +++ b/Lib/test/test_remote_pdb.py @@ -3,6 +3,7 @@ import time import itertools import json import os +import re import signal import socket import subprocess @@ -12,9 +13,9 @@ import textwrap import threading import unittest import unittest.mock -from contextlib import contextmanager, redirect_stdout, ExitStack +from contextlib import closing, contextmanager, redirect_stdout, redirect_stderr, ExitStack from pathlib import Path -from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT +from test.support import is_wasi, cpython_only, force_color, requires_subprocess, SHORT_TIMEOUT from test.support.os_helper import temp_dir, TESTFN, unlink from typing import Dict, List, Optional, Tuple, Union, Any @@ -79,44 +80,6 @@ class MockSocketFile: return results -class MockDebuggerSocket: - """Mock file-like simulating a connection to a _RemotePdb instance""" - - def __init__(self, incoming): - self.incoming = iter(incoming) - self.outgoing = [] - self.buffered = bytearray() - - def write(self, data: bytes) -> None: - """Simulate write to socket.""" - self.buffered += data - - def flush(self) -> None: - """Ensure each line is valid JSON.""" - lines = self.buffered.splitlines(keepends=True) - self.buffered.clear() - for line in lines: - assert line.endswith(b"\n") - self.outgoing.append(json.loads(line)) - - def readline(self) -> bytes: - """Read a line from the prepared input queue.""" - # Anything written must be flushed before trying to read, - # since the read will be dependent upon the last write. - assert not self.buffered - try: - item = next(self.incoming) - if not isinstance(item, bytes): - item = json.dumps(item).encode() - return item + b"\n" - except StopIteration: - return b"" - - def close(self) -> None: - """No-op close implementation.""" - pass - - class PdbClientTestCase(unittest.TestCase): """Tests for the _PdbClient class.""" @@ -124,8 +87,11 @@ class PdbClientTestCase(unittest.TestCase): self, *, incoming, - simulate_failure=None, + simulate_send_failure=False, + simulate_sigint_during_stdout_write=False, + use_interrupt_socket=False, expected_outgoing=None, + expected_outgoing_signals=None, expected_completions=None, expected_exception=None, expected_stdout="", @@ -134,6 +100,8 @@ class PdbClientTestCase(unittest.TestCase): ): if expected_outgoing is None: expected_outgoing = [] + if expected_outgoing_signals is None: + expected_outgoing_signals = [] if expected_completions is None: expected_completions = [] if expected_state is None: @@ -142,16 +110,6 @@ class PdbClientTestCase(unittest.TestCase): expected_state.setdefault("write_failed", False) messages = [m for source, m in incoming if source == "server"] prompts = [m["prompt"] for source, m in incoming if source == "user"] - sockfile = MockDebuggerSocket(messages) - stdout = io.StringIO() - - if simulate_failure: - sockfile.write = unittest.mock.Mock() - sockfile.flush = unittest.mock.Mock() - if simulate_failure == "write": - sockfile.write.side_effect = OSError("write failed") - elif simulate_failure == "flush": - sockfile.flush.side_effect = OSError("flush failed") input_iter = (m for source, m in incoming if source == "user") completions = [] @@ -178,18 +136,60 @@ class PdbClientTestCase(unittest.TestCase): reply = message["input"] if isinstance(reply, BaseException): raise reply - return reply + if isinstance(reply, str): + return reply + return reply() with ExitStack() as stack: + client_sock, server_sock = socket.socketpair() + stack.enter_context(closing(client_sock)) + stack.enter_context(closing(server_sock)) + + server_sock = unittest.mock.Mock(wraps=server_sock) + + client_sock.sendall( + b"".join( + (m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n" + for m in messages + ) + ) + client_sock.shutdown(socket.SHUT_WR) + + if simulate_send_failure: + server_sock.sendall = unittest.mock.Mock( + side_effect=OSError("sendall failed") + ) + client_sock.shutdown(socket.SHUT_RD) + + stdout = io.StringIO() + + if simulate_sigint_during_stdout_write: + orig_stdout_write = stdout.write + + def sigint_stdout_write(s): + signal.raise_signal(signal.SIGINT) + return orig_stdout_write(s) + + stdout.write = sigint_stdout_write + input_mock = stack.enter_context( unittest.mock.patch("pdb.input", side_effect=mock_input) ) stack.enter_context(redirect_stdout(stdout)) + if use_interrupt_socket: + interrupt_sock = unittest.mock.Mock(spec=socket.socket) + mock_kill = None + else: + interrupt_sock = None + mock_kill = stack.enter_context( + unittest.mock.patch("os.kill", spec=os.kill) + ) + client = _PdbClient( - pid=0, - sockfile=sockfile, - interrupt_script="/a/b.py", + pid=12345, + server_socket=server_sock, + interrupt_sock=interrupt_sock, ) if expected_exception is not None: @@ -199,13 +199,12 @@ class PdbClientTestCase(unittest.TestCase): client.cmdloop() - actual_outgoing = sockfile.outgoing - if simulate_failure: - actual_outgoing += [ - json.loads(msg.args[0]) for msg in sockfile.write.mock_calls - ] + sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls] + for msg in sent_msgs: + assert msg.endswith(b"\n") + actual_outgoing = [json.loads(msg) for msg in sent_msgs] - self.assertEqual(sockfile.outgoing, expected_outgoing) + self.assertEqual(actual_outgoing, expected_outgoing) self.assertEqual(completions, expected_completions) if expected_stdout_substring and not expected_stdout: self.assertIn(expected_stdout_substring, stdout.getvalue()) @@ -215,6 +214,20 @@ class PdbClientTestCase(unittest.TestCase): actual_state = {k: getattr(client, k) for k in expected_state} self.assertEqual(actual_state, expected_state) + if use_interrupt_socket: + outgoing_signals = [ + signal.Signals(int.from_bytes(call.args[0])) + for call in interrupt_sock.sendall.call_args_list + ] + else: + assert mock_kill is not None + outgoing_signals = [] + for call in mock_kill.call_args_list: + pid, signum = call.args + self.assertEqual(pid, 12345) + outgoing_signals.append(signal.Signals(signum)) + self.assertEqual(outgoing_signals, expected_outgoing_signals) + def test_remote_immediately_closing_the_connection(self): """Test the behavior when the remote closes the connection immediately.""" incoming = [] @@ -409,11 +422,17 @@ class PdbClientTestCase(unittest.TestCase): expected_state={"state": "dumb"}, ) - def test_keyboard_interrupt_at_prompt(self): - """Test signaling when a prompt gets a KeyboardInterrupt.""" + def test_sigint_at_prompt(self): + """Test signaling when a prompt gets interrupted.""" incoming = [ ("server", {"prompt": "(Pdb) ", "state": "pdb"}), - ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}), + ( + "user", + { + "prompt": "(Pdb) ", + "input": lambda: signal.raise_signal(signal.SIGINT), + }, + ), ] self.do_test( incoming=incoming, @@ -423,6 +442,43 @@ class PdbClientTestCase(unittest.TestCase): expected_state={"state": "pdb"}, ) + def test_sigint_at_continuation_prompt(self): + """Test signaling when a continuation prompt gets interrupted.""" + incoming = [ + ("server", {"prompt": "(Pdb) ", "state": "pdb"}), + ("user", {"prompt": "(Pdb) ", "input": "if True:"}), + ( + "user", + { + "prompt": "... ", + "input": lambda: signal.raise_signal(signal.SIGINT), + }, + ), + ] + self.do_test( + incoming=incoming, + expected_outgoing=[ + {"signal": "INT"}, + ], + expected_state={"state": "pdb"}, + ) + + def test_sigint_when_writing(self): + """Test siginaling when sys.stdout.write() gets interrupted.""" + incoming = [ + ("server", {"message": "Some message or other\n", "type": "info"}), + ] + for use_interrupt_socket in [False, True]: + with self.subTest(use_interrupt_socket=use_interrupt_socket): + self.do_test( + incoming=incoming, + simulate_sigint_during_stdout_write=True, + use_interrupt_socket=use_interrupt_socket, + expected_outgoing=[], + expected_outgoing_signals=[signal.SIGINT], + expected_stdout="Some message or other\n", + ) + def test_eof_at_prompt(self): """Test signaling when a prompt gets an EOFError.""" incoming = [ @@ -478,20 +534,7 @@ class PdbClientTestCase(unittest.TestCase): self.do_test( incoming=incoming, expected_outgoing=[{"signal": "INT"}], - simulate_failure="write", - expected_state={"write_failed": True}, - ) - - def test_flush_failing(self): - """Test terminating if flush fails due to a half closed socket.""" - incoming = [ - ("server", {"prompt": "(Pdb) ", "state": "pdb"}), - ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}), - ] - self.do_test( - incoming=incoming, - expected_outgoing=[{"signal": "INT"}], - simulate_failure="flush", + simulate_send_failure=True, expected_state={"write_failed": True}, ) @@ -660,42 +703,7 @@ class PdbClientTestCase(unittest.TestCase): }, {"reply": "xyz"}, ], - simulate_failure="write", - expected_completions=[], - expected_state={"state": "interact", "write_failed": True}, - ) - - def test_flush_failure_during_completion(self): - """Test failing to flush to the socket to request tab completions.""" - incoming = [ - ("server", {"prompt": ">>> ", "state": "interact"}), - ( - "user", - { - "prompt": ">>> ", - "completion_request": { - "line": "xy", - "begidx": 0, - "endidx": 2, - }, - "input": "xyz", - }, - ), - ] - self.do_test( - incoming=incoming, - expected_outgoing=[ - { - "complete": { - "text": "xy", - "line": "xy", - "begidx": 0, - "endidx": 2, - } - }, - {"reply": "xyz"}, - ], - simulate_failure="flush", + simulate_send_failure=True, expected_completions=[], expected_state={"state": "interact", "write_failed": True}, ) @@ -1032,6 +1040,8 @@ class PdbConnectTestCase(unittest.TestCase): frame=frame, commands="", version=pdb._PdbServer.protocol_version(), + signal_raising_thread=False, + colorize=False, ) return x # This line won't be reached in debugging @@ -1089,23 +1099,6 @@ class PdbConnectTestCase(unittest.TestCase): client_file.write(json.dumps({"reply": command}).encode() + b"\n") client_file.flush() - def _send_interrupt(self, pid): - """Helper to send an interrupt signal to the debugger.""" - # with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script: - interrupt_script = TESTFN + "_interrupt_script.py" - with open(interrupt_script, 'w') as f: - f.write( - 'import pdb, sys\n' - 'print("Hello, world!")\n' - 'if inst := pdb.Pdb._last_pdb_instance:\n' - ' inst.set_trace(sys._getframe(1))\n' - ) - self.addCleanup(unlink, interrupt_script) - try: - sys.remote_exec(pid, interrupt_script) - except PermissionError: - self.skipTest("Insufficient permissions to execute code in remote process") - def test_connect_and_basic_commands(self): """Test connecting to a remote debugger and sending basic commands.""" self._create_script() @@ -1218,6 +1211,8 @@ class PdbConnectTestCase(unittest.TestCase): frame=frame, commands="", version=pdb._PdbServer.protocol_version(), + signal_raising_thread=True, + colorize=False, ) print("Connected to debugger") iterations = 50 @@ -1233,6 +1228,10 @@ class PdbConnectTestCase(unittest.TestCase): self._create_script(script=script) process, client_file = self._connect_and_get_client_file() + # Accept a 2nd connection from the subprocess to tell it about signals + signal_sock, _ = self.server_sock.accept() + self.addCleanup(signal_sock.close) + with kill_on_error(process): # Skip initial messages until we get to the prompt self._read_until_prompt(client_file) @@ -1248,7 +1247,7 @@ class PdbConnectTestCase(unittest.TestCase): break # Inject a script to interrupt the running process - self._send_interrupt(process.pid) + signal_sock.sendall(signal.SIGINT.to_bytes()) messages = self._read_until_prompt(client_file) # Verify we got the keyboard interrupt message. @@ -1304,6 +1303,8 @@ class PdbConnectTestCase(unittest.TestCase): frame=frame, commands="", version=fake_version, + signal_raising_thread=False, + colorize=False, ) # This should print if the debugger detaches correctly @@ -1431,5 +1432,152 @@ class PdbConnectTestCase(unittest.TestCase): self.assertIn("Function returned: 42", stdout) self.assertEqual(process.returncode, 0) + +def _supports_remote_attaching(): + from contextlib import suppress + PROCESS_VM_READV_SUPPORTED = False + + try: + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + except ImportError: + pass + + return PROCESS_VM_READV_SUPPORTED + + +@unittest.skipIf(not sys.is_remote_debug_enabled(), "Remote debugging is not enabled") +@unittest.skipIf(sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32", + "Test only runs on Linux, Windows and MacOS") +@unittest.skipIf(sys.platform == "linux" and not _supports_remote_attaching(), + "Testing on Linux requires process_vm_readv support") +@cpython_only +@requires_subprocess() +class PdbAttachTestCase(unittest.TestCase): + def setUp(self): + # Create a server socket that will wait for the debugger to connect + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.bind(('127.0.0.1', 0)) # Let OS assign port + self.sock.listen(1) + self.port = self.sock.getsockname()[1] + self._create_script() + + def _create_script(self, script=None): + # Create a file for subprocess script + script = textwrap.dedent( + f""" + import socket + import time + + def foo(): + return bar() + + def bar(): + return baz() + + def baz(): + x = 1 + # Trigger attach + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('127.0.0.1', {self.port})) + sock.close() + count = 0 + while x == 1 and count < 100: + count += 1 + time.sleep(0.1) + return x + + result = foo() + print(f"Function returned: {{result}}") + """ + ) + + self.script_path = TESTFN + "_connect_test.py" + with open(self.script_path, 'w') as f: + f.write(script) + + def tearDown(self): + self.sock.close() + try: + unlink(self.script_path) + except OSError: + pass + + def do_integration_test(self, client_stdin): + process = subprocess.Popen( + [sys.executable, self.script_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + self.addCleanup(process.stdout.close) + self.addCleanup(process.stderr.close) + + # Wait for the process to reach our attachment point + self.sock.settimeout(10) + conn, _ = self.sock.accept() + conn.close() + + client_stdin = io.StringIO(client_stdin) + client_stdout = io.StringIO() + client_stderr = io.StringIO() + + self.addCleanup(client_stdin.close) + self.addCleanup(client_stdout.close) + self.addCleanup(client_stderr.close) + self.addCleanup(process.wait) + + with ( + unittest.mock.patch("sys.stdin", client_stdin), + redirect_stdout(client_stdout), + redirect_stderr(client_stderr), + unittest.mock.patch("sys.argv", ["pdb", "-p", str(process.pid)]), + ): + try: + pdb.main() + except PermissionError: + self.skipTest("Insufficient permissions for remote execution") + + process.wait() + server_stdout = process.stdout.read() + server_stderr = process.stderr.read() + + if process.returncode != 0: + print("server failed") + print(f"server stdout:\n{server_stdout}") + print(f"server stderr:\n{server_stderr}") + + self.assertEqual(process.returncode, 0) + return { + "client": { + "stdout": client_stdout.getvalue(), + "stderr": client_stderr.getvalue(), + }, + "server": { + "stdout": server_stdout, + "stderr": server_stderr, + }, + } + + def test_attach_to_process_without_colors(self): + with force_color(False): + output = self.do_integration_test("ll\nx=42\n") + self.assertEqual(output["client"]["stderr"], "") + self.assertEqual(output["server"]["stderr"], "") + + self.assertEqual(output["server"]["stdout"], "Function returned: 42\n") + self.assertIn("while x == 1", output["client"]["stdout"]) + self.assertNotIn("\x1b", output["client"]["stdout"]) + + def test_attach_to_process_with_colors(self): + with force_color(True): + output = self.do_integration_test("ll\nx=42\n") + self.assertEqual(output["client"]["stderr"], "") + self.assertEqual(output["server"]["stderr"], "") + + self.assertEqual(output["server"]["stdout"], "Function returned: 42\n") + self.assertIn("\x1b", output["client"]["stdout"]) + self.assertNotIn("while x == 1", output["client"]["stdout"]) + self.assertIn("while x == 1", re.sub("\x1b[^m]*m", "", output["client"]["stdout"])) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py index ffeb1fba7b8..ffad35092f9 100644 --- a/Lib/test/test_reprlib.py +++ b/Lib/test/test_reprlib.py @@ -3,6 +3,7 @@ Nick Mathewson """ +import annotationlib import sys import os import shutil @@ -11,7 +12,7 @@ import importlib.util import unittest import textwrap -from test.support import verbose +from test.support import verbose, EqualToForwardRef from test.support.os_helper import create_empty_file from reprlib import repr as r # Don't shadow builtin repr from reprlib import Repr @@ -829,5 +830,19 @@ class TestRecursiveRepr(unittest.TestCase): self.assertEqual(type_params[0].__name__, 'T') self.assertEqual(type_params[0].__bound__, str) + def test_annotations(self): + class My: + @recursive_repr() + def __repr__(self, default: undefined = ...): + return default + + annotations = annotationlib.get_annotations( + My.__repr__, format=annotationlib.Format.FORWARDREF + ) + self.assertEqual( + annotations, + {'default': EqualToForwardRef("undefined", owner=My.__repr__)} + ) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py index f35571ea886..a13ddcb76b7 100644 --- a/Lib/test/test_shlex.py +++ b/Lib/test/test_shlex.py @@ -3,6 +3,7 @@ import itertools import shlex import string import unittest +from test.support import cpython_only from test.support import import_helper @@ -364,6 +365,7 @@ class ShlexTest(unittest.TestCase): with self.assertRaises(AttributeError): shlex_instance.punctuation_chars = False + @cpython_only def test_lazy_imports(self): import_helper.ensure_lazy_imports('shlex', {'collections', 're', 'os'}) diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index ed01163074a..87991fbda4c 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -2153,6 +2153,10 @@ class TestArchives(BaseTest, unittest.TestCase): def test_unpack_archive_bztar(self): self.check_unpack_tarball('bztar') + @support.requires_zstd() + def test_unpack_archive_zstdtar(self): + self.check_unpack_tarball('zstdtar') + @support.requires_lzma() @unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger") def test_unpack_archive_xztar(self): diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index ace97ce0cbe..03c54151a22 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2,8 +2,9 @@ import unittest from unittest import mock from test import support from test.support import ( - is_apple, os_helper, refleak_helper, socket_helper, threading_helper + cpython_only, is_apple, os_helper, refleak_helper, socket_helper, threading_helper ) +from test.support.import_helper import ensure_lazy_imports import _thread as thread import array import contextlib @@ -257,6 +258,12 @@ HAVE_SOCKET_HYPERV = _have_socket_hyperv() # Size in bytes of the int type SIZEOF_INT = array.array("i").itemsize +class TestLazyImport(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("socket", {"array", "selectors"}) + + class SocketTCPTest(unittest.TestCase): def setUp(self): diff --git a/Lib/test/test_sqlite3/test_cli.py b/Lib/test/test_sqlite3/test_cli.py index dcd90d11d46..ad0dcb3cccb 100644 --- a/Lib/test/test_sqlite3/test_cli.py +++ b/Lib/test/test_sqlite3/test_cli.py @@ -4,7 +4,12 @@ import unittest from sqlite3.__main__ import main as cli from test.support.os_helper import TESTFN, unlink -from test.support import captured_stdout, captured_stderr, captured_stdin +from test.support import ( + captured_stdout, + captured_stderr, + captured_stdin, + force_not_colorized, +) class CommandLineInterface(unittest.TestCase): @@ -32,6 +37,7 @@ class CommandLineInterface(unittest.TestCase): self.assertEqual(out, "") return err + @force_not_colorized def test_cli_help(self): out = self.expect_success("-h") self.assertIn("usage: ", out) diff --git a/Lib/test/test_string/test_string.py b/Lib/test/test_string/test_string.py index f6d112d8a93..5394fe4e12c 100644 --- a/Lib/test/test_string/test_string.py +++ b/Lib/test/test_string/test_string.py @@ -2,6 +2,14 @@ import unittest import string from string import Template import types +from test.support import cpython_only +from test.support.import_helper import ensure_lazy_imports + + +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("base64", {"re", "collections"}) class ModuleTest(unittest.TestCase): diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index 3cb755cd56c..ca35804fb36 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -162,6 +162,20 @@ class ProcessTestCase(BaseTestCase): [sys.executable, "-c", "while True: pass"], timeout=0.1) + def test_timeout_exception(self): + try: + subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = -1) + except subprocess.TimeoutExpired as e: + self.assertIn("-1 seconds", str(e)) + else: + self.fail("Expected TimeoutExpired exception not raised") + try: + subprocess.run([sys.executable, '-c', 'import time;time.sleep(9)'], timeout = 0) + except subprocess.TimeoutExpired as e: + self.assertIn("0 seconds", str(e)) + else: + self.fail("Expected TimeoutExpired exception not raised") + def test_check_call_zero(self): # check_call() function with zero return code rc = subprocess.check_call(ZERO_RETURN_CMD) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 468bac82924..8446da03e36 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -561,6 +561,7 @@ class TestSupport(unittest.TestCase): ['-Wignore', '-X', 'dev'], ['-X', 'faulthandler'], ['-X', 'importtime'], + ['-X', 'importtime=2'], ['-X', 'showrefcount'], ['-X', 'tracemalloc'], ['-X', 'tracemalloc=3'], diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index 10c3e0e9a1d..59ef5c99309 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -1960,7 +1960,7 @@ def _supports_remote_attaching(): PROCESS_VM_READV_SUPPORTED = False try: - from _remotedebuggingmodule import PROCESS_VM_READV_SUPPORTED + from _remote_debugging import PROCESS_VM_READV_SUPPORTED except ImportError: pass @@ -2101,7 +2101,7 @@ print("Remote script executed successfully!") prologue = '''\ import sys def audit_hook(event, arg): - print(f"Audit event: {event}, arg: {arg}") + print(f"Audit event: {event}, arg: {arg}".encode("ascii", errors="replace")) sys.addaudithook(audit_hook) ''' script = ''' @@ -2196,6 +2196,64 @@ this is invalid python code self.assertIn(b"Remote debugging is not enabled", err) self.assertEqual(out, b"") +class TestSysJIT(unittest.TestCase): + + def test_jit_is_available(self): + available = sys._jit.is_available() + script = f"import sys; assert sys._jit.is_available() is {available}" + assert_python_ok("-c", script, PYTHON_JIT="0") + assert_python_ok("-c", script, PYTHON_JIT="1") + + def test_jit_is_enabled(self): + available = sys._jit.is_available() + script = "import sys; assert sys._jit.is_enabled() is {enabled}" + assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0") + assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1") + + def test_jit_is_active(self): + available = sys._jit.is_available() + script = textwrap.dedent( + """ + import _testcapi + import _testinternalcapi + import sys + + def frame_0_interpreter() -> None: + assert sys._jit.is_active() is False + + def frame_1_interpreter() -> None: + assert sys._jit.is_active() is False + frame_0_interpreter() + assert sys._jit.is_active() is False + + def frame_2_jit(expected: bool) -> None: + # Inlined into the last loop of frame_3_jit: + assert sys._jit.is_active() is expected + # Insert C frame: + _testcapi.pyobject_vectorcall(frame_1_interpreter, None, None) + assert sys._jit.is_active() is expected + + def frame_3_jit() -> None: + # JITs just before the last loop: + for i in range(_testinternalcapi.TIER2_THRESHOLD + 1): + # Careful, doing this in the reverse order breaks tracing: + expected = {enabled} and i == _testinternalcapi.TIER2_THRESHOLD + assert sys._jit.is_active() is expected + frame_2_jit(expected) + assert sys._jit.is_active() is expected + + def frame_4_interpreter() -> None: + assert sys._jit.is_active() is False + frame_3_jit() + assert sys._jit.is_active() is False + + assert sys._jit.is_active() is False + frame_4_interpreter() + assert sys._jit.is_active() is False + """ + ) + assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0") + assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1") if __name__ == "__main__": diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index fcbaf854cc2..2d9649237a9 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -38,6 +38,10 @@ try: import lzma except ImportError: lzma = None +try: + from compression import zstd +except ImportError: + zstd = None def sha256sum(data): return sha256(data).hexdigest() @@ -48,6 +52,7 @@ tarname = support.findfile("testtar.tar", subdir="archivetestdata") gzipname = os.path.join(TEMPDIR, "testtar.tar.gz") bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2") xzname = os.path.join(TEMPDIR, "testtar.tar.xz") +zstname = os.path.join(TEMPDIR, "testtar.tar.zst") tmpname = os.path.join(TEMPDIR, "tmp.tar") dotlessname = os.path.join(TEMPDIR, "testtar") @@ -90,6 +95,12 @@ class LzmaTest: open = lzma.LZMAFile if lzma else None taropen = tarfile.TarFile.xzopen +@support.requires_zstd() +class ZstdTest: + tarname = zstname + suffix = 'zst' + open = zstd.ZstdFile if zstd else None + taropen = tarfile.TarFile.zstopen class ReadTest(TarTest): @@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest): class LzmaUstarReadTest(LzmaTest, UstarReadTest): pass +class ZstdUstarReadTest(ZstdTest, UstarReadTest): + pass class ListTest(ReadTest, unittest.TestCase): @@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest): class LzmaListTest(LzmaTest, ListTest): pass +class ZstdListTest(ZstdTest, ListTest): + pass class CommonReadTest(ReadTest): @@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase): class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase): pass +class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase): + pass class StreamReadTest(CommonReadTest, unittest.TestCase): @@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest): class LzmaStreamReadTest(LzmaTest, StreamReadTest): pass +class ZstdStreamReadTest(ZstdTest, StreamReadTest): + pass + class TarStreamModeReadTest(StreamModeTest, unittest.TestCase): def test_stream_mode_no_cache(self): @@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest): class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest): pass +class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest): + pass + class DetectReadTest(TarTest, unittest.TestCase): def _testfunc_file(self, name, mode): try: @@ -986,6 +1009,8 @@ class Bz2DetectReadTest(Bz2Test, DetectReadTest): class LzmaDetectReadTest(LzmaTest, DetectReadTest): pass +class ZstdDetectReadTest(ZstdTest, DetectReadTest): + pass class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase): """ @@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest): class LzmaWriteTest(LzmaTest, WriteTest): pass +class ZstdWriteTest(ZstdTest, WriteTest): + pass class StreamWriteTest(WriteTestBase, unittest.TestCase): @@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest): class LzmaStreamWriteTest(LzmaTest, StreamWriteTest): decompressor = lzma.LZMADecompressor if lzma else None +class ZstdStreamWriteTest(ZstdTest, StreamWriteTest): + decompressor = zstd.ZstdDecompressor if zstd else None + class _CompressedWriteTest(TarTest): # This is not actually a standalone test. # It does not inherit WriteTest because it only makes sense with gz,bz2 @@ -2042,6 +2072,14 @@ class LzmaCreateTest(LzmaTest, CreateTest): tobj.add(self.file_path) +class ZstdCreateTest(ZstdTest, CreateTest): + + # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel. + # It does not allow for level to be specified when reading. + def test_create_with_level(self): + with tarfile.open(tmpname, self.mode, level=1) as tobj: + tobj.add(self.file_path) + class CreateWithXModeTest(CreateTest): prefix = "x" @@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase): class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase): pass +class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase): + pass class LimitsTest(unittest.TestCase): @@ -2835,7 +2875,7 @@ class CommandLineTest(unittest.TestCase): support.findfile('tokenize_tests-no-coding-cookie-' 'and-utf8-bom-sig-only.txt', subdir='tokenizedata')] - for filetype in (GzipTest, Bz2Test, LzmaTest): + for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest): if not filetype.open: continue try: @@ -4257,7 +4297,7 @@ def setUpModule(): data = fobj.read() # Create compressed tarfiles. - for c in GzipTest, Bz2Test, LzmaTest: + for c in GzipTest, Bz2Test, LzmaTest, ZstdTest: if c.open: os_helper.unlink(c.tarname) testtarnames.append(c.tarname) diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 814c00ca0fd..4ab38c2598b 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -5,7 +5,7 @@ Tests for the threading module. import test.support from test.support import threading_helper, requires_subprocess, requires_gil_enabled from test.support import verbose, cpython_only, os_helper -from test.support.import_helper import import_module +from test.support.import_helper import ensure_lazy_imports, import_module from test.support.script_helper import assert_python_ok, assert_python_failure from test.support import force_not_colorized @@ -121,6 +121,10 @@ class ThreadTests(BaseTestCase): maxDiff = 9999 @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("threading", {"functools", "warnings"}) + + @cpython_only def test_name(self): def func(): pass diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index 683486e9aca..b9be87f357f 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -37,6 +37,12 @@ test_code.co_positions = lambda _: iter([(6, 6, 0, 0)]) test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals']) test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next', 'tb_lasti']) +color_overrides = {"reset": "z", "filename": "fn", "error_highlight": "E"} +colors = { + color_overrides.get(k, k[0].lower()): v + for k, v in _colorize.default_theme.traceback.items() +} + LEVENSHTEIN_DATA_FILE = Path(__file__).parent / 'levenshtein_examples.json' @@ -4721,6 +4727,8 @@ class MiscTest(unittest.TestCase): class TestColorizedTraceback(unittest.TestCase): + maxDiff = None + def test_colorized_traceback(self): def foo(*args): x = {'a':{'b': None}} @@ -4743,9 +4751,9 @@ class TestColorizedTraceback(unittest.TestCase): e, capture_locals=True ) lines = "".join(exc.format(colorize=True)) - red = _colorize.ANSIColors.RED - boldr = _colorize.ANSIColors.BOLD_RED - reset = _colorize.ANSIColors.RESET + red = colors["e"] + boldr = colors["E"] + reset = colors["z"] self.assertIn("y = " + red + "x['a']['b']" + reset + boldr + "['c']" + reset, lines) self.assertIn("return " + red + "(lambda *args: foo(*args))" + reset + boldr + "(1,2,3,4)" + reset, lines) self.assertIn("return (lambda *args: " + red + "foo" + reset + boldr + "(*args)" + reset + ")(1,2,3,4)", lines) @@ -4761,18 +4769,16 @@ class TestColorizedTraceback(unittest.TestCase): e, capture_locals=True ) actual = "".join(exc.format(colorize=True)) - red = _colorize.ANSIColors.RED - magenta = _colorize.ANSIColors.MAGENTA - boldm = _colorize.ANSIColors.BOLD_MAGENTA - boldr = _colorize.ANSIColors.BOLD_RED - reset = _colorize.ANSIColors.RESET - expected = "".join([ - f' File {magenta}"<string>"{reset}, line {magenta}1{reset}\n', - f' a {boldr}${reset} b\n', - f' {boldr}^{reset}\n', - f'{boldm}SyntaxError{reset}: {magenta}invalid syntax{reset}\n'] - ) - self.assertIn(expected, actual) + def expected(t, m, fn, l, f, E, e, z): + return "".join( + [ + f' File {fn}"<string>"{z}, line {l}1{z}\n', + f' a {E}${z} b\n', + f' {E}^{z}\n', + f'{t}SyntaxError{z}: {m}invalid syntax{z}\n' + ] + ) + self.assertIn(expected(**colors), actual) def test_colorized_traceback_is_the_default(self): def foo(): @@ -4788,23 +4794,21 @@ class TestColorizedTraceback(unittest.TestCase): exception_print(e) actual = tbstderr.getvalue().splitlines() - red = _colorize.ANSIColors.RED - boldr = _colorize.ANSIColors.BOLD_RED - magenta = _colorize.ANSIColors.MAGENTA - boldm = _colorize.ANSIColors.BOLD_MAGENTA - reset = _colorize.ANSIColors.RESET lno_foo = foo.__code__.co_firstlineno - expected = ['Traceback (most recent call last):', - f' File {magenta}"{__file__}"{reset}, ' - f'line {magenta}{lno_foo+5}{reset}, in {magenta}test_colorized_traceback_is_the_default{reset}', - f' {red}foo{reset+boldr}(){reset}', - f' {red}~~~{reset+boldr}^^{reset}', - f' File {magenta}"{__file__}"{reset}, ' - f'line {magenta}{lno_foo+1}{reset}, in {magenta}foo{reset}', - f' {red}1{reset+boldr}/{reset+red}0{reset}', - f' {red}~{reset+boldr}^{reset+red}~{reset}', - f'{boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}'] - self.assertEqual(actual, expected) + def expected(t, m, fn, l, f, E, e, z): + return [ + 'Traceback (most recent call last):', + f' File {fn}"{__file__}"{z}, ' + f'line {l}{lno_foo+5}{z}, in {f}test_colorized_traceback_is_the_default{z}', + f' {e}foo{z}{E}(){z}', + f' {e}~~~{z}{E}^^{z}', + f' File {fn}"{__file__}"{z}, ' + f'line {l}{lno_foo+1}{z}, in {f}foo{z}', + f' {e}1{z}{E}/{z}{e}0{z}', + f' {e}~{z}{E}^{z}{e}~{z}', + f'{t}ZeroDivisionError{z}: {m}division by zero{z}', + ] + self.assertEqual(actual, expected(**colors)) def test_colorized_traceback_from_exception_group(self): def foo(): @@ -4822,33 +4826,31 @@ class TestColorizedTraceback(unittest.TestCase): e, capture_locals=True ) - red = _colorize.ANSIColors.RED - boldr = _colorize.ANSIColors.BOLD_RED - magenta = _colorize.ANSIColors.MAGENTA - boldm = _colorize.ANSIColors.BOLD_MAGENTA - reset = _colorize.ANSIColors.RESET lno_foo = foo.__code__.co_firstlineno actual = "".join(exc.format(colorize=True)).splitlines() - expected = [f" + Exception Group Traceback (most recent call last):", - f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+9}{reset}, in {magenta}test_colorized_traceback_from_exception_group{reset}', - f' | {red}foo{reset}{boldr}(){reset}', - f' | {red}~~~{reset}{boldr}^^{reset}', - f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])", - f" | foo = {foo}", - f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>', - f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+6}{reset}, in {magenta}foo{reset}', - f' | raise ExceptionGroup("test", exceptions)', - f" | exceptions = [ZeroDivisionError('division by zero')]", - f' | {boldm}ExceptionGroup{reset}: {magenta}test (1 sub-exception){reset}', - f' +-+---------------- 1 ----------------', - f' | Traceback (most recent call last):', - f' | File {magenta}"{__file__}"{reset}, line {magenta}{lno_foo+3}{reset}, in {magenta}foo{reset}', - f' | {red}1 {reset}{boldr}/{reset}{red} 0{reset}', - f' | {red}~~{reset}{boldr}^{reset}{red}~~{reset}', - f" | exceptions = [ZeroDivisionError('division by zero')]", - f' | {boldm}ZeroDivisionError{reset}: {magenta}division by zero{reset}', - f' +------------------------------------'] - self.assertEqual(actual, expected) + def expected(t, m, fn, l, f, E, e, z): + return [ + f" + Exception Group Traceback (most recent call last):", + f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+9}{z}, in {f}test_colorized_traceback_from_exception_group{z}', + f' | {e}foo{z}{E}(){z}', + f' | {e}~~~{z}{E}^^{z}', + f" | e = ExceptionGroup('test', [ZeroDivisionError('division by zero')])", + f" | foo = {foo}", + f' | self = <{__name__}.TestColorizedTraceback testMethod=test_colorized_traceback_from_exception_group>', + f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+6}{z}, in {f}foo{z}', + f' | raise ExceptionGroup("test", exceptions)', + f" | exceptions = [ZeroDivisionError('division by zero')]", + f' | {t}ExceptionGroup{z}: {m}test (1 sub-exception){z}', + f' +-+---------------- 1 ----------------', + f' | Traceback (most recent call last):', + f' | File {fn}"{__file__}"{z}, line {l}{lno_foo+3}{z}, in {f}foo{z}', + f' | {e}1 {z}{E}/{z}{e} 0{z}', + f' | {e}~~{z}{E}^{z}{e}~~{z}', + f" | exceptions = [ZeroDivisionError('division by zero')]", + f' | {t}ZeroDivisionError{z}: {m}division by zero{z}', + f' +------------------------------------', + ] + self.assertEqual(actual, expected(**colors)) if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index 90de828cc71..c965860fbb1 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -1551,7 +1551,8 @@ class Pathname_Tests(unittest.TestCase): urllib.request.url2pathname(url, require_scheme=True), expected_path) - error_subtests = [ + def test_url2pathname_require_scheme_errors(self): + subtests = [ '', ':', 'foo', @@ -1561,13 +1562,20 @@ class Pathname_Tests(unittest.TestCase): 'data:file:foo', 'data:file://foo', ] - for url in error_subtests: + for url in subtests: with self.subTest(url=url): self.assertRaises( urllib.error.URLError, urllib.request.url2pathname, url, require_scheme=True) + def test_url2pathname_resolve_host(self): + fn = urllib.request.url2pathname + sep = os.path.sep + self.assertEqual(fn('//127.0.0.1/foo/bar', resolve_host=True), f'{sep}foo{sep}bar') + self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), f'{sep}foo{sep}bar') + self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar', resolve_host=True), f'{sep}foo{sep}bar') + @unittest.skipUnless(sys.platform == 'win32', 'test specific to Windows pathnames.') def test_url2pathname_win(self): @@ -1598,6 +1606,7 @@ class Pathname_Tests(unittest.TestCase): self.assertEqual(fn('//server/path/to/file'), '\\\\server\\path\\to\\file') self.assertEqual(fn('////server/path/to/file'), '\\\\server\\path\\to\\file') self.assertEqual(fn('/////server/path/to/file'), '\\\\server\\path\\to\\file') + self.assertEqual(fn('//127.0.0.1/path/to/file'), '\\\\127.0.0.1\\path\\to\\file') # Localhost paths self.assertEqual(fn('//localhost/C:/path/to/file'), 'C:\\path\\to\\file') self.assertEqual(fn('//localhost/C|/path/to/file'), 'C:\\path\\to\\file') @@ -1622,8 +1631,7 @@ class Pathname_Tests(unittest.TestCase): self.assertRaises(urllib.error.URLError, fn, '//:80/foo/bar') self.assertRaises(urllib.error.URLError, fn, '//:/foo/bar') self.assertRaises(urllib.error.URLError, fn, '//c:80/foo/bar') - self.assertEqual(fn('//127.0.0.1/foo/bar'), '/foo/bar') - self.assertEqual(fn(f'//{socket.gethostname()}/foo/bar'), '/foo/bar') + self.assertRaises(urllib.error.URLError, fn, '//127.0.0.1/foo/bar') @unittest.skipUnless(os_helper.FS_NONASCII, 'need os_helper.FS_NONASCII') def test_url2pathname_nonascii(self): diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py index 7c8a82d821a..ae898150658 100644 --- a/Lib/test/test_zipfile/test_core.py +++ b/Lib/test/test_zipfile/test_core.py @@ -23,11 +23,13 @@ from test import archiver_tests from test.support import script_helper, os_helper from test.support import ( findfile, requires_zlib, requires_bz2, requires_lzma, - captured_stdout, captured_stderr, requires_subprocess, + requires_zstd, captured_stdout, captured_stderr, requires_subprocess, + cpython_only ) from test.support.os_helper import ( TESTFN, unlink, rmtree, temp_dir, temp_cwd, fd_count, FakePath ) +from test.support.import_helper import ensure_lazy_imports TESTFN2 = TESTFN + "2" @@ -49,6 +51,13 @@ def get_files(test): yield f test.assertFalse(f.closed) + +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("zipfile", {"typing"}) + + class AbstractTestsWithSourceFile: @classmethod def setUpClass(cls): @@ -693,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractTestZip64InSmallFiles: # These tests test the ZIP64 functionality without using large files, @@ -1270,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractWriterTests: @@ -1339,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase): class LzmaWriterTests(AbstractWriterTests, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdWriterTests(AbstractWriterTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class PyZipFileTests(unittest.TestCase): def assertCompiledIn(self, name, namelist): @@ -2669,6 +2689,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase): b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' b'\x00>\x00\x00\x00\x00\x00') +@requires_zstd() +class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD + zip_with_bad_crc = ( + b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00' + b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00' + b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00' + b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK' + b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00' + b'\x00\x00') class DecryptionTests(unittest.TestCase): """Check that ZIP decryption works. Since the library does not @@ -2896,6 +2927,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD # Provide the tell() method but not seek() class Tellable: diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py new file mode 100644 index 00000000000..f4a25376e52 --- /dev/null +++ b/Lib/test/test_zstd.py @@ -0,0 +1,2507 @@ +import array +import gc +import io +import pathlib +import random +import re +import os +import unittest +import tempfile +import threading + +from test.support.import_helper import import_module +from test.support import threading_helper +from test.support import _1M +from test.support import Py_GIL_DISABLED + +_zstd = import_module("_zstd") +zstd = import_module("compression.zstd") + +from compression.zstd import ( + open, + compress, + decompress, + ZstdCompressor, + ZstdDecompressor, + ZstdDict, + ZstdError, + zstd_version, + zstd_version_info, + COMPRESSION_LEVEL_DEFAULT, + get_frame_info, + get_frame_size, + finalize_dict, + train_dict, + CompressionParameter, + DecompressionParameter, + Strategy, + ZstdFile, +) + +_1K = 1024 +_130_1K = 130 * _1K +DICT_SIZE1 = 3*_1K + +DAT_130K_D = None +DAT_130K_C = None + +DECOMPRESSED_DAT = None +COMPRESSED_DAT = None + +DECOMPRESSED_100_PLUS_32KB = None +COMPRESSED_100_PLUS_32KB = None + +SKIPPABLE_FRAME = None + +THIS_FILE_BYTES = None +THIS_FILE_STR = None +COMPRESSED_THIS_FILE = None + +COMPRESSED_BOGUS = None + +SAMPLES = None + +TRAINED_DICT = None + +SUPPORT_MULTITHREADING = False + +def setUpModule(): + global SUPPORT_MULTITHREADING + SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0) + # uncompressed size 130KB, more than a zstd block. + # with a frame epilogue, 4 bytes checksum. + global DAT_130K_D + DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) + + global DAT_130K_C + DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1}) + + global DECOMPRESSED_DAT + DECOMPRESSED_DAT = b'abcdefg123456' * 1000 + + global COMPRESSED_DAT + COMPRESSED_DAT = compress(DECOMPRESSED_DAT) + + global DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K) + + global COMPRESSED_100_PLUS_32KB + COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) + + global SKIPPABLE_FRAME + SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ + (32*_1K).to_bytes(4, byteorder='little') + \ + b'a' * (32*_1K) + + global THIS_FILE_BYTES, THIS_FILE_STR + with io.open(os.path.abspath(__file__), 'rb') as f: + THIS_FILE_BYTES = f.read() + THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) + THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') + + global COMPRESSED_THIS_FILE + COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES) + + global COMPRESSED_BOGUS + COMPRESSED_BOGUS = DECOMPRESSED_DAT + + # dict data + words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue', + b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive', + b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird'] + lst = [] + for i in range(300): + sample = [b'%s = %d' % (random.choice(words), random.randrange(100)) + for j in range(20)] + sample = b'\n'.join(sample) + + lst.append(sample) + global SAMPLES + SAMPLES = lst + assert len(SAMPLES) > 10 + + global TRAINED_DICT + TRAINED_DICT = train_dict(SAMPLES, 3*_1K) + assert len(TRAINED_DICT.dict_content) <= 3*_1K + + +class FunctionsTestCase(unittest.TestCase): + + def test_version(self): + s = ".".join((str(i) for i in zstd_version_info)) + self.assertEqual(s, zstd_version) + + def test_compressionLevel_values(self): + min, max = CompressionParameter.compression_level.bounds() + self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) + self.assertIs(type(min), int) + self.assertIs(type(max), int) + self.assertLess(min, max) + + def test_roundtrip_default(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + dat1 = compress(raw_dat) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_roundtrip_level(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + level_min, level_max = CompressionParameter.compression_level.bounds() + + for level in range(max(-20, level_min), level_max + 1): + dat1 = compress(raw_dat, level) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_get_frame_info(self): + # no dict + info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) + self.assertEqual(info.decompressed_size, 32 * _1K + 100) + self.assertEqual(info.dictionary_id, 0) + + # use dict + dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT) + info = get_frame_info(dat) + self.assertEqual(info.decompressed_size, 345) + self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id) + + with self.assertRaisesRegex(ZstdError, "not less than the frame header"): + get_frame_info(b"aaaaaaaaaaaaaa") + + def test_get_frame_size(self): + size = get_frame_size(COMPRESSED_100_PLUS_32KB) + self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB)) + + with self.assertRaisesRegex(ZstdError, "not less than this complete frame"): + get_frame_size(b"aaaaaaaaaaaaaa") + + def test_decompress_2x130_1K(self): + decompressed_size = get_frame_info(DAT_130K_C).decompressed_size + self.assertEqual(decompressed_size, _130_1K) + + dat = decompress(DAT_130K_C + DAT_130K_C) + self.assertEqual(len(dat), 2 * _130_1K) + + +class CompressorTestCase(unittest.TestCase): + + def test_simple_compress_bad_args(self): + # ZstdCompressor + self.assertRaises(TypeError, ZstdCompressor, []) + self.assertRaises(TypeError, ZstdCompressor, level=3.14) + self.assertRaises(TypeError, ZstdCompressor, level="abc") + self.assertRaises(TypeError, ZstdCompressor, options=b"abc") + + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234") + self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4}) + + with self.assertRaises(ValueError): + ZstdCompressor(2**31) + with self.assertRaises(ValueError): + ZstdCompressor(options={2**31: 100}) + + with self.assertRaises(ZstdError): + ZstdCompressor(options={CompressionParameter.window_log: 100}) + with self.assertRaises(ZstdError): + ZstdCompressor(options={3333: 100}) + + # Method bad arguments + zc = ZstdCompressor() + self.assertRaises(TypeError, zc.compress) + self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar") + self.assertRaises(TypeError, zc.compress, "str") + self.assertRaises((TypeError, ValueError), zc.flush, b"foo") + self.assertRaises(TypeError, zc.flush, b"blah", 1) + + self.assertRaises(ValueError, zc.compress, b'', -1) + self.assertRaises(ValueError, zc.compress, b'', 3) + self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0 + self.assertRaises(ValueError, zc.flush, 3) + + zc.compress(b'') + zc.compress(b'', zc.CONTINUE) + zc.compress(b'', zc.FLUSH_BLOCK) + zc.compress(b'', zc.FLUSH_FRAME) + empty = zc.flush() + zc.flush(zc.FLUSH_BLOCK) + zc.flush(zc.FLUSH_FRAME) + + def test_compress_parameters(self): + d = {CompressionParameter.compression_level : 10, + + CompressionParameter.window_log : 12, + CompressionParameter.hash_log : 10, + CompressionParameter.chain_log : 12, + CompressionParameter.search_log : 12, + CompressionParameter.min_match : 4, + CompressionParameter.target_length : 12, + CompressionParameter.strategy : Strategy.lazy, + + CompressionParameter.enable_long_distance_matching : 1, + CompressionParameter.ldm_hash_log : 12, + CompressionParameter.ldm_min_match : 11, + CompressionParameter.ldm_bucket_size_log : 5, + CompressionParameter.ldm_hash_rate_log : 12, + + CompressionParameter.content_size_flag : 1, + CompressionParameter.checksum_flag : 1, + CompressionParameter.dict_id_flag : 0, + + CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, + CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, + CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, + } + ZstdCompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[CompressionParameter.ldm_bucket_size_log] = 2**31 + self.assertRaises(ValueError, ZstdCompressor, options=d1) + + # clamp compressionLevel + level_min, level_max = CompressionParameter.compression_level.bounds() + compress(b'', level_max+1) + compress(b'', level_min-1) + + compress(b'', options={CompressionParameter.compression_level:level_max+1}) + compress(b'', options={CompressionParameter.compression_level:level_min-1}) + + # zstd lib doesn't support MT compression + if not SUPPORT_MULTITHREADING: + with self.assertRaises(ZstdError): + ZstdCompressor(options={CompressionParameter.nb_workers:4}) + with self.assertRaises(ZstdError): + ZstdCompressor(options={CompressionParameter.job_size:4}) + with self.assertRaises(ZstdError): + ZstdCompressor(options={CompressionParameter.overlap_log:4}) + + # out of bounds error msg + option = {CompressionParameter.window_log:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd compression parameter "window_log", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + compress(b'', options=option) + + def test_unknown_compression_parameter(self): + KEY = 100001234 + option = {CompressionParameter.compression_level: 10, + KEY: 200000000} + pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdCompressor(options=option) + + @unittest.skipIf(not SUPPORT_MULTITHREADING, + "zstd build doesn't support multi-threaded compression") + def test_zstd_multithread_compress(self): + size = 40*_1M + b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) + + options = {CompressionParameter.compression_level : 4, + CompressionParameter.nb_workers : 2} + + # compress() + dat1 = compress(b, options=options) + dat2 = decompress(dat1) + self.assertEqual(dat2, b) + + # ZstdCompressor + c = ZstdCompressor(options=options) + dat1 = c.compress(b, c.CONTINUE) + dat2 = c.compress(b, c.FLUSH_BLOCK) + dat3 = c.compress(b, c.FLUSH_FRAME) + dat4 = decompress(dat1+dat2+dat3) + self.assertEqual(dat4, b * 3) + + # ZstdFile + with ZstdFile(io.BytesIO(), 'w', options=options) as f: + f.write(b) + + def test_compress_flushblock(self): + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) + self.assertEqual(c.last_mode, c.FLUSH_BLOCK) + dat2 = c.flush() + pattern = "Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(dat1) + + dat3 = decompress(dat1 + dat2) + + self.assertEqual(dat3, THIS_FILE_BYTES) + + def test_compress_flushframe(self): + # test compress & decompress + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat1) + self.assertEqual(nt.decompressed_size, None) # no content size + + dat2 = decompress(dat1) + + self.assertEqual(dat2, THIS_FILE_BYTES) + + # single .FLUSH_FRAME mode has content size + c = ZstdCompressor() + dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat) + self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES)) + + def test_compress_empty(self): + # output empty content frame + self.assertNotEqual(compress(b''), b'') + + c = ZstdCompressor() + self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'') + +class DecompressorTestCase(unittest.TestCase): + + def test_simple_decompress_bad_args(self): + # ZstdDecompressor + self.assertRaises(TypeError, ZstdDecompressor, ()) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc') + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4}) + + self.assertRaises(TypeError, ZstdDecompressor, options=123) + self.assertRaises(TypeError, ZstdDecompressor, options='abc') + self.assertRaises(TypeError, ZstdDecompressor, options=b'abc') + + with self.assertRaises(ValueError): + ZstdDecompressor(options={2**31 : 100}) + + with self.assertRaises(ZstdError): + ZstdDecompressor(options={DecompressionParameter.window_log_max:100}) + with self.assertRaises(ZstdError): + ZstdDecompressor(options={3333 : 100}) + + empty = compress(b'') + lzd = ZstdDecompressor() + self.assertRaises(TypeError, lzd.decompress) + self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar") + self.assertRaises(TypeError, lzd.decompress, "str") + lzd.decompress(empty) + + def test_decompress_parameters(self): + d = {DecompressionParameter.window_log_max : 15} + ZstdDecompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[DecompressionParameter.window_log_max] = 2**31 + self.assertRaises(ValueError, ZstdDecompressor, None, d1) + + # out of bounds error msg + options = {DecompressionParameter.window_log_max:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd decompression parameter "window_log_max", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + decompress(b'', options=options) + + def test_unknown_decompression_parameter(self): + KEY = 100001234 + options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1], + KEY: 200000000} + pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdDecompressor(options=options) + + def test_decompress_epilogue_flags(self): + # DAT_130K_C has a 4 bytes checksum at frame epilogue + + # full unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'') + + # full limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'', 0) + + # [:-4] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-4] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-3] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-3] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-1] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-1] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + def test_decompressor_arg(self): + zd = ZstdDict(b'12345678', True) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(zstd_dict={}) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(options=zd) + + ZstdDecompressor() + ZstdDecompressor(zd, {}) + ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25}) + + def test_decompressor_1(self): + # empty + d = ZstdDecompressor() + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + + # 130_1K full + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K full, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K, without 4 bytes checksum + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + + # above, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + + # full, unused_data + TRAIL = b'89234893abcd' + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C + TRAIL, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, TRAIL) + + def test_decompressor_chunks_read_300(self): + TRAIL = b'89234893abcd' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(300) + if not dat: + break + else: + raise Exception('should not get here') + + ret = d.decompress(dat) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + def test_decompressor_chunks_read_3(self): + TRAIL = b'89234893' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(3) + if not dat: + break + else: + dat = b'' + + ret = d.decompress(dat, 1) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + + def test_decompress_empty(self): + with self.assertRaises(ZstdError): + decompress(b'') + + d = ZstdDecompressor() + self.assertEqual(d.decompress(b''), b'') + self.assertFalse(d.eof) + + def test_decompress_empty_content_frame(self): + DAT = compress(b'') + # decompress + self.assertGreaterEqual(len(DAT), 4) + self.assertEqual(decompress(DAT), b'') + + with self.assertRaises(ZstdError): + decompress(DAT[:-1]) + + # ZstdDecompressor + d = ZstdDecompressor() + dat = d.decompress(DAT) + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + d = ZstdDecompressor() + dat = d.decompress(DAT[:-1]) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + +class DecompressorFlagsTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + options = {CompressionParameter.checksum_flag:1} + c = ZstdCompressor(options=options) + + cls.DECOMPRESSED_42 = b'a'*42 + cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) + + cls.DECOMPRESSED_60 = b'a'*60 + cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME) + + cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 + cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 + + cls._130_1K = 130*_1K + + c = ZstdCompressor() + cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() + cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush() + cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60 + + cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' + + def test_function_decompress(self): + + self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K) + + # 1 frame + self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42) + + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:1]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-1]) + + # 2 frames + self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60), + self.DECOMPRESSED_42_60) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42_60[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.UNKNOWN_FRAME_42_60[:-1]) + + # 130_1K + self.assertEqual(decompress(DAT_130K_C), DAT_130K_D) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.FRAME_42 + b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa') + + # doesn't match checksum + checksum = DAT_130K_C[-4:] + if checksum[0] == 255: + wrong_checksum = bytes([254]) + checksum[1:] + else: + wrong_checksum = bytes([checksum[0]+1]) + checksum[1:] + + dat = DAT_130K_C[:-4] + wrong_checksum + + with self.assertRaisesRegex(ZstdError, "doesn't match checksum"): + decompress(dat) + + def test_function_skippable(self): + self.assertEqual(decompress(SKIPPABLE_FRAME), b'') + self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'') + + # 1 frame + 2 skippable + self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)), + self._130_1K) + + self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)), + self._130_1K) + + self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)), + self._130_1K) + + # unknown size + self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_60) + + # 2 frames + 1 skippable + self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_42_60) + + # incomplete + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:1]) + + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:-1]) + + with self.assertRaises(ZstdError): + decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa') + + def test_decompressor_1(self): + # empty 1 + d = ZstdDecompressor() + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # empty 2 + d = ZstdDecompressor() + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # 1 frame + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, trail + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42 + self.TRAIL) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # 1 frame, 32_1K + temp = compress(b'a'*(32*_1K)) + d = ZstdDecompressor() + dat = d.decompress(temp, 32*_1K) + + self.assertEqual(dat, b'a'*(32*_1K)) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, 32_1K+100, trail + d = ZstdDecompressor() + dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes + + self.assertEqual(len(dat), 100) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + + dat = d.decompress(b'') # 32_1K + + self.assertEqual(len(dat), 32*_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # incomplete 1 + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_60[:1]) + + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 2 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 3 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-1]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + + # incomplete 4 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4], 60) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # Unknown frame descriptor + d = ZstdDecompressor() + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + d.decompress(b'aaaaaaaaa') + + def test_decompressor_skippable(self): + # 1 skippable + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, max_length=0 + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME, 0) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, trail + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1]) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1], 0) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + + +class ZstdDictTestCase(unittest.TestCase): + + def test_is_raw(self): + # content < 8 + b = b'1234567' + with self.assertRaises(ValueError): + ZstdDict(b) + + # content == 8 + b = b'12345678' + zd = ZstdDict(b, is_raw=True) + self.assertEqual(zd.dict_id, 0) + + temp = compress(b'aaa12345678', level=3, zstd_dict=zd) + self.assertEqual(b'aaa12345678', decompress(temp, zd)) + + # is_raw == False + b = b'12345678abcd' + with self.assertRaises(ValueError): + ZstdDict(b) + + # read only attributes + with self.assertRaises(AttributeError): + zd.dict_content = b + + with self.assertRaises(AttributeError): + zd.dict_id = 10000 + + # ZstdDict arguments + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False) + self.assertNotEqual(zd.dict_id, 0) + + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True) + self.assertNotEqual(zd.dict_id, 0) # note this assertion + + with self.assertRaises(TypeError): + ZstdDict("12345678abcdef", is_raw=True) + with self.assertRaises(TypeError): + ZstdDict(TRAINED_DICT) + + # invalid parameter + with self.assertRaises(TypeError): + ZstdDict(desk333=345) + + def test_invalid_dict(self): + DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little') + dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz' + + # corrupted + zd = ZstdDict(dict_content, is_raw=False) + with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?corrupted'): + ZstdCompressor(zstd_dict=zd.as_digested_dict) + with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?corrupted'): + ZstdDecompressor(zd) + + # wrong type + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 3)) + + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 3)) + + def test_train_dict(self): + + + TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) + ZstdDict(TRAINED_DICT.dict_content, False) + + self.assertNotEqual(TRAINED_DICT.dict_id, 0) + self.assertGreater(len(TRAINED_DICT.dict_content), 0) + self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1) + self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT))) + + # compress/decompress + c = ZstdCompressor(zstd_dict=TRAINED_DICT) + for sample in SAMPLES: + dat1 = compress(sample, zstd_dict=TRAINED_DICT) + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + def test_finalize_dict(self): + DICT_SIZE2 = 200*_1K + C_LEVEL = 6 + + try: + dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + self.assertNotEqual(dic2.dict_id, 0) + self.assertGreater(len(dic2.dict_content), 0) + self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2) + + # compress/decompress + c = ZstdCompressor(C_LEVEL, zstd_dict=dic2) + for sample in SAMPLES: + dat1 = compress(sample, C_LEVEL, zstd_dict=dic2) + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + # dict mismatch + self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id) + + dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT) + with self.assertRaises(ZstdError): + decompress(dat1, dic2) + + def test_train_dict_arguments(self): + with self.assertRaises(ValueError): + train_dict([], 100*_1K) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, -100) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, 0) + + def test_finalize_dict_arguments(self): + with self.assertRaises(TypeError): + finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, [], 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, -100, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, 0, 2) + + def test_train_dict_c(self): + # argument wrong type + with self.assertRaises(TypeError): + _zstd._train_dict({}, (), 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', (), 100.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._train_dict(b'', (2**64+1,), 100) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._train_dict(b'', (), 0) + + def test_finalize_dict_c(self): + with self.assertRaises(TypeError): + _zstd._finalize_dict(1, 2, 3, 4, 5) + + # argument wrong type + with self.assertRaises(TypeError): + _zstd._finalize_dict({}, b'', (), 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) + + def test_train_buffer_protocol_samples(self): + def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + return memoryview(dat).nbytes + + # prepare samples + chunk_lst = [] + wrong_size_lst = [] + correct_size_lst = [] + for _ in range(300): + arr = array.array('Q', [random.randint(0, 20) for i in range(20)]) + chunk_lst.append(arr) + correct_size_lst.append(_nbytes(arr)) + wrong_size_lst.append(len(arr)) + concatenation = b''.join(chunk_lst) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size tuple doesn't match the concatenation's size"): + _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*_1K) + + # correct size list + _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*_1K) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size tuple doesn't match the concatenation's size"): + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, tuple(wrong_size_lst), 300*_1K, 5) + + # correct size list + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, tuple(correct_size_lst), 300*_1K, 5) + + def test_as_prefix(self): + # V1 + V1 = THIS_FILE_BYTES + zd = ZstdDict(V1, True) + + # V2 + mid = len(V1) // 2 + V2 = V1[:mid] + \ + (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \ + V1[mid+1:] + + # compress + dat = compress(V2, zstd_dict=zd.as_prefix) + self.assertEqual(get_frame_info(dat).dictionary_id, 0) + + # decompress + self.assertEqual(decompress(dat, zd.as_prefix), V2) + + # use wrong prefix + zd2 = ZstdDict(SAMPLES[0], True) + try: + decompressed = decompress(dat, zd2.as_prefix) + except ZstdError: # expected + pass + else: + self.assertNotEqual(decompressed, V2) + + # read only attribute + with self.assertRaises(AttributeError): + zd.as_prefix = b'1234' + + def test_as_digested_dict(self): + zd = TRAINED_DICT + + # test .as_digested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict) + self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_digested_dict = b'1234' + + # test .as_undigested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict) + self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_undigested_dict = b'1234' + + def test_advanced_compression_parameters(self): + options = {CompressionParameter.compression_level: 6, + CompressionParameter.window_log: 20, + CompressionParameter.enable_long_distance_matching: 1} + + # automatically select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + # explicitly select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + def test_len(self): + self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content)) + self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT)) + +class FileTestCase(unittest.TestCase): + def setUp(self): + self.DECOMPRESSED_42 = b'a'*42 + self.FRAME_42 = compress(self.DECOMPRESSED_42) + + def test_init(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + pass + with ZstdFile(io.BytesIO(), "w") as f: + pass + with ZstdFile(io.BytesIO(), "x") as f: + pass + with ZstdFile(io.BytesIO(), "a") as f: + pass + + with ZstdFile(io.BytesIO(), "w", level=12) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={}) as f: + pass + with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: + pass + + with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f: + pass + with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: + pass + + def test_init_with_PathLike_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2) + + os.remove(filename) + + def test_init_with_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename) as f: + pass + with ZstdFile(filename, "w") as f: + pass + with ZstdFile(filename, "a") as f: + pass + + os.remove(filename) + + def test_init_mode(self): + bi = io.BytesIO() + + with ZstdFile(bi, "r"): + pass + with ZstdFile(bi, "rb"): + pass + with ZstdFile(bi, "w"): + pass + with ZstdFile(bi, "wb"): + pass + with ZstdFile(bi, "a"): + pass + with ZstdFile(bi, "ab"): + pass + + def test_init_with_x_mode(self): + with tempfile.NamedTemporaryFile() as tmp_f: + filename = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb"): + with ZstdFile(filename, mode): + pass + with self.assertRaises(FileExistsError): + with ZstdFile(filename, mode): + pass + os.remove(filename) + + def test_init_bad_mode(self): + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x")) + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") + + with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"): + ZstdFile(io.BytesIO(), 'rb', + options={CompressionParameter.compression_level:5}) + with self.assertRaisesRegex(TypeError, + r"NOT be DecompressionParameter"): + ZstdFile(io.BytesIO(), 'wb', + options={DecompressionParameter.window_log_max:21}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) + + def test_init_bad_check(self): + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(), "w", level='asd') + # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid. + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={999:9999}) + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) + + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DecompressionParameter.window_log_max:2**31}) + + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={444:333}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456') + + def test_init_close_fp(self): + # get a temp file name + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + tmp_f.write(DAT_130K_C) + filename = tmp_f.name + + with self.assertRaises(ValueError): + ZstdFile(filename, options={'a':'b'}) + + # for PyPy + gc.collect() + + os.remove(filename) + + def test_close(self): + with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src: + f = ZstdFile(src) + f.close() + # ZstdFile.close() should not close the underlying file object. + self.assertFalse(src.closed) + # Try closing an already-closed ZstdFile. + f.close() + self.assertFalse(src.closed) + + # Test with a real file on disk, opened directly by ZstdFile. + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + fp = f._fp + f.close() + # Here, ZstdFile.close() *should* close the underlying file object. + self.assertTrue(fp.closed) + # Try closing an already-closed ZstdFile. + f.close() + + os.remove(filename) + + def test_closed(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.closed) + f.read() + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + def test_fileno(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertRaises(io.UnsupportedOperation, f.fileno) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.fileno(), f._fp.fileno()) + self.assertIsInstance(f.fileno(), int) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + os.remove(filename) + + # 3, no .fileno() method + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'fileno'): + f.fileno() + + def test_name(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + with self.assertRaises(AttributeError): + f.name + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.name, f._fp.name) + self.assertIsInstance(f.name, str) + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + os.remove(filename) + + # 3, no .filename property + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'name'): + f.name + + def test_seekable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.seekable()) + f.read() + self.assertTrue(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + src = io.BytesIO(COMPRESSED_100_PLUS_32KB) + src.seekable = lambda: False + f = ZstdFile(src) + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + def test_readable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.readable()) + f.read() + self.assertTrue(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + def test_writable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.writable()) + f.read() + self.assertFalse(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertTrue(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + def test_read_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertEqual(f.read(0), b"") + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DecompressionParameter.window_log_max:20}) as f: + self.assertEqual(f.read(0), b"") + + # empty file + with ZstdFile(io.BytesIO(b'')) as f: + self.assertEqual(f.read(0), b"") + with self.assertRaises(EOFError): + f.read(10) + + with ZstdFile(io.BytesIO(b'')) as f: + with self.assertRaises(EOFError): + f.read(10) + + def test_read_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + chunks = [] + while True: + result = f.read(10) + if not result: + break + self.assertLessEqual(len(result), 10) + chunks.append(result) + self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB) + + def test_read_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT) + + def test_read_incomplete(self): + with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f: + self.assertRaises(EOFError, f.read) + + # Trailing data isn't a valid compressed stream + with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f: + self.assertEqual(f.read(), self.DECOMPRESSED_42) + + with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f: + self.assertEqual(f.read(), b'') + + def test_read_truncated(self): + # Drop stream epilogue: 4 bytes checksum + truncated = DAT_130K_C[:-4] + with ZstdFile(io.BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + + with ZstdFile(io.BytesIO(truncated)) as f: + # this is an important test, make sure it doesn't raise EOFError. + self.assertEqual(f.read(130*_1K), DAT_130K_D) + with self.assertRaises(EOFError): + f.read(1) + + # Incomplete header + for i in range(1, 20): + with ZstdFile(io.BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + def test_read_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_DAT)) + f.close() + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertRaises(TypeError, f.read, float()) + + def test_read_bad_data(self): + with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f: + self.assertRaises(ZstdError, f.read) + + def test_read_exception(self): + class C: + def read(self, size=-1): + raise OSError + with ZstdFile(C()) as f: + with self.assertRaises(OSError): + f.read(10) + + def test_read1(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DAT_130K_D) + self.assertEqual(f.read1(), b"") + + def test_read1_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertEqual(f.read1(0), b"") + + def test_read1_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + blocks = [] + while True: + result = f.read1(10) + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT) + self.assertEqual(f.read1(), b"") + + def test_read1_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5) + self.assertEqual(f.read1(), b"") + + def test_read1_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(TypeError, f.read1, None) + + def test_readinto(self): + arr = array.array("I", range(100)) + self.assertEqual(len(arr), 100) + self.assertEqual(len(arr) * arr.itemsize, 400) + ba = bytearray(300) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + # 0 length output buffer + self.assertEqual(f.readinto(ba[0:0]), 0) + + # use correct length for buffer protocol object + self.assertEqual(f.readinto(arr), 400) + self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400]) + + # normal readinto + self.assertEqual(f.readinto(ba), 300) + self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700]) + + def test_peek(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek() + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek(10) + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + + def test_peek_bad_args(self): + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.peek) + + def test_iterator(self): + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + compressed = compress(THIS_FILE_BYTES) + + # iter + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(list(iter(f)), lines) + + # readline + with ZstdFile(io.BytesIO(compressed)) as f: + for line in lines: + self.assertEqual(f.readline(), line) + self.assertEqual(f.readline(), b'') + self.assertEqual(f.readline(), b'') + + # readlines + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(f.readlines(), lines) + + def test_decompress_limited(self): + _ZSTD_DStreamInSize = 128*_1K + 3 + + bomb = compress(b'\0' * int(2e6), level=10) + self.assertLess(len(bomb), _ZSTD_DStreamInSize) + + decomp = ZstdFile(io.BytesIO(bomb)) + self.assertEqual(decomp.read(1), b'\0') + + # BufferedReader uses 128 KiB buffer in __init__.py + max_decomp = 128*_1K + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + + def test_write(self): + raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(raw_data) + + comp = ZstdCompressor() + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", level=12) as f: + f.write(raw_data) + + comp = ZstdCompressor(12) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f: + f.write(raw_data) + + comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1}) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + options = {CompressionParameter.compression_level:-5, + CompressionParameter.checksum_flag:1} + with ZstdFile(dst, "w", + options=options) as f: + f.write(raw_data) + + comp = ZstdCompressor(options=options) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_empty_frame(self): + # .FLUSH_FRAME generates an empty content frame + c = ZstdCompressor() + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + + # don't generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + pass + self.assertEqual(bo.getvalue(), b'') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_FRAME) + self.assertEqual(bo.getvalue(), b'') + + # if .write(b''), generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'') + self.assertNotEqual(bo.getvalue(), b'') + + # has an empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertNotEqual(bo.getvalue(), b'') + + def test_write_empty_block(self): + # If no internal data, .FLUSH_BLOCK return b''. + c = ZstdCompressor() + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK), + b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + + # mode = .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + f.flush(f.FLUSH_BLOCK) + fp_pos = f._fp.tell() + self.assertNotEqual(fp_pos, 0) + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), fp_pos) + + # mode != .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + f.write(b'') + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + + def test_write_101(self): + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + for start in range(0, len(THIS_FILE_BYTES), 101): + f.write(THIS_FILE_BYTES[start:start+101]) + + comp = ZstdCompressor() + expected = comp.compress(THIS_FILE_BYTES) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_append(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + part1 = THIS_FILE_BYTES[:_1K] + part2 = THIS_FILE_BYTES[_1K:1536] + part3 = THIS_FILE_BYTES[1536:] + expected = b"".join(comp(x) for x in (part1, part2, part3)) + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(part1) + with ZstdFile(dst, "a") as f: + f.write(part2) + with ZstdFile(dst, "a") as f: + f.write(part3) + self.assertEqual(dst.getvalue(), expected) + + def test_write_bad_args(self): + f = ZstdFile(io.BytesIO(), "w") + f.close() + self.assertRaises(ValueError, f.write, b"foo") + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f: + self.assertRaises(ValueError, f.write, b"bar") + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(TypeError, f.write, None) + self.assertRaises(TypeError, f.write, "text") + self.assertRaises(TypeError, f.write, 789) + + def test_writelines(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.writelines(lines) + expected = comp(THIS_FILE_BYTES) + self.assertEqual(dst.getvalue(), expected) + + def test_seek_forward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(555) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:]) + + def test_seek_forward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:]) + + def test_seek_forward_relative_to_current(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(100) + f.seek(1236, 1) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:]) + + def test_seek_forward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-555, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:]) + + def test_seek_backward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(1001) + f.seek(211) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:]) + + def test_seek_backward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333) + f.seek(737) + self.assertEqual(f.read(), + DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB) + + def test_seek_backward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-150, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:]) + + def test_seek_past_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001) + self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB)) + self.assertEqual(f.read(), b"") + + def test_seek_past_start(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-88) + self.assertEqual(f.tell(), 0) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + def test_seek_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(ValueError, f.seek, 0, 3) + # io.BufferedReader raises TypeError instead of ValueError + self.assertRaises((TypeError, ValueError), f.seek, 9, ()) + self.assertRaises(TypeError, f.seek, None) + self.assertRaises(TypeError, f.seek, b"derp") + + def test_seek_not_seekable(self): + class C(io.BytesIO): + def seekable(self): + return False + obj = C(COMPRESSED_100_PLUS_32KB) + with ZstdFile(obj, 'r') as f: + d = f.read(1) + self.assertFalse(f.seekable()) + with self.assertRaisesRegex(io.UnsupportedOperation, + 'File or stream is not seekable'): + f.seek(0) + d += f.read() + self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB) + + def test_tell(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + pos = 0 + while True: + self.assertEqual(f.tell(), pos) + result = f.read(random.randint(171, 189)) + if not result: + break + pos += len(result) + self.assertEqual(f.tell(), len(DAT_130K_D)) + with ZstdFile(io.BytesIO(), "w") as f: + for pos in range(0, len(DAT_130K_D), 143): + self.assertEqual(f.tell(), pos) + f.write(DAT_130K_D[pos:pos+143]) + self.assertEqual(f.tell(), len(DAT_130K_D)) + + def test_tell_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.tell) + + def test_file_dict(self): + # default + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_file_prefix(self): + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_UnsupportedOperation(self): + # 1 + with ZstdFile(io.BytesIO(), 'r') as f: + with self.assertRaises(io.UnsupportedOperation): + f.write(b'1234') + + # 2 + class T: + def read(self, size): + return b'a' * size + + with self.assertRaises(TypeError): # on creation + with ZstdFile(T(), 'w') as f: + pass + + # 3 + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.read(100) + with self.assertRaises(io.UnsupportedOperation): + f.seek(100) + self.assertEqual(f.closed, True) + with self.assertRaises(ValueError): + f.readable() + with self.assertRaises(ValueError): + f.tell() + with self.assertRaises(ValueError): + f.read(100) + + def test_read_readinto_readinto1(self): + lst = [] + with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f: + while True: + method = random.randint(0, 2) + size = random.randint(0, 300) + + if method == 0: + dat = f.read(size) + if not dat and size: + break + lst.append(dat) + elif method == 1: + ba = bytearray(size) + read_size = f.readinto(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + elif method == 2: + ba = bytearray(size) + read_size = f.readinto1(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5) + + def test_zstdfile_flush(self): + # closed + f = ZstdFile(io.BytesIO(), 'w') + f.close() + with self.assertRaises(ValueError): + f.flush() + + # read + with ZstdFile(io.BytesIO(), 'r') as f: + # does nothing for read-only stream + f.flush() + + # write + DAT = b'abcd' + bi = io.BytesIO() + with ZstdFile(bi, 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + self.assertEqual(bi.tell(), 0) # not enough for a block + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + self.assertGreater(bi.tell(), 0) # flushed + + # write, no .flush() method + class C: + def write(self, b): + return len(b) + with ZstdFile(C(), 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + + def test_zstdfile_flush_mode(self): + self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK) + self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME) + with self.assertRaises(AttributeError): + ZstdFile.CONTINUE + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + # flush block + self.assertEqual(f.write(b'123'), 3) + self.assertIsNone(f.flush(f.FLUSH_BLOCK)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush()) + p2 = bo.tell() + self.assertEqual(p1, p2) + # flush frame + self.assertEqual(f.write(b'456'), 3) + self.assertIsNone(f.flush(mode=f.FLUSH_FRAME)) + # flush frame + self.assertEqual(f.write(b'789'), 3) + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p2 = bo.tell() + self.assertEqual(p1, p2) + self.assertEqual(decompress(bo.getvalue()), b'123456789') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'): + f.flush(ZstdCompressor.CONTINUE) + with self.assertRaises(ValueError): + f.flush(-1) + with self.assertRaises(ValueError): + f.flush(123456) + with self.assertRaises(TypeError): + f.flush(node=ZstdCompressor.CONTINUE) + with self.assertRaises((TypeError, ValueError)): + f.flush('FLUSH_FRAME') + with self.assertRaises(TypeError): + f.flush(b'456', f.FLUSH_BLOCK) + + def test_zstdfile_truncate(self): + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.truncate(200) + + def test_zstdfile_iter_issue45475(self): + lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))] + self.assertGreater(len(lines), 0) + + def test_append_new_file(self): + with tempfile.NamedTemporaryFile(delete=True) as tmp_f: + filename = tmp_f.name + + with ZstdFile(filename, 'a') as f: + pass + self.assertTrue(os.path.isfile(filename)) + + os.remove(filename) + +class OpenTestCase(unittest.TestCase): + + def test_binary_modes(self): + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with io.BytesIO() as bio: + with open(bio, "wb") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + with open(bio, "ab") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2) + + def test_text_modes(self): + # empty input + with self.assertRaises(EOFError): + with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: + for _ in reader: + pass + + # read + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f: + self.assertEqual(f.read(), uncompressed) + + with io.BytesIO() as bio: + # write + with open(bio, "wt", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + # append + with open(bio, "at", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2) + + def test_bad_params(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + with self.assertRaises(ValueError): + open(TESTFN, "") + with self.assertRaises(ValueError): + open(TESTFN, "rbt") + with self.assertRaises(ValueError): + open(TESTFN, "rb", encoding="utf-8") + with self.assertRaises(ValueError): + open(TESTFN, "rb", errors="ignore") + with self.assertRaises(ValueError): + open(TESTFN, "rb", newline="\n") + + os.remove(TESTFN) + + def test_option(self): + options = {DecompressionParameter.window_log_max:25} + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + options = {CompressionParameter.compression_level:12} + with io.BytesIO() as bio: + with open(bio, "wb", options=options) as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + + def test_encoding(self): + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-16-le") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-16-le") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + bio.seek(0) + with open(bio, "rt", encoding="utf-16-le") as f: + self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed) + + def test_encoding_error_handler(self): + with io.BytesIO(compress(b"foo\xffbar")) as bio: + with open(bio, "rt", encoding="ascii", errors="ignore") as f: + self.assertEqual(f.read(), "foobar") + + def test_newline(self): + # Test with explicit newline (universal newline mode disabled). + text = THIS_FILE_STR.replace(os.linesep, "\n") + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-8", newline="\n") as f: + f.write(text) + bio.seek(0) + with open(bio, "rt", encoding="utf-8", newline="\r") as f: + self.assertEqual(f.readlines(), [text]) + + def test_x_mode(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb", "xt"): + os.remove(TESTFN) + + if mode == "xt": + encoding = "utf-8" + else: + encoding = None + with open(TESTFN, mode, encoding=encoding): + pass + with self.assertRaises(FileExistsError): + with open(TESTFN, mode): + pass + + os.remove(TESTFN) + + def test_open_dict(self): + # default + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # invalid dictionary + bi = io.BytesIO() + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict={1:2, 2:3}) + + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict=b'1234567890') + + def test_open_prefix(self): + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_buffer_protocol(self): + # don't use len() for buffer protocol objects + arr = array.array("i", range(1000)) + LENGTH = len(arr) * arr.itemsize + + with open(io.BytesIO(), "wb") as f: + self.assertEqual(f.write(arr), LENGTH) + self.assertEqual(f.tell(), LENGTH) + +class FreeThreadingMethodTests(unittest.TestCase): + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_locking(self): + input = b'a'* (16*_1K) + num_threads = 8 + + comp = ZstdCompressor() + parts = [] + for _ in range(num_threads): + res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) + if res: + parts.append(res) + rest1 = comp.flush() + expected = b''.join(parts) + rest1 + + comp = ZstdCompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, ZstdCompressor.FLUSH_BLOCK) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + rest2 = comp.flush() + self.assertEqual(rest1, rest2) + actual = b''.join(output) + rest2 + self.assertEqual(expected, actual) + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_locking(self): + input = compress(b'a'* (16*_1K)) + num_threads = 8 + # to ensure we decompress over multiple calls, set maxsize + window_size = _1K * 16//num_threads + + decomp = ZstdDecompressor() + parts = [] + for _ in range(num_threads): + res = decomp.decompress(input, window_size) + if res: + parts.append(res) + expected = b''.join(parts) + + comp = ZstdDecompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, window_size) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + actual = b''.join(output) + self.assertEqual(expected, actual) + + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/tokenize.py b/Lib/tokenize.py index 117b485b934..8d01fd7bce4 100644 --- a/Lib/tokenize.py +++ b/Lib/tokenize.py @@ -518,7 +518,7 @@ def _main(args=None): sys.exit(1) # Parse the arguments and options - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument(dest='filename', nargs='?', metavar='filename.py', help='the file to tokenize; defaults to stdin') diff --git a/Lib/trace.py b/Lib/trace.py index a87bc6d61a8..cf8817f4383 100644 --- a/Lib/trace.py +++ b/Lib/trace.py @@ -604,7 +604,7 @@ class Trace: def main(): import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('--version', action='version', version='trace 2.0') grp = parser.add_argument_group('Main options', diff --git a/Lib/traceback.py b/Lib/traceback.py index 16ba7fc2ee8..17b082eced6 100644 --- a/Lib/traceback.py +++ b/Lib/traceback.py @@ -10,9 +10,9 @@ import codeop import keyword import tokenize import io -from contextlib import suppress import _colorize -from _colorize import ANSIColors + +from contextlib import suppress __all__ = ['extract_stack', 'extract_tb', 'format_exception', 'format_exception_only', 'format_list', 'format_stack', @@ -187,15 +187,13 @@ def _format_final_exc_line(etype, value, *, insert_final_newline=True, colorize= valuestr = _safe_string(value, 'exception') end_char = "\n" if insert_final_newline else "" if colorize: - if value is None or not valuestr: - line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}{end_char}" - else: - line = f"{ANSIColors.BOLD_MAGENTA}{etype}{ANSIColors.RESET}: {ANSIColors.MAGENTA}{valuestr}{ANSIColors.RESET}{end_char}" + theme = _colorize.get_theme(force_color=True).traceback else: - if value is None or not valuestr: - line = f"{etype}{end_char}" - else: - line = f"{etype}: {valuestr}{end_char}" + theme = _colorize.get_theme(force_no_color=True).traceback + if value is None or not valuestr: + line = f"{theme.type}{etype}{theme.reset}{end_char}" + else: + line = f"{theme.type}{etype}{theme.reset}: {theme.message}{valuestr}{theme.reset}{end_char}" return line @@ -539,21 +537,22 @@ class StackSummary(list): if frame_summary.filename.startswith("<stdin>-"): filename = "<stdin>" if colorize: - row.append(' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format( - ANSIColors.MAGENTA, - filename, - ANSIColors.RESET, - ANSIColors.MAGENTA, - frame_summary.lineno, - ANSIColors.RESET, - ANSIColors.MAGENTA, - frame_summary.name, - ANSIColors.RESET, - ) - ) + theme = _colorize.get_theme(force_color=True).traceback else: - row.append(' File "{}", line {}, in {}\n'.format( - filename, frame_summary.lineno, frame_summary.name)) + theme = _colorize.get_theme(force_no_color=True).traceback + row.append( + ' File {}"{}"{}, line {}{}{}, in {}{}{}\n'.format( + theme.filename, + filename, + theme.reset, + theme.line_no, + frame_summary.lineno, + theme.reset, + theme.frame, + frame_summary.name, + theme.reset, + ) + ) if frame_summary._dedented_lines and frame_summary._dedented_lines.strip(): if ( frame_summary.colno is None or @@ -672,11 +671,11 @@ class StackSummary(list): for color, group in itertools.groupby(itertools.zip_longest(line, carets, fillvalue=""), key=lambda x: x[1]): caret_group = list(group) if color == "^": - colorized_line_parts.append(ANSIColors.BOLD_RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET) - colorized_carets_parts.append(ANSIColors.BOLD_RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET) + colorized_line_parts.append(theme.error_highlight + "".join(char for char, _ in caret_group) + theme.reset) + colorized_carets_parts.append(theme.error_highlight + "".join(caret for _, caret in caret_group) + theme.reset) elif color == "~": - colorized_line_parts.append(ANSIColors.RED + "".join(char for char, _ in caret_group) + ANSIColors.RESET) - colorized_carets_parts.append(ANSIColors.RED + "".join(caret for _, caret in caret_group) + ANSIColors.RESET) + colorized_line_parts.append(theme.error_range + "".join(char for char, _ in caret_group) + theme.reset) + colorized_carets_parts.append(theme.error_range + "".join(caret for _, caret in caret_group) + theme.reset) else: colorized_line_parts.append("".join(char for char, _ in caret_group)) colorized_carets_parts.append("".join(caret for _, caret in caret_group)) @@ -1378,20 +1377,20 @@ class TracebackException: """Format SyntaxError exceptions (internal helper).""" # Show exactly where the problem was found. colorize = kwargs.get("colorize", False) + if colorize: + theme = _colorize.get_theme(force_color=True).traceback + else: + theme = _colorize.get_theme(force_no_color=True).traceback filename_suffix = '' if self.lineno is not None: - if colorize: - yield ' File {}"{}"{}, line {}{}{}\n'.format( - ANSIColors.MAGENTA, - self.filename or "<string>", - ANSIColors.RESET, - ANSIColors.MAGENTA, - self.lineno, - ANSIColors.RESET, - ) - else: - yield ' File "{}", line {}\n'.format( - self.filename or "<string>", self.lineno) + yield ' File {}"{}"{}, line {}{}{}\n'.format( + theme.filename, + self.filename or "<string>", + theme.reset, + theme.line_no, + self.lineno, + theme.reset, + ) elif self.filename is not None: filename_suffix = ' ({})'.format(self.filename) @@ -1441,11 +1440,11 @@ class TracebackException: # colorize from colno to end_colno ltext = ( ltext[:colno] + - ANSIColors.BOLD_RED + ltext[colno:end_colno] + ANSIColors.RESET + + theme.error_highlight + ltext[colno:end_colno] + theme.reset + ltext[end_colno:] ) - start_color = ANSIColors.BOLD_RED - end_color = ANSIColors.RESET + start_color = theme.error_highlight + end_color = theme.reset yield ' {}\n'.format(ltext) yield ' {}{}{}{}\n'.format( "".join(caretspace), @@ -1456,17 +1455,15 @@ class TracebackException: else: yield ' {}\n'.format(ltext) msg = self.msg or "<no detail available>" - if colorize: - yield "{}{}{}: {}{}{}{}\n".format( - ANSIColors.BOLD_MAGENTA, - stype, - ANSIColors.RESET, - ANSIColors.MAGENTA, - msg, - ANSIColors.RESET, - filename_suffix) - else: - yield "{}: {}{}\n".format(stype, msg, filename_suffix) + yield "{}{}{}: {}{}{}{}\n".format( + theme.type, + stype, + theme.reset, + theme.message, + msg, + theme.reset, + filename_suffix, + ) def format(self, *, chain=True, _ctx=None, **kwargs): """Format the exception. diff --git a/Lib/typing.py b/Lib/typing.py index e019c597580..2baf655256d 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -3477,7 +3477,7 @@ class IO(Generic[AnyStr]): pass @abstractmethod - def readlines(self, hint: int = -1) -> List[AnyStr]: + def readlines(self, hint: int = -1) -> list[AnyStr]: pass @abstractmethod @@ -3493,7 +3493,7 @@ class IO(Generic[AnyStr]): pass @abstractmethod - def truncate(self, size: int = None) -> int: + def truncate(self, size: int | None = None) -> int: pass @abstractmethod @@ -3505,11 +3505,11 @@ class IO(Generic[AnyStr]): pass @abstractmethod - def writelines(self, lines: List[AnyStr]) -> None: + def writelines(self, lines: list[AnyStr]) -> None: pass @abstractmethod - def __enter__(self) -> 'IO[AnyStr]': + def __enter__(self) -> IO[AnyStr]: pass @abstractmethod @@ -3523,11 +3523,11 @@ class BinaryIO(IO[bytes]): __slots__ = () @abstractmethod - def write(self, s: Union[bytes, bytearray]) -> int: + def write(self, s: bytes | bytearray) -> int: pass @abstractmethod - def __enter__(self) -> 'BinaryIO': + def __enter__(self) -> BinaryIO: pass @@ -3548,7 +3548,7 @@ class TextIO(IO[str]): @property @abstractmethod - def errors(self) -> Optional[str]: + def errors(self) -> str | None: pass @property @@ -3562,7 +3562,7 @@ class TextIO(IO[str]): pass @abstractmethod - def __enter__(self) -> 'TextIO': + def __enter__(self) -> TextIO: pass diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index c3869de3f6f..6fd949581f3 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -197,7 +197,7 @@ class TestProgram(object): return parser def _getMainArgParser(self, parent): - parser = argparse.ArgumentParser(parents=[parent]) + parser = argparse.ArgumentParser(parents=[parent], color=True) parser.prog = self.progName parser.print_help = self._print_help @@ -208,7 +208,7 @@ class TestProgram(object): return parser def _getDiscoveryArgParser(self, parent): - parser = argparse.ArgumentParser(parents=[parent]) + parser = argparse.ArgumentParser(parents=[parent], color=True) parser.prog = '%s discover' % self.progName parser.epilog = ('For test discovery all test modules must be ' 'importable from the top level directory of the ' diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py index eb0234a2617..5f22d91aebd 100644 --- a/Lib/unittest/runner.py +++ b/Lib/unittest/runner.py @@ -4,7 +4,7 @@ import sys import time import warnings -from _colorize import get_colors +from _colorize import get_theme from . import result from .case import _SubTest @@ -45,7 +45,7 @@ class TextTestResult(result.TestResult): self.showAll = verbosity > 1 self.dots = verbosity == 1 self.descriptions = descriptions - self._ansi = get_colors(file=stream) + self._theme = get_theme(tty_file=stream).unittest self._newline = True self.durations = durations @@ -79,101 +79,99 @@ class TextTestResult(result.TestResult): def addSubTest(self, test, subtest, err): if err is not None: - red, reset = self._ansi.RED, self._ansi.RESET + t = self._theme if self.showAll: if issubclass(err[0], subtest.failureException): - self._write_status(subtest, f"{red}FAIL{reset}") + self._write_status(subtest, f"{t.fail}FAIL{t.reset}") else: - self._write_status(subtest, f"{red}ERROR{reset}") + self._write_status(subtest, f"{t.fail}ERROR{t.reset}") elif self.dots: if issubclass(err[0], subtest.failureException): - self.stream.write(f"{red}F{reset}") + self.stream.write(f"{t.fail}F{t.reset}") else: - self.stream.write(f"{red}E{reset}") + self.stream.write(f"{t.fail}E{t.reset}") self.stream.flush() super(TextTestResult, self).addSubTest(test, subtest, err) def addSuccess(self, test): super(TextTestResult, self).addSuccess(test) - green, reset = self._ansi.GREEN, self._ansi.RESET + t = self._theme if self.showAll: - self._write_status(test, f"{green}ok{reset}") + self._write_status(test, f"{t.passed}ok{t.reset}") elif self.dots: - self.stream.write(f"{green}.{reset}") + self.stream.write(f"{t.passed}.{t.reset}") self.stream.flush() def addError(self, test, err): super(TextTestResult, self).addError(test, err) - red, reset = self._ansi.RED, self._ansi.RESET + t = self._theme if self.showAll: - self._write_status(test, f"{red}ERROR{reset}") + self._write_status(test, f"{t.fail}ERROR{t.reset}") elif self.dots: - self.stream.write(f"{red}E{reset}") + self.stream.write(f"{t.fail}E{t.reset}") self.stream.flush() def addFailure(self, test, err): super(TextTestResult, self).addFailure(test, err) - red, reset = self._ansi.RED, self._ansi.RESET + t = self._theme if self.showAll: - self._write_status(test, f"{red}FAIL{reset}") + self._write_status(test, f"{t.fail}FAIL{t.reset}") elif self.dots: - self.stream.write(f"{red}F{reset}") + self.stream.write(f"{t.fail}F{t.reset}") self.stream.flush() def addSkip(self, test, reason): super(TextTestResult, self).addSkip(test, reason) - yellow, reset = self._ansi.YELLOW, self._ansi.RESET + t = self._theme if self.showAll: - self._write_status(test, f"{yellow}skipped{reset} {reason!r}") + self._write_status(test, f"{t.warn}skipped{t.reset} {reason!r}") elif self.dots: - self.stream.write(f"{yellow}s{reset}") + self.stream.write(f"{t.warn}s{t.reset}") self.stream.flush() def addExpectedFailure(self, test, err): super(TextTestResult, self).addExpectedFailure(test, err) - yellow, reset = self._ansi.YELLOW, self._ansi.RESET + t = self._theme if self.showAll: - self.stream.writeln(f"{yellow}expected failure{reset}") + self.stream.writeln(f"{t.warn}expected failure{t.reset}") self.stream.flush() elif self.dots: - self.stream.write(f"{yellow}x{reset}") + self.stream.write(f"{t.warn}x{t.reset}") self.stream.flush() def addUnexpectedSuccess(self, test): super(TextTestResult, self).addUnexpectedSuccess(test) - red, reset = self._ansi.RED, self._ansi.RESET + t = self._theme if self.showAll: - self.stream.writeln(f"{red}unexpected success{reset}") + self.stream.writeln(f"{t.fail}unexpected success{t.reset}") self.stream.flush() elif self.dots: - self.stream.write(f"{red}u{reset}") + self.stream.write(f"{t.fail}u{t.reset}") self.stream.flush() def printErrors(self): - bold_red = self._ansi.BOLD_RED - red = self._ansi.RED - reset = self._ansi.RESET + t = self._theme if self.dots or self.showAll: self.stream.writeln() self.stream.flush() - self.printErrorList(f"{red}ERROR{reset}", self.errors) - self.printErrorList(f"{red}FAIL{reset}", self.failures) + self.printErrorList(f"{t.fail}ERROR{t.reset}", self.errors) + self.printErrorList(f"{t.fail}FAIL{t.reset}", self.failures) unexpectedSuccesses = getattr(self, "unexpectedSuccesses", ()) if unexpectedSuccesses: self.stream.writeln(self.separator1) for test in unexpectedSuccesses: self.stream.writeln( - f"{red}UNEXPECTED SUCCESS{bold_red}: " - f"{self.getDescription(test)}{reset}" + f"{t.fail}UNEXPECTED SUCCESS{t.fail_info}: " + f"{self.getDescription(test)}{t.reset}" ) self.stream.flush() def printErrorList(self, flavour, errors): - bold_red, reset = self._ansi.BOLD_RED, self._ansi.RESET + t = self._theme for test, err in errors: self.stream.writeln(self.separator1) self.stream.writeln( - f"{flavour}{bold_red}: {self.getDescription(test)}{reset}" + f"{flavour}{t.fail_info}: {self.getDescription(test)}{t.reset}" ) self.stream.writeln(self.separator2) self.stream.writeln("%s" % err) @@ -286,31 +284,26 @@ class TextTestRunner(object): expected_fails, unexpected_successes, skipped = results infos = [] - ansi = get_colors(file=self.stream) - bold_red = ansi.BOLD_RED - green = ansi.GREEN - red = ansi.RED - reset = ansi.RESET - yellow = ansi.YELLOW + t = get_theme(tty_file=self.stream).unittest if not result.wasSuccessful(): - self.stream.write(f"{bold_red}FAILED{reset}") + self.stream.write(f"{t.fail_info}FAILED{t.reset}") failed, errored = len(result.failures), len(result.errors) if failed: - infos.append(f"{bold_red}failures={failed}{reset}") + infos.append(f"{t.fail_info}failures={failed}{t.reset}") if errored: - infos.append(f"{bold_red}errors={errored}{reset}") + infos.append(f"{t.fail_info}errors={errored}{t.reset}") elif run == 0 and not skipped: - self.stream.write(f"{yellow}NO TESTS RAN{reset}") + self.stream.write(f"{t.warn}NO TESTS RAN{t.reset}") else: - self.stream.write(f"{green}OK{reset}") + self.stream.write(f"{t.passed}OK{t.reset}") if skipped: - infos.append(f"{yellow}skipped={skipped}{reset}") + infos.append(f"{t.warn}skipped={skipped}{t.reset}") if expected_fails: - infos.append(f"{yellow}expected failures={expected_fails}{reset}") + infos.append(f"{t.warn}expected failures={expected_fails}{t.reset}") if unexpected_successes: infos.append( - f"{red}unexpected successes={unexpected_successes}{reset}" + f"{t.fail}unexpected successes={unexpected_successes}{t.reset}" ) if infos: self.stream.writeln(" (%s)" % (", ".join(infos),)) diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py index 9a6b29a90a2..41dc5d7b35d 100644 --- a/Lib/urllib/request.py +++ b/Lib/urllib/request.py @@ -1466,7 +1466,7 @@ class FileHandler(BaseHandler): def open_local_file(self, req): import email.utils import mimetypes - localfile = url2pathname(req.full_url, require_scheme=True) + localfile = url2pathname(req.full_url, require_scheme=True, resolve_host=True) try: stats = os.stat(localfile) size = stats.st_size @@ -1482,7 +1482,7 @@ class FileHandler(BaseHandler): file_open = open_local_file -def _is_local_authority(authority): +def _is_local_authority(authority, resolve): # Compare hostnames if not authority or authority == 'localhost': return True @@ -1494,9 +1494,11 @@ def _is_local_authority(authority): if authority == hostname: return True # Compare IP addresses + if not resolve: + return False try: address = socket.gethostbyname(authority) - except (socket.gaierror, AttributeError): + except (socket.gaierror, AttributeError, UnicodeEncodeError): return False return address in FileHandler().get_names() @@ -1641,13 +1643,16 @@ class DataHandler(BaseHandler): return addinfourl(io.BytesIO(data), headers, url) -# Code move from the old urllib module +# Code moved from the old urllib module -def url2pathname(url, *, require_scheme=False): +def url2pathname(url, *, require_scheme=False, resolve_host=False): """Convert the given file URL to a local file system path. The 'file:' scheme prefix must be omitted unless *require_scheme* is set to true. + + The URL authority may be resolved with gethostbyname() if + *resolve_host* is set to true. """ if require_scheme: scheme, url = _splittype(url) @@ -1655,7 +1660,7 @@ def url2pathname(url, *, require_scheme=False): raise URLError("URL is missing a 'file:' scheme") authority, url = _splithost(url) if os.name == 'nt': - if not _is_local_authority(authority): + if not _is_local_authority(authority, resolve_host): # e.g. file://server/share/file.txt url = '//' + authority + url elif url[:3] == '///': @@ -1669,7 +1674,7 @@ def url2pathname(url, *, require_scheme=False): # Older URLs use a pipe after a drive letter url = url[:1] + ':' + url[2:] url = url.replace('/', '\\') - elif not _is_local_authority(authority): + elif not _is_local_authority(authority, resolve_host): raise URLError("file:// scheme is supported only on localhost") encoding = sys.getfilesystemencoding() errors = sys.getfilesystemencodeerrors() diff --git a/Lib/uuid.py b/Lib/uuid.py index 2c16c3f0f5a..036ffebf67a 100644 --- a/Lib/uuid.py +++ b/Lib/uuid.py @@ -949,7 +949,9 @@ def main(): import argparse parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="Generate a UUID using the selected UUID function.") + description="Generate a UUID using the selected UUID function.", + color=True, + ) parser.add_argument("-u", "--uuid", choices=uuid_funcs.keys(), default="uuid4", diff --git a/Lib/venv/__init__.py b/Lib/venv/__init__.py index dc4c9ef3531..15e15b7a518 100644 --- a/Lib/venv/__init__.py +++ b/Lib/venv/__init__.py @@ -624,7 +624,9 @@ def main(args=None): 'created, you may wish to ' 'activate it, e.g. by ' 'sourcing an activate script ' - 'in its bin directory.') + 'in its bin directory.', + color=True, + ) parser.add_argument('dirs', metavar='ENV_DIR', nargs='+', help='A directory to create the environment in.') parser.add_argument('--system-site-packages', default=False, diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py index ab50ec1ee95..f2e2394089d 100644 --- a/Lib/webbrowser.py +++ b/Lib/webbrowser.py @@ -719,7 +719,9 @@ if sys.platform == "ios": def parse_args(arg_list: list[str] | None): import argparse - parser = argparse.ArgumentParser(description="Open URL in a web browser.") + parser = argparse.ArgumentParser( + description="Open URL in a web browser.", color=True, + ) parser.add_argument("url", help="URL to open") group = parser.add_mutually_exclusive_group() diff --git a/Lib/zipapp.py b/Lib/zipapp.py index 59b444075a6..7a4ef96ea0f 100644 --- a/Lib/zipapp.py +++ b/Lib/zipapp.py @@ -187,7 +187,7 @@ def main(args=None): """ import argparse - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(color=True) parser.add_argument('--output', '-o', default=None, help="The name of the output archive. " "Required if SOURCE is an archive.") diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py index b7840d0f945..88356abe8cb 100644 --- a/Lib/zipfile/__init__.py +++ b/Lib/zipfile/__init__.py @@ -31,6 +31,11 @@ try: except ImportError: lzma = None +try: + from compression import zstd # We may need its compression method +except ImportError: + zstd = None + __all__ = ["BadZipFile", "BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA", "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile", @@ -58,12 +63,14 @@ ZIP_STORED = 0 ZIP_DEFLATED = 8 ZIP_BZIP2 = 12 ZIP_LZMA = 14 +ZIP_ZSTANDARD = 93 # Other ZIP compression methods not supported DEFAULT_VERSION = 20 ZIP64_VERSION = 45 BZIP2_VERSION = 46 LZMA_VERSION = 63 +ZSTANDARD_VERSION = 63 # we recognize (but not necessarily support) all features up to that version MAX_EXTRACT_VERSION = 63 @@ -505,6 +512,8 @@ class ZipInfo: min_version = max(BZIP2_VERSION, min_version) elif self.compress_type == ZIP_LZMA: min_version = max(LZMA_VERSION, min_version) + elif self.compress_type == ZIP_ZSTANDARD: + min_version = max(ZSTANDARD_VERSION, min_version) self.extract_version = max(min_version, self.extract_version) self.create_version = max(min_version, self.create_version) @@ -766,6 +775,7 @@ compressor_names = { 14: 'lzma', 18: 'terse', 19: 'lz77', + 93: 'zstd', 97: 'wavpack', 98: 'ppmd', } @@ -785,6 +795,10 @@ def _check_compression(compression): if not lzma: raise RuntimeError( "Compression requires the (missing) lzma module") + elif compression == ZIP_ZSTANDARD: + if not zstd: + raise RuntimeError( + "Compression requires the (missing) compression.zstd module") else: raise NotImplementedError("That compression method is not supported") @@ -798,9 +812,11 @@ def _get_compressor(compress_type, compresslevel=None): if compresslevel is not None: return bz2.BZ2Compressor(compresslevel) return bz2.BZ2Compressor() - # compresslevel is ignored for ZIP_LZMA + # compresslevel is ignored for ZIP_LZMA and ZIP_ZSTANDARD elif compress_type == ZIP_LZMA: return LZMACompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdCompressor() else: return None @@ -815,6 +831,8 @@ def _get_decompressor(compress_type): return bz2.BZ2Decompressor() elif compress_type == ZIP_LZMA: return LZMADecompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdDecompressor() else: descr = compressor_names.get(compress_type) if descr: @@ -2317,7 +2335,7 @@ def main(args=None): import argparse description = 'A simple command-line interface for zipfile module.' - parser = argparse.ArgumentParser(description=description) + parser = argparse.ArgumentParser(description=description, color=True) group = parser.add_mutually_exclusive_group(required=True) group.add_argument('-l', '--list', metavar='<zipfile>', help='Show listing of a zipfile') |