diff options
Diffstat (limited to 'Tools/jit')
-rw-r--r-- | Tools/jit/_optimizers.py | 319 | ||||
-rw-r--r-- | Tools/jit/_stencils.py | 67 | ||||
-rw-r--r-- | Tools/jit/_targets.py | 58 |
3 files changed, 350 insertions, 94 deletions
diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py new file mode 100644 index 00000000000..1077e4106fd --- /dev/null +++ b/Tools/jit/_optimizers.py @@ -0,0 +1,319 @@ +"""Low-level optimization of textual assembly.""" + +import dataclasses +import pathlib +import re +import typing + +# Same as saying "not string.startswith('')": +_RE_NEVER_MATCH = re.compile(r"(?!)") +# Dictionary mapping branch instructions to their inverted branch instructions. +# If a branch cannot be inverted, the value is None: +_X86_BRANCHES = { + # https://www.felixcloutier.com/x86/jcc + "ja": "jna", + "jae": "jnae", + "jb": "jnb", + "jbe": "jnbe", + "jc": "jnc", + "jcxz": None, + "je": "jne", + "jecxz": None, + "jg": "jng", + "jge": "jnge", + "jl": "jnl", + "jle": "jnle", + "jo": "jno", + "jp": "jnp", + "jpe": "jpo", + "jrcxz": None, + "js": "jns", + "jz": "jnz", + # https://www.felixcloutier.com/x86/loop:loopcc + "loop": None, + "loope": None, + "loopne": None, + "loopnz": None, + "loopz": None, +} +# Update with all of the inverted branches, too: +_X86_BRANCHES |= {v: k for k, v in _X86_BRANCHES.items() if v} + + +@dataclasses.dataclass +class _Block: + label: str | None = None + # Non-instruction lines like labels, directives, and comments: + noninstructions: list[str] = dataclasses.field(default_factory=list) + # Instruction lines: + instructions: list[str] = dataclasses.field(default_factory=list) + # If this block ends in a jump, where to? + target: typing.Self | None = None + # The next block in the linked list: + link: typing.Self | None = None + # Whether control flow can fall through to the linked block above: + fallthrough: bool = True + # Whether this block can eventually reach the next uop (_JIT_CONTINUE): + hot: bool = False + + def resolve(self) -> typing.Self: + """Find the first non-empty block reachable from this one.""" + block = self + while block.link and not block.instructions: + block = block.link + return block + + +@dataclasses.dataclass +class Optimizer: + """Several passes of analysis and optimization for textual assembly.""" + + path: pathlib.Path + _: dataclasses.KW_ONLY + # prefix used to mangle symbols on some platforms: + prefix: str = "" + # The first block in the linked list: + _root: _Block = dataclasses.field(init=False, default_factory=_Block) + _labels: dict[str, _Block] = dataclasses.field(init=False, default_factory=dict) + # No groups: + _re_noninstructions: typing.ClassVar[re.Pattern[str]] = re.compile( + r"\s*(?:\.|#|//|$)" + ) + # One group (label): + _re_label: typing.ClassVar[re.Pattern[str]] = re.compile( + r'\s*(?P<label>[\w."$?@]+):' + ) + # Override everything that follows in subclasses: + _alignment: typing.ClassVar[int] = 1 + _branches: typing.ClassVar[dict[str, str | None]] = {} + # Two groups (instruction and target): + _re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH + # One group (target): + _re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH + # No groups: + _re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH + + def __post_init__(self) -> None: + # Split the code into a linked list of basic blocks. A basic block is an + # optional label, followed by zero or more non-instruction lines, + # followed by zero or more instruction lines (only the last of which may + # be a branch, jump, or return): + text = self._preprocess(self.path.read_text()) + block = self._root + for line in text.splitlines(): + # See if we need to start a new block: + if match := self._re_label.match(line): + # Label. New block: + block.link = block = self._lookup_label(match["label"]) + block.noninstructions.append(line) + continue + if self._re_noninstructions.match(line): + if block.instructions: + # Non-instruction lines. New block: + block.link = block = _Block() + block.noninstructions.append(line) + continue + if block.target or not block.fallthrough: + # Current block ends with a branch, jump, or return. New block: + block.link = block = _Block() + block.instructions.append(line) + if match := self._re_branch.match(line): + # A block ending in a branch has a target and fallthrough: + block.target = self._lookup_label(match["target"]) + assert block.fallthrough + elif match := self._re_jump.match(line): + # A block ending in a jump has a target and no fallthrough: + block.target = self._lookup_label(match["target"]) + block.fallthrough = False + elif self._re_return.match(line): + # A block ending in a return has no target and fallthrough: + assert not block.target + block.fallthrough = False + + def _preprocess(self, text: str) -> str: + # Override this method to do preprocessing of the textual assembly: + return text + + @classmethod + def _invert_branch(cls, line: str, target: str) -> str | None: + match = cls._re_branch.match(line) + assert match + inverted = cls._branches.get(match["instruction"]) + if not inverted: + return None + (a, b), (c, d) = match.span("instruction"), match.span("target") + # Before: + # je FOO + # After: + # jne BAR + return "".join([line[:a], inverted, line[b:c], target, line[d:]]) + + @classmethod + def _update_jump(cls, line: str, target: str) -> str: + match = cls._re_jump.match(line) + assert match + a, b = match.span("target") + # Before: + # jmp FOO + # After: + # jmp BAR + return "".join([line[:a], target, line[b:]]) + + def _lookup_label(self, label: str) -> _Block: + if label not in self._labels: + self._labels[label] = _Block(label) + return self._labels[label] + + def _blocks(self) -> typing.Generator[_Block, None, None]: + block: _Block | None = self._root + while block: + yield block + block = block.link + + def _body(self) -> str: + lines = [] + hot = True + for block in self._blocks(): + if hot != block.hot: + hot = block.hot + # Make it easy to tell at a glance where cold code is: + lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#")) + lines.extend(block.noninstructions) + lines.extend(block.instructions) + return "\n".join(lines) + + def _predecessors(self, block: _Block) -> typing.Generator[_Block, None, None]: + # This is inefficient, but it's never wrong: + for pre in self._blocks(): + if pre.target is block or pre.fallthrough and pre.link is block: + yield pre + + def _insert_continue_label(self) -> None: + # Find the block with the last instruction: + for end in reversed(list(self._blocks())): + if end.instructions: + break + # Before: + # jmp FOO + # After: + # jmp FOO + # .balign 8 + # _JIT_CONTINUE: + # This lets the assembler encode _JIT_CONTINUE jumps at build time! + align = _Block() + align.noninstructions.append(f"\t.balign\t{self._alignment}") + continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE") + assert continuation.label + continuation.noninstructions.append(f"{continuation.label}:") + end.link, align.link, continuation.link = align, continuation, end.link + + def _mark_hot_blocks(self) -> None: + # Start with the last block, and perform a DFS to find all blocks that + # can eventually reach it: + todo = list(self._blocks())[-1:] + while todo: + block = todo.pop() + block.hot = True + todo.extend(pre for pre in self._predecessors(block) if not pre.hot) + + def _invert_hot_branches(self) -> None: + for branch in self._blocks(): + link = branch.link + if link is None: + continue + jump = link.resolve() + # Before: + # je HOT + # jmp COLD + # After: + # jne COLD + # jmp HOT + if ( + # block ends with a branch to hot code... + branch.target + and branch.fallthrough + and branch.target.hot + # ...followed by a jump to cold code with no other predecessors: + and jump.target + and not jump.fallthrough + and not jump.target.hot + and len(jump.instructions) == 1 + and list(self._predecessors(jump)) == [branch] + ): + assert jump.target.label + assert branch.target.label + inverted = self._invert_branch( + branch.instructions[-1], jump.target.label + ) + # Check to see if the branch can even be inverted: + if inverted is None: + continue + branch.instructions[-1] = inverted + jump.instructions[-1] = self._update_jump( + jump.instructions[-1], branch.target.label + ) + branch.target, jump.target = jump.target, branch.target + jump.hot = True + + def _remove_redundant_jumps(self) -> None: + # Zero-length jumps can be introduced by _insert_continue_label and + # _invert_hot_branches: + for block in self._blocks(): + # Before: + # jmp FOO + # FOO: + # After: + # FOO: + if ( + block.target + and block.link + and block.target.resolve() is block.link.resolve() + ): + block.target = None + block.fallthrough = True + block.instructions.pop() + + def run(self) -> None: + """Run this optimizer.""" + self._insert_continue_label() + self._mark_hot_blocks() + self._invert_hot_branches() + self._remove_redundant_jumps() + self.path.write_text(self._body()) + + +class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods + """aarch64-apple-darwin/aarch64-pc-windows-msvc/aarch64-unknown-linux-gnu""" + + # TODO: @diegorusso + _alignment = 8 + # https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch- + _re_jump = re.compile(r"\s*b\s+(?P<target>[\w.]+)") + + +class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods + """i686-pc-windows-msvc/x86_64-apple-darwin/x86_64-unknown-linux-gnu""" + + _branches = _X86_BRANCHES + _re_branch = re.compile( + rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)" + ) + # https://www.felixcloutier.com/x86/jmp + _re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w.]+)") + # https://www.felixcloutier.com/x86/ret + _re_return = re.compile(r"\s*ret\b") + + +class OptimizerX8664Windows(OptimizerX86): # pylint: disable = too-few-public-methods + """x86_64-pc-windows-msvc""" + + def _preprocess(self, text: str) -> str: + text = super()._preprocess(text) + # Before: + # rex64 jmpq *__imp__JIT_CONTINUE(%rip) + # After: + # jmp _JIT_CONTINUE + far_indirect_jump = ( + rf"rex64\s+jmpq\s+\*__imp_(?P<target>{self.prefix}_JIT_\w+)\(%rip\)" + ) + return re.sub(far_indirect_jump, r"jmp\t\g<target>", text) diff --git a/Tools/jit/_stencils.py b/Tools/jit/_stencils.py index 03b0ba647b0..1d82f5366f6 100644 --- a/Tools/jit/_stencils.py +++ b/Tools/jit/_stencils.py @@ -17,8 +17,6 @@ class HoleValue(enum.Enum): # The base address of the machine code for the current uop (exposed as _JIT_ENTRY): CODE = enum.auto() - # The base address of the machine code for the next uop (exposed as _JIT_CONTINUE): - CONTINUE = enum.auto() # The base address of the read-only data for this uop: DATA = enum.auto() # The address of the current executor (exposed as _JIT_EXECUTOR): @@ -97,7 +95,6 @@ _PATCH_FUNCS = { # Translate HoleValues to C expressions: _HOLE_EXPRS = { HoleValue.CODE: "(uintptr_t)code", - HoleValue.CONTINUE: "(uintptr_t)code + sizeof(code_body)", HoleValue.DATA: "(uintptr_t)data", HoleValue.EXECUTOR: "(uintptr_t)executor", # These should all have been turned into DATA values by process_relocations: @@ -209,64 +206,6 @@ class Stencil: self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}") self.body.extend([0] * padding) - def add_nops(self, nop: bytes, alignment: int) -> None: - """Add NOPs until there is alignment. Fail if it is not possible.""" - offset = len(self.body) - nop_size = len(nop) - - # Calculate the gap to the next multiple of alignment. - gap = -offset % alignment - if gap: - if gap % nop_size == 0: - count = gap // nop_size - self.body.extend(nop * count) - else: - raise ValueError( - f"Cannot add nops of size '{nop_size}' to a body with " - f"offset '{offset}' to align with '{alignment}'" - ) - - def remove_jump(self) -> None: - """Remove a zero-length continuation jump, if it exists.""" - hole = max(self.holes, key=lambda hole: hole.offset) - match hole: - case Hole( - offset=offset, - kind="IMAGE_REL_AMD64_REL32", - value=HoleValue.GOT, - symbol="_JIT_CONTINUE", - addend=-4, - ) as hole: - # jmp qword ptr [rip] - jump = b"\x48\xff\x25\x00\x00\x00\x00" - offset -= 3 - case Hole( - offset=offset, - kind="IMAGE_REL_I386_REL32" | "R_X86_64_PLT32" | "X86_64_RELOC_BRANCH", - value=HoleValue.CONTINUE, - symbol=None, - addend=addend, - ) as hole if ( - _signed(addend) == -4 - ): - # jmp 5 - jump = b"\xe9\x00\x00\x00\x00" - offset -= 1 - case Hole( - offset=offset, - kind="R_AARCH64_JUMP26", - value=HoleValue.CONTINUE, - symbol=None, - addend=0, - ) as hole: - # b #4 - jump = b"\x00\x00\x00\x14" - case _: - return - if self.body[offset:] == jump: - self.body = self.body[:offset] - self.holes.remove(hole) - @dataclasses.dataclass class StencilGroup: @@ -284,9 +223,7 @@ class StencilGroup: _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False) _trampolines: set[int] = dataclasses.field(default_factory=set, init=False) - def process_relocations( - self, known_symbols: dict[str, int], *, alignment: int = 1, nop: bytes = b"" - ) -> None: + def process_relocations(self, known_symbols: dict[str, int]) -> None: """Fix up all GOT and internal relocations for this stencil group.""" for hole in self.code.holes.copy(): if ( @@ -306,8 +243,6 @@ class StencilGroup: self._trampolines.add(ordinal) hole.addend = ordinal hole.symbol = None - self.code.remove_jump() - self.code.add_nops(nop=nop, alignment=alignment) self.data.pad(8) for stencil in [self.code, self.data]: for hole in stencil.holes: diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index b383e39da19..ed10329d25d 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -13,6 +13,7 @@ import typing import shlex import _llvm +import _optimizers import _schema import _stencils import _writer @@ -41,8 +42,8 @@ class _Target(typing.Generic[_S, _R]): triple: str condition: str _: dataclasses.KW_ONLY - alignment: int = 1 args: typing.Sequence[str] = () + optimizer: type[_optimizers.Optimizer] = _optimizers.Optimizer prefix: str = "" stable: bool = False debug: bool = False @@ -121,8 +122,9 @@ class _Target(typing.Generic[_S, _R]): async def _compile( self, opname: str, c: pathlib.Path, tempdir: pathlib.Path ) -> _stencils.StencilGroup: + s = tempdir / f"{opname}.s" o = tempdir / f"{opname}.o" - args = [ + args_s = [ f"--target={self.triple}", "-DPy_BUILD_CORE_MODULE", "-D_DEBUG" if self.debug else "-DNDEBUG", @@ -136,7 +138,7 @@ class _Target(typing.Generic[_S, _R]): f"-I{CPYTHON / 'Python'}", f"-I{CPYTHON / 'Tools' / 'jit'}", "-O3", - "-c", + "-S", # Shorten full absolute file paths in the generated code (like the # __FILE__ macro and assert failure messages) for reproducibility: f"-ffile-prefix-map={CPYTHON}=.", @@ -155,13 +157,16 @@ class _Target(typing.Generic[_S, _R]): "-fno-stack-protector", "-std=c11", "-o", - f"{o}", + f"{s}", f"{c}", *self.args, # Allow user-provided CFLAGS to override any defaults *shlex.split(self.cflags), ] - await _llvm.run("clang", args, echo=self.verbose) + await _llvm.run("clang", args_s, echo=self.verbose) + self.optimizer(s, prefix=self.prefix).run() + args_o = [f"--target={self.triple}", "-c", "-o", f"{o}", f"{s}"] + await _llvm.run("clang", args_o, echo=self.verbose) return await self._parse(o) async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]: @@ -190,11 +195,7 @@ class _Target(typing.Generic[_S, _R]): tasks.append(group.create_task(coro, name=opname)) stencil_groups = {task.get_name(): task.result() for task in tasks} for stencil_group in stencil_groups.values(): - stencil_group.process_relocations( - known_symbols=self.known_symbols, - alignment=self.alignment, - nop=self._get_nop(), - ) + stencil_group.process_relocations(self.known_symbols) return stencil_groups def build( @@ -524,42 +525,43 @@ class _MachO( def get_target(host: str) -> _COFF | _ELF | _MachO: """Build a _Target for the given host "triple" and options.""" + optimizer: type[_optimizers.Optimizer] target: _COFF | _ELF | _MachO if re.fullmatch(r"aarch64-apple-darwin.*", host): condition = "defined(__aarch64__) && defined(__APPLE__)" - target = _MachO(host, condition, alignment=8, prefix="_") + optimizer = _optimizers.OptimizerAArch64 + target = _MachO(host, condition, optimizer=optimizer, prefix="_") elif re.fullmatch(r"aarch64-pc-windows-msvc", host): args = ["-fms-runtime-lib=dll", "-fplt"] condition = "defined(_M_ARM64)" - target = _COFF(host, condition, alignment=8, args=args) + optimizer = _optimizers.OptimizerAArch64 + target = _COFF(host, condition, args=args, optimizer=optimizer) elif re.fullmatch(r"aarch64-.*-linux-gnu", host): - args = [ - "-fpic", - # On aarch64 Linux, intrinsics were being emitted and this flag - # was required to disable them. - "-mno-outline-atomics", - ] + # -mno-outline-atomics: Keep intrinsics from being emitted. + args = ["-fpic", "-mno-outline-atomics"] condition = "defined(__aarch64__) && defined(__linux__)" - target = _ELF(host, condition, alignment=8, args=args) + optimizer = _optimizers.OptimizerAArch64 + target = _ELF(host, condition, args=args, optimizer=optimizer) elif re.fullmatch(r"i686-pc-windows-msvc", host): - args = [ - "-DPy_NO_ENABLE_SHARED", - # __attribute__((preserve_none)) is not supported - "-Wno-ignored-attributes", - ] + # -Wno-ignored-attributes: __attribute__((preserve_none)) is not supported here. + args = ["-DPy_NO_ENABLE_SHARED", "-Wno-ignored-attributes"] + optimizer = _optimizers.OptimizerX86 condition = "defined(_M_IX86)" - target = _COFF(host, condition, args=args, prefix="_") + target = _COFF(host, condition, args=args, optimizer=optimizer, prefix="_") elif re.fullmatch(r"x86_64-apple-darwin.*", host): condition = "defined(__x86_64__) && defined(__APPLE__)" - target = _MachO(host, condition, prefix="_") + optimizer = _optimizers.OptimizerX86 + target = _MachO(host, condition, optimizer=optimizer, prefix="_") elif re.fullmatch(r"x86_64-pc-windows-msvc", host): args = ["-fms-runtime-lib=dll"] condition = "defined(_M_X64)" - target = _COFF(host, condition, args=args) + optimizer = _optimizers.OptimizerX8664Windows + target = _COFF(host, condition, args=args, optimizer=optimizer) elif re.fullmatch(r"x86_64-.*-linux-gnu", host): args = ["-fno-pic", "-mcmodel=medium", "-mlarge-data-threshold=0"] condition = "defined(__x86_64__) && defined(__linux__)" - target = _ELF(host, condition, args=args) + optimizer = _optimizers.OptimizerX86 + target = _ELF(host, condition, args=args, optimizer=optimizer) else: raise ValueError(host) return target |