aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_httpservers.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_httpservers.py')
-rw-r--r--Lib/test/test_httpservers.py254
1 files changed, 248 insertions, 6 deletions
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 11c74a02bf2..2548a7c5f29 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -8,6 +8,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer, HTTPSServer, \
SimpleHTTPRequestHandler
from http import server, HTTPStatus
+import contextlib
import os
import socket
import sys
@@ -20,6 +21,7 @@ import email.utils
import html
import http, http.client
import urllib.parse
+import urllib.request
import tempfile
import time
import datetime
@@ -32,6 +34,8 @@ from test import support
from test.support import (
is_apple, import_helper, os_helper, threading_helper
)
+from test.support.script_helper import kill_python, spawn_python
+from test.support.socket_helper import find_unused_port
try:
import ssl
@@ -627,13 +631,14 @@ class SimpleHTTPServerTestCase(BaseTestCase):
self.check_list_dir_filename(filename)
os_helper.unlink(os.path.join(self.tempdir, filename))
- def test_undecodable_parameter(self):
- # sanity check using a valid parameter
+ def test_list_dir_with_query_and_fragment(self):
+ prefix = f'listing for {self.base_url}/</'.encode('latin1')
+ response = self.request(self.base_url + '/#123').read()
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
response = self.request(self.base_url + '/?x=123').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=123'.encode('latin1'))
- # now the bogus encoding
- response = self.request(self.base_url + '/?x=%bb').read()
- self.assertRegex(response, rf'listing for {self.base_url}/\?x=\xef\xbf\xbd'.encode('latin1'))
+ self.assertIn(prefix + b'title>', response)
+ self.assertIn(prefix + b'h1>', response)
def test_get_dir_redirect_location_domain_injection_bug(self):
"""Ensure //evil.co/..%2f../../X does not put //evil.co/ in Location.
@@ -1280,6 +1285,243 @@ class ScriptTestCase(unittest.TestCase):
self.assertEqual(mock_server.address_family, socket.AF_INET)
+class CommandLineTestCase(unittest.TestCase):
+ default_port = 8000
+ default_bind = None
+ default_protocol = 'HTTP/1.0'
+ default_handler = SimpleHTTPRequestHandler
+ default_server = unittest.mock.ANY
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = 'somepass'
+ tls_cert_options = ['--tls-cert']
+ tls_key_options = ['--tls-key']
+ tls_password_options = ['--tls-password-file']
+ args = {
+ 'HandlerClass': default_handler,
+ 'ServerClass': default_server,
+ 'protocol': default_protocol,
+ 'port': default_port,
+ 'bind': default_bind,
+ 'tls_cert': None,
+ 'tls_key': None,
+ 'tls_password': None,
+ }
+
+ def setUp(self):
+ super().setUp()
+ self.tls_password_file = tempfile.mktemp()
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password.encode())
+ self.addCleanup(os_helper.unlink, self.tls_password_file)
+
+ def invoke_httpd(self, *args, stdout=None, stderr=None):
+ stdout = StringIO() if stdout is None else stdout
+ stderr = StringIO() if stderr is None else stderr
+ with contextlib.redirect_stdout(stdout), \
+ contextlib.redirect_stderr(stderr):
+ server._main(args)
+ return stdout.getvalue(), stderr.getvalue()
+
+ @mock.patch('http.server.test')
+ def test_port_flag(self, mock_func):
+ ports = [8000, 65535]
+ for port in ports:
+ with self.subTest(port=port):
+ self.invoke_httpd(str(port))
+ call_args = self.args | dict(port=port)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_directory_flag(self, mock_func):
+ options = ['-d', '--directory']
+ directories = ['.', '/foo', '\\bar', '/',
+ 'C:\\', 'C:\\foo', 'C:\\bar',
+ '/home/user', './foo/foo2', 'D:\\foo\\bar']
+ for flag in options:
+ for directory in directories:
+ with self.subTest(flag=flag, directory=directory):
+ self.invoke_httpd(flag, directory)
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_bind_flag(self, mock_func):
+ options = ['-b', '--bind']
+ bind_addresses = ['localhost', '127.0.0.1', '::1',
+ '0.0.0.0', '8.8.8.8']
+ for flag in options:
+ for bind_address in bind_addresses:
+ with self.subTest(flag=flag, bind_address=bind_address):
+ self.invoke_httpd(flag, bind_address)
+ call_args = self.args | dict(bind=bind_address)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_protocol_flag(self, mock_func):
+ options = ['-p', '--protocol']
+ protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0']
+ for flag in options:
+ for protocol in protocols:
+ with self.subTest(flag=flag, protocol=protocol):
+ self.invoke_httpd(flag, protocol)
+ call_args = self.args | dict(protocol=protocol)
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ self.invoke_httpd(tls_cert_option, self.tls_cert,
+ tls_key_option, self.tls_key)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_tls_cert_and_key_and_password_flags(self, mock_func):
+ for tls_cert_option in self.tls_cert_options:
+ for tls_key_option in self.tls_key_options:
+ for tls_password_option in self.tls_password_options:
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_key_option,
+ self.tls_key,
+ tls_password_option,
+ self.tls_password_file)
+ call_args = self.args | {
+ 'tls_cert': self.tls_cert,
+ 'tls_key': self.tls_key,
+ 'tls_password': self.tls_password,
+ }
+ mock_func.assert_called_once_with(**call_args)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_missing_tls_cert_flag(self, mock_func):
+ for tls_key_option in self.tls_key_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_key_option, self.tls_key)
+ mock_func.reset_mock()
+
+ for tls_password_option in self.tls_password_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_password_option, self.tls_password)
+ mock_func.reset_mock()
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ @mock.patch('http.server.test')
+ def test_invalid_password_file(self, mock_func):
+ non_existent_file = 'non_existent_file'
+ for tls_password_option in self.tls_password_options:
+ for tls_cert_option in self.tls_cert_options:
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(tls_cert_option,
+ self.tls_cert,
+ tls_password_option,
+ non_existent_file)
+
+ @mock.patch('http.server.test')
+ def test_no_arguments(self, mock_func):
+ self.invoke_httpd()
+ mock_func.assert_called_once_with(**self.args)
+ mock_func.reset_mock()
+
+ @mock.patch('http.server.test')
+ def test_help_flag(self, _):
+ options = ['-h', '--help']
+ for option in options:
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd(option, stdout=stdout, stderr=stderr)
+ self.assertIn('usage', stdout.getvalue())
+ self.assertEqual(stderr.getvalue(), '')
+
+ @mock.patch('http.server.test')
+ def test_unknown_flag(self, _):
+ stdout, stderr = StringIO(), StringIO()
+ with self.assertRaises(SystemExit):
+ self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr)
+ self.assertEqual(stdout.getvalue(), '')
+ self.assertIn('error', stderr.getvalue())
+
+
+class CommandLineRunTimeTestCase(unittest.TestCase):
+ served_data = os.urandom(32)
+ served_filename = 'served_filename'
+ tls_cert = certdata_file('ssl_cert.pem')
+ tls_key = certdata_file('ssl_key.pem')
+ tls_password = b'somepass'
+ tls_password_file = 'ssl_key_password'
+
+ def setUp(self):
+ super().setUp()
+ server_dir_context = os_helper.temp_cwd()
+ server_dir = self.enterContext(server_dir_context)
+ with open(self.served_filename, 'wb') as f:
+ f.write(self.served_data)
+ with open(self.tls_password_file, 'wb') as f:
+ f.write(self.tls_password)
+
+ def fetch_file(self, path, context=None):
+ req = urllib.request.Request(path, method='GET')
+ with urllib.request.urlopen(req, context=context) as res:
+ return res.read()
+
+ def parse_cli_output(self, output):
+ match = re.search(r'Serving (HTTP|HTTPS) on (.+) port (\d+)', output)
+ if match is None:
+ return None, None, None
+ return match.group(1).lower(), match.group(2), int(match.group(3))
+
+ def wait_for_server(self, proc, protocol, bind, port):
+ """Check that the server has been successfully started."""
+ line = proc.stdout.readline().strip()
+ if support.verbose:
+ print()
+ print('python -m http.server: ', line)
+ return self.parse_cli_output(line) == (protocol, bind, port)
+
+ def test_http_client(self):
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'http', bind, port))
+ res = self.fetch_file(f'http://{bind}:{port}/{self.served_filename}')
+ self.assertEqual(res, self.served_data)
+
+ @unittest.skipIf(ssl is None, "requires ssl")
+ def test_https_client(self):
+ context = ssl.create_default_context()
+ # allow self-signed certificates
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+
+ bind, port = '127.0.0.1', find_unused_port()
+ proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind,
+ '--tls-cert', self.tls_cert,
+ '--tls-key', self.tls_key,
+ '--tls-password-file', self.tls_password_file,
+ bufsize=1, text=True)
+ self.addCleanup(kill_python, proc)
+ self.addCleanup(proc.terminate)
+ self.assertTrue(self.wait_for_server(proc, 'https', bind, port))
+ url = f'https://{bind}:{port}/{self.served_filename}'
+ res = self.fetch_file(url, context=context)
+ self.assertEqual(res, self.served_data)
+
+
def setUpModule():
unittest.addModuleCleanup(os.chdir, os.getcwd())