diff options
Diffstat (limited to 'Lib/sre_parse.py')
-rw-r--r-- | Lib/sre_parse.py | 160 |
1 files changed, 93 insertions, 67 deletions
diff --git a/Lib/sre_parse.py b/Lib/sre_parse.py index 7149dca491b..9aea56a825b 100644 --- a/Lib/sre_parse.py +++ b/Lib/sre_parse.py @@ -58,6 +58,7 @@ FLAGS = { "s": SRE_FLAG_DOTALL, "x": SRE_FLAG_VERBOSE, # extensions + "a": SRE_FLAG_ASCII, "t": SRE_FLAG_TEMPLATE, "u": SRE_FLAG_UNICODE, } @@ -75,8 +76,8 @@ class Pattern: if name is not None: ogid = self.groupdict.get(name, None) if ogid is not None: - raise error, ("redefinition of group name %s as group %d; " - "was group %d" % (repr(name), gid, ogid)) + raise error("redefinition of group name %s as group %d; " + "was group %d" % (repr(name), gid, ogid)) self.groupdict[name] = gid self.open.append(gid) return gid @@ -95,32 +96,32 @@ class SubPattern: self.width = None def dump(self, level=0): nl = 1 - seqtypes = type(()), type([]) + seqtypes = (tuple, list) for op, av in self.data: - print level*" " + op,; nl = 0 + print(level*" " + op, end=' '); nl = 0 if op == "in": # member sublanguage - print; nl = 1 + print(); nl = 1 for op, a in av: - print (level+1)*" " + op, a + print((level+1)*" " + op, a) elif op == "branch": - print; nl = 1 + print(); nl = 1 i = 0 for a in av[1]: if i > 0: - print level*" " + "or" + print(level*" " + "or") a.dump(level+1); nl = 1 i = i + 1 - elif type(av) in seqtypes: + elif isinstance(av, seqtypes): for a in av: if isinstance(a, SubPattern): - if not nl: print + if not nl: print() a.dump(level+1); nl = 1 else: - print a, ; nl = 0 + print(a, end=' ') ; nl = 0 else: - print av, ; nl = 0 - if not nl: print + print(av, end=' ') ; nl = 0 + if not nl: print() def __repr__(self): return repr(self.data) def __len__(self): @@ -141,12 +142,12 @@ class SubPattern: # determine the width (min, max) for this subpattern if self.width: return self.width - lo = hi = 0L + lo = hi = 0 UNITCODES = (ANY, RANGE, IN, LITERAL, NOT_LITERAL, CATEGORY) REPEATCODES = (MIN_REPEAT, MAX_REPEAT) for op, av in self.data: if op is BRANCH: - i = sys.maxint + i = sys.maxsize j = 0 for av in av[1]: l, h = av.getwidth() @@ -164,14 +165,14 @@ class SubPattern: hi = hi + j elif op in REPEATCODES: i, j = av[2].getwidth() - lo = lo + long(i) * av[0] - hi = hi + long(j) * av[1] + lo = lo + int(i) * av[0] + hi = hi + int(j) * av[1] elif op in UNITCODES: lo = lo + 1 hi = hi + 1 elif op == SUCCESS: break - self.width = int(min(lo, sys.maxint)), int(min(hi, sys.maxint)) + self.width = int(min(lo, sys.maxsize)), int(min(hi, sys.maxsize)) return self.width class Tokenizer: @@ -183,12 +184,18 @@ class Tokenizer: if self.index >= len(self.string): self.next = None return - char = self.string[self.index] - if char[0] == "\\": + char = self.string[self.index:self.index+1] + # Special case for the str8, since indexing returns a integer + # XXX This is only needed for test_bug_926075 in test_re.py + if char and isinstance(char, bytes): + char = chr(char[0]) + if char == "\\": try: c = self.string[self.index + 1] except IndexError: - raise error, "bogus escape (end of line)" + raise error("bogus escape (end of line)") + if isinstance(self.string, bytes): + c = chr(c) char = char + c self.index = self.index + len(char) self.next = char @@ -238,7 +245,7 @@ def _class_escape(source, escape): escape = escape + source.get() escape = escape[2:] if len(escape) != 2: - raise error, "bogus escape: %s" % repr("\\" + escape) + raise error("bogus escape: %s" % repr("\\" + escape)) return LITERAL, int(escape, 16) & 0xff elif c in OCTDIGITS: # octal escape (up to three digits) @@ -247,12 +254,12 @@ def _class_escape(source, escape): escape = escape[1:] return LITERAL, int(escape, 8) & 0xff elif c in DIGITS: - raise error, "bogus escape: %s" % repr(escape) + raise error("bogus escape: %s" % repr(escape)) if len(escape) == 2: return LITERAL, ord(escape[1]) except ValueError: pass - raise error, "bogus escape: %s" % repr(escape) + raise error("bogus escape: %s" % repr(escape)) def _escape(source, escape, state): # handle escape code in expression @@ -289,14 +296,14 @@ def _escape(source, escape, state): group = int(escape[1:]) if group < state.groups: if not state.checkgroup(group): - raise error, "cannot refer to open group" + raise error("cannot refer to open group") return GROUPREF, group raise ValueError if len(escape) == 2: return LITERAL, ord(escape[1]) except ValueError: pass - raise error, "bogus escape: %s" % repr(escape) + raise error("bogus escape: %s" % repr(escape)) def _parse_sub(source, state, nested=1): # parse an alternation: a|b|c @@ -313,7 +320,7 @@ def _parse_sub(source, state, nested=1): if not source.next or sourcematch(")", 0): break else: - raise error, "pattern not properly closed" + raise error("pattern not properly closed") if len(items) == 1: return items[0] @@ -362,11 +369,11 @@ def _parse_sub_cond(source, state, condgroup): if source.match("|"): item_no = _parse(source, state) if source.match("|"): - raise error, "conditional backref with more than two branches" + raise error("conditional backref with more than two branches") else: item_no = None if source.next and not source.match(")", 0): - raise error, "pattern not properly closed" + raise error("pattern not properly closed") subpattern = SubPattern(state) subpattern.append((GROUPREF_EXISTS, (condgroup, item_yes, item_no))) return subpattern @@ -431,7 +438,7 @@ def _parse(source, state): elif this: code1 = LITERAL, ord(this) else: - raise error, "unexpected end of regular expression" + raise error("unexpected end of regular expression") if sourcematch("-"): # potential range this = sourceget() @@ -447,14 +454,14 @@ def _parse(source, state): else: code2 = LITERAL, ord(this) if code1[0] != LITERAL or code2[0] != LITERAL: - raise error, "bad character range" + raise error("bad character range") lo = code1[1] hi = code2[1] if hi < lo: - raise error, "bad character range" + raise error("bad character range") setappend((RANGE, (lo, hi))) else: - raise error, "unexpected end of regular expression" + raise error("unexpected end of regular expression") else: if code1[0] is IN: code1 = code1[1][0] @@ -501,18 +508,18 @@ def _parse(source, state): if hi: max = int(hi) if max < min: - raise error, "bad repeat interval" + raise error("bad repeat interval") else: - raise error, "not supported" + raise error("not supported") # figure out which item to repeat if subpattern: item = subpattern[-1:] else: item = None if not item or (_len(item) == 1 and item[0][0] == AT): - raise error, "nothing to repeat" + raise error("nothing to repeat") if item[0][0] in REPEATCODES: - raise error, "multiple repeat" + raise error("multiple repeat") if sourcematch("?"): subpattern[-1] = (MIN_REPEAT, (min, max, item)) else: @@ -536,7 +543,7 @@ def _parse(source, state): while 1: char = sourceget() if char is None: - raise error, "unterminated name" + raise error("unterminated name") if char == ">": break name = name + char @@ -544,31 +551,31 @@ def _parse(source, state): if not name: raise error("missing group name") if not isname(name): - raise error, "bad character in group name" + raise error("bad character in group name") elif sourcematch("="): # named backreference name = "" while 1: char = sourceget() if char is None: - raise error, "unterminated name" + raise error("unterminated name") if char == ")": break name = name + char if not name: raise error("missing group name") if not isname(name): - raise error, "bad character in group name" + raise error("bad character in group name") gid = state.groupdict.get(name) if gid is None: - raise error, "unknown group name" + raise error("unknown group name") subpatternappend((GROUPREF, gid)) continue else: char = sourceget() if char is None: - raise error, "unexpected end of pattern" - raise error, "unknown specifier: ?P%s" % char + raise error("unexpected end of pattern") + raise error("unknown specifier: ?P%s" % char) elif sourcematch(":"): # non-capturing group group = 2 @@ -579,7 +586,7 @@ def _parse(source, state): break sourceget() if not sourcematch(")"): - raise error, "unbalanced parenthesis" + raise error("unbalanced parenthesis") continue elif source.next in ASSERTCHARS: # lookahead assertions @@ -587,12 +594,12 @@ def _parse(source, state): dir = 1 if char == "<": if source.next not in LOOKBEHINDASSERTCHARS: - raise error, "syntax error" + raise error("syntax error") dir = -1 # lookbehind char = sourceget() p = _parse_sub(source, state) if not sourcematch(")"): - raise error, "unbalanced parenthesis" + raise error("unbalanced parenthesis") if char == "=": subpatternappend((ASSERT, (dir, p))) else: @@ -604,7 +611,7 @@ def _parse(source, state): while 1: char = sourceget() if char is None: - raise error, "unterminated name" + raise error("unterminated name") if char == ")": break condname = condname + char @@ -614,16 +621,16 @@ def _parse(source, state): if isname(condname): condgroup = state.groupdict.get(condname) if condgroup is None: - raise error, "unknown group name" + raise error("unknown group name") else: try: condgroup = int(condname) except ValueError: - raise error, "bad character in group name" + raise error("bad character in group name") else: # flags if not source.next in FLAGS: - raise error, "unexpected end of pattern" + raise error("unexpected end of pattern") while source.next in FLAGS: state.flags = state.flags | FLAGS[sourceget()] if group: @@ -638,7 +645,7 @@ def _parse(source, state): else: p = _parse_sub(source, state) if not sourcematch(")"): - raise error, "unbalanced parenthesis" + raise error("unbalanced parenthesis") if group is not None: state.closegroup(group) subpatternappend((SUBPATTERN, (group, p))) @@ -646,10 +653,10 @@ def _parse(source, state): while 1: char = sourceget() if char is None: - raise error, "unexpected end of pattern" + raise error("unexpected end of pattern") if char == ")": break - raise error, "unknown extension" + raise error("unknown extension") elif this == "^": subpatternappend((AT, AT_BEGINNING)) @@ -662,10 +669,22 @@ def _parse(source, state): subpatternappend(code) else: - raise error, "parser error" + raise error("parser error") return subpattern +def fix_flags(src, flags): + # Check and fix flags according to the type of pattern (str or bytes) + if isinstance(src, str): + if not flags & SRE_FLAG_ASCII: + flags |= SRE_FLAG_UNICODE + elif flags & SRE_FLAG_UNICODE: + raise ValueError("ASCII and UNICODE flags are incompatible") + else: + if flags & SRE_FLAG_UNICODE: + raise ValueError("can't use UNICODE flag with a bytes pattern") + return flags + def parse(str, flags=0, pattern=None): # parse 're' pattern into list of (opcode, argument) tuples @@ -677,12 +696,13 @@ def parse(str, flags=0, pattern=None): pattern.str = str p = _parse_sub(source, pattern, 0) + p.pattern.flags = fix_flags(str, p.pattern.flags) tail = source.get() if tail == ")": - raise error, "unbalanced parenthesis" + raise error("unbalanced parenthesis") elif tail: - raise error, "bogus characters at end of regular expression" + raise error("bogus characters at end of regular expression") if flags & SRE_FLAG_DEBUG: p.dump() @@ -707,10 +727,10 @@ def parse_template(source, pattern): else: pappend((LITERAL, literal)) sep = source[:0] - if type(sep) is type(""): + if isinstance(sep, str): makechar = chr else: - makechar = unichr + makechar = chr while 1: this = sget() if this is None: @@ -724,23 +744,23 @@ def parse_template(source, pattern): while 1: char = sget() if char is None: - raise error, "unterminated group name" + raise error("unterminated group name") if char == ">": break name = name + char if not name: - raise error, "missing group name" + raise error("missing group name") try: index = int(name) if index < 0: - raise error, "negative group number" + raise error("negative group number") except ValueError: if not isname(name): - raise error, "bad character in group name" + raise error("bad character in group name") try: index = pattern.groupindex[name] except KeyError: - raise IndexError, "unknown group name" + raise IndexError("unknown group name") a((MARK, index)) elif c == "0": if s.next in OCTDIGITS: @@ -772,12 +792,18 @@ def parse_template(source, pattern): groups = [] groupsappend = groups.append literals = [None] * len(p) + if isinstance(source, str): + encode = lambda x: x + else: + # The tokenizer implicitly decodes bytes objects as latin-1, we must + # therefore re-encode the final representation. + encode = lambda x: x.encode('latin1') for c, s in p: if c is MARK: groupsappend((i, s)) # literal[i] is already None else: - literals[i] = s + literals[i] = encode(s) i = i + 1 return groups, literals @@ -790,7 +816,7 @@ def expand_template(template, match): for index, group in groups: literals[index] = s = g(group) if s is None: - raise error, "unmatched group" + raise error("unmatched group") except IndexError: - raise error, "invalid group reference" + raise error("invalid group reference") return sep.join(literals) |