aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/annotationlib.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/annotationlib.py')
-rw-r--r--Lib/annotationlib.py379
1 files changed, 272 insertions, 107 deletions
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 971f636f971..5ad0893106a 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -12,7 +12,7 @@ __all__ = [
"ForwardRef",
"call_annotate_function",
"call_evaluate_function",
- "get_annotate_function",
+ "get_annotate_from_class_namespace",
"get_annotations",
"annotations_to_string",
"type_repr",
@@ -38,6 +38,7 @@ _SLOTS = (
"__weakref__",
"__arg__",
"__globals__",
+ "__extra_names__",
"__code__",
"__ast_node__",
"__cell__",
@@ -82,6 +83,7 @@ class ForwardRef:
# is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None
self.__cell__ = None
+ self.__extra_names__ = None
# These are initially None but serve as a cache and may be set to a non-None
# value later.
self.__code__ = None
@@ -90,11 +92,28 @@ class ForwardRef:
def __init_subclass__(cls, /, *args, **kwds):
raise TypeError("Cannot subclass ForwardRef")
- def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
+ def evaluate(
+ self,
+ *,
+ globals=None,
+ locals=None,
+ type_params=None,
+ owner=None,
+ format=Format.VALUE,
+ ):
"""Evaluate the forward reference and return the value.
If the forward reference cannot be evaluated, raise an exception.
"""
+ match format:
+ case Format.STRING:
+ return self.__forward_arg__
+ case Format.VALUE:
+ is_forwardref_format = False
+ case Format.FORWARDREF:
+ is_forwardref_format = True
+ case _:
+ raise NotImplementedError(format)
if self.__cell__ is not None:
try:
return self.__cell__.cell_contents
@@ -151,21 +170,42 @@ class ForwardRef:
if not self.__forward_is_class__ or param_name not in globals:
globals[param_name] = param
locals.pop(param_name, None)
+ if self.__extra_names__:
+ locals = {**locals, **self.__extra_names__}
arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg):
if arg in locals:
- value = locals[arg]
+ return locals[arg]
elif arg in globals:
- value = globals[arg]
+ return globals[arg]
elif hasattr(builtins, arg):
return getattr(builtins, arg)
+ elif is_forwardref_format:
+ return self
else:
raise NameError(arg)
else:
code = self.__forward_code__
- value = eval(code, globals=globals, locals=locals)
- return value
+ try:
+ return eval(code, globals=globals, locals=locals)
+ except Exception:
+ if not is_forwardref_format:
+ raise
+ new_locals = _StringifierDict(
+ {**builtins.__dict__, **locals},
+ globals=globals,
+ owner=owner,
+ is_class=self.__forward_is_class__,
+ format=format,
+ )
+ try:
+ result = eval(code, globals=globals, locals=new_locals)
+ except Exception:
+ return self
+ else:
+ new_locals.transmogrify()
+ return result
def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
import typing
@@ -231,6 +271,10 @@ class ForwardRef:
and self.__forward_is_class__ == other.__forward_is_class__
and self.__cell__ == other.__cell__
and self.__owner__ == other.__owner__
+ and (
+ (tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
+ (tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
+ )
)
def __hash__(self):
@@ -241,6 +285,7 @@ class ForwardRef:
self.__forward_is_class__,
self.__cell__,
self.__owner__,
+ tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
))
def __or__(self, other):
@@ -274,6 +319,7 @@ class _Stringifier:
cell=None,
*,
stringifier_dict,
+ extra_names=None,
):
# Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name).
@@ -285,6 +331,7 @@ class _Stringifier:
self.__code__ = None
self.__ast_node__ = node
self.__globals__ = globals
+ self.__extra_names__ = extra_names
self.__cell__ = cell
self.__owner__ = owner
self.__stringifier_dict__ = stringifier_dict
@@ -292,28 +339,63 @@ class _Stringifier:
def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str):
- return ast.Name(id=other.__ast_node__)
- return other.__ast_node__
- elif isinstance(other, slice):
+ return ast.Name(id=other.__ast_node__), other.__extra_names__
+ return other.__ast_node__, other.__extra_names__
+ elif (
+ # In STRING format we don't bother with the create_unique_name() dance;
+ # it's better to emit the repr() of the object instead of an opaque name.
+ self.__stringifier_dict__.format == Format.STRING
+ or other is None
+ or type(other) in (str, int, float, bool, complex)
+ ):
+ return ast.Constant(value=other), None
+ elif type(other) is dict:
+ extra_names = {}
+ keys = []
+ values = []
+ for key, value in other.items():
+ new_key, new_extra_names = self.__convert_to_ast(key)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ keys.append(new_key)
+ new_value, new_extra_names = self.__convert_to_ast(value)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ values.append(new_value)
+ return ast.Dict(keys, values), extra_names
+ elif type(other) in (list, tuple, set):
+ extra_names = {}
+ elts = []
+ for elt in other:
+ new_elt, new_extra_names = self.__convert_to_ast(elt)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ elts.append(new_elt)
+ ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
+ return ast_class(elts), extra_names
+ else:
+ name = self.__stringifier_dict__.create_unique_name()
+ return ast.Name(id=name), {name: other}
+
+ def __convert_to_ast_getitem(self, other):
+ if isinstance(other, slice):
+ extra_names = {}
+
+ def conv(obj):
+ if obj is None:
+ return None
+ new_obj, new_extra_names = self.__convert_to_ast(obj)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ return new_obj
+
return ast.Slice(
- lower=(
- self.__convert_to_ast(other.start)
- if other.start is not None
- else None
- ),
- upper=(
- self.__convert_to_ast(other.stop)
- if other.stop is not None
- else None
- ),
- step=(
- self.__convert_to_ast(other.step)
- if other.step is not None
- else None
- ),
- )
+ lower=conv(other.start),
+ upper=conv(other.stop),
+ step=conv(other.step),
+ ), extra_names
else:
- return ast.Constant(value=other)
+ return self.__convert_to_ast(other)
def __get_ast(self):
node = self.__ast_node__
@@ -321,13 +403,19 @@ class _Stringifier:
return ast.Name(id=node)
return node
- def __make_new(self, node):
+ def __make_new(self, node, extra_names=None):
+ new_extra_names = {}
+ if self.__extra_names__ is not None:
+ new_extra_names.update(self.__extra_names__)
+ if extra_names is not None:
+ new_extra_names.update(extra_names)
stringifier = _Stringifier(
node,
self.__globals__,
self.__owner__,
self.__forward_is_class__,
stringifier_dict=self.__stringifier_dict__,
+ extra_names=new_extra_names or None,
)
self.__stringifier_dict__.stringifiers.append(stringifier)
return stringifier
@@ -343,27 +431,37 @@ class _Stringifier:
if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
- elts = [self.__convert_to_ast(elt) for elt in other]
+ extra_names = {}
+ elts = []
+ for elt in other:
+ new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ elts.append(new_elt)
other = ast.Tuple(elts)
else:
- other = self.__convert_to_ast(other)
+ other, extra_names = self.__convert_to_ast_getitem(other)
assert isinstance(other, ast.AST), repr(other)
- return self.__make_new(ast.Subscript(self.__get_ast(), other))
+ return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)
def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__get_ast(), attr))
def __call__(self, *args, **kwargs):
- return self.__make_new(
- ast.Call(
- self.__get_ast(),
- [self.__convert_to_ast(arg) for arg in args],
- [
- ast.keyword(key, self.__convert_to_ast(value))
- for key, value in kwargs.items()
- ],
- )
- )
+ extra_names = {}
+ ast_args = []
+ for arg in args:
+ new_arg, new_extra_names = self.__convert_to_ast(arg)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ ast_args.append(new_arg)
+ ast_kwargs = []
+ for key, value in kwargs.items():
+ new_value, new_extra_names = self.__convert_to_ast(value)
+ if new_extra_names is not None:
+ extra_names.update(new_extra_names)
+ ast_kwargs.append(ast.keyword(key, new_value))
+ return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)
def __iter__(self):
yield self.__make_new(ast.Starred(self.__get_ast()))
@@ -378,8 +476,9 @@ class _Stringifier:
def _make_binop(op: ast.AST):
def binop(self, other):
+ rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
- ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
+ ast.BinOp(self.__get_ast(), op, rhs), extra_names
)
return binop
@@ -402,8 +501,9 @@ class _Stringifier:
def _make_rbinop(op: ast.AST):
def rbinop(self, other):
+ new_other, extra_names = self.__convert_to_ast(other)
return self.__make_new(
- ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
+ ast.BinOp(new_other, op, self.__get_ast()), extra_names
)
return rbinop
@@ -426,12 +526,14 @@ class _Stringifier:
def _make_compare(op):
def compare(self, other):
+ rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
ast.Compare(
left=self.__get_ast(),
ops=[op],
- comparators=[self.__convert_to_ast(other)],
- )
+ comparators=[rhs],
+ ),
+ extra_names,
)
return compare
@@ -459,13 +561,15 @@ class _Stringifier:
class _StringifierDict(dict):
- def __init__(self, namespace, globals=None, owner=None, is_class=False):
+ def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
super().__init__(namespace)
self.namespace = namespace
self.globals = globals
self.owner = owner
self.is_class = is_class
self.stringifiers = []
+ self.next_id = 1
+ self.format = format
def __missing__(self, key):
fwdref = _Stringifier(
@@ -478,6 +582,19 @@ class _StringifierDict(dict):
self.stringifiers.append(fwdref)
return fwdref
+ def transmogrify(self):
+ for obj in self.stringifiers:
+ obj.__class__ = ForwardRef
+ obj.__stringifier_dict__ = None # not needed for ForwardRef
+ if isinstance(obj.__ast_node__, str):
+ obj.__arg__ = obj.__ast_node__
+ obj.__ast_node__ = None
+
+ def create_unique_name(self):
+ name = f"__annotationlib_name_{self.next_id}__"
+ self.next_id += 1
+ return name
+
def call_evaluate_function(evaluate, format, *, owner=None):
"""Call an evaluate function. Evaluate functions are normally generated for
@@ -521,20 +638,11 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# possibly constants if the annotate function uses them directly). We then
# convert each of those into a string to get an approximation of the
# original source.
- globals = _StringifierDict({})
- if annotate.__closure__:
- freevars = annotate.__code__.co_freevars
- new_closure = []
- for i, cell in enumerate(annotate.__closure__):
- if i < len(freevars):
- name = freevars[i]
- else:
- name = "__cell__"
- fwdref = _Stringifier(name, stringifier_dict=globals)
- new_closure.append(types.CellType(fwdref))
- closure = tuple(new_closure)
- else:
- closure = None
+ globals = _StringifierDict({}, format=format)
+ is_class = isinstance(owner, type)
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=False
+ )
func = types.FunctionType(
annotate.__code__,
globals,
@@ -544,9 +652,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
)
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
if _is_evaluate:
- return annos if isinstance(annos, str) else repr(annos)
+ return _stringify_single(annos)
return {
- key: val if isinstance(val, str) else repr(val)
+ key: _stringify_single(val)
for key, val in annos.items()
}
elif format == Format.FORWARDREF:
@@ -569,33 +677,43 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# that returns a bool and an defined set of attributes.
namespace = {**annotate.__builtins__, **annotate.__globals__}
is_class = isinstance(owner, type)
- globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
- if annotate.__closure__:
- freevars = annotate.__code__.co_freevars
- new_closure = []
- for i, cell in enumerate(annotate.__closure__):
- try:
- cell.cell_contents
- except ValueError:
- if i < len(freevars):
- name = freevars[i]
- else:
- name = "__cell__"
- fwdref = _Stringifier(
- name,
- cell=cell,
- owner=owner,
- globals=annotate.__globals__,
- is_class=is_class,
- stringifier_dict=globals,
- )
- globals.stringifiers.append(fwdref)
- new_closure.append(types.CellType(fwdref))
- else:
- new_closure.append(cell)
- closure = tuple(new_closure)
+ globals = _StringifierDict(
+ namespace,
+ globals=annotate.__globals__,
+ owner=owner,
+ is_class=is_class,
+ format=format,
+ )
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=True
+ )
+ func = types.FunctionType(
+ annotate.__code__,
+ globals,
+ closure=closure,
+ argdefs=annotate.__defaults__,
+ kwdefaults=annotate.__kwdefaults__,
+ )
+ try:
+ result = func(Format.VALUE_WITH_FAKE_GLOBALS)
+ except Exception:
+ pass
else:
- closure = None
+ globals.transmogrify()
+ return result
+
+ # Try again, but do not provide any globals. This allows us to return
+ # a value in certain cases where an exception gets raised during evaluation.
+ globals = _StringifierDict(
+ {},
+ globals=annotate.__globals__,
+ owner=owner,
+ is_class=is_class,
+ format=format,
+ )
+ closure = _build_closure(
+ annotate, owner, is_class, globals, allow_evaluation=False
+ )
func = types.FunctionType(
annotate.__code__,
globals,
@@ -604,13 +722,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
kwdefaults=annotate.__kwdefaults__,
)
result = func(Format.VALUE_WITH_FAKE_GLOBALS)
- for obj in globals.stringifiers:
- obj.__class__ = ForwardRef
- obj.__stringifier_dict__ = None # not needed for ForwardRef
- if isinstance(obj.__ast_node__, str):
- obj.__arg__ = obj.__ast_node__
- obj.__ast_node__ = None
- return result
+ globals.transmogrify()
+ if _is_evaluate:
+ if isinstance(result, ForwardRef):
+ return result.evaluate(format=Format.FORWARDREF)
+ else:
+ return result
+ else:
+ return {
+ key: (
+ val.evaluate(format=Format.FORWARDREF)
+ if isinstance(val, ForwardRef)
+ else val
+ )
+ for key, val in result.items()
+ }
elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise
# NotImplementedError for this format.
@@ -619,20 +745,59 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
raise ValueError(f"Invalid format: {format!r}")
-def get_annotate_function(obj):
- """Get the __annotate__ function for an object.
+def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation):
+ if not annotate.__closure__:
+ return None
+ freevars = annotate.__code__.co_freevars
+ new_closure = []
+ for i, cell in enumerate(annotate.__closure__):
+ if i < len(freevars):
+ name = freevars[i]
+ else:
+ name = "__cell__"
+ new_cell = None
+ if allow_evaluation:
+ try:
+ cell.cell_contents
+ except ValueError:
+ pass
+ else:
+ new_cell = cell
+ if new_cell is None:
+ fwdref = _Stringifier(
+ name,
+ cell=cell,
+ owner=owner,
+ globals=annotate.__globals__,
+ is_class=is_class,
+ stringifier_dict=stringifier_dict,
+ )
+ stringifier_dict.stringifiers.append(fwdref)
+ new_cell = types.CellType(fwdref)
+ new_closure.append(new_cell)
+ return tuple(new_closure)
+
+
+def _stringify_single(anno):
+ if anno is ...:
+ return "..."
+ # We have to handle str specially to support PEP 563 stringified annotations.
+ elif isinstance(anno, str):
+ return anno
+ else:
+ return repr(anno)
+
- obj may be a function, class, or module, or a user-defined type with
- an `__annotate__` attribute.
+def get_annotate_from_class_namespace(obj):
+ """Retrieve the annotate function from a class namespace dictionary.
- Returns the __annotate__ function or None.
+ Return None if the namespace does not contain an annotate function.
+ This is useful in metaclass ``__new__`` methods to retrieve the annotate function.
"""
- if isinstance(obj, dict):
- try:
- return obj["__annotate__"]
- except KeyError:
- return obj.get("__annotate_func__", None)
- return getattr(obj, "__annotate__", None)
+ try:
+ return obj["__annotate__"]
+ except KeyError:
+ return obj.get("__annotate_func__", None)
def get_annotations(
@@ -724,7 +889,7 @@ def get_annotations(
# But if we didn't get it, we use __annotations__ instead.
ann = _get_dunder_annotations(obj)
if ann is not None:
- return annotations_to_string(ann)
+ return annotations_to_string(ann)
case Format.VALUE_WITH_FAKE_GLOBALS:
raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only")
case _:
@@ -832,7 +997,7 @@ def _get_and_call_annotate(obj, format):
May not return a fresh dictionary.
"""
- annotate = get_annotate_function(obj)
+ annotate = getattr(obj, "__annotate__", None)
if annotate is not None:
ann = call_annotate_function(annotate, format, owner=obj)
if not isinstance(ann, dict):