aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_asyncio/test_streams.py
diff options
context:
space:
mode:
authorOleg Iarygin <oleg@arhadthedev.net>2022-04-15 15:23:14 +0300
committerGitHub <noreply@github.com>2022-04-15 14:23:14 +0200
commit6217864fe5f6855f59d608733ce83fd4466e1b8c (patch)
tree3d852fadd0e29891d382ed9f41f161b237b3e703 /Lib/test/test_asyncio/test_streams.py
parentbd26ef5e9e701d2ab3509a49d9351259a3670772 (diff)
downloadcpython-6217864fe5f6855f59d608733ce83fd4466e1b8c.tar.gz
cpython-6217864fe5f6855f59d608733ce83fd4466e1b8c.zip
gh-79156: Add start_tls() method to streams API (#91453)
The existing event loop `start_tls()` method is not sufficient for connections using the streams API. The existing StreamReader works because the new transport passes received data to the original protocol. The StreamWriter must then write data to the new transport, and the StreamReaderProtocol must be updated to close the new transport correctly. The new StreamWriter `start_tls()` updates itself and the reader protocol to the new SSL transport. Co-authored-by: Ian Good <icgood@gmail.com>
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r--Lib/test/test_asyncio/test_streams.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 227b2279e17..a7d17894e1c 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -706,6 +706,69 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(messages, [])
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_start_tls(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ async def handle_client(self, client_reader, client_writer):
+ data1 = await client_reader.readline()
+ client_writer.write(data1)
+ await client_writer.drain()
+ assert client_writer.get_extra_info('sslcontext') is None
+ await client_writer.start_tls(
+ test_utils.simple_server_sslcontext())
+ assert client_writer.get_extra_info('sslcontext') is not None
+ data2 = await client_reader.readline()
+ client_writer.write(data2)
+ await client_writer.drain()
+ client_writer.close()
+ await client_writer.wait_closed()
+
+ def start(self):
+ sock = socket.create_server(('127.0.0.1', 0))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client,
+ sock=sock))
+ return sock.getsockname()
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr)
+ writer.write(b"hello world 1!\n")
+ await writer.drain()
+ msgback1 = await reader.readline()
+ assert writer.get_extra_info('sslcontext') is None
+ await writer.start_tls(test_utils.simple_client_sslcontext())
+ assert writer.get_extra_info('sslcontext') is not None
+ writer.write(b"hello world 2!\n")
+ await writer.drain()
+ msgback2 = await reader.readline()
+ writer.close()
+ await writer.wait_closed()
+ return msgback1, msgback2
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server = MyServer(self.loop)
+ addr = server.start()
+ msg1, msg2 = self.loop.run_until_complete(client(addr))
+ server.stop()
+
+ self.assertEqual(messages, [])
+ self.assertEqual(msg1, b"hello world 1!\n")
+ self.assertEqual(msg2, b"hello world 2!\n")
+
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example