diff options
Diffstat (limited to 'Lib/test/test_external_inspection.py')
-rw-r--r-- | Lib/test/test_external_inspection.py | 664 |
1 files changed, 557 insertions, 107 deletions
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py index aa05db972f0..0f31c225e68 100644 --- a/Lib/test/test_external_inspection.py +++ b/Lib/test/test_external_inspection.py @@ -4,7 +4,10 @@ import textwrap import importlib import sys import socket -from test.support import os_helper, SHORT_TIMEOUT, busy_retry +import threading +from asyncio import staggered, taskgroups, base_events, tasks +from unittest.mock import ANY +from test.support import os_helper, SHORT_TIMEOUT, busy_retry, requires_gil_enabled from test.support.script_helper import make_script from test.support.socket_helper import find_unused_port @@ -13,33 +16,60 @@ import subprocess PROCESS_VM_READV_SUPPORTED = False try: - from _testexternalinspection import PROCESS_VM_READV_SUPPORTED - from _testexternalinspection import get_stack_trace - from _testexternalinspection import get_async_stack_trace - from _testexternalinspection import get_all_awaited_by + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + from _remote_debugging import RemoteUnwinder + from _remote_debugging import FrameInfo, CoroInfo, TaskInfo except ImportError: raise unittest.SkipTest( - "Test only runs when _testexternalinspection is available") + "Test only runs when _remote_debugging is available" + ) + def _make_test_script(script_dir, script_basename, source): to_return = make_script(script_dir, script_basename, source) importlib.invalidate_caches() return to_return -skip_if_not_supported = unittest.skipIf((sys.platform != "darwin" - and sys.platform != "linux" - and sys.platform != "win32"), - "Test only runs on Linux, Windows and MacOS") + +skip_if_not_supported = unittest.skipIf( + ( + sys.platform != "darwin" + and sys.platform != "linux" + and sys.platform != "win32" + ), + "Test only runs on Linux, Windows and MacOS", +) + + +def get_stack_trace(pid): + unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) + return unwinder.get_stack_trace() + + +def get_async_stack_trace(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_async_stack_trace() + + +def get_all_awaited_by(pid): + unwinder = RemoteUnwinder(pid, debug=True) + return unwinder.get_all_awaited_by() + + class TestGetStackTrace(unittest.TestCase): + maxDiff = None @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ - import time, sys, socket + script = textwrap.dedent( + f"""\ + import time, sys, socket, threading # Connect to the test process sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', {port})) @@ -48,15 +78,18 @@ class TestGetStackTrace(unittest.TestCase): for x in range(100): if x == 50: baz() + def baz(): foo() def foo(): - sock.sendall(b"ready") - time.sleep(1000) + sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number - bar() - """) + t = threading.Thread(target=bar) + t.start() + sock.sendall(b"ready:main\\n"); t.join() # same line number + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -65,21 +98,27 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") + response = b"" + while ( + b"ready:main" not in response + or b"ready:thread" not in response + ): + response += client_socket.recv(1024) stack_trace = get_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -87,22 +126,34 @@ class TestGetStackTrace(unittest.TestCase): p.terminate() p.wait(timeout=SHORT_TIMEOUT) - - expected_stack_trace = [ - 'foo', - 'baz', - 'bar', - '<module>' + thread_expected_stack_trace = [ + FrameInfo([script_name, 15, "foo"]), + FrameInfo([script_name, 12, "baz"]), + FrameInfo([script_name, 9, "bar"]), + FrameInfo([threading.__file__, ANY, "Thread.run"]), ] - self.assertEqual(stack_trace, expected_stack_trace) + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + self.assertIn((ANY, thread_expected_stack_trace), stack_trace) + + # Check that the main thread stack trace is in the result + frame = FrameInfo([script_name, 19, "<module>"]) + for _, stack in stack_trace: + if frame in stack: + break + else: + self.fail("Main thread stack trace not found in result") @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -112,8 +163,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) def c5(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c4(): await asyncio.sleep(0) @@ -142,7 +192,8 @@ class TestGetStackTrace(unittest.TestCase): return loop asyncio.run(main(), loop_factory={{TASK_FACTORY}}) - """) + """ + ) stack_trace = None for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": with ( @@ -151,19 +202,23 @@ class TestGetStackTrace(unittest.TestCase): ): script_dir = os.path.join(work_dir, "script_pkg") os.mkdir(script_dir) - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) + server_socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) script_name = _make_test_script( - script_dir, 'script', - script.format(TASK_FACTORY=task_factory_variant)) + script_dir, + "script", + script.format(TASK_FACTORY=task_factory_variant), + ) client_socket = None try: - p = subprocess.Popen( - [sys.executable, script_name] - ) + p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() response = client_socket.recv(1024) @@ -171,7 +226,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -182,25 +238,63 @@ class TestGetStackTrace(unittest.TestCase): # sets are unordered, so we want to sort "awaited_by"s stack_trace[2].sort(key=lambda x: x[1]) - root_task = "Task-1" expected_stack_trace = [ - ["c5", "c4", "c3", "c2"], + [ + FrameInfo([script_name, 10, "c5"]), + FrameInfo([script_name, 14, "c4"]), + FrameInfo([script_name, 17, "c3"]), + FrameInfo([script_name, 20, "c2"]), + ], "c2_root", [ - [["main"], root_task, []], - [["c1"], "sub_main_1", [[["main"], root_task, []]]], - [["c1"], "sub_main_2", [[["main"], root_task, []]]], + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo([script_name, 26, "main"]), + ], + "Task-1", + ] + ), + CoroInfo( + [ + [FrameInfo([script_name, 23, "c1"])], + "sub_main_1", + ] + ), + CoroInfo( + [ + [FrameInfo([script_name, 23, "c1"])], + "sub_main_2", + ] + ), ], ] self.assertEqual(stack_trace, expected_stack_trace) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_asyncgen_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -210,8 +304,7 @@ class TestGetStackTrace(unittest.TestCase): sock.connect(('localhost', {port})) async def gen_nested_call(): - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def gen(): for num in range(2): @@ -224,7 +317,8 @@ class TestGetStackTrace(unittest.TestCase): pass asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -232,10 +326,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -245,7 +339,9 @@ class TestGetStackTrace(unittest.TestCase): self.assertEqual(response, b"ready") stack_trace = get_async_stack_trace(p.pid) except PermissionError: - self.skipTest("Insufficient permissions to read the stack trace") + self.skipTest( + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -257,17 +353,26 @@ class TestGetStackTrace(unittest.TestCase): stack_trace[2].sort(key=lambda x: x[1]) expected_stack_trace = [ - ['gen_nested_call', 'gen', 'main'], 'Task-1', [] + [ + FrameInfo([script_name, 10, "gen_nested_call"]), + FrameInfo([script_name, 16, "gen"]), + FrameInfo([script_name, 19, "main"]), + ], + "Task-1", + [], ] self.assertEqual(stack_trace, expected_stack_trace) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_gather_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import time import sys @@ -278,8 +383,7 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) @@ -292,7 +396,8 @@ class TestGetStackTrace(unittest.TestCase): await asyncio.gather(c1(), c2()) asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -300,10 +405,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -314,7 +419,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -325,18 +431,26 @@ class TestGetStackTrace(unittest.TestCase): # sets are unordered, so we want to sort "awaited_by"s stack_trace[2].sort(key=lambda x: x[1]) - expected_stack_trace = [ - ['deep', 'c1'], 'Task-2', [[['main'], 'Task-1', []]] + expected_stack_trace = [ + [ + FrameInfo([script_name, 11, "deep"]), + FrameInfo([script_name, 15, "c1"]), + ], + "Task-2", + [CoroInfo([[FrameInfo([script_name, 21, "main"])], "Task-1"])], ] self.assertEqual(stack_trace, expected_stack_trace) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_staggered_race_remote_stack_trace(self): # Spawn a process with some realistic Python code port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio.staggered import time import sys @@ -347,15 +461,14 @@ class TestGetStackTrace(unittest.TestCase): async def deep(): await asyncio.sleep(0) - sock.sendall(b"ready") - time.sleep(10000) + sock.sendall(b"ready"); time.sleep(10_000) # same line number async def c1(): await asyncio.sleep(0) await deep() async def c2(): - await asyncio.sleep(10000) + await asyncio.sleep(10_000) async def main(): await asyncio.staggered.staggered_race( @@ -364,7 +477,8 @@ class TestGetStackTrace(unittest.TestCase): ) asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -372,10 +486,10 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) @@ -386,7 +500,8 @@ class TestGetStackTrace(unittest.TestCase): stack_trace = get_async_stack_trace(p.pid) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -396,18 +511,44 @@ class TestGetStackTrace(unittest.TestCase): # sets are unordered, so we want to sort "awaited_by"s stack_trace[2].sort(key=lambda x: x[1]) - - expected_stack_trace = [ - ['deep', 'c1', 'run_one_coro'], 'Task-2', [[['main'], 'Task-1', []]] + expected_stack_trace = [ + [ + FrameInfo([script_name, 11, "deep"]), + FrameInfo([script_name, 15, "c1"]), + FrameInfo( + [ + staggered.__file__, + ANY, + "staggered_race.<locals>.run_one_coro", + ] + ), + ], + "Task-2", + [ + CoroInfo( + [ + [ + FrameInfo( + [staggered.__file__, ANY, "staggered_race"] + ), + FrameInfo([script_name, 21, "main"]), + ], + "Task-1", + ] + ) + ], ] self.assertEqual(stack_trace, expected_stack_trace) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_async_global_awaited_by(self): port = find_unused_port() - script = textwrap.dedent(f"""\ + script = textwrap.dedent( + f"""\ import asyncio import os import random @@ -443,6 +584,8 @@ class TestGetStackTrace(unittest.TestCase): assert message == data.decode() writer.close() await writer.wait_closed() + # Signal we are ready to sleep + sock.sendall(b"ready") await asyncio.sleep(SHORT_TIMEOUT) async def echo_client_spam(server): @@ -452,8 +595,10 @@ class TestGetStackTrace(unittest.TestCase): random.shuffle(msg) tg.create_task(echo_client("".join(msg))) await asyncio.sleep(0) - # at least a 1000 tasks created - sock.sendall(b"ready") + # at least a 1000 tasks created. Each task will signal + # when is ready to avoid the race caused by the fact that + # tasks are waited on tg.__exit__ and we cannot signal when + # that happens otherwise # at this point all client tasks completed without assertion errors # let's wrap up the test server.close() @@ -468,7 +613,8 @@ class TestGetStackTrace(unittest.TestCase): tg.create_task(echo_client_spam(server), name="echo client spam") asyncio.run(main()) - """) + """ + ) stack_trace = None with os_helper.temp_dir() as work_dir: script_dir = os.path.join(work_dir, "script_pkg") @@ -476,17 +622,19 @@ class TestGetStackTrace(unittest.TestCase): # Create a socket server to communicate with the target process server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('localhost', port)) + server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) - script_name = _make_test_script(script_dir, 'script', script) + script_name = _make_test_script(script_dir, "script", script) client_socket = None try: p = subprocess.Popen([sys.executable, script_name]) client_socket, _ = server_socket.accept() server_socket.close() - response = client_socket.recv(1024) - self.assertEqual(response, b"ready") + for _ in range(1000): + expected_response = b"ready" + response = client_socket.recv(len(expected_response)) + self.assertEqual(response, expected_response) for _ in busy_retry(SHORT_TIMEOUT): try: all_awaited_by = get_all_awaited_by(p.pid) @@ -497,7 +645,9 @@ class TestGetStackTrace(unittest.TestCase): msg = str(re) if msg.startswith("Task list appears corrupted"): continue - elif msg.startswith("Invalid linked list structure reading remote memory"): + elif msg.startswith( + "Invalid linked list structure reading remote memory" + ): continue elif msg.startswith("Unknown error reading memory"): continue @@ -516,22 +666,174 @@ class TestGetStackTrace(unittest.TestCase): # expected: at least 1000 pending tasks self.assertGreaterEqual(len(entries), 1000) # the first three tasks stem from the code structure - self.assertIn(('Task-1', []), entries) - self.assertIn(('server task', [[['main'], 'Task-1', []]]), entries) - self.assertIn(('echo client spam', [[['main'], 'Task-1', []]]), entries) + main_stack = [ + FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]), + FrameInfo( + [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] + ), + FrameInfo([script_name, 60, "main"]), + ] + self.assertIn( + TaskInfo( + [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "server task", + [ + CoroInfo( + [ + [ + FrameInfo( + [ + base_events.__file__, + ANY, + "Server.serve_forever", + ] + ) + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, ANY, "main"] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) + self.assertIn( + TaskInfo( + [ + ANY, + "Task-4", + [ + CoroInfo( + [ + [ + FrameInfo( + [tasks.__file__, ANY, "sleep"] + ), + FrameInfo( + [ + script_name, + 38, + "echo_client", + ] + ), + ], + ANY, + ] + ) + ], + [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [ + script_name, + 41, + "echo_client_spam", + ] + ), + ], + ANY, + ] + ) + ], + ] + ), + entries, + ) - expected_stack = [[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]] - tasks_with_stack = [task for task in entries if task[1] == expected_stack] - self.assertGreaterEqual(len(tasks_with_stack), 1000) + expected_awaited_by = [ + CoroInfo( + [ + [ + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup._aexit", + ] + ), + FrameInfo( + [ + taskgroups.__file__, + ANY, + "TaskGroup.__aexit__", + ] + ), + FrameInfo( + [script_name, 41, "echo_client_spam"] + ), + ], + ANY, + ] + ) + ] + tasks_with_awaited = [ + task + for task in entries + if task.awaited_by == expected_awaited_by + ] + self.assertGreaterEqual(len(tasks_with_awaited), 1000) # the final task will have some random number, but it should for # sure be one of the echo client spam horde (In windows this is not true # for some reason) if sys.platform != "win32": - self.assertEqual([[['echo_client_spam'], 'echo client spam', [[['main'], 'Task-1', []]]]], entries[-1][1]) + self.assertEqual( + tasks_with_awaited[-1].awaited_by, + entries[-1].awaited_by, + ) except PermissionError: self.skipTest( - "Insufficient permissions to read the stack trace") + "Insufficient permissions to read the stack trace" + ) finally: if client_socket is not None: client_socket.close() @@ -540,12 +842,160 @@ class TestGetStackTrace(unittest.TestCase): p.wait(timeout=SHORT_TIMEOUT) @skip_if_not_supported - @unittest.skipIf(sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, - "Test only runs on Linux with process_vm_readv support") + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) def test_self_trace(self): stack_trace = get_stack_trace(os.getpid()) - print(stack_trace) - self.assertEqual(stack_trace[0], "test_self_trace") + # Is possible that there are more threads, so we check that the + # expected stack traces are in the result (looking at you Windows!) + this_tread_stack = None + for thread_id, stack in stack_trace: + if thread_id == threading.get_native_id(): + this_tread_stack = stack + break + self.assertIsNotNone(this_tread_stack) + self.assertEqual( + stack[:2], + [ + FrameInfo( + [ + __file__, + get_stack_trace.__code__.co_firstlineno + 2, + "get_stack_trace", + ] + ), + FrameInfo( + [ + __file__, + self.test_self_trace.__code__.co_firstlineno + 6, + "TestGetStackTrace.test_self_trace", + ] + ), + ], + ) + + @skip_if_not_supported + @unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", + ) + @requires_gil_enabled("Free threaded builds don't have an 'active thread'") + def test_only_active_thread(self): + # Test that only_active_thread parameter works correctly + port = find_unused_port() + script = textwrap.dedent( + f"""\ + import time, sys, socket, threading + + # Connect to the test process + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('localhost', {port})) + + def worker_thread(name, barrier, ready_event): + barrier.wait() # Synchronize thread start + ready_event.wait() # Wait for main thread signal + # Sleep to keep thread alive + time.sleep(10_000) + + def main_work(): + # Do busy work to hold the GIL + sock.sendall(b"working\\n") + count = 0 + while count < 100000000: + count += 1 + if count % 10000000 == 0: + pass # Keep main thread busy + sock.sendall(b"done\\n") + + # Create synchronization primitives + num_threads = 3 + barrier = threading.Barrier(num_threads + 1) # +1 for main thread + ready_event = threading.Event() + + # Start worker threads + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) + t.start() + threads.append(t) + + # Wait for all threads to be ready + barrier.wait() + + # Signal ready to parent process + sock.sendall(b"ready\\n") + + # Signal threads to start waiting + ready_event.set() + + # Give threads time to start sleeping + time.sleep(0.1) + + # Now do busy work to hold the GIL + main_work() + """ + ) + + with os_helper.temp_dir() as work_dir: + script_dir = os.path.join(work_dir, "script_pkg") + os.mkdir(script_dir) + + # Create a socket server to communicate with the target process + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(("localhost", port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(1) + + script_name = _make_test_script(script_dir, "script", script) + client_socket = None + try: + p = subprocess.Popen([sys.executable, script_name]) + client_socket, _ = server_socket.accept() + server_socket.close() + + # Wait for ready signal + response = b"" + while b"ready" not in response: + response += client_socket.recv(1024) + + # Wait for the main thread to start its busy work + while b"working" not in response: + response += client_socket.recv(1024) + + # Get stack trace with all threads + unwinder_all = RemoteUnwinder(p.pid, all_threads=True) + all_traces = unwinder_all.get_stack_trace() + + # Get stack trace with only GIL holder + unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True) + gil_traces = unwinder_gil.get_stack_trace() + + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + finally: + if client_socket is not None: + client_socket.close() + p.kill() + p.terminate() + p.wait(timeout=SHORT_TIMEOUT) + + # Verify we got multiple threads in all_traces + self.assertGreater(len(all_traces), 1, "Should have multiple threads") + + # Verify we got exactly one thread in gil_traces + self.assertEqual(len(gil_traces), 1, "Should have exactly one GIL holder") + + # The GIL holder should be in the all_traces list + gil_thread_id = gil_traces[0][0] + all_thread_ids = [trace[0] for trace in all_traces] + self.assertIn(gil_thread_id, all_thread_ids, + "GIL holder should be among all threads") + if __name__ == "__main__": unittest.main() |