diff options
Diffstat (limited to 'Lib/_pyrepl/_module_completer.py')
-rw-r--r-- | Lib/_pyrepl/_module_completer.py | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/Lib/_pyrepl/_module_completer.py b/Lib/_pyrepl/_module_completer.py index 1fb043e0b70..1e9462a4215 100644 --- a/Lib/_pyrepl/_module_completer.py +++ b/Lib/_pyrepl/_module_completer.py @@ -2,6 +2,7 @@ from __future__ import annotations import pkgutil import sys +import token import tokenize from io import StringIO from contextlib import contextmanager @@ -16,8 +17,8 @@ if TYPE_CHECKING: def make_default_module_completer() -> ModuleCompleter: - # Inside pyrepl, __package__ is set to '_pyrepl' - return ModuleCompleter(namespace={'__package__': '_pyrepl'}) + # Inside pyrepl, __package__ is set to None by default + return ModuleCompleter(namespace={'__package__': None}) class ModuleCompleter: @@ -41,11 +42,11 @@ class ModuleCompleter: self._global_cache: list[pkgutil.ModuleInfo] = [] self._curr_sys_path: list[str] = sys.path[:] - def get_completions(self, line: str) -> list[str]: + def get_completions(self, line: str) -> list[str] | None: """Return the next possible import completions for 'line'.""" result = ImportParser(line).parse() if not result: - return [] + return None try: return self.complete(*result) except Exception: @@ -80,8 +81,11 @@ class ModuleCompleter: def _find_modules(self, path: str, prefix: str) -> list[str]: if not path: # Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)` - return [name for _, name, _ in self.global_cache - if name.startswith(prefix)] + builtin_modules = [name for name in sys.builtin_module_names + if self.is_suggestion_match(name, prefix)] + third_party_modules = [module.name for module in self.global_cache + if self.is_suggestion_match(module.name, prefix)] + return sorted(builtin_modules + third_party_modules) if path.startswith('.'): # Convert relative path to absolute path @@ -96,7 +100,14 @@ class ModuleCompleter: if mod_info.ispkg and mod_info.name == segment] modules = self.iter_submodules(modules) return [module.name for module in modules - if module.name.startswith(prefix)] + if self.is_suggestion_match(module.name, prefix)] + + def is_suggestion_match(self, module_name: str, prefix: str) -> bool: + if prefix: + return module_name.startswith(prefix) + # For consistency with attribute completion, which + # does not suggest private attributes unless requested. + return not module_name.startswith("_") def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: """Iterate over all submodules of the given parent modules.""" @@ -180,8 +191,8 @@ class ImportParser: when parsing multiple statements. """ _ignored_tokens = { - tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, - tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER + token.INDENT, token.DEDENT, token.COMMENT, + token.NL, token.NEWLINE, token.ENDMARKER } _keywords = {'import', 'from', 'as'} @@ -350,11 +361,11 @@ class TokenQueue: def peek_name(self) -> bool: if not (tok := self.peek()): return False - return tok.type == tokenize.NAME + return tok.type == token.NAME def pop_name(self) -> str: tok = self.pop() - if tok.type != tokenize.NAME: + if tok.type != token.NAME: raise ParseError('pop_name') return tok.string |