aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/importlib/resources/_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/importlib/resources/_common.py')
-rw-r--r--Lib/importlib/resources/_common.py86
1 files changed, 67 insertions, 19 deletions
diff --git a/Lib/importlib/resources/_common.py b/Lib/importlib/resources/_common.py
index 92a37e2c12b..a3902535342 100644
--- a/Lib/importlib/resources/_common.py
+++ b/Lib/importlib/resources/_common.py
@@ -5,25 +5,58 @@ import functools
import contextlib
import types
import importlib
+import inspect
+import warnings
+import itertools
-from typing import Union, Optional
+from typing import Union, Optional, cast
from .abc import ResourceReader, Traversable
from ._adapters import wrap_spec
Package = Union[types.ModuleType, str]
+Anchor = Package
-def files(package):
- # type: (Package) -> Traversable
+def package_to_anchor(func):
"""
- Get a Traversable resource from a package
+ Replace 'package' parameter as 'anchor' and warn about the change.
+
+ Other errors should fall through.
+
+ >>> files('a', 'b')
+ Traceback (most recent call last):
+ TypeError: files() takes from 0 to 1 positional arguments but 2 were given
+ """
+ undefined = object()
+
+ @functools.wraps(func)
+ def wrapper(anchor=undefined, package=undefined):
+ if package is not undefined:
+ if anchor is not undefined:
+ return func(anchor, package)
+ warnings.warn(
+ "First parameter to files is renamed to 'anchor'",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return func(package)
+ elif anchor is undefined:
+ return func()
+ return func(anchor)
+
+ return wrapper
+
+
+@package_to_anchor
+def files(anchor: Optional[Anchor] = None) -> Traversable:
+ """
+ Get a Traversable resource for an anchor.
"""
- return from_package(get_package(package))
+ return from_package(resolve(anchor))
-def get_resource_reader(package):
- # type: (types.ModuleType) -> Optional[ResourceReader]
+def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
"""
Return the package's loader if it's a ResourceReader.
"""
@@ -39,24 +72,39 @@ def get_resource_reader(package):
return reader(spec.name) # type: ignore
-def resolve(cand):
- # type: (Package) -> types.ModuleType
- return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand)
+@functools.singledispatch
+def resolve(cand: Optional[Anchor]) -> types.ModuleType:
+ return cast(types.ModuleType, cand)
+
+
+@resolve.register
+def _(cand: str) -> types.ModuleType:
+ return importlib.import_module(cand)
+
+@resolve.register
+def _(cand: None) -> types.ModuleType:
+ return resolve(_infer_caller().f_globals['__name__'])
-def get_package(package):
- # type: (Package) -> types.ModuleType
- """Take a package name or module object and return the module.
- Raise an exception if the resolved module is not a package.
+def _infer_caller():
"""
- resolved = resolve(package)
- if wrap_spec(resolved).submodule_search_locations is None:
- raise TypeError(f'{package!r} is not a package')
- return resolved
+ Walk the stack and find the frame of the first caller not in this module.
+ """
+
+ def is_this_file(frame_info):
+ return frame_info.filename == __file__
+
+ def is_wrapper(frame_info):
+ return frame_info.function == 'wrapper'
+
+ not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
+ # also exclude 'wrapper' due to singledispatch in the call stack
+ callers = itertools.filterfalse(is_wrapper, not_this_file)
+ return next(callers).frame
-def from_package(package):
+def from_package(package: types.ModuleType):
"""
Return a Traversable object for the given package.