diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2021-05-03 00:34:15 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-03 00:34:15 +0300 |
commit | 5fb06edbbb769561e245d0fe13002bab50e2ae60 (patch) | |
tree | a6341e32a1140447b2d37a3a47fedb9d5043c75d /Lib/asyncio/base_events.py | |
parent | c96cc089f60d2bf7e003c27413c3239ee9de2990 (diff) | |
download | cpython-5fb06edbbb769561e245d0fe13002bab50e2ae60.tar.gz cpython-5fb06edbbb769561e245d0fe13002bab50e2ae60.zip |
bpo-44011: New asyncio ssl implementation (#17975)
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r-- | Lib/asyncio/base_events.py | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index f789635e0f8..e54ee309e42 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -273,7 +273,7 @@ class _SendfileFallbackProtocol(protocols.Protocol): class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, - ssl_handshake_timeout): + ssl_handshake_timeout, ssl_shutdown_timeout=None): self._loop = loop self._sockets = sockets self._active_count = 0 @@ -282,6 +282,7 @@ class Server(events.AbstractServer): self._backlog = backlog self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None @@ -313,7 +314,8 @@ class Server(events.AbstractServer): sock.listen(self._backlog) self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, - self, self._backlog, self._ssl_handshake_timeout) + self, self._backlog, self._ssl_handshake_timeout, + self._ssl_shutdown_timeout) def get_loop(self): return self._loop @@ -467,6 +469,7 @@ class BaseEventLoop(events.AbstractEventLoop): *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -969,6 +972,7 @@ class BaseEventLoop(events.AbstractEventLoop): proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. @@ -1004,6 +1008,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: # If using happy eyeballs, default to interleave addresses by family interleave = 1 @@ -1079,7 +1087,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1091,7 +1100,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): sock.setblocking(False) @@ -1102,7 +1112,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -1193,7 +1204,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1216,6 +1228,7 @@ class BaseEventLoop(events.AbstractEventLoop): self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1414,6 +1427,7 @@ class BaseEventLoop(events.AbstractEventLoop): reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """Create a TCP server. @@ -1437,6 +1451,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and ssl is None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1509,7 +1527,8 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) server = Server(self, sockets, protocol_factory, - ssl, backlog, ssl_handshake_timeout) + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' @@ -1523,7 +1542,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1531,9 +1551,14 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket |