aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py193
1 files changed, 159 insertions, 34 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index f98bd73428f..e7c175f063a 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -55,13 +55,16 @@ PROTOCOL_TLSv1
"""
import textwrap
+import re
import _ssl # if we can't import it, let the error propagate
-from _ssl import SSLError
+from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
+from _ssl import _SSLContext, SSLError
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
PROTOCOL_TLSv1)
+from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1
from _ssl import RAND_status, RAND_egd, RAND_add
from _ssl import (
SSL_ERROR_ZERO_RETURN,
@@ -74,17 +77,99 @@ from _ssl import (
SSL_ERROR_EOF,
SSL_ERROR_INVALID_ERROR_CODE,
)
+from _ssl import HAS_SNI
from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
-from socket import dup as _dup
from socket import socket, AF_INET, SOCK_STREAM
import base64 # for DER-to-PEM translation
import traceback
import errno
-class SSLSocket(socket):
+class CertificateError(ValueError):
+ pass
+
+
+def _dnsname_to_pat(dn):
+ pats = []
+ for frag in dn.split(r'.'):
+ if frag == '*':
+ # When '*' is a fragment by itself, it matches a non-empty dotless
+ # fragment.
+ pats.append('[^.]+')
+ else:
+ # Otherwise, '*' matches any dotless fragment.
+ frag = re.escape(frag)
+ pats.append(frag.replace(r'\*', '[^.]*'))
+ return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+
+
+def match_hostname(cert, hostname):
+ """Verify that *cert* (in decoded format as returned by
+ SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules
+ are mostly followed, but IP addresses are not accepted for *hostname*.
+
+ CertificateError is raised on failure. On success, the function
+ returns nothing.
+ """
+ if not cert:
+ raise ValueError("empty or no certificate")
+ dnsnames = []
+ san = cert.get('subjectAltName', ())
+ for key, value in san:
+ if key == 'DNS':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if not dnsnames:
+ # The subject is only checked when there is no dNSName entry
+ # in subjectAltName
+ for sub in cert.get('subject', ()):
+ for key, value in sub:
+ # XXX according to RFC 2818, the most specific Common Name
+ # must be used.
+ if key == 'commonName':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if len(dnsnames) > 1:
+ raise CertificateError("hostname %r "
+ "doesn't match either of %s"
+ % (hostname, ', '.join(map(repr, dnsnames))))
+ elif len(dnsnames) == 1:
+ raise CertificateError("hostname %r "
+ "doesn't match %r"
+ % (hostname, dnsnames[0]))
+ else:
+ raise CertificateError("no appropriate commonName or "
+ "subjectAltName fields were found")
+
+
+class SSLContext(_SSLContext):
+ """An SSLContext holds various SSL-related configuration options and
+ data, such as certificates and possibly a private key."""
+
+ __slots__ = ('protocol',)
+
+ def __new__(cls, protocol, *args, **kwargs):
+ return _SSLContext.__new__(cls, protocol)
+
+ def __init__(self, protocol):
+ self.protocol = protocol
+
+ def wrap_socket(self, sock, server_side=False,
+ do_handshake_on_connect=True,
+ suppress_ragged_eofs=True,
+ server_hostname=None):
+ return SSLSocket(sock=sock, server_side=server_side,
+ do_handshake_on_connect=do_handshake_on_connect,
+ suppress_ragged_eofs=suppress_ragged_eofs,
+ server_hostname=server_hostname,
+ _context=self)
+
+
+class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
provides read and write methods over that channel."""
@@ -94,15 +179,48 @@ class SSLSocket(socket):
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
- suppress_ragged_eofs=True):
+ suppress_ragged_eofs=True, ciphers=None,
+ server_hostname=None,
+ _context=None):
+ if _context:
+ self.context = _context
+ else:
+ if server_side and not certfile:
+ raise ValueError("certfile must be specified for server-side "
+ "operations")
+ if keyfile and not certfile:
+ raise ValueError("certfile must be specified")
+ if certfile and not keyfile:
+ keyfile = certfile
+ self.context = SSLContext(ssl_version)
+ self.context.verify_mode = cert_reqs
+ if ca_certs:
+ self.context.load_verify_locations(ca_certs)
+ if certfile:
+ self.context.load_cert_chain(certfile, keyfile)
+ if ciphers:
+ self.context.set_ciphers(ciphers)
+ self.keyfile = keyfile
+ self.certfile = certfile
+ self.cert_reqs = cert_reqs
+ self.ssl_version = ssl_version
+ self.ca_certs = ca_certs
+ self.ciphers = ciphers
+ if server_side and server_hostname:
+ raise ValueError("server_hostname can only be specified "
+ "in client mode")
+ self.server_side = server_side
+ self.server_hostname = server_hostname
+ self.do_handshake_on_connect = do_handshake_on_connect
+ self.suppress_ragged_eofs = suppress_ragged_eofs
connected = False
if sock is not None:
socket.__init__(self,
family=sock.family,
type=sock.type,
proto=sock.proto,
- fileno=_dup(sock.fileno()))
+ fileno=sock.fileno())
self.settimeout(sock.gettimeout())
# see if it's connected
try:
@@ -112,23 +230,20 @@ class SSLSocket(socket):
raise
else:
connected = True
- sock.close()
+ sock.detach()
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
socket.__init__(self, family=family, type=type, proto=proto)
- if certfile and not keyfile:
- keyfile = certfile
-
self._closed = False
self._sslobj = None
+ self._connected = connected
if connected:
# create the SSL object
try:
- self._sslobj = _ssl.sslwrap(self, server_side,
- keyfile, certfile,
- cert_reqs, ssl_version, ca_certs)
+ self._sslobj = self.context._wrap_socket(self, server_side,
+ server_hostname)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@@ -140,14 +255,6 @@ class SSLSocket(socket):
self.close()
raise x
- self.keyfile = keyfile
- self.certfile = certfile
- self.cert_reqs = cert_reqs
- self.ssl_version = ssl_version
- self.ca_certs = ca_certs
- self.do_handshake_on_connect = do_handshake_on_connect
- self.suppress_ragged_eofs = suppress_ragged_eofs
-
def dup(self):
raise NotImplemented("Can't dup() %s instances" %
self.__class__.__name__)
@@ -234,6 +341,10 @@ class SSLSocket(socket):
def sendall(self, data, flags=0):
self._checkClosed()
if self._sslobj:
+ if flags != 0:
+ raise ValueError(
+ "non-zero flags not allowed in calls to sendall() on %s" %
+ self.__class__)
amount = len(data)
count = 0
while (count < amount):
@@ -321,24 +432,36 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
- def connect(self, addr):
- """Connects to remote ADDR, and then wraps the connection in
- an SSL channel."""
-
+ def _real_connect(self, addr, return_errno):
+ if self.server_side:
+ raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
- if self._sslobj:
+ if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
- socket.connect(self, addr)
- self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
- self.cert_reqs, self.ssl_version,
- self.ca_certs)
+ self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
try:
+ socket.connect(self, addr)
if self.do_handshake_on_connect:
self.do_handshake()
- except:
- self._sslobj = None
- raise
+ except socket_error as e:
+ if return_errno:
+ return e.errno
+ else:
+ self._sslobj = None
+ raise e
+ self._connected = True
+ return 0
+
+ def connect(self, addr):
+ """Connects to remote ADDR, and then wraps the connection in
+ an SSL channel."""
+ self._real_connect(addr, False)
+
+ def connect_ex(self, addr):
+ """Connects to remote ADDR, and then wraps the connection in
+ an SSL channel."""
+ return self._real_connect(addr, True)
def accept(self):
"""Accepts a new connection from a remote client, and returns
@@ -352,6 +475,7 @@ class SSLSocket(socket):
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
+ ciphers=self.ciphers,
do_handshake_on_connect=
self.do_handshake_on_connect),
addr)
@@ -365,13 +489,14 @@ def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
- suppress_ragged_eofs=True):
+ suppress_ragged_eofs=True, ciphers=None):
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
- suppress_ragged_eofs=suppress_ragged_eofs)
+ suppress_ragged_eofs=suppress_ragged_eofs,
+ ciphers=ciphers)
# some utility functions