diff options
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r-- | Lib/asyncio/base_events.py | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 56ea7ba44e2..703c8a4ce24 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -269,7 +269,7 @@ class _SendfileFallbackProtocol(protocols.Protocol): class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, - ssl_handshake_timeout): + ssl_handshake_timeout, ssl_shutdown_timeout=None): self._loop = loop self._sockets = sockets self._active_count = 0 @@ -278,6 +278,7 @@ class Server(events.AbstractServer): self._backlog = backlog self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None @@ -309,7 +310,8 @@ class Server(events.AbstractServer): sock.listen(self._backlog) self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, - self, self._backlog, self._ssl_handshake_timeout) + self, self._backlog, self._ssl_handshake_timeout, + self._ssl_shutdown_timeout) def get_loop(self): return self._loop @@ -463,6 +465,7 @@ class BaseEventLoop(events.AbstractEventLoop): *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -965,6 +968,7 @@ class BaseEventLoop(events.AbstractEventLoop): proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. @@ -1000,6 +1004,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: # If using happy eyeballs, default to interleave addresses by family interleave = 1 @@ -1075,7 +1083,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1087,7 +1096,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): sock.setblocking(False) @@ -1098,7 +1108,8 @@ class BaseEventLoop(events.AbstractEventLoop): transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -1189,7 +1200,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1212,6 +1224,7 @@ class BaseEventLoop(events.AbstractEventLoop): self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1397,6 +1410,7 @@ class BaseEventLoop(events.AbstractEventLoop): reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """Create a TCP server. @@ -1420,6 +1434,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and ssl is None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1492,7 +1510,8 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) server = Server(self, sockets, protocol_factory, - ssl, backlog, ssl_handshake_timeout) + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' @@ -1506,7 +1525,8 @@ class BaseEventLoop(events.AbstractEventLoop): async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1514,9 +1534,14 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket |