diff options
Diffstat (limited to 'Lib/annotationlib.py')
-rw-r--r-- | Lib/annotationlib.py | 379 |
1 files changed, 272 insertions, 107 deletions
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 971f636f971..5ad0893106a 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -12,7 +12,7 @@ __all__ = [ "ForwardRef", "call_annotate_function", "call_evaluate_function", - "get_annotate_function", + "get_annotate_from_class_namespace", "get_annotations", "annotations_to_string", "type_repr", @@ -38,6 +38,7 @@ _SLOTS = ( "__weakref__", "__arg__", "__globals__", + "__extra_names__", "__code__", "__ast_node__", "__cell__", @@ -82,6 +83,7 @@ class ForwardRef: # is created through __class__ assignment on a _Stringifier object. self.__globals__ = None self.__cell__ = None + self.__extra_names__ = None # These are initially None but serve as a cache and may be set to a non-None # value later. self.__code__ = None @@ -90,11 +92,28 @@ class ForwardRef: def __init_subclass__(cls, /, *args, **kwds): raise TypeError("Cannot subclass ForwardRef") - def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): + def evaluate( + self, + *, + globals=None, + locals=None, + type_params=None, + owner=None, + format=Format.VALUE, + ): """Evaluate the forward reference and return the value. If the forward reference cannot be evaluated, raise an exception. """ + match format: + case Format.STRING: + return self.__forward_arg__ + case Format.VALUE: + is_forwardref_format = False + case Format.FORWARDREF: + is_forwardref_format = True + case _: + raise NotImplementedError(format) if self.__cell__ is not None: try: return self.__cell__.cell_contents @@ -151,21 +170,42 @@ class ForwardRef: if not self.__forward_is_class__ or param_name not in globals: globals[param_name] = param locals.pop(param_name, None) + if self.__extra_names__: + locals = {**locals, **self.__extra_names__} arg = self.__forward_arg__ if arg.isidentifier() and not keyword.iskeyword(arg): if arg in locals: - value = locals[arg] + return locals[arg] elif arg in globals: - value = globals[arg] + return globals[arg] elif hasattr(builtins, arg): return getattr(builtins, arg) + elif is_forwardref_format: + return self else: raise NameError(arg) else: code = self.__forward_code__ - value = eval(code, globals=globals, locals=locals) - return value + try: + return eval(code, globals=globals, locals=locals) + except Exception: + if not is_forwardref_format: + raise + new_locals = _StringifierDict( + {**builtins.__dict__, **locals}, + globals=globals, + owner=owner, + is_class=self.__forward_is_class__, + format=format, + ) + try: + result = eval(code, globals=globals, locals=new_locals) + except Exception: + return self + else: + new_locals.transmogrify() + return result def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard): import typing @@ -231,6 +271,10 @@ class ForwardRef: and self.__forward_is_class__ == other.__forward_is_class__ and self.__cell__ == other.__cell__ and self.__owner__ == other.__owner__ + and ( + (tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) == + (tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None) + ) ) def __hash__(self): @@ -241,6 +285,7 @@ class ForwardRef: self.__forward_is_class__, self.__cell__, self.__owner__, + tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None, )) def __or__(self, other): @@ -274,6 +319,7 @@ class _Stringifier: cell=None, *, stringifier_dict, + extra_names=None, ): # Either an AST node or a simple str (for the common case where a ForwardRef # represent a single name). @@ -285,6 +331,7 @@ class _Stringifier: self.__code__ = None self.__ast_node__ = node self.__globals__ = globals + self.__extra_names__ = extra_names self.__cell__ = cell self.__owner__ = owner self.__stringifier_dict__ = stringifier_dict @@ -292,28 +339,63 @@ class _Stringifier: def __convert_to_ast(self, other): if isinstance(other, _Stringifier): if isinstance(other.__ast_node__, str): - return ast.Name(id=other.__ast_node__) - return other.__ast_node__ - elif isinstance(other, slice): + return ast.Name(id=other.__ast_node__), other.__extra_names__ + return other.__ast_node__, other.__extra_names__ + elif ( + # In STRING format we don't bother with the create_unique_name() dance; + # it's better to emit the repr() of the object instead of an opaque name. + self.__stringifier_dict__.format == Format.STRING + or other is None + or type(other) in (str, int, float, bool, complex) + ): + return ast.Constant(value=other), None + elif type(other) is dict: + extra_names = {} + keys = [] + values = [] + for key, value in other.items(): + new_key, new_extra_names = self.__convert_to_ast(key) + if new_extra_names is not None: + extra_names.update(new_extra_names) + keys.append(new_key) + new_value, new_extra_names = self.__convert_to_ast(value) + if new_extra_names is not None: + extra_names.update(new_extra_names) + values.append(new_value) + return ast.Dict(keys, values), extra_names + elif type(other) in (list, tuple, set): + extra_names = {} + elts = [] + for elt in other: + new_elt, new_extra_names = self.__convert_to_ast(elt) + if new_extra_names is not None: + extra_names.update(new_extra_names) + elts.append(new_elt) + ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)] + return ast_class(elts), extra_names + else: + name = self.__stringifier_dict__.create_unique_name() + return ast.Name(id=name), {name: other} + + def __convert_to_ast_getitem(self, other): + if isinstance(other, slice): + extra_names = {} + + def conv(obj): + if obj is None: + return None + new_obj, new_extra_names = self.__convert_to_ast(obj) + if new_extra_names is not None: + extra_names.update(new_extra_names) + return new_obj + return ast.Slice( - lower=( - self.__convert_to_ast(other.start) - if other.start is not None - else None - ), - upper=( - self.__convert_to_ast(other.stop) - if other.stop is not None - else None - ), - step=( - self.__convert_to_ast(other.step) - if other.step is not None - else None - ), - ) + lower=conv(other.start), + upper=conv(other.stop), + step=conv(other.step), + ), extra_names else: - return ast.Constant(value=other) + return self.__convert_to_ast(other) def __get_ast(self): node = self.__ast_node__ @@ -321,13 +403,19 @@ class _Stringifier: return ast.Name(id=node) return node - def __make_new(self, node): + def __make_new(self, node, extra_names=None): + new_extra_names = {} + if self.__extra_names__ is not None: + new_extra_names.update(self.__extra_names__) + if extra_names is not None: + new_extra_names.update(extra_names) stringifier = _Stringifier( node, self.__globals__, self.__owner__, self.__forward_is_class__, stringifier_dict=self.__stringifier_dict__, + extra_names=new_extra_names or None, ) self.__stringifier_dict__.stringifiers.append(stringifier) return stringifier @@ -343,27 +431,37 @@ class _Stringifier: if self.__ast_node__ == "__classdict__": raise KeyError if isinstance(other, tuple): - elts = [self.__convert_to_ast(elt) for elt in other] + extra_names = {} + elts = [] + for elt in other: + new_elt, new_extra_names = self.__convert_to_ast_getitem(elt) + if new_extra_names is not None: + extra_names.update(new_extra_names) + elts.append(new_elt) other = ast.Tuple(elts) else: - other = self.__convert_to_ast(other) + other, extra_names = self.__convert_to_ast_getitem(other) assert isinstance(other, ast.AST), repr(other) - return self.__make_new(ast.Subscript(self.__get_ast(), other)) + return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names) def __getattr__(self, attr): return self.__make_new(ast.Attribute(self.__get_ast(), attr)) def __call__(self, *args, **kwargs): - return self.__make_new( - ast.Call( - self.__get_ast(), - [self.__convert_to_ast(arg) for arg in args], - [ - ast.keyword(key, self.__convert_to_ast(value)) - for key, value in kwargs.items() - ], - ) - ) + extra_names = {} + ast_args = [] + for arg in args: + new_arg, new_extra_names = self.__convert_to_ast(arg) + if new_extra_names is not None: + extra_names.update(new_extra_names) + ast_args.append(new_arg) + ast_kwargs = [] + for key, value in kwargs.items(): + new_value, new_extra_names = self.__convert_to_ast(value) + if new_extra_names is not None: + extra_names.update(new_extra_names) + ast_kwargs.append(ast.keyword(key, new_value)) + return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names) def __iter__(self): yield self.__make_new(ast.Starred(self.__get_ast())) @@ -378,8 +476,9 @@ class _Stringifier: def _make_binop(op: ast.AST): def binop(self, other): + rhs, extra_names = self.__convert_to_ast(other) return self.__make_new( - ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other)) + ast.BinOp(self.__get_ast(), op, rhs), extra_names ) return binop @@ -402,8 +501,9 @@ class _Stringifier: def _make_rbinop(op: ast.AST): def rbinop(self, other): + new_other, extra_names = self.__convert_to_ast(other) return self.__make_new( - ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast()) + ast.BinOp(new_other, op, self.__get_ast()), extra_names ) return rbinop @@ -426,12 +526,14 @@ class _Stringifier: def _make_compare(op): def compare(self, other): + rhs, extra_names = self.__convert_to_ast(other) return self.__make_new( ast.Compare( left=self.__get_ast(), ops=[op], - comparators=[self.__convert_to_ast(other)], - ) + comparators=[rhs], + ), + extra_names, ) return compare @@ -459,13 +561,15 @@ class _Stringifier: class _StringifierDict(dict): - def __init__(self, namespace, globals=None, owner=None, is_class=False): + def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format): super().__init__(namespace) self.namespace = namespace self.globals = globals self.owner = owner self.is_class = is_class self.stringifiers = [] + self.next_id = 1 + self.format = format def __missing__(self, key): fwdref = _Stringifier( @@ -478,6 +582,19 @@ class _StringifierDict(dict): self.stringifiers.append(fwdref) return fwdref + def transmogrify(self): + for obj in self.stringifiers: + obj.__class__ = ForwardRef + obj.__stringifier_dict__ = None # not needed for ForwardRef + if isinstance(obj.__ast_node__, str): + obj.__arg__ = obj.__ast_node__ + obj.__ast_node__ = None + + def create_unique_name(self): + name = f"__annotationlib_name_{self.next_id}__" + self.next_id += 1 + return name + def call_evaluate_function(evaluate, format, *, owner=None): """Call an evaluate function. Evaluate functions are normally generated for @@ -521,20 +638,11 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # possibly constants if the annotate function uses them directly). We then # convert each of those into a string to get an approximation of the # original source. - globals = _StringifierDict({}) - if annotate.__closure__: - freevars = annotate.__code__.co_freevars - new_closure = [] - for i, cell in enumerate(annotate.__closure__): - if i < len(freevars): - name = freevars[i] - else: - name = "__cell__" - fwdref = _Stringifier(name, stringifier_dict=globals) - new_closure.append(types.CellType(fwdref)) - closure = tuple(new_closure) - else: - closure = None + globals = _StringifierDict({}, format=format) + is_class = isinstance(owner, type) + closure = _build_closure( + annotate, owner, is_class, globals, allow_evaluation=False + ) func = types.FunctionType( annotate.__code__, globals, @@ -544,9 +652,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): ) annos = func(Format.VALUE_WITH_FAKE_GLOBALS) if _is_evaluate: - return annos if isinstance(annos, str) else repr(annos) + return _stringify_single(annos) return { - key: val if isinstance(val, str) else repr(val) + key: _stringify_single(val) for key, val in annos.items() } elif format == Format.FORWARDREF: @@ -569,33 +677,43 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # that returns a bool and an defined set of attributes. namespace = {**annotate.__builtins__, **annotate.__globals__} is_class = isinstance(owner, type) - globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class) - if annotate.__closure__: - freevars = annotate.__code__.co_freevars - new_closure = [] - for i, cell in enumerate(annotate.__closure__): - try: - cell.cell_contents - except ValueError: - if i < len(freevars): - name = freevars[i] - else: - name = "__cell__" - fwdref = _Stringifier( - name, - cell=cell, - owner=owner, - globals=annotate.__globals__, - is_class=is_class, - stringifier_dict=globals, - ) - globals.stringifiers.append(fwdref) - new_closure.append(types.CellType(fwdref)) - else: - new_closure.append(cell) - closure = tuple(new_closure) + globals = _StringifierDict( + namespace, + globals=annotate.__globals__, + owner=owner, + is_class=is_class, + format=format, + ) + closure = _build_closure( + annotate, owner, is_class, globals, allow_evaluation=True + ) + func = types.FunctionType( + annotate.__code__, + globals, + closure=closure, + argdefs=annotate.__defaults__, + kwdefaults=annotate.__kwdefaults__, + ) + try: + result = func(Format.VALUE_WITH_FAKE_GLOBALS) + except Exception: + pass else: - closure = None + globals.transmogrify() + return result + + # Try again, but do not provide any globals. This allows us to return + # a value in certain cases where an exception gets raised during evaluation. + globals = _StringifierDict( + {}, + globals=annotate.__globals__, + owner=owner, + is_class=is_class, + format=format, + ) + closure = _build_closure( + annotate, owner, is_class, globals, allow_evaluation=False + ) func = types.FunctionType( annotate.__code__, globals, @@ -604,13 +722,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): kwdefaults=annotate.__kwdefaults__, ) result = func(Format.VALUE_WITH_FAKE_GLOBALS) - for obj in globals.stringifiers: - obj.__class__ = ForwardRef - obj.__stringifier_dict__ = None # not needed for ForwardRef - if isinstance(obj.__ast_node__, str): - obj.__arg__ = obj.__ast_node__ - obj.__ast_node__ = None - return result + globals.transmogrify() + if _is_evaluate: + if isinstance(result, ForwardRef): + return result.evaluate(format=Format.FORWARDREF) + else: + return result + else: + return { + key: ( + val.evaluate(format=Format.FORWARDREF) + if isinstance(val, ForwardRef) + else val + ) + for key, val in result.items() + } elif format == Format.VALUE: # Should be impossible because __annotate__ functions must not raise # NotImplementedError for this format. @@ -619,20 +745,59 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): raise ValueError(f"Invalid format: {format!r}") -def get_annotate_function(obj): - """Get the __annotate__ function for an object. +def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation): + if not annotate.__closure__: + return None + freevars = annotate.__code__.co_freevars + new_closure = [] + for i, cell in enumerate(annotate.__closure__): + if i < len(freevars): + name = freevars[i] + else: + name = "__cell__" + new_cell = None + if allow_evaluation: + try: + cell.cell_contents + except ValueError: + pass + else: + new_cell = cell + if new_cell is None: + fwdref = _Stringifier( + name, + cell=cell, + owner=owner, + globals=annotate.__globals__, + is_class=is_class, + stringifier_dict=stringifier_dict, + ) + stringifier_dict.stringifiers.append(fwdref) + new_cell = types.CellType(fwdref) + new_closure.append(new_cell) + return tuple(new_closure) + + +def _stringify_single(anno): + if anno is ...: + return "..." + # We have to handle str specially to support PEP 563 stringified annotations. + elif isinstance(anno, str): + return anno + else: + return repr(anno) + - obj may be a function, class, or module, or a user-defined type with - an `__annotate__` attribute. +def get_annotate_from_class_namespace(obj): + """Retrieve the annotate function from a class namespace dictionary. - Returns the __annotate__ function or None. + Return None if the namespace does not contain an annotate function. + This is useful in metaclass ``__new__`` methods to retrieve the annotate function. """ - if isinstance(obj, dict): - try: - return obj["__annotate__"] - except KeyError: - return obj.get("__annotate_func__", None) - return getattr(obj, "__annotate__", None) + try: + return obj["__annotate__"] + except KeyError: + return obj.get("__annotate_func__", None) def get_annotations( @@ -724,7 +889,7 @@ def get_annotations( # But if we didn't get it, we use __annotations__ instead. ann = _get_dunder_annotations(obj) if ann is not None: - return annotations_to_string(ann) + return annotations_to_string(ann) case Format.VALUE_WITH_FAKE_GLOBALS: raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only") case _: @@ -832,7 +997,7 @@ def _get_and_call_annotate(obj, format): May not return a fresh dictionary. """ - annotate = get_annotate_function(obj) + annotate = getattr(obj, "__annotate__", None) if annotate is not None: ann = call_annotate_function(annotate, format, owner=obj) if not isinstance(ann, dict): |