aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_asyncio/test_streams.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r--Lib/test/test_asyncio/test_streams.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 227b2279e17..a7d17894e1c 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -706,6 +706,69 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(messages, [])
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_start_tls(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ async def handle_client(self, client_reader, client_writer):
+ data1 = await client_reader.readline()
+ client_writer.write(data1)
+ await client_writer.drain()
+ assert client_writer.get_extra_info('sslcontext') is None
+ await client_writer.start_tls(
+ test_utils.simple_server_sslcontext())
+ assert client_writer.get_extra_info('sslcontext') is not None
+ data2 = await client_reader.readline()
+ client_writer.write(data2)
+ await client_writer.drain()
+ client_writer.close()
+ await client_writer.wait_closed()
+
+ def start(self):
+ sock = socket.create_server(('127.0.0.1', 0))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client,
+ sock=sock))
+ return sock.getsockname()
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr)
+ writer.write(b"hello world 1!\n")
+ await writer.drain()
+ msgback1 = await reader.readline()
+ assert writer.get_extra_info('sslcontext') is None
+ await writer.start_tls(test_utils.simple_client_sslcontext())
+ assert writer.get_extra_info('sslcontext') is not None
+ writer.write(b"hello world 2!\n")
+ await writer.drain()
+ msgback2 = await reader.readline()
+ writer.close()
+ await writer.wait_closed()
+ return msgback1, msgback2
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server = MyServer(self.loop)
+ addr = server.start()
+ msg1, msg2 = self.loop.run_until_complete(client(addr))
+ server.stop()
+
+ self.assertEqual(messages, [])
+ self.assertEqual(msg1, b"hello world 1!\n")
+ self.assertEqual(msg2, b"hello world 2!\n")
+
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example