aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/annotationlib.py
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2024-07-23 14:16:50 -0700
committerGitHub <noreply@github.com>2024-07-23 21:16:50 +0000
commit7b7b90d1ce5116f29ad6c8120c0490824baa54e0 (patch)
treeb273afb5767b7a55fb52a277cc9512ebfc4ac12d /Lib/annotationlib.py
parent64e221d7ada8f6c20189035c7e81503f4c914f04 (diff)
downloadcpython-7b7b90d1ce5116f29ad6c8120c0490824baa54e0.tar.gz
cpython-7b7b90d1ce5116f29ad6c8120c0490824baa54e0.zip
gh-119180: Add `annotationlib` module to support PEP 649 (#119891)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Diffstat (limited to 'Lib/annotationlib.py')
-rw-r--r--Lib/annotationlib.py655
1 files changed, 655 insertions, 0 deletions
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
new file mode 100644
index 00000000000..b4036ffb189
--- /dev/null
+++ b/Lib/annotationlib.py
@@ -0,0 +1,655 @@
+"""Helpers for introspecting and wrapping annotations."""
+
+import ast
+import enum
+import functools
+import sys
+import types
+
+__all__ = ["Format", "ForwardRef", "call_annotate_function", "get_annotations"]
+
+
+class Format(enum.IntEnum):
+ VALUE = 1
+ FORWARDREF = 2
+ SOURCE = 3
+
+
+_Union = None
+_sentinel = object()
+
+# Slots shared by ForwardRef and _Stringifier. The __forward__ names must be
+# preserved for compatibility with the old typing.ForwardRef class. The remaining
+# names are private.
+_SLOTS = (
+ "__forward_evaluated__",
+ "__forward_value__",
+ "__forward_is_argument__",
+ "__forward_is_class__",
+ "__forward_module__",
+ "__weakref__",
+ "__arg__",
+ "__ast_node__",
+ "__code__",
+ "__globals__",
+ "__owner__",
+ "__cell__",
+)
+
+
+class ForwardRef:
+ """Wrapper that holds a forward reference."""
+
+ __slots__ = _SLOTS
+
+ def __init__(
+ self,
+ arg,
+ *,
+ module=None,
+ owner=None,
+ is_argument=True,
+ is_class=False,
+ _globals=None,
+ _cell=None,
+ ):
+ if not isinstance(arg, str):
+ raise TypeError(f"Forward reference must be a string -- got {arg!r}")
+
+ self.__arg__ = arg
+ self.__forward_evaluated__ = False
+ self.__forward_value__ = None
+ self.__forward_is_argument__ = is_argument
+ self.__forward_is_class__ = is_class
+ self.__forward_module__ = module
+ self.__code__ = None
+ self.__ast_node__ = None
+ self.__globals__ = _globals
+ self.__cell__ = _cell
+ self.__owner__ = owner
+
+ def __init_subclass__(cls, /, *args, **kwds):
+ raise TypeError("Cannot subclass ForwardRef")
+
+ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
+ """Evaluate the forward reference and return the value.
+
+ If the forward reference is not evaluatable, raise an exception.
+ """
+ if self.__forward_evaluated__:
+ return self.__forward_value__
+ if self.__cell__ is not None:
+ try:
+ value = self.__cell__.cell_contents
+ except ValueError:
+ pass
+ else:
+ self.__forward_evaluated__ = True
+ self.__forward_value__ = value
+ return value
+ if owner is None:
+ owner = self.__owner__
+ if type_params is None and owner is None:
+ raise TypeError("Either 'type_params' or 'owner' must be provided")
+
+ if self.__forward_module__ is not None:
+ globals = getattr(
+ sys.modules.get(self.__forward_module__, None), "__dict__", globals
+ )
+ if globals is None:
+ globals = self.__globals__
+ if globals is None:
+ if isinstance(owner, type):
+ module_name = getattr(owner, "__module__", None)
+ if module_name:
+ module = sys.modules.get(module_name, None)
+ if module:
+ globals = getattr(module, "__dict__", None)
+ elif isinstance(owner, types.ModuleType):
+ globals = getattr(owner, "__dict__", None)
+ elif callable(owner):
+ globals = getattr(owner, "__globals__", None)
+
+ if locals is None:
+ locals = {}
+ if isinstance(self.__owner__, type):
+ locals.update(vars(self.__owner__))
+
+ if type_params is None and self.__owner__ is not None:
+ # "Inject" type parameters into the local namespace
+ # (unless they are shadowed by assignments *in* the local namespace),
+ # as a way of emulating annotation scopes when calling `eval()`
+ type_params = getattr(self.__owner__, "__type_params__", None)
+
+ # type parameters require some special handling,
+ # as they exist in their own scope
+ # but `eval()` does not have a dedicated parameter for that scope.
+ # For classes, names in type parameter scopes should override
+ # names in the global scope (which here are called `localns`!),
+ # but should in turn be overridden by names in the class scope
+ # (which here are called `globalns`!)
+ if type_params is not None:
+ globals, locals = dict(globals), dict(locals)
+ for param in type_params:
+ param_name = param.__name__
+ if not self.__forward_is_class__ or param_name not in globals:
+ globals[param_name] = param
+ locals.pop(param_name, None)
+
+ code = self.__forward_code__
+ value = eval(code, globals=globals, locals=locals)
+ self.__forward_evaluated__ = True
+ self.__forward_value__ = value
+ return value
+
+ def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
+ import typing
+ import warnings
+
+ if type_params is _sentinel:
+ typing._deprecation_warning_for_no_type_params_passed(
+ "typing.ForwardRef._evaluate"
+ )
+ type_params = ()
+ warnings._deprecated(
+ "ForwardRef._evaluate",
+ "{name} is a private API and is retained for compatibility, but will be removed"
+ " in Python 3.16. Use ForwardRef.evaluate() or typing.evaluate_forward_ref() instead.",
+ remove=(3, 16),
+ )
+ return typing.evaluate_forward_ref(
+ self,
+ globals=globalns,
+ locals=localns,
+ type_params=type_params,
+ _recursive_guard=recursive_guard,
+ )
+
+ @property
+ def __forward_arg__(self):
+ if self.__arg__ is not None:
+ return self.__arg__
+ if self.__ast_node__ is not None:
+ self.__arg__ = ast.unparse(self.__ast_node__)
+ return self.__arg__
+ raise AssertionError(
+ "Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
+ )
+
+ @property
+ def __forward_code__(self):
+ if self.__code__ is not None:
+ return self.__code__
+ arg = self.__forward_arg__
+ # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
+ # Unfortunately, this isn't a valid expression on its own, so we
+ # do the unpacking manually.
+ if arg.startswith("*"):
+ arg_to_compile = f"({arg},)[0]" # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
+ else:
+ arg_to_compile = arg
+ try:
+ self.__code__ = compile(arg_to_compile, "<string>", "eval")
+ except SyntaxError:
+ raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
+ return self.__code__
+
+ def __eq__(self, other):
+ if not isinstance(other, ForwardRef):
+ return NotImplemented
+ if self.__forward_evaluated__ and other.__forward_evaluated__:
+ return (
+ self.__forward_arg__ == other.__forward_arg__
+ and self.__forward_value__ == other.__forward_value__
+ )
+ return (
+ self.__forward_arg__ == other.__forward_arg__
+ and self.__forward_module__ == other.__forward_module__
+ )
+
+ def __hash__(self):
+ return hash((self.__forward_arg__, self.__forward_module__))
+
+ def __or__(self, other):
+ global _Union
+ if _Union is None:
+ from typing import Union as _Union
+ return _Union[self, other]
+
+ def __ror__(self, other):
+ global _Union
+ if _Union is None:
+ from typing import Union as _Union
+ return _Union[other, self]
+
+ def __repr__(self):
+ if self.__forward_module__ is None:
+ module_repr = ""
+ else:
+ module_repr = f", module={self.__forward_module__!r}"
+ return f"ForwardRef({self.__forward_arg__!r}{module_repr})"
+
+
+class _Stringifier:
+ # Must match the slots on ForwardRef, so we can turn an instance of one into an
+ # instance of the other in place.
+ __slots__ = _SLOTS
+
+ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
+ assert isinstance(node, ast.AST)
+ self.__arg__ = None
+ self.__forward_evaluated__ = False
+ self.__forward_value__ = None
+ self.__forward_is_argument__ = False
+ self.__forward_is_class__ = is_class
+ self.__forward_module__ = None
+ self.__code__ = None
+ self.__ast_node__ = node
+ self.__globals__ = globals
+ self.__cell__ = cell
+ self.__owner__ = owner
+
+ def __convert(self, other):
+ if isinstance(other, _Stringifier):
+ return other.__ast_node__
+ elif isinstance(other, slice):
+ return ast.Slice(
+ lower=self.__convert(other.start) if other.start is not None else None,
+ upper=self.__convert(other.stop) if other.stop is not None else None,
+ step=self.__convert(other.step) if other.step is not None else None,
+ )
+ else:
+ return ast.Constant(value=other)
+
+ def __make_new(self, node):
+ return _Stringifier(
+ node, self.__globals__, self.__owner__, self.__forward_is_class__
+ )
+
+ # Must implement this since we set __eq__. We hash by identity so that
+ # stringifiers in dict keys are kept separate.
+ def __hash__(self):
+ return id(self)
+
+ def __getitem__(self, other):
+ # Special case, to avoid stringifying references to class-scoped variables
+ # as '__classdict__["x"]'.
+ if (
+ isinstance(self.__ast_node__, ast.Name)
+ and self.__ast_node__.id == "__classdict__"
+ ):
+ raise KeyError
+ if isinstance(other, tuple):
+ elts = [self.__convert(elt) for elt in other]
+ other = ast.Tuple(elts)
+ else:
+ other = self.__convert(other)
+ assert isinstance(other, ast.AST), repr(other)
+ return self.__make_new(ast.Subscript(self.__ast_node__, other))
+
+ def __getattr__(self, attr):
+ return self.__make_new(ast.Attribute(self.__ast_node__, attr))
+
+ def __call__(self, *args, **kwargs):
+ return self.__make_new(
+ ast.Call(
+ self.__ast_node__,
+ [self.__convert(arg) for arg in args],
+ [
+ ast.keyword(key, self.__convert(value))
+ for key, value in kwargs.items()
+ ],
+ )
+ )
+
+ def __iter__(self):
+ yield self.__make_new(ast.Starred(self.__ast_node__))
+
+ def __repr__(self):
+ return ast.unparse(self.__ast_node__)
+
+ def __format__(self, format_spec):
+ raise TypeError("Cannot stringify annotation containing string formatting")
+
+ def _make_binop(op: ast.AST):
+ def binop(self, other):
+ return self.__make_new(
+ ast.BinOp(self.__ast_node__, op, self.__convert(other))
+ )
+
+ return binop
+
+ __add__ = _make_binop(ast.Add())
+ __sub__ = _make_binop(ast.Sub())
+ __mul__ = _make_binop(ast.Mult())
+ __matmul__ = _make_binop(ast.MatMult())
+ __truediv__ = _make_binop(ast.Div())
+ __mod__ = _make_binop(ast.Mod())
+ __lshift__ = _make_binop(ast.LShift())
+ __rshift__ = _make_binop(ast.RShift())
+ __or__ = _make_binop(ast.BitOr())
+ __xor__ = _make_binop(ast.BitXor())
+ __and__ = _make_binop(ast.BitAnd())
+ __floordiv__ = _make_binop(ast.FloorDiv())
+ __pow__ = _make_binop(ast.Pow())
+
+ del _make_binop
+
+ def _make_rbinop(op: ast.AST):
+ def rbinop(self, other):
+ return self.__make_new(
+ ast.BinOp(self.__convert(other), op, self.__ast_node__)
+ )
+
+ return rbinop
+
+ __radd__ = _make_rbinop(ast.Add())
+ __rsub__ = _make_rbinop(ast.Sub())
+ __rmul__ = _make_rbinop(ast.Mult())
+ __rmatmul__ = _make_rbinop(ast.MatMult())
+ __rtruediv__ = _make_rbinop(ast.Div())
+ __rmod__ = _make_rbinop(ast.Mod())
+ __rlshift__ = _make_rbinop(ast.LShift())
+ __rrshift__ = _make_rbinop(ast.RShift())
+ __ror__ = _make_rbinop(ast.BitOr())
+ __rxor__ = _make_rbinop(ast.BitXor())
+ __rand__ = _make_rbinop(ast.BitAnd())
+ __rfloordiv__ = _make_rbinop(ast.FloorDiv())
+ __rpow__ = _make_rbinop(ast.Pow())
+
+ del _make_rbinop
+
+ def _make_compare(op):
+ def compare(self, other):
+ return self.__make_new(
+ ast.Compare(
+ left=self.__ast_node__,
+ ops=[op],
+ comparators=[self.__convert(other)],
+ )
+ )
+
+ return compare
+
+ __lt__ = _make_compare(ast.Lt())
+ __le__ = _make_compare(ast.LtE())
+ __eq__ = _make_compare(ast.Eq())
+ __ne__ = _make_compare(ast.NotEq())
+ __gt__ = _make_compare(ast.Gt())
+ __ge__ = _make_compare(ast.GtE())
+
+ del _make_compare
+
+ def _make_unary_op(op):
+ def unary_op(self):
+ return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
+
+ return unary_op
+
+ __invert__ = _make_unary_op(ast.Invert())
+ __pos__ = _make_unary_op(ast.UAdd())
+ __neg__ = _make_unary_op(ast.USub())
+
+ del _make_unary_op
+
+
+class _StringifierDict(dict):
+ def __init__(self, namespace, globals=None, owner=None, is_class=False):
+ super().__init__(namespace)
+ self.namespace = namespace
+ self.globals = globals
+ self.owner = owner
+ self.is_class = is_class
+ self.stringifiers = []
+
+ def __missing__(self, key):
+ fwdref = _Stringifier(
+ ast.Name(id=key),
+ globals=self.globals,
+ owner=self.owner,
+ is_class=self.is_class,
+ )
+ self.stringifiers.append(fwdref)
+ return fwdref
+
+
+def call_annotate_function(annotate, format, owner=None):
+ """Call an __annotate__ function. __annotate__ functions are normally
+ generated by the compiler to defer the evaluation of annotations. They
+ can be called with any of the format arguments in the Format enum, but
+ compiler-generated __annotate__ functions only support the VALUE format.
+ This function provides additional functionality to call __annotate__
+ functions with the FORWARDREF and SOURCE formats.
+
+ *annotate* must be an __annotate__ function, which takes a single argument
+ and returns a dict of annotations.
+
+ *format* must be a member of the Format enum or one of the corresponding
+ integer values.
+
+ *owner* can be the object that owns the annotations (i.e., the module,
+ class, or function that the __annotate__ function derives from). With the
+ FORWARDREF format, it is used to provide better evaluation capabilities
+ on the generated ForwardRef objects.
+
+ """
+ try:
+ return annotate(format)
+ except NotImplementedError:
+ pass
+ if format == Format.SOURCE:
+ # SOURCE is implemented by calling the annotate function in a special
+ # environment where every name lookup results in an instance of _Stringifier.
+ # _Stringifier supports every dunder operation and returns a new _Stringifier.
+ # At the end, we get a dictionary that mostly contains _Stringifier objects (or
+ # possibly constants if the annotate function uses them directly). We then
+ # convert each of those into a string to get an approximation of the
+ # original source.
+ globals = _StringifierDict({})
+ if annotate.__closure__:
+ freevars = annotate.__code__.co_freevars
+ new_closure = []
+ for i, cell in enumerate(annotate.__closure__):
+ if i < len(freevars):
+ name = freevars[i]
+ else:
+ name = "__cell__"
+ fwdref = _Stringifier(ast.Name(id=name))
+ new_closure.append(types.CellType(fwdref))
+ closure = tuple(new_closure)
+ else:
+ closure = None
+ func = types.FunctionType(annotate.__code__, globals, closure=closure)
+ annos = func(Format.VALUE)
+ return {
+ key: val if isinstance(val, str) else repr(val)
+ for key, val in annos.items()
+ }
+ elif format == Format.FORWARDREF:
+ # FORWARDREF is implemented similarly to SOURCE, but there are two changes,
+ # at the beginning and the end of the process.
+ # First, while SOURCE uses an empty dictionary as the namespace, so that all
+ # name lookups result in _Stringifier objects, FORWARDREF uses the globals
+ # and builtins, so that defined names map to their real values.
+ # Second, instead of returning strings, we want to return either real values
+ # or ForwardRef objects. To do this, we keep track of all _Stringifier objects
+ # created while the annotation is being evaluated, and at the end we convert
+ # them all to ForwardRef objects by assigning to __class__. To make this
+ # technique work, we have to ensure that the _Stringifier and ForwardRef
+ # classes share the same attributes.
+ # We use this technique because while the annotations are being evaluated,
+ # we want to support all operations that the language allows, including even
+ # __getattr__ and __eq__, and return new _Stringifier objects so we can accurately
+ # reconstruct the source. But in the dictionary that we eventually return, we
+ # want to return objects with more user-friendly behavior, such as an __eq__
+ # that returns a bool and an defined set of attributes.
+ namespace = {**annotate.__builtins__, **annotate.__globals__}
+ is_class = isinstance(owner, type)
+ globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
+ if annotate.__closure__:
+ freevars = annotate.__code__.co_freevars
+ new_closure = []
+ for i, cell in enumerate(annotate.__closure__):
+ try:
+ cell.cell_contents
+ except ValueError:
+ if i < len(freevars):
+ name = freevars[i]
+ else:
+ name = "__cell__"
+ fwdref = _Stringifier(
+ ast.Name(id=name),
+ cell=cell,
+ owner=owner,
+ globals=annotate.__globals__,
+ is_class=is_class,
+ )
+ globals.stringifiers.append(fwdref)
+ new_closure.append(types.CellType(fwdref))
+ else:
+ new_closure.append(cell)
+ closure = tuple(new_closure)
+ else:
+ closure = None
+ func = types.FunctionType(annotate.__code__, globals, closure=closure)
+ result = func(Format.VALUE)
+ for obj in globals.stringifiers:
+ obj.__class__ = ForwardRef
+ return result
+ elif format == Format.VALUE:
+ # Should be impossible because __annotate__ functions must not raise
+ # NotImplementedError for this format.
+ raise RuntimeError("annotate function does not support VALUE format")
+ else:
+ raise ValueError(f"Invalid format: {format!r}")
+
+
+def get_annotations(
+ obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE
+):
+ """Compute the annotations dict for an object.
+
+ obj may be a callable, class, or module.
+ Passing in an object of any other type raises TypeError.
+
+ Returns a dict. get_annotations() returns a new dict every time
+ it's called; calling it twice on the same object will return two
+ different but equivalent dicts.
+
+ This function handles several details for you:
+
+ * If eval_str is true, values of type str will
+ be un-stringized using eval(). This is intended
+ for use with stringized annotations
+ ("from __future__ import annotations").
+ * If obj doesn't have an annotations dict, returns an
+ empty dict. (Functions and methods always have an
+ annotations dict; classes, modules, and other types of
+ callables may not.)
+ * Ignores inherited annotations on classes. If a class
+ doesn't have its own annotations dict, returns an empty dict.
+ * All accesses to object members and dict values are done
+ using getattr() and dict.get() for safety.
+ * Always, always, always returns a freshly-created dict.
+
+ eval_str controls whether or not values of type str are replaced
+ with the result of calling eval() on those values:
+
+ * If eval_str is true, eval() is called on values of type str.
+ * If eval_str is false (the default), values of type str are unchanged.
+
+ globals and locals are passed in to eval(); see the documentation
+ for eval() for more information. If either globals or locals is
+ None, this function may replace that value with a context-specific
+ default, contingent on type(obj):
+
+ * If obj is a module, globals defaults to obj.__dict__.
+ * If obj is a class, globals defaults to
+ sys.modules[obj.__module__].__dict__ and locals
+ defaults to the obj class namespace.
+ * If obj is a callable, globals defaults to obj.__globals__,
+ although if obj is a wrapped function (using
+ functools.update_wrapper()) it is first unwrapped.
+ """
+ if eval_str and format != Format.VALUE:
+ raise ValueError("eval_str=True is only supported with format=Format.VALUE")
+
+ # For VALUE format, we look at __annotations__ directly.
+ if format != Format.VALUE:
+ annotate = getattr(obj, "__annotate__", None)
+ if annotate is not None:
+ ann = call_annotate_function(annotate, format, owner=obj)
+ if not isinstance(ann, dict):
+ raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
+ return dict(ann)
+
+ ann = getattr(obj, "__annotations__", None)
+ if ann is None:
+ return {}
+
+ if not isinstance(ann, dict):
+ raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+
+ if not ann:
+ return {}
+
+ if not eval_str:
+ return dict(ann)
+
+ if isinstance(obj, type):
+ # class
+ obj_globals = None
+ module_name = getattr(obj, "__module__", None)
+ if module_name:
+ module = sys.modules.get(module_name, None)
+ if module:
+ obj_globals = getattr(module, "__dict__", None)
+ obj_locals = dict(vars(obj))
+ unwrap = obj
+ elif isinstance(obj, types.ModuleType):
+ # module
+ obj_globals = getattr(obj, "__dict__")
+ obj_locals = None
+ unwrap = None
+ elif callable(obj):
+ # this includes types.Function, types.BuiltinFunctionType,
+ # types.BuiltinMethodType, functools.partial, functools.singledispatch,
+ # "class funclike" from Lib/test/test_inspect... on and on it goes.
+ obj_globals = getattr(obj, "__globals__", None)
+ obj_locals = None
+ unwrap = obj
+ elif ann is not None:
+ obj_globals = obj_locals = unwrap = None
+ else:
+ raise TypeError(f"{obj!r} is not a module, class, or callable.")
+
+ if unwrap is not None:
+ while True:
+ if hasattr(unwrap, "__wrapped__"):
+ unwrap = unwrap.__wrapped__
+ continue
+ if isinstance(unwrap, functools.partial):
+ unwrap = unwrap.func
+ continue
+ break
+ if hasattr(unwrap, "__globals__"):
+ obj_globals = unwrap.__globals__
+
+ if globals is None:
+ globals = obj_globals
+ if locals is None:
+ locals = obj_locals
+
+ # "Inject" type parameters into the local namespace
+ # (unless they are shadowed by assignments *in* the local namespace),
+ # as a way of emulating annotation scopes when calling `eval()`
+ if type_params := getattr(obj, "__type_params__", ()):
+ if locals is None:
+ locals = {}
+ locals = {param.__name__: param for param in type_params} | locals
+
+ return_value = {
+ key: value if not isinstance(value, str) else eval(value, globals, locals)
+ for key, value in ann.items()
+ }
+ return return_value