diff options
Diffstat (limited to 'Lib/test/test_external_inspection.py')
-rw-r--r-- | Lib/test/test_external_inspection.py | 569 |
1 files changed, 428 insertions, 141 deletions
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py index ad3f669a030..0f31c225e68 100644 --- a/Lib/test/test_external_inspection.py +++ b/Lib/test/test_external_inspection.py @@ -4,9 +4,10 @@ import textwrap import importlib import sys import socket -from asyncio import staggered, taskgroups +import threading +from asyncio import staggered, taskgroups, base_events, tasks from unittest.mock import ANY -from test.support import os_helper, SHORT_TIMEOUT, busy_retry +from test.support import os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled from test.support.script_helper import make_script from test.support.socket_helper import find_unused_port @@ -16,11 +17,13 @@ PROCESS_VM_READV_SUPPORTED = False try: from _remote_debugging import PROCESS_VM_READV_SUPPORTED - from _remote_debugging import get_stack_trace - from _remote_debugging import get_async_stack_trace - from _remote_debugging import get_all_awaited_by + from _remote_debugging import RemoteUnwinder + from _remote_debugging import FrameInfo, CoroInfo, TaskInfo except ImportError: - raise unittest.SkipTest("Test only runs when _remote_debugging is available") + raise unittest.SkipTest( + "Test only runs when _remote_debugging is available" + ) + def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) @@ -29,12 +32,32 @@ def _make_test_script(script_dir, script_basename, source): skip_if_not_supported = unittest.skipIf( - (sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32"), + ( + sys.platform != "darwin" + and sys.platform != "linux" + and sys.platform != "win32" + ), "Test only runs on Linux, Windows and MacOS", ) +def get_stack_trace(pid): + unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) + return unwinder.get_stack_trace() + + +def get_async_stack_trace(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_async_stack_trace() + + +def get_all_awaited_by(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_all_awaited_by() + + class TestGetStackTrace(unittest.TestCase): + maxDiff = None @skip_if_not_supported @unittest.skipIf( @@ -46,7 +69,7 @@ class TestGetStackTrace(unittest.TestCase): port = find_unused_port() script = textwrap.dedent( f"""\ - import time, sys, socket + import time, sys, socket, threading # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) @@ -55,13 +78,16 @@ class TestGetStackTrace(unittest.TestCase): for x in range(100): if x == 50: baz() + def baz(): foo() def foo(): - sock.sendall(b"ready"); time.sleep(10_000) # same line number + sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number - bar() + t = threading.Thread(target=bar) + t.start() + sock.sendall(b"ready:main\\n"); t.join() # same line number """ ) stack_trace = None @@ -82,11 +108,17 @@ class TestGetStackTrace(unittest.TestCase): p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") + response = b"" + while ( + b"ready:main" not in response + or b"ready:thread" not in response + ): + response += client_socket.recv(1024) stack_trace = get_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -94,13 +126,23 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - expected_stack_trace = [ - ("foo", script_name, 14), - ("baz", script_name, 11), - ("bar", script_name, 9), - ("<module>", script_name, 16), + thread_expected_stack_trace = [ + FrameInfo([script_name, 15, "foo"]), + FrameInfo([script_name, 12, "baz"]), + FrameInfo([script_name, 9, "bar"]), + FrameInfo([threading.__file__, ANY, "Thread.run"]), ] - self.assertEqual(stack_trace, expected_stack_trace) + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + self.assertIn((ANY, thread_expected_stack_trace), stack_trace) + + # Check that the main thread stack trace is in the result + frame = FrameInfo([script_name, 19, "<module>"]) + for _, stack in stack_trace: + if frame in stack: + break + else: + self.fail("Main thread stack trace not found in result") @skip_if_not_supported @unittest.skipIf( @@ -160,8 +202,12 @@ class TestGetStackTrace(unittest.TestCase): ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) + server_socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) @@ -179,7 +225,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -190,79 +238,49 @@ class TestGetStackTrace(unittest.TestCase): # sets are unordered, so we want to sort "awaited_by"s stack_trace[2].sort(key=lambda x: x[1]) - root_task = "Task-1" expected_stack_trace = [ [ - ("c5", script_name, 10), - ("c4", script_name, 14), - ("c3", script_name, 17), - ("c2", script_name, 20), + FrameInfo([script_name, 10, "c5"]), + FrameInfo([script_name, 14, "c4"]), + FrameInfo([script_name, 17, "c3"]), + FrameInfo([script_name, 20, "c2"]), ], "c2_root", [ - [ - [ - ( - "TaskGroup._aexit", - taskgroups.__file__, - ANY, - ), - ( - "TaskGroup.__aexit__", - taskgroups.__file__, - ANY, - ), - ("main", script_name, 26), - ], - "Task-1", - [], - ], - [ - [("c1", script_name, 23)], - "sub_main_1", + CoroInfo( [ [ - [ - ( - "TaskGroup._aexit", + FrameInfo( + [ taskgroups.__file__, ANY, - ), - ( - "TaskGroup.__aexit__", - taskgroups.__file__, - ANY, - ), - ("main", script_name, 26), - ], - "Task-1", - [], - ] - ], - ], - [ - [("c1", script_name, 23)], - "sub_main_2", - [ - [ - [ - ( "TaskGroup._aexit", + ] + ), + FrameInfo( + [ taskgroups.__file__, ANY, - ), - ( "TaskGroup.__aexit__", - taskgroups.__file__, - ANY, - ), - ("main", script_name, 26), - ], - "Task-1", - [], - ] - ], - ], + ] + ), + FrameInfo([script_name, 26, "main"]), + ], + "Task-1", + ] + ), + CoroInfo( + [ + [FrameInfo([script_name, 23, "c1"])], + "sub_main_1", + ] + ), + CoroInfo( + [ + [FrameInfo([script_name, 23, "c1"])], + "sub_main_2", + ] + ), ], ] self.assertEqual(stack_trace, expected_stack_trace) @@ -321,7 +339,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -334,9 +354,9 @@ class TestGetStackTrace(unittest.TestCase): expected_stack_trace = [ [ - ("gen_nested_call", script_name, 10), - ("gen", script_name, 16), - ("main", script_name, 19), + FrameInfo([script_name, 10, "gen_nested_call"]), + FrameInfo([script_name, 16, "gen"]), + FrameInfo([script_name, 19, "main"]), ], "Task-1", [], @@ -398,7 +418,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -410,9 +432,12 @@ class TestGetStackTrace(unittest.TestCase): stack_trace[2].sort(key=lambda x: x[1]) expected_stack_trace = [ - [("deep", script_name, 11), ("c1", script_name, 15)], + [ + FrameInfo([script_name, 11, "deep"]), + FrameInfo([script_name, 15, "c1"]), + ], "Task-2", - [[[("main", script_name, 21)], "Task-1", []]], + [CoroInfo([[FrameInfo([script_name, 21, "main"])], "Task-1"])], ] self.assertEqual(stack_trace, expected_stack_trace) @@ -474,7 +499,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -486,20 +513,29 @@ class TestGetStackTrace(unittest.TestCase): stack_trace[2].sort(key=lambda x: x[1]) expected_stack_trace = [ [ - ("deep", script_name, 11), - ("c1", script_name, 15), - ("staggered_race.<locals>.run_one_coro", staggered.__file__, ANY), + FrameInfo([script_name, 11, "deep"]), + FrameInfo([script_name, 15, "c1"]), + FrameInfo( + [ + staggered.__file__, + ANY, + "staggered_race.<locals>.run_one_coro", + ] + ), ], "Task-2", [ - [ + CoroInfo( [ - ("staggered_race", staggered.__file__, ANY), - ("main", script_name, 21), - ], - "Task-1", - [], - ] + [ + FrameInfo( + [staggered.__file__, ANY, "staggered_race"] + ), + FrameInfo([script_name, 21, "main"]), + ], + "Task-1", + ] + ) ], ] self.assertEqual(stack_trace, expected_stack_trace) @@ -630,62 +666,174 @@ class TestGetStackTrace(unittest.TestCase): # expected: at least 1000 pending tasks self.assertGreaterEqual(len(entries), 1000) # the first three tasks stem from the code structure - self.assertIn((ANY, "Task-1", []), entries) main_stack = [ - ( - "TaskGroup._aexit", - taskgroups.__file__, - ANY, - ), - ( - "TaskGroup.__aexit__", - taskgroups.__file__, - ANY, + FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]), + FrameInfo( + [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] ), - ("main", script_name, 60), + FrameInfo([script_name, 60, "main"]), ] self.assertIn( - (ANY, "server task", [[main_stack, ANY]]), + TaskInfo( + [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] + ), entries, ) self.assertIn( - (ANY, "echo client spam", [[main_stack, ANY]]), + TaskInfo( + [ + ANY, + "server task", + [ + CoroInfo( + [ + [ + FrameInfo( + [ + base_events.__file__, + ANY, + "Server.serve_forever", + ] + ) + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, ANY, "main"] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "Task-4", + [ + CoroInfo( + [ + [ + FrameInfo( + [tasks.__file__, ANY, "sleep"] + ), + FrameInfo( + [ + script_name, + 38, + "echo_client", + ] + ), + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [ + script_name, + 41, + "echo_client_spam", + ] + ), + ], + ANY, + ] + ) + ], + ] + ), entries, ) - expected_stack = [ - [ + expected_awaited_by = [ + CoroInfo( [ - ( - "TaskGroup._aexit", - taskgroups.__file__, - ANY, - ), - ( - "TaskGroup.__aexit__", - taskgroups.__file__, - ANY, - ), - ("echo_client_spam", script_name, 41), - ], - ANY, - ] + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, 41, "echo_client_spam"] + ), + ], + ANY, + ] + ) ] - tasks_with_stack = [ - task for task in entries if task[2] == expected_stack + tasks_with_awaited = [ + task + for task in entries + if task.awaited_by == expected_awaited_by ] - self.assertGreaterEqual(len(tasks_with_stack), 1000) + self.assertGreaterEqual(len(tasks_with_awaited), 1000) # the final task will have some random number, but it should for # sure be one of the echo client spam horde (In windows this is not true # for some reason) if sys.platform != "win32": self.assertEqual( - expected_stack, - entries[-1][2], + tasks_with_awaited[-1].awaited_by, + entries[-1].awaited_by, ) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -700,15 +848,154 @@ class TestGetStackTrace(unittest.TestCase): ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + this_tread_stack = None + for thread_id, stack in stack_trace: + if thread_id == threading.get_native_id(): + this_tread_stack = stack + break + self.assertIsNotNone(this_tread_stack) self.assertEqual( - stack_trace[0], - ( - "TestGetStackTrace.test_self_trace", - __file__, - self.test_self_trace.__code__.co_firstlineno + 6, - ), + stack[:2], + [ + FrameInfo( + [ + __file__, + get_stack_trace.__code__.co_firstlineno + 2, + "get_stack_trace", + ] + ), + FrameInfo( + [ + __file__, + self.test_self_trace.__code__.co_firstlineno + 6, + "TestGetStackTrace.test_self_trace", + ] + ), + ], + ) + + @skip_if_not_supported + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) + @requires_gil_enabled("Free threaded builds don't have an 'active thread'") + def test_only_active_thread(self): + # Test that only_active_thread parameter works correctly + port = find_unused_port() + script = textwrap.dedent( + f"""\ + import time, sys, socket, threading + + # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('localhost', {port})) + + def worker_thread(name, barrier, ready_event): + barrier.wait() # Synchronize thread start + ready_event.wait() # Wait for main thread signal + # Sleep to keep thread alive + time.sleep(10_000) + + def main_work(): + # Do busy work to hold the GIL + sock.sendall(b"working\\n") + count = 0 + while count < 100000000: + count += 1 + if count % 10000000 == 0: + pass # Keep main thread busy + sock.sendall(b"done\\n") + + # Create synchronization primitives + num_threads = 3 + barrier = threading.Barrier(num_threads + 1) # +1 for main thread + ready_event = threading.Event() + + # Start worker threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) + t.start() + threads.append(t) + + # Wait for all threads to be ready + barrier.wait() + + # Signal ready to parent process + sock.sendall(b"ready\\n") + + # Signal threads to start waiting + ready_event.set() + + # Give threads time to start sleeping + time.sleep(0.1) + + # Now do busy work to hold the GIL + main_work() + """ ) + with os_helper.temp_dir() as work_dir: + script_dir = os.path.join(work_dir, "script_pkg") + os.mkdir(script_dir) + + # Create a socket server to communicate with the target process + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(("localhost", port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(1) + + script_name = _make_test_script(script_dir, "script", script) + client_socket = None + try: + p = subprocess.Popen([sys.executable, script_name]) + client_socket, _ = server_socket.accept() + server_socket.close() + + # Wait for ready signal + response = b"" + while b"ready" not in response: + response += client_socket.recv(1024) + + # Wait for the main thread to start its busy work + while b"working" not in response: + response += client_socket.recv(1024) + + # Get stack trace with all threads + unwinder_all = RemoteUnwinder(p.pid, all_threads=True) + all_traces = unwinder_all.get_stack_trace() + + # Get stack trace with only GIL holder + unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True) + gil_traces = unwinder_gil.get_stack_trace() + + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + finally: + if client_socket is not None: + client_socket.close() + p.kill() + p.terminate() + p.wait(timeout=SHORT_TIMEOUT) + + # Verify we got multiple threads in all_traces + self.assertGreater(len(all_traces), 1, "Should have multiple threads") + + # Verify we got exactly one thread in gil_traces + self.assertEqual(len(gil_traces), 1, "Should have exactly one GIL holder") + + # The GIL holder should be in the all_traces list + gil_thread_id = gil_traces[0][0] + all_thread_ids = [trace[0] for trace in all_traces] + self.assertIn(gil_thread_id, all_thread_ids, + "GIL holder should be among all threads") + if __name__ == "__main__": unittest.main() |