aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/test/test_external_inspection.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_external_inspection.py')
-rw-r--r--Lib/test/test_external_inspection.py82
1 files changed, 63 insertions, 19 deletions
diff --git a/Lib/test/test_external_inspection.py b/Lib/test/test_external_inspection.py
index ad3f669a030..2b4b63a030b 100644
--- a/Lib/test/test_external_inspection.py
+++ b/Lib/test/test_external_inspection.py
@@ -4,6 +4,7 @@ import textwrap
import importlib
import sys
import socket
+import threading
from asyncio import staggered, taskgroups
from unittest.mock import ANY
from test.support import os_helper, SHORT_TIMEOUT, busy_retry
@@ -16,9 +17,7 @@ PROCESS_VM_READV_SUPPORTED = False
try:
from _remote_debugging import PROCESS_VM_READV_SUPPORTED
- from _remote_debugging import get_stack_trace
- from _remote_debugging import get_async_stack_trace
- from _remote_debugging import get_all_awaited_by
+ from _remote_debugging import RemoteUnwinder
except ImportError:
raise unittest.SkipTest("Test only runs when _remote_debugging is available")
@@ -34,7 +33,23 @@ skip_if_not_supported = unittest.skipIf(
)
+def get_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
+ return unwinder.get_stack_trace()
+
+
+def get_async_stack_trace(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_async_stack_trace()
+
+
+def get_all_awaited_by(pid):
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_all_awaited_by()
+
+
class TestGetStackTrace(unittest.TestCase):
+ maxDiff = None
@skip_if_not_supported
@unittest.skipIf(
@@ -46,7 +61,7 @@ class TestGetStackTrace(unittest.TestCase):
port = find_unused_port()
script = textwrap.dedent(
f"""\
- import time, sys, socket
+ import time, sys, socket, threading
# Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
@@ -55,13 +70,16 @@ class TestGetStackTrace(unittest.TestCase):
for x in range(100):
if x == 50:
baz()
+
def baz():
foo()
def foo():
- sock.sendall(b"ready"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number
- bar()
+ t = threading.Thread(target=bar)
+ t.start()
+ sock.sendall(b"ready:main\\n"); t.join() # same line number
"""
)
stack_trace = None
@@ -82,8 +100,9 @@ class TestGetStackTrace(unittest.TestCase):
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
+ response = b""
+ while b"ready:main" not in response or b"ready:thread" not in response:
+ response += client_socket.recv(1024)
stack_trace = get_stack_trace(p.pid)
except PermissionError:
self.skipTest("Insufficient permissions to read the stack trace")
@@ -94,13 +113,23 @@ class TestGetStackTrace(unittest.TestCase):
p.terminate()
p.wait(timeout=SHORT_TIMEOUT)
- expected_stack_trace = [
- ("foo", script_name, 14),
- ("baz", script_name, 11),
+ thread_expected_stack_trace = [
+ ("foo", script_name, 15),
+ ("baz", script_name, 12),
("bar", script_name, 9),
- ("<module>", script_name, 16),
+ ('Thread.run', threading.__file__, ANY)
]
- self.assertEqual(stack_trace, expected_stack_trace)
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ self.assertIn((ANY, thread_expected_stack_trace), stack_trace)
+
+ # Check that the main thread stack trace is in the result
+ frame = ("<module>", script_name, 19)
+ for _, stack in stack_trace:
+ if frame in stack:
+ break
+ else:
+ self.fail("Main thread stack trace not found in result")
@skip_if_not_supported
@unittest.skipIf(
@@ -700,13 +729,28 @@ class TestGetStackTrace(unittest.TestCase):
)
def test_self_trace(self):
stack_trace = get_stack_trace(os.getpid())
+ # Is possible that there are more threads, so we check that the
+ # expected stack traces are in the result (looking at you Windows!)
+ this_tread_stack = None
+ for thread_id, stack in stack_trace:
+ if thread_id == threading.get_native_id():
+ this_tread_stack = stack
+ break
+ self.assertIsNotNone(this_tread_stack)
self.assertEqual(
- stack_trace[0],
- (
- "TestGetStackTrace.test_self_trace",
- __file__,
- self.test_self_trace.__code__.co_firstlineno + 6,
- ),
+ stack[:2],
+ [
+ (
+ "get_stack_trace",
+ __file__,
+ get_stack_trace.__code__.co_firstlineno + 2,
+ ),
+ (
+ "TestGetStackTrace.test_self_trace",
+ __file__,
+ self.test_self_trace.__code__.co_firstlineno + 6,
+ ),
+ ]
)