diff options
Diffstat (limited to 'Lib/test/test_httpservers.py')
-rw-r--r-- | Lib/test/test_httpservers.py | 254 |
1 files changed, 248 insertions, 6 deletions
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 11c74a02bf2..2548a7c5f29 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -8,6 +8,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer, HTTPSServer, \ SimpleHTTPRequestHandler from http import server, HTTPStatus +import contextlib import os import socket import sys @@ -20,6 +21,7 @@ import email.utils import html import http, http.client import urllib.parse +import urllib.request import tempfile import time import datetime @@ -32,6 +34,8 @@ from test import support from test.support import ( is_apple, import_helper, os_helper, threading_helper ) +from test.support.script_helper import kill_python, spawn_python +from test.support.socket_helper import find_unused_port try: import ssl @@ -627,13 +631,14 @@ class SimpleHTTPServerTestCase(BaseTestCase): self.check_list_dir_filename(filename) os_helper.unlink(os.path.join(self.tempdir, filename)) - def test_undecodable_parameter(self): - # sanity check using a valid parameter + def test_list_dir_with_query_and_fragment(self): + prefix = f'listing for {self.base_url}/</'.encode('latin1') + response = self.request(self.base_url + '/#123').read() + self.assertIn(prefix + b'title>', response) + self.assertIn(prefix + b'h1>', response) response = self.request(self.base_url + '/?x=123').read() - self.assertRegex(response, rf'listing for {self.base_url}/\?x=123'.encode('latin1')) - # now the bogus encoding - response = self.request(self.base_url + '/?x=%bb').read() - self.assertRegex(response, rf'listing for {self.base_url}/\?x=\xef\xbf\xbd'.encode('latin1')) + self.assertIn(prefix + b'title>', response) + self.assertIn(prefix + b'h1>', response) def test_get_dir_redirect_location_domain_injection_bug(self): """Ensure //evil.co/..%2f../../X does not put //evil.co/ in Location. @@ -1280,6 +1285,243 @@ class ScriptTestCase(unittest.TestCase): self.assertEqual(mock_server.address_family, socket.AF_INET) +class CommandLineTestCase(unittest.TestCase): + default_port = 8000 + default_bind = None + default_protocol = 'HTTP/1.0' + default_handler = SimpleHTTPRequestHandler + default_server = unittest.mock.ANY + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = 'somepass' + tls_cert_options = ['--tls-cert'] + tls_key_options = ['--tls-key'] + tls_password_options = ['--tls-password-file'] + args = { + 'HandlerClass': default_handler, + 'ServerClass': default_server, + 'protocol': default_protocol, + 'port': default_port, + 'bind': default_bind, + 'tls_cert': None, + 'tls_key': None, + 'tls_password': None, + } + + def setUp(self): + super().setUp() + self.tls_password_file = tempfile.mktemp() + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password.encode()) + self.addCleanup(os_helper.unlink, self.tls_password_file) + + def invoke_httpd(self, *args, stdout=None, stderr=None): + stdout = StringIO() if stdout is None else stdout + stderr = StringIO() if stderr is None else stderr + with contextlib.redirect_stdout(stdout), \ + contextlib.redirect_stderr(stderr): + server._main(args) + return stdout.getvalue(), stderr.getvalue() + + @mock.patch('http.server.test') + def test_port_flag(self, mock_func): + ports = [8000, 65535] + for port in ports: + with self.subTest(port=port): + self.invoke_httpd(str(port)) + call_args = self.args | dict(port=port) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_directory_flag(self, mock_func): + options = ['-d', '--directory'] + directories = ['.', '/foo', '\\bar', '/', + 'C:\\', 'C:\\foo', 'C:\\bar', + '/home/user', './foo/foo2', 'D:\\foo\\bar'] + for flag in options: + for directory in directories: + with self.subTest(flag=flag, directory=directory): + self.invoke_httpd(flag, directory) + mock_func.assert_called_once_with(**self.args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_bind_flag(self, mock_func): + options = ['-b', '--bind'] + bind_addresses = ['localhost', '127.0.0.1', '::1', + '0.0.0.0', '8.8.8.8'] + for flag in options: + for bind_address in bind_addresses: + with self.subTest(flag=flag, bind_address=bind_address): + self.invoke_httpd(flag, bind_address) + call_args = self.args | dict(bind=bind_address) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_protocol_flag(self, mock_func): + options = ['-p', '--protocol'] + protocols = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2.0', 'HTTP/3.0'] + for flag in options: + for protocol in protocols: + with self.subTest(flag=flag, protocol=protocol): + self.invoke_httpd(flag, protocol) + call_args = self.args | dict(protocol=protocol) + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_tls_cert_and_key_flags(self, mock_func): + for tls_cert_option in self.tls_cert_options: + for tls_key_option in self.tls_key_options: + self.invoke_httpd(tls_cert_option, self.tls_cert, + tls_key_option, self.tls_key) + call_args = self.args | { + 'tls_cert': self.tls_cert, + 'tls_key': self.tls_key, + } + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_tls_cert_and_key_and_password_flags(self, mock_func): + for tls_cert_option in self.tls_cert_options: + for tls_key_option in self.tls_key_options: + for tls_password_option in self.tls_password_options: + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_key_option, + self.tls_key, + tls_password_option, + self.tls_password_file) + call_args = self.args | { + 'tls_cert': self.tls_cert, + 'tls_key': self.tls_key, + 'tls_password': self.tls_password, + } + mock_func.assert_called_once_with(**call_args) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_missing_tls_cert_flag(self, mock_func): + for tls_key_option in self.tls_key_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_key_option, self.tls_key) + mock_func.reset_mock() + + for tls_password_option in self.tls_password_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_password_option, self.tls_password) + mock_func.reset_mock() + + @unittest.skipIf(ssl is None, "requires ssl") + @mock.patch('http.server.test') + def test_invalid_password_file(self, mock_func): + non_existent_file = 'non_existent_file' + for tls_password_option in self.tls_password_options: + for tls_cert_option in self.tls_cert_options: + with self.assertRaises(SystemExit): + self.invoke_httpd(tls_cert_option, + self.tls_cert, + tls_password_option, + non_existent_file) + + @mock.patch('http.server.test') + def test_no_arguments(self, mock_func): + self.invoke_httpd() + mock_func.assert_called_once_with(**self.args) + mock_func.reset_mock() + + @mock.patch('http.server.test') + def test_help_flag(self, _): + options = ['-h', '--help'] + for option in options: + stdout, stderr = StringIO(), StringIO() + with self.assertRaises(SystemExit): + self.invoke_httpd(option, stdout=stdout, stderr=stderr) + self.assertIn('usage', stdout.getvalue()) + self.assertEqual(stderr.getvalue(), '') + + @mock.patch('http.server.test') + def test_unknown_flag(self, _): + stdout, stderr = StringIO(), StringIO() + with self.assertRaises(SystemExit): + self.invoke_httpd('--unknown-flag', stdout=stdout, stderr=stderr) + self.assertEqual(stdout.getvalue(), '') + self.assertIn('error', stderr.getvalue()) + + +class CommandLineRunTimeTestCase(unittest.TestCase): + served_data = os.urandom(32) + served_filename = 'served_filename' + tls_cert = certdata_file('ssl_cert.pem') + tls_key = certdata_file('ssl_key.pem') + tls_password = b'somepass' + tls_password_file = 'ssl_key_password' + + def setUp(self): + super().setUp() + server_dir_context = os_helper.temp_cwd() + server_dir = self.enterContext(server_dir_context) + with open(self.served_filename, 'wb') as f: + f.write(self.served_data) + with open(self.tls_password_file, 'wb') as f: + f.write(self.tls_password) + + def fetch_file(self, path, context=None): + req = urllib.request.Request(path, method='GET') + with urllib.request.urlopen(req, context=context) as res: + return res.read() + + def parse_cli_output(self, output): + match = re.search(r'Serving (HTTP|HTTPS) on (.+) port (\d+)', output) + if match is None: + return None, None, None + return match.group(1).lower(), match.group(2), int(match.group(3)) + + def wait_for_server(self, proc, protocol, bind, port): + """Check that the server has been successfully started.""" + line = proc.stdout.readline().strip() + if support.verbose: + print() + print('python -m http.server: ', line) + return self.parse_cli_output(line) == (protocol, bind, port) + + def test_http_client(self): + bind, port = '127.0.0.1', find_unused_port() + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'http', bind, port)) + res = self.fetch_file(f'http://{bind}:{port}/{self.served_filename}') + self.assertEqual(res, self.served_data) + + @unittest.skipIf(ssl is None, "requires ssl") + def test_https_client(self): + context = ssl.create_default_context() + # allow self-signed certificates + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + bind, port = '127.0.0.1', find_unused_port() + proc = spawn_python('-u', '-m', 'http.server', str(port), '-b', bind, + '--tls-cert', self.tls_cert, + '--tls-key', self.tls_key, + '--tls-password-file', self.tls_password_file, + bufsize=1, text=True) + self.addCleanup(kill_python, proc) + self.addCleanup(proc.terminate) + self.assertTrue(self.wait_for_server(proc, 'https', bind, port)) + url = f'https://{bind}:{port}/{self.served_filename}' + res = self.fetch_file(url, context=context) + self.assertEqual(res, self.served_data) + + def setUpModule(): unittest.addModuleCleanup(os.chdir, os.getcwd()) |