aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/asyncio/base_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/base_events.py')
-rw-r--r--Lib/asyncio/base_events.py43
1 files changed, 34 insertions, 9 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 56ea7ba44e2..703c8a4ce24 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -269,7 +269,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):
class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
- ssl_handshake_timeout):
+ ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
@@ -278,6 +278,7 @@ class Server(events.AbstractServer):
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
+ self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
@@ -309,7 +310,8 @@ class Server(events.AbstractServer):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
- self, self._backlog, self._ssl_handshake_timeout)
+ self, self._backlog, self._ssl_handshake_timeout,
+ self._ssl_shutdown_timeout)
def get_loop(self):
return self._loop
@@ -463,6 +465,7 @@ class BaseEventLoop(events.AbstractEventLoop):
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
@@ -965,6 +968,7 @@ class BaseEventLoop(events.AbstractEventLoop):
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
@@ -1000,6 +1004,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and not ssl:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
@@ -1075,7 +1083,8 @@ class BaseEventLoop(events.AbstractEventLoop):
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
@@ -1087,7 +1096,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
sock.setblocking(False)
@@ -1098,7 +1108,8 @@ class BaseEventLoop(events.AbstractEventLoop):
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
@@ -1189,7 +1200,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
@@ -1212,6 +1224,7 @@ class BaseEventLoop(events.AbstractEventLoop):
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)
# Pause early so that "ssl_protocol.data_received()" doesn't
@@ -1397,6 +1410,7 @@ class BaseEventLoop(events.AbstractEventLoop):
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
@@ -1420,6 +1434,10 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and ssl is None:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
if host is not None or port is not None:
if sock is not None:
raise ValueError(
@@ -1492,7 +1510,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False)
server = Server(self, sockets, protocol_factory,
- ssl, backlog, ssl_handshake_timeout)
+ ssl, backlog, ssl_handshake_timeout,
+ ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
@@ -1506,7 +1525,8 @@ class BaseEventLoop(events.AbstractEventLoop):
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
- ssl_handshake_timeout=None):
+ ssl_handshake_timeout=None,
+ ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
@@ -1514,9 +1534,14 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
+ if ssl_shutdown_timeout is not None and not ssl:
+ raise ValueError(
+ 'ssl_shutdown_timeout is only meaningful with ssl')
+
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
- ssl_handshake_timeout=ssl_handshake_timeout)
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket