diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 411 |
1 files changed, 254 insertions, 157 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index 88296358a02..e901b640a62 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -1,8 +1,7 @@ # Wrapper module for _ssl, providing some additional facilities # implemented in Python. Written by Bill Janssen. -"""\ -This module provides some more Pythonic support for SSL. +"""This module provides some more Pythonic support for SSL. Object types: @@ -56,24 +55,30 @@ PROTOCOL_TLSv1 """ import textwrap +import re import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import SSLError +from _ssl import _SSLContext, SSLError from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 from _ssl import RAND_status, RAND_egd, RAND_add -from _ssl import \ - SSL_ERROR_ZERO_RETURN, \ - SSL_ERROR_WANT_READ, \ - SSL_ERROR_WANT_WRITE, \ - SSL_ERROR_WANT_X509_LOOKUP, \ - SSL_ERROR_SYSCALL, \ - SSL_ERROR_SSL, \ - SSL_ERROR_WANT_CONNECT, \ - SSL_ERROR_EOF, \ - SSL_ERROR_INVALID_ERROR_CODE +from _ssl import ( + SSL_ERROR_ZERO_RETURN, + SSL_ERROR_WANT_READ, + SSL_ERROR_WANT_WRITE, + SSL_ERROR_WANT_X509_LOOKUP, + SSL_ERROR_SYSCALL, + SSL_ERROR_SSL, + SSL_ERROR_WANT_CONNECT, + SSL_ERROR_EOF, + SSL_ERROR_INVALID_ERROR_CODE, + ) +from _ssl import HAS_SNI from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import _OPENSSL_API_VERSION + _PROTOCOL_NAMES = { PROTOCOL_TLSv1: "TLSv1", PROTOCOL_SSLv23: "SSLv23", @@ -87,9 +92,11 @@ except ImportError: else: _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" -from socket import socket, _fileobject, _delegate_methods, error as socket_error from socket import getnameinfo as _getnameinfo +from socket import error as socket_error +from socket import socket, AF_INET, SOCK_STREAM import base64 # for DER-to-PEM translation +import traceback import errno # Disable weak or insecure ciphers by default @@ -97,97 +104,230 @@ import errno _DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2' -class SSLSocket(socket): +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not dnsnames: + # The subject is only checked when there is no dNSName entry + # in subjectAltName + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") + + +class SSLContext(_SSLContext): + """An SSLContext holds various SSL-related configuration options and + data, such as certificates and possibly a private key.""" + + __slots__ = ('protocol',) + def __new__(cls, protocol, *args, **kwargs): + self = _SSLContext.__new__(cls, protocol) + if protocol != _SSLv2_IF_EXISTS: + self.set_ciphers(_DEFAULT_CIPHERS) + return self + + def __init__(self, protocol): + self.protocol = protocol + + def wrap_socket(self, sock, server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None): + return SSLSocket(sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self) + + +class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and provides read and write methods over that channel.""" - def __init__(self, sock, keyfile=None, certfile=None, + def __init__(self, sock=None, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): - socket.__init__(self, _sock=sock._sock) - # The initializer for socket overrides the methods send(), recv(), etc. - # in the instancce, which we don't need -- but we want to provide the - # methods defined in SSLSocket. - for attr in _delegate_methods: - try: - delattr(self, attr) - except AttributeError: - pass + family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, + suppress_ragged_eofs=True, ciphers=None, + server_hostname=None, + _context=None): - if ciphers is None and ssl_version != _SSLv2_IF_EXISTS: - ciphers = _DEFAULT_CIPHERS - - if certfile and not keyfile: - keyfile = certfile - # see if it's connected - try: - socket.getpeername(self) - except socket_error, e: - if e.errno != errno.ENOTCONN: - raise - # no, no connection yet - self._connected = False - self._sslobj = None + if _context: + self.context = _context else: - # yes, create the SSL object - self._connected = True - self._sslobj = _ssl.sslwrap(self._sock, server_side, - keyfile, certfile, - cert_reqs, ssl_version, ca_certs, - ciphers) - if do_handshake_on_connect: - self.do_handshake() - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ssl_version = ssl_version - self.ca_certs = ca_certs - self.ciphers = ciphers + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile and not keyfile: + keyfile = certfile + self.context = SSLContext(ssl_version) + self.context.verify_mode = cert_reqs + if ca_certs: + self.context.load_verify_locations(ca_certs) + if certfile: + self.context.load_cert_chain(certfile, keyfile) + if ciphers: + self.context.set_ciphers(ciphers) + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + if server_side and server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + self.server_side = server_side + self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs - self._makefile_refs = 0 + connected = False + if sock is not None: + socket.__init__(self, + family=sock.family, + type=sock.type, + proto=sock.proto, + fileno=sock.fileno()) + self.settimeout(sock.gettimeout()) + # see if it's connected + try: + sock.getpeername() + except socket_error as e: + if e.errno != errno.ENOTCONN: + raise + else: + connected = True + sock.detach() + elif fileno is not None: + socket.__init__(self, fileno=fileno) + else: + socket.__init__(self, family=family, type=type, proto=proto) + + self._closed = False + self._sslobj = None + self._connected = connected + if connected: + # create the SSL object + try: + self._sslobj = self.context._wrap_socket(self, server_side, + server_hostname) + if do_handshake_on_connect: + timeout = self.gettimeout() + if timeout == 0.0: + # non-blocking + raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") + self.do_handshake() - def read(self, len=1024): + except socket_error as x: + self.close() + raise x + def dup(self): + raise NotImplemented("Can't dup() %s instances" % + self.__class__.__name__) + + def _checkClosed(self, msg=None): + # raise an exception here if you wish to check for spurious closes + pass + + def read(self, len=0, buffer=None): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" + self._checkClosed() try: - return self._sslobj.read(len) - except SSLError, x: + if buffer is not None: + v = self._sslobj.read(len, buffer) + else: + v = self._sslobj.read(len or 1024) + return v + except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: - return '' + if buffer is not None: + return 0 + else: + return b'' else: raise def write(self, data): - """Write DATA to the underlying SSL channel. Returns number of bytes of DATA actually transmitted.""" + self._checkClosed() return self._sslobj.write(data) def getpeercert(self, binary_form=False): - """Returns a formatted version of the data in the certificate provided by the other end of the SSL channel. Return None if no certificate was provided, {} if a certificate was provided, but not validated.""" + self._checkClosed() return self._sslobj.peer_certificate(binary_form) def cipher(self): - + self._checkClosed() if not self._sslobj: return None else: return self._sslobj.cipher() def send(self, data, flags=0): + self._checkClosed() if self._sslobj: if flags != 0: raise ValueError( @@ -196,7 +336,7 @@ class SSLSocket(socket): while True: try: v = self._sslobj.write(data) - except SSLError, x: + except SSLError as x: if x.args[0] == SSL_ERROR_WANT_READ: return 0 elif x.args[0] == SSL_ERROR_WANT_WRITE: @@ -206,18 +346,20 @@ class SSLSocket(socket): else: return v else: - return self._sock.send(data, flags) + return socket.send(self, data, flags) def sendto(self, data, flags_or_addr, addr=None): + self._checkClosed() if self._sslobj: raise ValueError("sendto not allowed on instances of %s" % self.__class__) elif addr is None: - return self._sock.sendto(data, flags_or_addr) + return socket.sendto(self, data, flags_or_addr) else: - return self._sock.sendto(data, flags_or_addr, addr) + return socket.sendto(self, data, flags_or_addr, addr) def sendall(self, data, flags=0): + self._checkClosed() if self._sslobj: if flags != 0: raise ValueError( @@ -233,6 +375,7 @@ class SSLSocket(socket): return socket.sendall(self, data, flags) def recv(self, buflen=1024, flags=0): + self._checkClosed() if self._sslobj: if flags != 0: raise ValueError( @@ -240,9 +383,10 @@ class SSLSocket(socket): self.__class__) return self.read(buflen) else: - return self._sock.recv(buflen, flags) + return socket.recv(self, buflen, flags) def recv_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() if buffer and (nbytes is None): nbytes = len(buffer) elif nbytes is None: @@ -252,33 +396,38 @@ class SSLSocket(socket): raise ValueError( "non-zero flags not allowed in calls to recv_into() on %s" % self.__class__) - tmp_buffer = self.read(nbytes) - v = len(tmp_buffer) - buffer[:v] = tmp_buffer - return v + return self.read(nbytes, buffer) else: - return self._sock.recv_into(buffer, nbytes, flags) + return socket.recv_into(self, buffer, nbytes, flags) def recvfrom(self, buflen=1024, flags=0): + self._checkClosed() if self._sslobj: raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) else: - return self._sock.recvfrom(buflen, flags) + return socket.recvfrom(self, buflen, flags) def recvfrom_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() if self._sslobj: raise ValueError("recvfrom_into not allowed on instances of %s" % self.__class__) else: - return self._sock.recvfrom_into(buffer, nbytes, flags) + return socket.recvfrom_into(self, buffer, nbytes, flags) def pending(self): + self._checkClosed() if self._sslobj: return self._sslobj.pending() else: return 0 + def shutdown(self, how): + self._checkClosed() + self._sslobj = None + socket.shutdown(self, how) + def unwrap(self): if self._sslobj: s = self._sslobj.shutdown() @@ -287,33 +436,32 @@ class SSLSocket(socket): else: raise ValueError("No SSL wrapper around " + str(self)) - def shutdown(self, how): + def _real_close(self): self._sslobj = None - socket.shutdown(self, how) - - def close(self): - if self._makefile_refs < 1: - self._sslobj = None - socket.close(self) - else: - self._makefile_refs -= 1 - - def do_handshake(self): + # self._closed = True + socket._real_close(self) + def do_handshake(self, block=False): """Perform a TLS/SSL handshake.""" - self._sslobj.do_handshake() - - def _real_connect(self, addr, return_errno): + timeout = self.gettimeout() + try: + if timeout == 0.0 and block: + self.settimeout(None) + self._sslobj.do_handshake() + finally: + self.settimeout(timeout) + + def _real_connect(self, addr, connect_ex): + if self.server_side: + raise ValueError("can't connect in server-side mode") # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") - self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs, self.ciphers) + self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) try: - if return_errno: + if connect_ex: rc = socket.connect_ex(self, addr) else: rc = None @@ -338,35 +486,20 @@ class SSLSocket(socket): return self._real_connect(addr, True) def accept(self): - """Accepts a new connection from a remote client, and returns a tuple containing that new connection wrapped with a server-side SSL channel, and the address of the remote client.""" newsock, addr = socket.accept(self) - return (SSLSocket(newsock, - keyfile=self.keyfile, - certfile=self.certfile, - server_side=True, - cert_reqs=self.cert_reqs, - ssl_version=self.ssl_version, - ca_certs=self.ca_certs, - ciphers=self.ciphers, - do_handshake_on_connect=self.do_handshake_on_connect, - suppress_ragged_eofs=self.suppress_ragged_eofs), - addr) - - def makefile(self, mode='r', bufsize=-1): - - """Make and return a file-like object that - works with the SSL connection. Just use the code - from the socket module.""" - - self._makefile_refs += 1 - # close=True so as to decrement the reference count when done with - # the file-like object. - return _fileobject(self, mode, bufsize, close=True) + newsock = self.context.wrap_socket(newsock, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + server_side=True) + return newsock, addr + def __del__(self): + # sys.stderr.write("__del__ on %s\n" % repr(self)) + self._real_close() def wrap_socket(sock, keyfile=None, certfile=None, @@ -375,18 +508,16 @@ def wrap_socket(sock, keyfile=None, certfile=None, do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None): - return SSLSocket(sock, keyfile=keyfile, certfile=certfile, + return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs, ciphers=ciphers) - # some utility functions def cert_time_to_seconds(cert_time): - """Takes a date-time string in standard ASN1_print form ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return a Python time value in seconds past the epoch.""" @@ -398,23 +529,15 @@ PEM_HEADER = "-----BEGIN CERTIFICATE-----" PEM_FOOTER = "-----END CERTIFICATE-----" def DER_cert_to_PEM_cert(der_cert_bytes): - """Takes a certificate in binary DER format and returns the PEM version of it as a string.""" - if hasattr(base64, 'standard_b64encode'): - # preferred because older API gets line-length wrong - f = base64.standard_b64encode(der_cert_bytes) - return (PEM_HEADER + '\n' + - textwrap.fill(f, 64) + '\n' + - PEM_FOOTER + '\n') - else: - return (PEM_HEADER + '\n' + - base64.encodestring(der_cert_bytes) + - PEM_FOOTER + '\n') + f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict') + return (PEM_HEADER + '\n' + + textwrap.fill(f, 64) + '\n' + + PEM_FOOTER + '\n') def PEM_cert_to_DER_cert(pem_cert_string): - """Takes a certificate in ASCII PEM format and returns the DER-encoded version of it as a byte sequence""" @@ -425,10 +548,9 @@ def PEM_cert_to_DER_cert(pem_cert_string): raise ValueError("Invalid PEM encoding; must end with %s" % PEM_FOOTER) d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] - return base64.decodestring(d) + return base64.decodebytes(d.encode('ASCII', 'strict')) def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): - """Retrieve the certificate from the server at the specified address, and return it as a PEM-encoded string. If 'ca_certs' is specified, validate the server cert against it. @@ -448,28 +570,3 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): def get_protocol_name(protocol_code): return _PROTOCOL_NAMES.get(protocol_code, '<unknown>') - - -# a replacement for the old socket.ssl function - -def sslwrap_simple(sock, keyfile=None, certfile=None): - - """A replacement for the old socket.ssl function. Designed - for compability with Python 2.5 and earlier. Will disappear in - Python 3.0.""" - - if hasattr(sock, "_sock"): - sock = sock._sock - - ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE, - PROTOCOL_SSLv23, None) - try: - sock.getpeername() - except socket_error: - # no, no connection yet - pass - else: - # yes, do the handshake - ssl_sock.do_handshake() - - return ssl_sock |