aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/_ast_unparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/_ast_unparse.py')
-rw-r--r--Lib/_ast_unparse.py96
1 files changed, 69 insertions, 27 deletions
diff --git a/Lib/_ast_unparse.py b/Lib/_ast_unparse.py
index 56d9e935dd9..c25066eb107 100644
--- a/Lib/_ast_unparse.py
+++ b/Lib/_ast_unparse.py
@@ -573,21 +573,11 @@ class Unparser(NodeVisitor):
quote_type = quote_types[0]
self.write(f"{quote_type}{string}{quote_type}")
- def visit_JoinedStr(self, node):
- self.write("f")
-
- fstring_parts = []
- for value in node.values:
- with self.buffered() as buffer:
- self._write_fstring_inner(value)
- fstring_parts.append(
- ("".join(buffer), isinstance(value, Constant))
- )
-
- new_fstring_parts = []
+ def _ftstring_helper(self, parts):
+ new_parts = []
quote_types = list(_ALL_QUOTES)
fallback_to_repr = False
- for value, is_constant in fstring_parts:
+ for value, is_constant in parts:
if is_constant:
value, new_quote_types = self._str_literal_helper(
value,
@@ -606,30 +596,71 @@ class Unparser(NodeVisitor):
new_quote_types = [q for q in quote_types if q not in value]
if new_quote_types:
quote_types = new_quote_types
- new_fstring_parts.append(value)
+ new_parts.append(value)
if fallback_to_repr:
# If we weren't able to find a quote type that works for all parts
# of the JoinedStr, fallback to using repr and triple single quotes.
quote_types = ["'''"]
- new_fstring_parts.clear()
- for value, is_constant in fstring_parts:
+ new_parts.clear()
+ for value, is_constant in parts:
if is_constant:
value = repr('"' + value) # force repr to use single quotes
expected_prefix = "'\""
assert value.startswith(expected_prefix), repr(value)
value = value[len(expected_prefix):-1]
- new_fstring_parts.append(value)
+ new_parts.append(value)
- value = "".join(new_fstring_parts)
+ value = "".join(new_parts)
quote_type = quote_types[0]
self.write(f"{quote_type}{value}{quote_type}")
- def _write_fstring_inner(self, node, is_format_spec=False):
+ def _write_ftstring(self, values, prefix):
+ self.write(prefix)
+ fstring_parts = []
+ for value in values:
+ with self.buffered() as buffer:
+ self._write_ftstring_inner(value)
+ fstring_parts.append(
+ ("".join(buffer), isinstance(value, Constant))
+ )
+ self._ftstring_helper(fstring_parts)
+
+ def _tstring_helper(self, node):
+ if not node.values:
+ self._write_ftstring([], "t")
+ return
+ last_idx = 0
+ for i, value in enumerate(node.values):
+ # This can happen if we have an implicit concat of a t-string
+ # with an f-string
+ if isinstance(value, FormattedValue):
+ if i > last_idx:
+ # Write t-string until here
+ self._write_ftstring(node.values[last_idx:i], "t")
+ self.write(" ")
+ # Write f-string with the current formatted value
+ self._write_ftstring([node.values[i]], "f")
+ if i + 1 < len(node.values):
+ # Only add a space if there are more values after this
+ self.write(" ")
+ last_idx = i + 1
+
+ if last_idx < len(node.values):
+ # Write t-string from last_idx to end
+ self._write_ftstring(node.values[last_idx:], "t")
+
+ def visit_JoinedStr(self, node):
+ self._write_ftstring(node.values, "f")
+
+ def visit_TemplateStr(self, node):
+ self._tstring_helper(node)
+
+ def _write_ftstring_inner(self, node, is_format_spec=False):
if isinstance(node, JoinedStr):
# for both the f-string itself, and format_spec
for value in node.values:
- self._write_fstring_inner(value, is_format_spec=is_format_spec)
+ self._write_ftstring_inner(value, is_format_spec=is_format_spec)
elif isinstance(node, Constant) and isinstance(node.value, str):
value = node.value.replace("{", "{{").replace("}", "}}")
@@ -641,17 +672,22 @@ class Unparser(NodeVisitor):
self.write(value)
elif isinstance(node, FormattedValue):
self.visit_FormattedValue(node)
+ elif isinstance(node, Interpolation):
+ self.visit_Interpolation(node)
else:
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
- def visit_FormattedValue(self, node):
- def unparse_inner(inner):
- unparser = type(self)()
- unparser.set_precedence(_Precedence.TEST.next(), inner)
- return unparser.visit(inner)
+ def _unparse_interpolation_value(self, inner):
+ unparser = type(self)()
+ unparser.set_precedence(_Precedence.TEST.next(), inner)
+ return unparser.visit(inner)
+ def _write_interpolation(self, node, is_interpolation=False):
with self.delimit("{", "}"):
- expr = unparse_inner(node.value)
+ if is_interpolation:
+ expr = node.str
+ else:
+ expr = self._unparse_interpolation_value(node.value)
if expr.startswith("{"):
# Separate pair of opening brackets as "{ {"
self.write(" ")
@@ -660,7 +696,13 @@ class Unparser(NodeVisitor):
self.write(f"!{chr(node.conversion)}")
if node.format_spec:
self.write(":")
- self._write_fstring_inner(node.format_spec, is_format_spec=True)
+ self._write_ftstring_inner(node.format_spec, is_format_spec=True)
+
+ def visit_FormattedValue(self, node):
+ self._write_interpolation(node)
+
+ def visit_Interpolation(self, node):
+ self._write_interpolation(node, is_interpolation=True)
def visit_Name(self, node):
self.write(node.id)