diff options
Diffstat (limited to 'Lib/wsgiref/headers.py')
-rw-r--r-- | Lib/wsgiref/headers.py | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/Lib/wsgiref/headers.py b/Lib/wsgiref/headers.py index 6c8c60c8913..d93962831ae 100644 --- a/Lib/wsgiref/headers.py +++ b/Lib/wsgiref/headers.py @@ -5,8 +5,6 @@ so portions are Copyright (C) 2001,2002 Python Software Foundation, and were written by Barry Warsaw. """ -from types import ListType, TupleType - # Regular expression that matches `special' characters in parameters, the # existence of which force quoting of the parameter value. import re @@ -32,9 +30,20 @@ class Headers: """Manage a collection of HTTP response headers""" def __init__(self,headers): - if type(headers) is not ListType: + if type(headers) is not list: raise TypeError("Headers must be a list of name/value tuples") self._headers = headers + if __debug__: + for k, v in headers: + self._convert_string_type(k) + self._convert_string_type(v) + + def _convert_string_type(self, value): + """Convert/check value type.""" + if type(value) is str: + return value + raise AssertionError("Header names/values must be" + " of type str (got {0})".format(repr(value))) def __len__(self): """Return the total number of headers, including duplicates.""" @@ -43,14 +52,15 @@ class Headers: def __setitem__(self, name, val): """Set the value of a header.""" del self[name] - self._headers.append((name, val)) + self._headers.append( + (self._convert_string_type(name), self._convert_string_type(val))) def __delitem__(self,name): """Delete all occurrences of a header, if present. Does *not* raise an exception if the header is missing. """ - name = name.lower() + name = self._convert_string_type(name.lower()) self._headers[:] = [kv for kv in self._headers if kv[0].lower() != name] def __getitem__(self,name): @@ -64,12 +74,10 @@ class Headers: """ return self.get(name) - def has_key(self, name): + def __contains__(self, name): """Return true if the message contains the header.""" return self.get(name) is not None - __contains__ = has_key - def get_all(self, name): """Return a list of all the values for the named field. @@ -79,13 +87,13 @@ class Headers: fields deleted and re-inserted are always appended to the header list. If no fields exist with the given name, returns an empty list. """ - name = name.lower() + name = self._convert_string_type(name.lower()) return [kv[1] for kv in self._headers if kv[0].lower()==name] def get(self,name,default=None): """Get the first header value for 'name', or return 'default'""" - name = name.lower() + name = self._convert_string_type(name.lower()) for k,v in self._headers: if k.lower()==name: return v @@ -130,6 +138,9 @@ class Headers: suitable for direct HTTP transmission.""" return '\r\n'.join(["%s: %s" % kv for kv in self._headers]+['','']) + def __bytes__(self): + return str(self).encode('iso-8859-1') + def setdefault(self,name,value): """Return first matching header value for 'name', or 'value' @@ -137,7 +148,8 @@ class Headers: and value 'value'.""" result = self.get(name) if result is None: - self._headers.append((name,value)) + self._headers.append((self._convert_string_type(name), + self._convert_string_type(value))) return value else: return result @@ -160,10 +172,13 @@ class Headers: """ parts = [] if _value is not None: + _value = self._convert_string_type(_value) parts.append(_value) for k, v in _params.items(): + k = self._convert_string_type(k) if v is None: parts.append(k.replace('_', '-')) else: + v = self._convert_string_type(v) parts.append(_formatparam(k.replace('_', '-'), v)) - self._headers.append((_name, "; ".join(parts))) + self._headers.append((self._convert_string_type(_name), "; ".join(parts))) |