diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 193 |
1 files changed, 159 insertions, 34 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index f98bd73428f..e7c175f063a 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -55,13 +55,16 @@ PROTOCOL_TLSv1 """ import textwrap +import re import _ssl # if we can't import it, let the error propagate -from _ssl import SSLError +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION +from _ssl import _SSLContext, SSLError from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1) +from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1 from _ssl import RAND_status, RAND_egd, RAND_add from _ssl import ( SSL_ERROR_ZERO_RETURN, @@ -74,17 +77,99 @@ from _ssl import ( SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, ) +from _ssl import HAS_SNI from socket import getnameinfo as _getnameinfo from socket import error as socket_error -from socket import dup as _dup from socket import socket, AF_INET, SOCK_STREAM import base64 # for DER-to-PEM translation import traceback import errno -class SSLSocket(socket): +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not dnsnames: + # The subject is only checked when there is no dNSName entry + # in subjectAltName + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") + + +class SSLContext(_SSLContext): + """An SSLContext holds various SSL-related configuration options and + data, such as certificates and possibly a private key.""" + + __slots__ = ('protocol',) + + def __new__(cls, protocol, *args, **kwargs): + return _SSLContext.__new__(cls, protocol) + + def __init__(self, protocol): + self.protocol = protocol + + def wrap_socket(self, sock, server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None): + return SSLSocket(sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self) + + +class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and provides read and write methods over that channel.""" @@ -94,15 +179,48 @@ class SSLSocket(socket): ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True): + suppress_ragged_eofs=True, ciphers=None, + server_hostname=None, + _context=None): + if _context: + self.context = _context + else: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile and not keyfile: + keyfile = certfile + self.context = SSLContext(ssl_version) + self.context.verify_mode = cert_reqs + if ca_certs: + self.context.load_verify_locations(ca_certs) + if certfile: + self.context.load_cert_chain(certfile, keyfile) + if ciphers: + self.context.set_ciphers(ciphers) + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + if server_side and server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + self.server_side = server_side + self.server_hostname = server_hostname + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs connected = False if sock is not None: socket.__init__(self, family=sock.family, type=sock.type, proto=sock.proto, - fileno=_dup(sock.fileno())) + fileno=sock.fileno()) self.settimeout(sock.gettimeout()) # see if it's connected try: @@ -112,23 +230,20 @@ class SSLSocket(socket): raise else: connected = True - sock.close() + sock.detach() elif fileno is not None: socket.__init__(self, fileno=fileno) else: socket.__init__(self, family=family, type=type, proto=proto) - if certfile and not keyfile: - keyfile = certfile - self._closed = False self._sslobj = None + self._connected = connected if connected: # create the SSL object try: - self._sslobj = _ssl.sslwrap(self, server_side, - keyfile, certfile, - cert_reqs, ssl_version, ca_certs) + self._sslobj = self.context._wrap_socket(self, server_side, + server_hostname) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -140,14 +255,6 @@ class SSLSocket(socket): self.close() raise x - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ssl_version = ssl_version - self.ca_certs = ca_certs - self.do_handshake_on_connect = do_handshake_on_connect - self.suppress_ragged_eofs = suppress_ragged_eofs - def dup(self): raise NotImplemented("Can't dup() %s instances" % self.__class__.__name__) @@ -234,6 +341,10 @@ class SSLSocket(socket): def sendall(self, data, flags=0): self._checkClosed() if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to sendall() on %s" % + self.__class__) amount = len(data) count = 0 while (count < amount): @@ -321,24 +432,36 @@ class SSLSocket(socket): finally: self.settimeout(timeout) - def connect(self, addr): - """Connects to remote ADDR, and then wraps the connection in - an SSL channel.""" - + def _real_connect(self, addr, return_errno): + if self.server_side: + raise ValueError("can't connect in server-side mode") # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. - if self._sslobj: + if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") - socket.connect(self, addr) - self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs) + self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) try: + socket.connect(self, addr) if self.do_handshake_on_connect: self.do_handshake() - except: - self._sslobj = None - raise + except socket_error as e: + if return_errno: + return e.errno + else: + self._sslobj = None + raise e + self._connected = True + return 0 + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) def accept(self): """Accepts a new connection from a remote client, and returns @@ -352,6 +475,7 @@ class SSLSocket(socket): cert_reqs=self.cert_reqs, ssl_version=self.ssl_version, ca_certs=self.ca_certs, + ciphers=self.ciphers, do_handshake_on_connect= self.do_handshake_on_connect), addr) @@ -365,13 +489,14 @@ def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True): + suppress_ragged_eofs=True, ciphers=None): return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs) + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) # some utility functions |