diff options
Diffstat (limited to 'Lib/test/test_external_inspection.py')
-rw-r--r-- | Lib/test/test_external_inspection.py | 980 |
1 files changed, 853 insertions, 127 deletions
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py index aa05db972f0..354a82a800f 100644 --- a/Lib/test/test_external_inspection.py +++ b/Lib/test/test_external_inspection.py @@ -4,7 +4,16 @@ import textwrap import importlib import sys import socket -from test.support import os_helper, SHORT_TIMEOUT, busy_retry +import threading +import time +from asyncio import staggered, taskgroups, base_events, tasks +from unittest.mock import ANY +from test.support import ( + os_helper, + SHORT_TIMEOUT, + busy_retry, + requires_gil_enabled, +) from test.support.script_helper import make_script from test.support.socket_helper import find_unused_port @@ -13,33 +22,60 @@ import subprocess PROCESS_VM_READV_SUPPORTED = False try: - from _testexternalinspection import PROCESS_VM_READV_SUPPORTED - from _testexternalinspection import get_stack_trace - from _testexternalinspection import get_async_stack_trace - from _testexternalinspection import get_all_awaited_by + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + from _remote_debugging import RemoteUnwinder + from _remote_debugging import FrameInfo, CoroInfo, TaskInfo except ImportError: raise unittest.SkipTest( - "Test only runs when _testexternalinspection is available") + "Test only runs when _remote_debugging is available" + ) + def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) importlib.invalidate_caches() return to_return -skip_if_not_supported = unittest.skipIf((sys.platform != "darwin" - and sys.platform != "linux" - and sys.platform != "win32"), - "Test only runs on Linux, Windows and MacOS") + +skip_if_not_supported = unittest.skipIf( + ( + sys.platform != "darwin" + and sys.platform != "linux" + and sys.platform != "win32" + ), + "Test only runs on Linux, Windows and MacOS", +) + + +def get_stack_trace(pid): + unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) + return unwinder.get_stack_trace() + + +def get_async_stack_trace(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_async_stack_trace() + + +def get_all_awaited_by(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_all_awaited_by() + + class TestGetStackTrace(unittest.TestCase): + maxDiff = None @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ - import time, sys, socket + script = textwrap.dedent( + f"""\ + import time, sys, socket, threading # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) @@ -48,15 +84,18 @@ class TestGetStackTrace(unittest.TestCase): for x in range(100): if x == 50: baz() + def baz(): foo() def foo(): - sock.sendall(b"ready") - time.sleep(1000) + sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number - bar() - """) + t = threading.Thread(target=bar) + t.start() + sock.sendall(b"ready:main\\n"); t.join() # same line number + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -65,21 +104,27 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") + response = b"" + while ( + b"ready:main" not in response + or b"ready:thread" not in response + ): + response += client_socket.recv(1024) stack_trace = get_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -87,22 +132,34 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - - expected_stack_trace = [ - 'foo', - 'baz', - 'bar', - '<module>' + thread_expected_stack_trace = [ + FrameInfo([script_name, 15, "foo"]), + FrameInfo([script_name, 12, "baz"]), + FrameInfo([script_name, 9, "bar"]), + FrameInfo([threading.__file__, ANY, "Thread.run"]), ] - self.assertEqual(stack_trace, expected_stack_trace) + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + self.assertIn((ANY, thread_expected_stack_trace), stack_trace) + + # Check that the main thread stack trace is in the result + frame = FrameInfo([script_name, 19, "<module>"]) + for _, stack in stack_trace: + if frame in stack: + break + else: + self.fail("Main thread stack trace not found in result") @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -112,8 +169,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) def c5(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c4(): await asyncio.sleep(0) @@ -142,7 +198,8 @@ class TestGetStackTrace(unittest.TestCase): return loop asyncio.run(main(), loop_factory={{TASK_FACTORY}}) - """) + """ + ) stack_trace = None for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": with ( @@ -151,19 +208,23 @@ class TestGetStackTrace(unittest.TestCase): ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) + server_socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) script_name = _make_test_script( - script_dir, 'script', - script.format(TASK_FACTORY=task_factory_variant)) + script_dir, + "script", + script.format(TASK_FACTORY=task_factory_variant), + ) client_socket = None try: - p = subprocess.Popen( - [sys.executable, script_name] - ) + p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() response = client_socket.recv(1024) @@ -171,7 +232,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -179,28 +241,173 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - # sets are unordered, so we want to sort "awaited_by"s - stack_trace[2].sort(key=lambda x: x[1]) - - root_task = "Task-1" - expected_stack_trace = [ - ["c5", "c4", "c3", "c2"], - "c2_root", - [ - [["main"], root_task, []], - [["c1"], "sub_main_1", [[["main"], root_task, []]]], - [["c1"], "sub_main_2", [[["main"], root_task, []]]], - ], + # First check all the tasks are present + tasks_names = [ + task.task_name for task in stack_trace[0].awaited_by ] - self.assertEqual(stack_trace, expected_stack_trace) + for task_name in ["c2_root", "sub_main_1", "sub_main_2"]: + self.assertIn(task_name, tasks_names) + + # Now ensure that the awaited_by_relationships are correct + id_to_task = { + task.task_id: task for task in stack_trace[0].awaited_by + } + task_name_to_awaited_by = { + task.task_name: set( + id_to_task[awaited.task_name].task_name + for awaited in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + task_name_to_awaited_by, + { + "c2_root": {"Task-1", "sub_main_1", "sub_main_2"}, + "Task-1": set(), + "sub_main_1": {"Task-1"}, + "sub_main_2": {"Task-1"}, + }, + ) + + # Now ensure that the coroutine stacks are correct + coroutine_stacks = { + task.task_name: sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + coroutine_stacks, + { + "Task-1": [ + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ) + ], + "c2_root": [ + ( + tuple([script_name, 10, "c5"]), + tuple([script_name, 14, "c4"]), + tuple([script_name, 17, "c3"]), + tuple([script_name, 20, "c2"]), + ) + ], + "sub_main_1": [(tuple([script_name, 23, "c1"]),)], + "sub_main_2": [(tuple([script_name, 23, "c1"]),)], + }, + ) + + # Now ensure the coroutine stacks for the awaited_by relationships are correct. + awaited_by_coroutine_stacks = { + task.task_name: sorted( + ( + id_to_task[coro.task_name].task_name, + tuple(tuple(frame) for frame in coro.call_stack), + ) + for coro in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "c2_root": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), + ), + ("sub_main_1", (tuple([script_name, 23, "c1"]),)), + ("sub_main_2", (tuple([script_name, 23, "c1"]),)), + ], + "sub_main_1": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), + ) + ], + "sub_main_2": [ + ( + "Task-1", + ( + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + tuple( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + tuple([script_name, 26, "main"]), + ), + ) + ], + }, + ) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_asyncgen_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -210,8 +417,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) async def gen_nested_call(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def gen(): for num in range(2): @@ -224,7 +430,8 @@ class TestGetStackTrace(unittest.TestCase): pass asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -232,10 +439,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -245,7 +452,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -253,21 +462,40 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - # sets are unordered, so we want to sort "awaited_by"s - stack_trace[2].sort(key=lambda x: x[1]) + # For this simple asyncgen test, we only expect one task with the full coroutine stack + self.assertEqual(len(stack_trace[0].awaited_by), 1) + task = stack_trace[0].awaited_by[0] + self.assertEqual(task.task_name, "Task-1") - expected_stack_trace = [ - ['gen_nested_call', 'gen', 'main'], 'Task-1', [] - ] - self.assertEqual(stack_trace, expected_stack_trace) + # Check the coroutine stack - based on actual output, only shows main + coroutine_stack = sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + self.assertEqual( + coroutine_stack, + [ + ( + tuple([script_name, 10, "gen_nested_call"]), + tuple([script_name, 16, "gen"]), + tuple([script_name, 19, "main"]), + ) + ], + ) + + # No awaited_by relationships expected for this simple case + self.assertEqual(task.awaited_by, []) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_gather_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -278,8 +506,7 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) @@ -292,7 +519,8 @@ class TestGetStackTrace(unittest.TestCase): await asyncio.gather(c1(), c2()) asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -300,10 +528,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -314,7 +542,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -322,21 +551,84 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - # sets are unordered, so we want to sort "awaited_by"s - stack_trace[2].sort(key=lambda x: x[1]) - - expected_stack_trace = [ - ['deep', 'c1'], 'Task-2', [[['main'], 'Task-1', []]] + # First check all the tasks are present + tasks_names = [ + task.task_name for task in stack_trace[0].awaited_by ] - self.assertEqual(stack_trace, expected_stack_trace) + for task_name in ["Task-1", "Task-2"]: + self.assertIn(task_name, tasks_names) + + # Now ensure that the awaited_by_relationships are correct + id_to_task = { + task.task_id: task for task in stack_trace[0].awaited_by + } + task_name_to_awaited_by = { + task.task_name: set( + id_to_task[awaited.task_name].task_name + for awaited in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + task_name_to_awaited_by, + { + "Task-1": set(), + "Task-2": {"Task-1"}, + }, + ) + + # Now ensure that the coroutine stacks are correct + coroutine_stacks = { + task.task_name: sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + coroutine_stacks, + { + "Task-1": [(tuple([script_name, 21, "main"]),)], + "Task-2": [ + ( + tuple([script_name, 11, "deep"]), + tuple([script_name, 15, "c1"]), + ) + ], + }, + ) + + # Now ensure the coroutine stacks for the awaited_by relationships are correct. + awaited_by_coroutine_stacks = { + task.task_name: sorted( + ( + id_to_task[coro.task_name].task_name, + tuple(tuple(frame) for frame in coro.call_stack), + ) + for coro in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "Task-2": [ + ("Task-1", (tuple([script_name, 21, "main"]),)) + ], + }, + ) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_staggered_race_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio.staggered import time import sys @@ -347,15 +639,14 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) await deep() async def c2(): - await asyncio.sleep(10000) + await asyncio.sleep(10_000) async def main(): await asyncio.staggered.staggered_race( @@ -364,7 +655,8 @@ class TestGetStackTrace(unittest.TestCase): ) asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -372,10 +664,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -386,7 +678,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -394,20 +687,103 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - # sets are unordered, so we want to sort "awaited_by"s - stack_trace[2].sort(key=lambda x: x[1]) - - expected_stack_trace = [ - ['deep', 'c1', 'run_one_coro'], 'Task-2', [[['main'], 'Task-1', []]] + # First check all the tasks are present + tasks_names = [ + task.task_name for task in stack_trace[0].awaited_by ] - self.assertEqual(stack_trace, expected_stack_trace) + for task_name in ["Task-1", "Task-2"]: + self.assertIn(task_name, tasks_names) + + # Now ensure that the awaited_by_relationships are correct + id_to_task = { + task.task_id: task for task in stack_trace[0].awaited_by + } + task_name_to_awaited_by = { + task.task_name: set( + id_to_task[awaited.task_name].task_name + for awaited in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + task_name_to_awaited_by, + { + "Task-1": set(), + "Task-2": {"Task-1"}, + }, + ) + + # Now ensure that the coroutine stacks are correct + coroutine_stacks = { + task.task_name: sorted( + tuple(tuple(frame) for frame in coro.call_stack) + for coro in task.coroutine_stack + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + coroutine_stacks, + { + "Task-1": [ + ( + tuple([staggered.__file__, ANY, "staggered_race"]), + tuple([script_name, 21, "main"]), + ) + ], + "Task-2": [ + ( + tuple([script_name, 11, "deep"]), + tuple([script_name, 15, "c1"]), + tuple( + [ + staggered.__file__, + ANY, + "staggered_race.<locals>.run_one_coro", + ] + ), + ) + ], + }, + ) + + # Now ensure the coroutine stacks for the awaited_by relationships are correct. + awaited_by_coroutine_stacks = { + task.task_name: sorted( + ( + id_to_task[coro.task_name].task_name, + tuple(tuple(frame) for frame in coro.call_stack), + ) + for coro in task.awaited_by + ) + for task in stack_trace[0].awaited_by + } + self.assertEqual( + awaited_by_coroutine_stacks, + { + "Task-1": [], + "Task-2": [ + ( + "Task-1", + ( + tuple( + [staggered.__file__, ANY, "staggered_race"] + ), + tuple([script_name, 21, "main"]), + ), + ) + ], + }, + ) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_global_awaited_by(self): port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import os import random @@ -443,6 +819,8 @@ class TestGetStackTrace(unittest.TestCase): assert message == data.decode() writer.close() await writer.wait_closed() + # Signal we are ready to sleep + sock.sendall(b"ready") await asyncio.sleep(SHORT_TIMEOUT) async def echo_client_spam(server): @@ -452,8 +830,10 @@ class TestGetStackTrace(unittest.TestCase): random.shuffle(msg) tg.create_task(echo_client("".join(msg))) await asyncio.sleep(0) - # at least a 1000 tasks created - sock.sendall(b"ready") + # at least a 1000 tasks created. Each task will signal + # when is ready to avoid the race caused by the fact that + # tasks are waited on tg.__exit__ and we cannot signal when + # that happens otherwise # at this point all client tasks completed without assertion errors # let's wrap up the test server.close() @@ -468,7 +848,8 @@ class TestGetStackTrace(unittest.TestCase): tg.create_task(echo_client_spam(server), name="echo client spam") asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -476,17 +857,19 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") + for _ in range(1000): + expected_response = b"ready" + response = client_socket.recv(len(expected_response)) + self.assertEqual(response, expected_response) for _ in busy_retry(SHORT_TIMEOUT): try: all_awaited_by = get_all_awaited_by(p.pid) @@ -497,7 +880,9 @@ class TestGetStackTrace(unittest.TestCase): msg = str(re) if msg.startswith("Task list appears corrupted"): continue - elif msg.startswith("Invalid linked list structure reading remote memory"): + elif msg.startswith( + "Invalid linked list structure reading remote memory" + ): continue elif msg.startswith("Unknown error reading memory"): continue @@ -516,22 +901,174 @@ class TestGetStackTrace(unittest.TestCase): # expected: at least 1000 pending tasks self.assertGreaterEqual(len(entries), 1000) # the first three tasks stem from the code structure - self.assertIn(('Task-1', []), entries) - self.assertIn(('server task', [[['main'], 'Task-1', []]]), entries) - self.assertIn(('echo client spam', [[['main'], 'Task-1', []]]), entries) + main_stack = [ + FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]), + FrameInfo( + [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] + ), + FrameInfo([script_name, 60, "main"]), + ] + self.assertIn( + TaskInfo( + [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "server task", + [ + CoroInfo( + [ + [ + FrameInfo( + [ + base_events.__file__, + ANY, + "Server.serve_forever", + ] + ) + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, ANY, "main"] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "Task-4", + [ + CoroInfo( + [ + [ + FrameInfo( + [tasks.__file__, ANY, "sleep"] + ), + FrameInfo( + [ + script_name, + 38, + "echo_client", + ] + ), + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [ + script_name, + 41, + "echo_client_spam", + ] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) - expected_stack = [[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]] - tasks_with_stack = [task for task in entries if task[1] == expected_stack] - self.assertGreaterEqual(len(tasks_with_stack), 1000) + expected_awaited_by = [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, 41, "echo_client_spam"] + ), + ], + ANY, + ] + ) + ] + tasks_with_awaited = [ + task + for task in entries + if task.awaited_by == expected_awaited_by + ] + self.assertGreaterEqual(len(tasks_with_awaited), 1000) # the final task will have some random number, but it should for # sure be one of the echo client spam horde (In windows this is not true # for some reason) if sys.platform != "win32": - self.assertEqual([[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]], entries[-1][1]) + self.assertEqual( + tasks_with_awaited[-1].awaited_by, + entries[-1].awaited_by, + ) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -540,12 +1077,201 @@ class TestGetStackTrace(unittest.TestCase): p.wait(timeout=SHORT_TIMEOUT) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) - print(stack_trace) - self.assertEqual(stack_trace[0], "test_self_trace") + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + this_tread_stack = None + for thread_id, stack in stack_trace: + if thread_id == threading.get_native_id(): + this_tread_stack = stack + break + self.assertIsNotNone(this_tread_stack) + self.assertEqual( + stack[:2], + [ + FrameInfo( + [ + __file__, + get_stack_trace.__code__.co_firstlineno + 2, + "get_stack_trace", + ] + ), + FrameInfo( + [ + __file__, + self.test_self_trace.__code__.co_firstlineno + 6, + "TestGetStackTrace.test_self_trace", + ] + ), + ], + ) + + @skip_if_not_supported + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) + @requires_gil_enabled("Free threaded builds don't have an 'active thread'") + def test_only_active_thread(self): + # Test that only_active_thread parameter works correctly + port = find_unused_port() + script = textwrap.dedent( + f"""\ + import time, sys, socket, threading + + # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('localhost', {port})) + + def worker_thread(name, barrier, ready_event): + barrier.wait() # Synchronize thread start + ready_event.wait() # Wait for main thread signal + # Sleep to keep thread alive + time.sleep(10_000) + + def main_work(): + # Do busy work to hold the GIL + sock.sendall(b"working\\n") + count = 0 + while count < 100000000: + count += 1 + if count % 10000000 == 0: + pass # Keep main thread busy + sock.sendall(b"done\\n") + + # Create synchronization primitives + num_threads = 3 + barrier = threading.Barrier(num_threads + 1) # +1 for main thread + ready_event = threading.Event() + + # Start worker threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) + t.start() + threads.append(t) + + # Wait for all threads to be ready + barrier.wait() + + # Signal ready to parent process + sock.sendall(b"ready\\n") + + # Signal threads to start waiting + ready_event.set() + + # Now do busy work to hold the GIL + main_work() + """ + ) + + with os_helper.temp_dir() as work_dir: + script_dir = os.path.join(work_dir, "script_pkg") + os.mkdir(script_dir) + + # Create a socket server to communicate with the target process + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(("localhost", port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(1) + + script_name = _make_test_script(script_dir, "script", script) + client_socket = None + try: + p = subprocess.Popen([sys.executable, script_name]) + client_socket, _ = server_socket.accept() + server_socket.close() + + # Wait for ready signal + response = b"" + while b"ready" not in response: + response += client_socket.recv(1024) + + # Wait for the main thread to start its busy work + while b"working" not in response: + response += client_socket.recv(1024) + + # Get stack trace with all threads + unwinder_all = RemoteUnwinder(p.pid, all_threads=True) + for _ in range(10): + # Wait for the main thread to start its busy work + all_traces = unwinder_all.get_stack_trace() + found = False + for thread_id, stack in all_traces: + if not stack: + continue + current_frame = stack[0] + if ( + current_frame.funcname == "main_work" + and current_frame.lineno > 15 + ): + found = True + + if found: + break + # Give a bit of time to take the next sample + time.sleep(0.1) + else: + self.fail( + "Main thread did not start its busy work on time" + ) + + # Get stack trace with only GIL holder + unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True) + gil_traces = unwinder_gil.get_stack_trace() + + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + finally: + if client_socket is not None: + client_socket.close() + p.kill() + p.terminate() + p.wait(timeout=SHORT_TIMEOUT) + + # Verify we got multiple threads in all_traces + self.assertGreater( + len(all_traces), 1, "Should have multiple threads" + ) + + # Verify we got exactly one thread in gil_traces + self.assertEqual( + len(gil_traces), 1, "Should have exactly one GIL holder" + ) + + # The GIL holder should be in the all_traces list + gil_thread_id = gil_traces[0][0] + all_thread_ids = [trace[0] for trace in all_traces] + self.assertIn( + gil_thread_id, + all_thread_ids, + "GIL holder should be among all threads", + ) + + +class TestUnsupportedPlatformHandling(unittest.TestCase): + @unittest.skipIf( + sys.platform in ("linux", "darwin", "win32"), + "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", + ) + @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception") + def test_unsupported_platform_error(self): + with self.assertRaises(RuntimeError) as cm: + RemoteUnwinder(os.getpid()) + + self.assertIn( + "Reading the PyRuntime section is not supported on this platform", + str(cm.exception) + ) + if __name__ == "__main__": unittest.main() |