diff options
Diffstat (limited to 'Lib/test/test_sample_profiler.py')
-rw-r--r-- | Lib/test/test_sample_profiler.py | 1877 |
1 files changed, 1877 insertions, 0 deletions
diff --git a/Lib/test/test_sample_profiler.py b/Lib/test/test_sample_profiler.py new file mode 100644 index 00000000000..2c7fa1cba71 --- /dev/null +++ b/Lib/test/test_sample_profiler.py @@ -0,0 +1,1877 @@ +"""Tests for the sampling profiler (profile.sample).""" + +import contextlib +import io +import marshal +import os +import socket +import subprocess +import sys +import tempfile +import unittest +from unittest import mock + +from profile.pstats_collector import PstatsCollector +from profile.stack_collector import ( + CollapsedStackCollector, +) + +from test.support.os_helper import unlink +from test.support import force_not_colorized_test_class, SHORT_TIMEOUT +from test.support.socket_helper import find_unused_port +from test.support import requires_subprocess + +PROCESS_VM_READV_SUPPORTED = False + +try: + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + import _remote_debugging +except ImportError: + raise unittest.SkipTest( + "Test only runs when _remote_debugging is available" + ) +else: + import profile.sample + from profile.sample import SampleProfiler + + + +class MockFrameInfo: + """Mock FrameInfo for testing since the real one isn't accessible.""" + + def __init__(self, filename, lineno, funcname): + self.filename = filename + self.lineno = lineno + self.funcname = funcname + + def __repr__(self): + return f"MockFrameInfo(filename='{self.filename}', lineno={self.lineno}, funcname='{self.funcname}')" + + +skip_if_not_supported = unittest.skipIf( + ( + sys.platform != "darwin" + and sys.platform != "linux" + and sys.platform != "win32" + ), + "Test only runs on Linux, Windows and MacOS", +) + + +@contextlib.contextmanager +def test_subprocess(script): + # Find an unused port for socket communication + port = find_unused_port() + + # Inject socket connection code at the beginning of the script + socket_code = f''' +import socket +_test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +_test_sock.connect(('localhost', {port})) +_test_sock.sendall(b"ready") +''' + + # Combine socket code with user script + full_script = socket_code + script + + # Create server socket to wait for process to be ready + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(("localhost", port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(1) + + proc = subprocess.Popen( + [sys.executable, "-c", full_script], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + client_socket = None + try: + # Wait for process to connect and send ready signal + client_socket, _ = server_socket.accept() + server_socket.close() + response = client_socket.recv(1024) + if response != b"ready": + raise RuntimeError(f"Unexpected response from subprocess: {response}") + + yield proc + finally: + if client_socket is not None: + client_socket.close() + if proc.poll() is None: + proc.kill() + proc.wait() + + +def close_and_unlink(file): + file.close() + unlink(file.name) + + +class TestSampleProfilerComponents(unittest.TestCase): + """Unit tests for individual profiler components.""" + + def test_mock_frame_info_with_empty_and_unicode_values(self): + """Test MockFrameInfo handles empty strings, unicode characters, and very long names correctly.""" + # Test with empty strings + frame = MockFrameInfo("", 0, "") + self.assertEqual(frame.filename, "") + self.assertEqual(frame.lineno, 0) + self.assertEqual(frame.funcname, "") + self.assertIn("filename=''", repr(frame)) + + # Test with unicode characters + frame = MockFrameInfo("文件.py", 42, "函数名") + self.assertEqual(frame.filename, "文件.py") + self.assertEqual(frame.funcname, "函数名") + + # Test with very long names + long_filename = "x" * 1000 + ".py" + long_funcname = "func_" + "x" * 1000 + frame = MockFrameInfo(long_filename, 999999, long_funcname) + self.assertEqual(frame.filename, long_filename) + self.assertEqual(frame.lineno, 999999) + self.assertEqual(frame.funcname, long_funcname) + + def test_pstats_collector_with_extreme_intervals_and_empty_data(self): + """Test PstatsCollector handles zero/large intervals, empty frames, None thread IDs, and duplicate frames.""" + # Test with zero interval + collector = PstatsCollector(sample_interval_usec=0) + self.assertEqual(collector.sample_interval_usec, 0) + + # Test with very large interval + collector = PstatsCollector(sample_interval_usec=1000000000) + self.assertEqual(collector.sample_interval_usec, 1000000000) + + # Test collecting empty frames list + collector = PstatsCollector(sample_interval_usec=1000) + collector.collect([]) + self.assertEqual(len(collector.result), 0) + + # Test collecting frames with None thread id + test_frames = [(None, [MockFrameInfo("file.py", 10, "func")])] + collector.collect(test_frames) + # Should still process the frames + self.assertEqual(len(collector.result), 1) + + # Test collecting duplicate frames in same sample + test_frames = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 10, "func1"), # Duplicate + ], + ) + ] + collector = PstatsCollector(sample_interval_usec=1000) + collector.collect(test_frames) + # Should count both occurrences + self.assertEqual( + collector.result[("file.py", 10, "func1")]["cumulative_calls"], 2 + ) + + def test_pstats_collector_single_frame_stacks(self): + """Test PstatsCollector with single-frame call stacks to trigger len(frames) <= 1 branch.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Test with exactly one frame (should trigger the <= 1 condition) + single_frame = [(1, [MockFrameInfo("single.py", 10, "single_func")])] + collector.collect(single_frame) + + # Should record the single frame with inline call + self.assertEqual(len(collector.result), 1) + single_key = ("single.py", 10, "single_func") + self.assertIn(single_key, collector.result) + self.assertEqual(collector.result[single_key]["direct_calls"], 1) + self.assertEqual(collector.result[single_key]["cumulative_calls"], 1) + + # Test with empty frames (should also trigger <= 1 condition) + empty_frames = [(1, [])] + collector.collect(empty_frames) + + # Should not add any new entries + self.assertEqual( + len(collector.result), 1 + ) # Still just the single frame + + # Test mixed single and multi-frame stacks + mixed_frames = [ + ( + 1, + [MockFrameInfo("single2.py", 20, "single_func2")], + ), # Single frame + ( + 2, + [ # Multi-frame stack + MockFrameInfo("multi.py", 30, "multi_func1"), + MockFrameInfo("multi.py", 40, "multi_func2"), + ], + ), + ] + collector.collect(mixed_frames) + + # Should have recorded all functions + self.assertEqual( + len(collector.result), 4 + ) # single + single2 + multi1 + multi2 + + # Verify single frame handling + single2_key = ("single2.py", 20, "single_func2") + self.assertIn(single2_key, collector.result) + self.assertEqual(collector.result[single2_key]["direct_calls"], 1) + self.assertEqual(collector.result[single2_key]["cumulative_calls"], 1) + + # Verify multi-frame handling still works + multi1_key = ("multi.py", 30, "multi_func1") + multi2_key = ("multi.py", 40, "multi_func2") + self.assertIn(multi1_key, collector.result) + self.assertIn(multi2_key, collector.result) + self.assertEqual(collector.result[multi1_key]["direct_calls"], 1) + self.assertEqual( + collector.result[multi2_key]["cumulative_calls"], 1 + ) # Called from multi1 + + def test_collapsed_stack_collector_with_empty_and_deep_stacks(self): + """Test CollapsedStackCollector handles empty frames, single-frame stacks, and very deep call stacks.""" + collector = CollapsedStackCollector() + + # Test with empty frames + collector.collect([]) + self.assertEqual(len(collector.call_trees), 0) + + # Test with single frame stack + test_frames = [(1, [("file.py", 10, "func")])] + collector.collect(test_frames) + self.assertEqual(len(collector.call_trees), 1) + self.assertEqual(collector.call_trees[0], [("file.py", 10, "func")]) + + # Test with very deep stack + deep_stack = [(f"file{i}.py", i, f"func{i}") for i in range(100)] + test_frames = [(1, deep_stack)] + collector = CollapsedStackCollector() + collector.collect(test_frames) + self.assertEqual(len(collector.call_trees[0]), 100) + # Check it's properly reversed + self.assertEqual( + collector.call_trees[0][0], ("file99.py", 99, "func99") + ) + self.assertEqual(collector.call_trees[0][-1], ("file0.py", 0, "func0")) + + def test_pstats_collector_basic(self): + """Test basic PstatsCollector functionality.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Test empty state + self.assertEqual(len(collector.result), 0) + self.assertEqual(len(collector.stats), 0) + + # Test collecting sample data + test_frames = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + collector.collect(test_frames) + + # Should have recorded calls for both functions + self.assertEqual(len(collector.result), 2) + self.assertIn(("file.py", 10, "func1"), collector.result) + self.assertIn(("file.py", 20, "func2"), collector.result) + + # Top-level function should have direct call + self.assertEqual( + collector.result[("file.py", 10, "func1")]["direct_calls"], 1 + ) + self.assertEqual( + collector.result[("file.py", 10, "func1")]["cumulative_calls"], 1 + ) + + # Calling function should have cumulative call but no direct calls + self.assertEqual( + collector.result[("file.py", 20, "func2")]["cumulative_calls"], 1 + ) + self.assertEqual( + collector.result[("file.py", 20, "func2")]["direct_calls"], 0 + ) + + def test_pstats_collector_create_stats(self): + """Test PstatsCollector stats creation.""" + collector = PstatsCollector( + sample_interval_usec=1000000 + ) # 1 second intervals + + test_frames = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + collector.collect(test_frames) + collector.collect(test_frames) # Collect twice + + collector.create_stats() + + # Check stats format: (direct_calls, cumulative_calls, tt, ct, callers) + func1_stats = collector.stats[("file.py", 10, "func1")] + self.assertEqual(func1_stats[0], 2) # direct_calls (top of stack) + self.assertEqual(func1_stats[1], 2) # cumulative_calls + self.assertEqual( + func1_stats[2], 2.0 + ) # tt (total time - 2 samples * 1 sec) + self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) + + func2_stats = collector.stats[("file.py", 20, "func2")] + self.assertEqual( + func2_stats[0], 0 + ) # direct_calls (never top of stack) + self.assertEqual( + func2_stats[1], 2 + ) # cumulative_calls (appears in stack) + self.assertEqual(func2_stats[2], 0.0) # tt (no direct calls) + self.assertEqual(func2_stats[3], 2.0) # ct (cumulative time) + + def test_collapsed_stack_collector_basic(self): + collector = CollapsedStackCollector() + + # Test empty state + self.assertEqual(len(collector.call_trees), 0) + self.assertEqual(len(collector.function_samples), 0) + + # Test collecting sample data + test_frames = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] + collector.collect(test_frames) + + # Should store call tree (reversed) + self.assertEqual(len(collector.call_trees), 1) + expected_tree = [("file.py", 20, "func2"), ("file.py", 10, "func1")] + self.assertEqual(collector.call_trees[0], expected_tree) + + # Should count function samples + self.assertEqual( + collector.function_samples[("file.py", 10, "func1")], 1 + ) + self.assertEqual( + collector.function_samples[("file.py", 20, "func2")], 1 + ) + + def test_collapsed_stack_collector_export(self): + collapsed_out = tempfile.NamedTemporaryFile(delete=False) + self.addCleanup(close_and_unlink, collapsed_out) + + collector = CollapsedStackCollector() + + test_frames1 = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] + test_frames2 = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] # Same stack + test_frames3 = [(1, [("other.py", 5, "other_func")])] + + collector.collect(test_frames1) + collector.collect(test_frames2) + collector.collect(test_frames3) + + collector.export(collapsed_out.name) + # Check file contents + with open(collapsed_out.name, "r") as f: + content = f.read() + + lines = content.strip().split("\n") + self.assertEqual(len(lines), 2) # Two unique stacks + + # Check collapsed format: file:func:line;file:func:line count + stack1_expected = "file.py:func2:20;file.py:func1:10 2" + stack2_expected = "other.py:other_func:5 1" + + self.assertIn(stack1_expected, lines) + self.assertIn(stack2_expected, lines) + + def test_pstats_collector_export(self): + collector = PstatsCollector( + sample_interval_usec=1000000 + ) # 1 second intervals + + test_frames1 = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + test_frames2 = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] # Same stack + test_frames3 = [(1, [MockFrameInfo("other.py", 5, "other_func")])] + + collector.collect(test_frames1) + collector.collect(test_frames2) + collector.collect(test_frames3) + + pstats_out = tempfile.NamedTemporaryFile( + suffix=".pstats", delete=False + ) + self.addCleanup(close_and_unlink, pstats_out) + collector.export(pstats_out.name) + + # Check file can be loaded with marshal + with open(pstats_out.name, "rb") as f: + stats_data = marshal.load(f) + + # Should be a dictionary with the sampled marker + self.assertIsInstance(stats_data, dict) + self.assertIn(("__sampled__",), stats_data) + self.assertTrue(stats_data[("__sampled__",)]) + + # Should have function data + function_entries = [ + k for k in stats_data.keys() if k != ("__sampled__",) + ] + self.assertGreater(len(function_entries), 0) + + # Check specific function stats format: (cc, nc, tt, ct, callers) + func1_key = ("file.py", 10, "func1") + func2_key = ("file.py", 20, "func2") + other_key = ("other.py", 5, "other_func") + + self.assertIn(func1_key, stats_data) + self.assertIn(func2_key, stats_data) + self.assertIn(other_key, stats_data) + + # Check func1 stats (should have 2 samples) + func1_stats = stats_data[func1_key] + self.assertEqual(func1_stats[0], 2) # total_calls + self.assertEqual(func1_stats[1], 2) # nc (non-recursive calls) + self.assertEqual(func1_stats[2], 2.0) # tt (total time) + self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) + + +class TestSampleProfiler(unittest.TestCase): + """Test the SampleProfiler class.""" + + def test_sample_profiler_initialization(self): + """Test SampleProfiler initialization with various parameters.""" + from profile.sample import SampleProfiler + + # Mock RemoteUnwinder to avoid permission issues + with mock.patch( + "_remote_debugging.RemoteUnwinder" + ) as mock_unwinder_class: + mock_unwinder_class.return_value = mock.MagicMock() + + # Test basic initialization + profiler = SampleProfiler( + pid=12345, sample_interval_usec=1000, all_threads=False + ) + self.assertEqual(profiler.pid, 12345) + self.assertEqual(profiler.sample_interval_usec, 1000) + self.assertEqual(profiler.all_threads, False) + + # Test with all_threads=True + profiler = SampleProfiler( + pid=54321, sample_interval_usec=5000, all_threads=True + ) + self.assertEqual(profiler.pid, 54321) + self.assertEqual(profiler.sample_interval_usec, 5000) + self.assertEqual(profiler.all_threads, True) + + def test_sample_profiler_sample_method_timing(self): + """Test that the sample method respects duration and handles timing correctly.""" + from profile.sample import SampleProfiler + + # Mock the unwinder to avoid needing a real process + mock_unwinder = mock.MagicMock() + mock_unwinder.get_stack_trace.return_value = [ + ( + 1, + [ + mock.MagicMock( + filename="test.py", lineno=10, funcname="test_func" + ) + ], + ) + ] + + with mock.patch( + "_remote_debugging.RemoteUnwinder" + ) as mock_unwinder_class: + mock_unwinder_class.return_value = mock_unwinder + + profiler = SampleProfiler( + pid=12345, sample_interval_usec=100000, all_threads=False + ) # 100ms interval + + # Mock collector + mock_collector = mock.MagicMock() + + # Mock time to control the sampling loop + start_time = 1000.0 + times = [ + start_time + i * 0.1 for i in range(12) + ] # 0, 0.1, 0.2, ..., 1.1 seconds + + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample(mock_collector, duration_sec=1) + + result = output.getvalue() + + # Should have captured approximately 10 samples (1 second / 0.1 second interval) + self.assertIn("Captured", result) + self.assertIn("samples", result) + + # Verify collector was called multiple times + self.assertGreaterEqual(mock_collector.collect.call_count, 5) + self.assertLessEqual(mock_collector.collect.call_count, 11) + + def test_sample_profiler_error_handling(self): + """Test that the sample method handles errors gracefully.""" + from profile.sample import SampleProfiler + + # Mock unwinder that raises errors + mock_unwinder = mock.MagicMock() + error_sequence = [ + RuntimeError("Process died"), + [ + ( + 1, + [ + mock.MagicMock( + filename="test.py", lineno=10, funcname="test_func" + ) + ], + ) + ], + UnicodeDecodeError("utf-8", b"", 0, 1, "invalid"), + [ + ( + 1, + [ + mock.MagicMock( + filename="test.py", + lineno=20, + funcname="test_func2", + ) + ], + ) + ], + OSError("Permission denied"), + ] + mock_unwinder.get_stack_trace.side_effect = error_sequence + + with mock.patch( + "_remote_debugging.RemoteUnwinder" + ) as mock_unwinder_class: + mock_unwinder_class.return_value = mock_unwinder + + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + + mock_collector = mock.MagicMock() + + # Control timing to run exactly 5 samples + times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06] + + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample(mock_collector, duration_sec=0.05) + + result = output.getvalue() + + # Should report error rate + self.assertIn("Error rate:", result) + self.assertIn("%", result) + + # Collector should have been called only for successful samples (should be > 0) + self.assertGreater(mock_collector.collect.call_count, 0) + self.assertLessEqual(mock_collector.collect.call_count, 3) + + def test_sample_profiler_missed_samples_warning(self): + """Test that the profiler warns about missed samples when sampling is too slow.""" + from profile.sample import SampleProfiler + + mock_unwinder = mock.MagicMock() + mock_unwinder.get_stack_trace.return_value = [ + ( + 1, + [ + mock.MagicMock( + filename="test.py", lineno=10, funcname="test_func" + ) + ], + ) + ] + + with mock.patch( + "_remote_debugging.RemoteUnwinder" + ) as mock_unwinder_class: + mock_unwinder_class.return_value = mock_unwinder + + # Use very short interval that we'll miss + profiler = SampleProfiler( + pid=12345, sample_interval_usec=1000, all_threads=False + ) # 1ms interval + + mock_collector = mock.MagicMock() + + # Simulate slow sampling where we miss many samples + times = [ + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + ] # Extra time points to avoid StopIteration + + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample(mock_collector, duration_sec=0.5) + + result = output.getvalue() + + # Should warn about missed samples + self.assertIn("Warning: missed", result) + self.assertIn("samples from the expected total", result) + + +@force_not_colorized_test_class +class TestPrintSampledStats(unittest.TestCase): + """Test the print_sampled_stats function.""" + + def setUp(self): + """Set up test data.""" + # Mock stats data + self.mock_stats = mock.MagicMock() + self.mock_stats.stats = { + ("file1.py", 10, "func1"): ( + 100, + 100, + 0.5, + 0.5, + {}, + ), # cc, nc, tt, ct, callers + ("file2.py", 20, "func2"): (50, 50, 0.25, 0.3, {}), + ("file3.py", 30, "func3"): (200, 200, 1.5, 2.0, {}), + ("file4.py", 40, "func4"): ( + 10, + 10, + 0.001, + 0.001, + {}, + ), # millisecond range + ("file5.py", 50, "func5"): ( + 5, + 5, + 0.000001, + 0.000002, + {}, + ), # microsecond range + } + + def test_print_sampled_stats_basic(self): + """Test basic print_sampled_stats functionality.""" + from profile.sample import print_sampled_stats + + # Capture output + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(self.mock_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Check header is present + self.assertIn("Profile Stats:", result) + self.assertIn("nsamples", result) + self.assertIn("tottime", result) + self.assertIn("cumtime", result) + + # Check functions are present + self.assertIn("func1", result) + self.assertIn("func2", result) + self.assertIn("func3", result) + + def test_print_sampled_stats_sorting(self): + """Test different sorting options.""" + from profile.sample import print_sampled_stats + + # Test sort by calls + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, sort=0, sample_interval_usec=100 + ) + + result = output.getvalue() + lines = result.strip().split("\n") + + # Find the data lines (skip header) + data_lines = [l for l in lines if "file" in l and ".py" in l] + # func3 should be first (200 calls) + self.assertIn("func3", data_lines[0]) + + # Test sort by time + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, sort=1, sample_interval_usec=100 + ) + + result = output.getvalue() + lines = result.strip().split("\n") + + data_lines = [l for l in lines if "file" in l and ".py" in l] + # func3 should be first (1.5s time) + self.assertIn("func3", data_lines[0]) + + def test_print_sampled_stats_limit(self): + """Test limiting output rows.""" + from profile.sample import print_sampled_stats + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, limit=2, sample_interval_usec=100 + ) + + result = output.getvalue() + + # Count function entries in the main stats section (not in summary) + lines = result.split("\n") + # Find where the main stats section ends (before summary) + main_section_lines = [] + for line in lines: + if "Summary of Interesting Functions:" in line: + break + main_section_lines.append(line) + + # Count function entries only in main section + func_count = sum( + 1 + for line in main_section_lines + if "func" in line and ".py" in line + ) + self.assertEqual(func_count, 2) + + def test_print_sampled_stats_time_units(self): + """Test proper time unit selection.""" + from profile.sample import print_sampled_stats + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(self.mock_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Should use seconds for the header since max time is > 1s + self.assertIn("tottime (s)", result) + self.assertIn("cumtime (s)", result) + + # Test with only microsecond-range times + micro_stats = mock.MagicMock() + micro_stats.stats = { + ("file1.py", 10, "func1"): (100, 100, 0.000005, 0.000010, {}), + } + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(micro_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Should use microseconds + self.assertIn("tottime (μs)", result) + self.assertIn("cumtime (μs)", result) + + def test_print_sampled_stats_summary(self): + """Test summary section generation.""" + from profile.sample import print_sampled_stats + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, + show_summary=True, + sample_interval_usec=100, + ) + + result = output.getvalue() + + # Check summary sections are present + self.assertIn("Summary of Interesting Functions:", result) + self.assertIn( + "Functions with Highest Direct/Cumulative Ratio (Hot Spots):", + result, + ) + self.assertIn( + "Functions with Highest Call Frequency (Indirect Calls):", result + ) + self.assertIn( + "Functions with Highest Call Magnification (Cumulative/Direct):", + result, + ) + + def test_print_sampled_stats_no_summary(self): + """Test disabling summary output.""" + from profile.sample import print_sampled_stats + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, + show_summary=False, + sample_interval_usec=100, + ) + + result = output.getvalue() + + # Summary should not be present + self.assertNotIn("Summary of Interesting Functions:", result) + + def test_print_sampled_stats_empty_stats(self): + """Test with empty stats.""" + from profile.sample import print_sampled_stats + + empty_stats = mock.MagicMock() + empty_stats.stats = {} + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(empty_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Should still print header + self.assertIn("Profile Stats:", result) + + def test_print_sampled_stats_sample_percentage_sorting(self): + """Test sample percentage sorting options.""" + from profile.sample import print_sampled_stats + + # Add a function with high sample percentage (more direct calls than func3's 200) + self.mock_stats.stats[("expensive.py", 60, "expensive_func")] = ( + 300, # direct calls (higher than func3's 200) + 300, # cumulative calls + 1.0, # total time + 1.0, # cumulative time + {}, + ) + + # Test sort by sample percentage + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, sort=3, sample_interval_usec=100 + ) # sample percentage + + result = output.getvalue() + lines = result.strip().split("\n") + + data_lines = [l for l in lines if ".py" in l and "func" in l] + # expensive_func should be first (highest sample percentage) + self.assertIn("expensive_func", data_lines[0]) + + def test_print_sampled_stats_with_recursive_calls(self): + """Test print_sampled_stats with recursive calls where nc != cc.""" + from profile.sample import print_sampled_stats + + # Create stats with recursive calls (nc != cc) + recursive_stats = mock.MagicMock() + recursive_stats.stats = { + # (direct_calls, cumulative_calls, tt, ct, callers) - recursive function + ("recursive.py", 10, "factorial"): ( + 5, # direct_calls + 10, # cumulative_calls (appears more times in stack due to recursion) + 0.5, + 0.6, + {}, + ), + ("normal.py", 20, "normal_func"): ( + 3, # direct_calls + 3, # cumulative_calls (same as direct for non-recursive) + 0.2, + 0.2, + {}, + ), + } + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(recursive_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Should display recursive calls as "5/10" format + self.assertIn("5/10", result) # nc/cc format for recursive calls + self.assertIn("3", result) # just nc for non-recursive calls + self.assertIn("factorial", result) + self.assertIn("normal_func", result) + + def test_print_sampled_stats_with_zero_call_counts(self): + """Test print_sampled_stats with zero call counts to trigger division protection.""" + from profile.sample import print_sampled_stats + + # Create stats with zero call counts + zero_stats = mock.MagicMock() + zero_stats.stats = { + ("file.py", 10, "zero_calls"): (0, 0, 0.0, 0.0, {}), # Zero calls + ("file.py", 20, "normal_func"): ( + 5, + 5, + 0.1, + 0.1, + {}, + ), # Normal function + } + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats(zero_stats, sample_interval_usec=100) + + result = output.getvalue() + + # Should handle zero call counts gracefully + self.assertIn("zero_calls", result) + self.assertIn("zero_calls", result) + self.assertIn("normal_func", result) + + def test_print_sampled_stats_sort_by_name(self): + """Test sort by function name option.""" + from profile.sample import print_sampled_stats + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + self.mock_stats, sort=-1, sample_interval_usec=100 + ) # sort by name + + result = output.getvalue() + lines = result.strip().split("\n") + + # Find the data lines (skip header and summary) + # Data lines start with whitespace and numbers, and contain filename:lineno(function) + data_lines = [] + for line in lines: + # Skip header lines and summary sections + if ( + line.startswith(" ") + and "(" in line + and ")" in line + and not line.startswith( + " 1." + ) # Skip summary lines that start with times + and not line.startswith( + " 0." + ) # Skip summary lines that start with times + and not "per call" in line # Skip summary lines + and not "calls" in line # Skip summary lines + and not "total time" in line # Skip summary lines + and not "cumulative time" in line + ): # Skip summary lines + data_lines.append(line) + + # Extract just the function names for comparison + func_names = [] + import re + + for line in data_lines: + # Function name is between the last ( and ), accounting for ANSI color codes + match = re.search(r"\(([^)]+)\)$", line) + if match: + func_name = match.group(1) + # Remove ANSI color codes + func_name = re.sub(r"\x1b\[[0-9;]*m", "", func_name) + func_names.append(func_name) + + # Verify we extracted function names and they are sorted + self.assertGreater( + len(func_names), 0, "Should have extracted some function names" + ) + self.assertEqual( + func_names, + sorted(func_names), + f"Function names {func_names} should be sorted alphabetically", + ) + + def test_print_sampled_stats_with_zero_time_functions(self): + """Test summary sections with functions that have zero time.""" + from profile.sample import print_sampled_stats + + # Create stats with zero-time functions + zero_time_stats = mock.MagicMock() + zero_time_stats.stats = { + ("file1.py", 10, "zero_time_func"): ( + 5, + 5, + 0.0, + 0.0, + {}, + ), # Zero time + ("file2.py", 20, "normal_func"): ( + 3, + 3, + 0.1, + 0.1, + {}, + ), # Normal time + } + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + zero_time_stats, + show_summary=True, + sample_interval_usec=100, + ) + + result = output.getvalue() + + # Should handle zero-time functions gracefully in summary + self.assertIn("Summary of Interesting Functions:", result) + self.assertIn("zero_time_func", result) + self.assertIn("normal_func", result) + + def test_print_sampled_stats_with_malformed_qualified_names(self): + """Test summary generation with function names that don't contain colons.""" + from profile.sample import print_sampled_stats + + # Create stats with function names that would create malformed qualified names + malformed_stats = mock.MagicMock() + malformed_stats.stats = { + # Function name without clear module separation + ("no_colon_func", 10, "func"): (3, 3, 0.1, 0.1, {}), + ("", 20, "empty_filename_func"): (2, 2, 0.05, 0.05, {}), + ("normal.py", 30, "normal_func"): (5, 5, 0.2, 0.2, {}), + } + + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + print_sampled_stats( + malformed_stats, + show_summary=True, + sample_interval_usec=100, + ) + + result = output.getvalue() + + # Should handle malformed names gracefully in summary aggregation + self.assertIn("Summary of Interesting Functions:", result) + # All function names should appear somewhere in the output + self.assertIn("func", result) + self.assertIn("empty_filename_func", result) + self.assertIn("normal_func", result) + + def test_print_sampled_stats_with_recursive_call_stats_creation(self): + """Test create_stats with recursive call data to trigger total_rec_calls branch.""" + collector = PstatsCollector(sample_interval_usec=1000000) # 1 second + + # Simulate recursive function data where total_rec_calls would be set + # We need to manually manipulate the collector result to test this branch + collector.result = { + ("recursive.py", 10, "factorial"): { + "total_rec_calls": 3, # Non-zero recursive calls + "direct_calls": 5, + "cumulative_calls": 10, + }, + ("normal.py", 20, "normal_func"): { + "total_rec_calls": 0, # Zero recursive calls + "direct_calls": 2, + "cumulative_calls": 5, + }, + } + + collector.create_stats() + + # Check that recursive calls are handled differently from non-recursive + factorial_stats = collector.stats[("recursive.py", 10, "factorial")] + normal_stats = collector.stats[("normal.py", 20, "normal_func")] + + # factorial should use cumulative_calls (10) as nc + self.assertEqual( + factorial_stats[1], 10 + ) # nc should be cumulative_calls + self.assertEqual(factorial_stats[0], 5) # cc should be direct_calls + + # normal_func should use cumulative_calls as nc + self.assertEqual(normal_stats[1], 5) # nc should be cumulative_calls + self.assertEqual(normal_stats[0], 2) # cc should be direct_calls + + +@skip_if_not_supported +@unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", +) +class TestRecursiveFunctionProfiling(unittest.TestCase): + """Test profiling of recursive functions and complex call patterns.""" + + def test_recursive_function_call_counting(self): + """Test that recursive function calls are counted correctly.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Simulate a recursive call pattern: fibonacci(5) calling itself + recursive_frames = [ + ( + 1, + [ # First sample: deep in recursion + MockFrameInfo("fib.py", 10, "fibonacci"), + MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call + MockFrameInfo( + "fib.py", 10, "fibonacci" + ), # deeper recursion + MockFrameInfo("fib.py", 10, "fibonacci"), # even deeper + MockFrameInfo("main.py", 5, "main"), # main caller + ], + ), + ( + 1, + [ # Second sample: different recursion depth + MockFrameInfo("fib.py", 10, "fibonacci"), + MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call + MockFrameInfo("main.py", 5, "main"), # main caller + ], + ), + ( + 1, + [ # Third sample: back to deeper recursion + MockFrameInfo("fib.py", 10, "fibonacci"), + MockFrameInfo("fib.py", 10, "fibonacci"), + MockFrameInfo("fib.py", 10, "fibonacci"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + ] + + for frames in recursive_frames: + collector.collect([frames]) + + collector.create_stats() + + # Check that recursive calls are counted properly + fib_key = ("fib.py", 10, "fibonacci") + main_key = ("main.py", 5, "main") + + self.assertIn(fib_key, collector.stats) + self.assertIn(main_key, collector.stats) + + # Fibonacci should have many calls due to recursion + fib_stats = collector.stats[fib_key] + direct_calls, cumulative_calls, tt, ct, callers = fib_stats + + # Should have recorded multiple calls (9 total appearances in samples) + self.assertEqual(cumulative_calls, 9) + self.assertGreater(tt, 0) # Should have some total time + self.assertGreater(ct, 0) # Should have some cumulative time + + # Main should have fewer calls + main_stats = collector.stats[main_key] + main_direct_calls, main_cumulative_calls = main_stats[0], main_stats[1] + self.assertEqual(main_direct_calls, 0) # Never directly executing + self.assertEqual(main_cumulative_calls, 3) # Appears in all 3 samples + + def test_nested_function_hierarchy(self): + """Test profiling of deeply nested function calls.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Simulate a deep call hierarchy + deep_call_frames = [ + ( + 1, + [ + MockFrameInfo("level1.py", 10, "level1_func"), + MockFrameInfo("level2.py", 20, "level2_func"), + MockFrameInfo("level3.py", 30, "level3_func"), + MockFrameInfo("level4.py", 40, "level4_func"), + MockFrameInfo("level5.py", 50, "level5_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + ( + 1, + [ # Same hierarchy sampled again + MockFrameInfo("level1.py", 10, "level1_func"), + MockFrameInfo("level2.py", 20, "level2_func"), + MockFrameInfo("level3.py", 30, "level3_func"), + MockFrameInfo("level4.py", 40, "level4_func"), + MockFrameInfo("level5.py", 50, "level5_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + ] + + for frames in deep_call_frames: + collector.collect([frames]) + + collector.create_stats() + + # All levels should be recorded + for level in range(1, 6): + key = (f"level{level}.py", level * 10, f"level{level}_func") + self.assertIn(key, collector.stats) + + stats = collector.stats[key] + direct_calls, cumulative_calls, tt, ct, callers = stats + + # Each level should appear in stack twice (2 samples) + self.assertEqual(cumulative_calls, 2) + + # Only level1 (deepest) should have direct calls + if level == 1: + self.assertEqual(direct_calls, 2) + else: + self.assertEqual(direct_calls, 0) + + # Deeper levels should have lower cumulative time than higher levels + # (since they don't include time from functions they call) + if level == 1: # Deepest level with most time + self.assertGreater(ct, 0) + + def test_alternating_call_patterns(self): + """Test profiling with alternating call patterns.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Simulate alternating execution paths + pattern_frames = [ + # Pattern A: path through func_a + ( + 1, + [ + MockFrameInfo("module.py", 10, "func_a"), + MockFrameInfo("module.py", 30, "shared_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + # Pattern B: path through func_b + ( + 1, + [ + MockFrameInfo("module.py", 20, "func_b"), + MockFrameInfo("module.py", 30, "shared_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + # Pattern A again + ( + 1, + [ + MockFrameInfo("module.py", 10, "func_a"), + MockFrameInfo("module.py", 30, "shared_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + # Pattern B again + ( + 1, + [ + MockFrameInfo("module.py", 20, "func_b"), + MockFrameInfo("module.py", 30, "shared_func"), + MockFrameInfo("main.py", 5, "main"), + ], + ), + ] + + for frames in pattern_frames: + collector.collect([frames]) + + collector.create_stats() + + # Check that both paths are recorded equally + func_a_key = ("module.py", 10, "func_a") + func_b_key = ("module.py", 20, "func_b") + shared_key = ("module.py", 30, "shared_func") + main_key = ("main.py", 5, "main") + + # func_a and func_b should each be directly executing twice + self.assertEqual(collector.stats[func_a_key][0], 2) # direct_calls + self.assertEqual(collector.stats[func_a_key][1], 2) # cumulative_calls + self.assertEqual(collector.stats[func_b_key][0], 2) # direct_calls + self.assertEqual(collector.stats[func_b_key][1], 2) # cumulative_calls + + # shared_func should appear in all samples (4 times) but never directly executing + self.assertEqual(collector.stats[shared_key][0], 0) # direct_calls + self.assertEqual(collector.stats[shared_key][1], 4) # cumulative_calls + + # main should appear in all samples but never directly executing + self.assertEqual(collector.stats[main_key][0], 0) # direct_calls + self.assertEqual(collector.stats[main_key][1], 4) # cumulative_calls + + def test_collapsed_stack_with_recursion(self): + """Test collapsed stack collector with recursive patterns.""" + collector = CollapsedStackCollector() + + # Recursive call pattern + recursive_frames = [ + ( + 1, + [ + ("factorial.py", 10, "factorial"), + ("factorial.py", 10, "factorial"), # recursive + ("factorial.py", 10, "factorial"), # deeper + ("main.py", 5, "main"), + ], + ), + ( + 1, + [ + ("factorial.py", 10, "factorial"), + ("factorial.py", 10, "factorial"), # different depth + ("main.py", 5, "main"), + ], + ), + ] + + for frames in recursive_frames: + collector.collect([frames]) + + # Should capture both call trees + self.assertEqual(len(collector.call_trees), 2) + + # First tree should be longer (deeper recursion) + tree1 = collector.call_trees[0] + tree2 = collector.call_trees[1] + + # Trees should be different lengths due to different recursion depths + self.assertNotEqual(len(tree1), len(tree2)) + + # Both should contain factorial calls + self.assertTrue(any("factorial" in str(frame) for frame in tree1)) + self.assertTrue(any("factorial" in str(frame) for frame in tree2)) + + # Function samples should count all occurrences + factorial_key = ("factorial.py", 10, "factorial") + main_key = ("main.py", 5, "main") + + # factorial appears 5 times total (3 + 2) + self.assertEqual(collector.function_samples[factorial_key], 5) + # main appears 2 times total + self.assertEqual(collector.function_samples[main_key], 2) + + +@requires_subprocess() +@skip_if_not_supported +class TestSampleProfilerIntegration(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.test_script = ''' +import time +import os + +def slow_fibonacci(n): + """Recursive fibonacci - should show up prominently in profiler.""" + if n <= 1: + return n + return slow_fibonacci(n-1) + slow_fibonacci(n-2) + +def cpu_intensive_work(): + """CPU intensive work that should show in profiler.""" + result = 0 + for i in range(10000): + result += i * i + if i % 100 == 0: + result = result % 1000000 + return result + +def medium_computation(): + """Medium complexity function.""" + result = 0 + for i in range(100): + result += i * i + return result + +def fast_loop(): + """Fast simple loop.""" + total = 0 + for i in range(50): + total += i + return total + +def nested_calls(): + """Test nested function calls.""" + def level1(): + def level2(): + return medium_computation() + return level2() + return level1() + +def main_loop(): + """Main test loop with different execution paths.""" + iteration = 0 + + while True: + iteration += 1 + + # Different execution paths - focus on CPU intensive work + if iteration % 3 == 0: + # Very CPU intensive + result = cpu_intensive_work() + elif iteration % 5 == 0: + # Expensive recursive operation + result = slow_fibonacci(12) + else: + # Medium operation + result = nested_calls() + + # No sleep - keep CPU busy + +if __name__ == "__main__": + main_loop() +''' + + def test_sampling_basic_functionality(self): + with ( + test_subprocess(self.test_script) as proc, + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=2, + sample_interval_usec=1000, # 1ms + show_summary=False, + ) + except PermissionError: + self.skipTest("Insufficient permissions for remote profiling") + + output = captured_output.getvalue() + + # Basic checks on output + self.assertIn("Captured", output) + self.assertIn("samples", output) + self.assertIn("Profile Stats", output) + + # Should see some of our test functions + self.assertIn("slow_fibonacci", output) + + def test_sampling_with_pstats_export(self): + pstats_out = tempfile.NamedTemporaryFile( + suffix=".pstats", delete=False + ) + self.addCleanup(close_and_unlink, pstats_out) + + with test_subprocess(self.test_script) as proc: + # Suppress profiler output when testing file export + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + filename=pstats_out.name, + sample_interval_usec=10000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + # Verify file was created and contains valid data + self.assertTrue(os.path.exists(pstats_out.name)) + self.assertGreater(os.path.getsize(pstats_out.name), 0) + + # Try to load the stats file + with open(pstats_out.name, "rb") as f: + stats_data = marshal.load(f) + + # Should be a dictionary with the sampled marker + self.assertIsInstance(stats_data, dict) + self.assertIn(("__sampled__",), stats_data) + self.assertTrue(stats_data[("__sampled__",)]) + + # Should have some function data + function_entries = [ + k for k in stats_data.keys() if k != ("__sampled__",) + ] + self.assertGreater(len(function_entries), 0) + + def test_sampling_with_collapsed_export(self): + collapsed_file = tempfile.NamedTemporaryFile( + suffix=".txt", delete=False + ) + self.addCleanup(close_and_unlink, collapsed_file) + + with ( + test_subprocess(self.test_script) as proc, + ): + # Suppress profiler output when testing file export + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + filename=collapsed_file.name, + output_format="collapsed", + sample_interval_usec=10000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + # Verify file was created and contains valid data + self.assertTrue(os.path.exists(collapsed_file.name)) + self.assertGreater(os.path.getsize(collapsed_file.name), 0) + + # Check file format + with open(collapsed_file.name, "r") as f: + content = f.read() + + lines = content.strip().split("\n") + self.assertGreater(len(lines), 0) + + # Each line should have format: stack_trace count + for line in lines: + parts = line.rsplit(" ", 1) + self.assertEqual(len(parts), 2) + + stack_trace, count_str = parts + self.assertGreater(len(stack_trace), 0) + self.assertTrue(count_str.isdigit()) + self.assertGreater(int(count_str), 0) + + # Stack trace should contain semicolon-separated entries + if ";" in stack_trace: + stack_parts = stack_trace.split(";") + for part in stack_parts: + # Each part should be file:function:line + self.assertIn(":", part) + + def test_sampling_all_threads(self): + with ( + test_subprocess(self.test_script) as proc, + # Suppress profiler output + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + all_threads=True, + sample_interval_usec=10000, + show_summary=False, + ) + except PermissionError: + self.skipTest("Insufficient permissions for remote profiling") + + # Just verify that sampling completed without error + # We're not testing output format here + + +@skip_if_not_supported +@unittest.skipIf( + sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, + "Test only runs on Linux with process_vm_readv support", +) +class TestSampleProfilerErrorHandling(unittest.TestCase): + def test_invalid_pid(self): + with self.assertRaises((OSError, RuntimeError)): + profile.sample.sample(-1, duration_sec=1) + + def test_process_dies_during_sampling(self): + with test_subprocess("import time; time.sleep(0.5); exit()") as proc: + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=2, # Longer than process lifetime + sample_interval_usec=50000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + output = captured_output.getvalue() + + self.assertIn("Error rate", output) + + def test_invalid_output_format(self): + with self.assertRaises(ValueError): + profile.sample.sample( + os.getpid(), + duration_sec=1, + output_format="invalid_format", + ) + + def test_invalid_output_format_with_mocked_profiler(self): + """Test invalid output format with proper mocking to avoid permission issues.""" + with mock.patch( + "profile.sample.SampleProfiler" + ) as mock_profiler_class: + mock_profiler = mock.MagicMock() + mock_profiler_class.return_value = mock_profiler + + with self.assertRaises(ValueError) as cm: + profile.sample.sample( + 12345, + duration_sec=1, + output_format="unknown_format", + ) + + # Should raise ValueError with the invalid format name + self.assertIn( + "Invalid output format: unknown_format", str(cm.exception) + ) + + def test_is_process_running(self): + with test_subprocess("import time; time.sleep(1000)") as proc: + try: + profiler = SampleProfiler(pid=proc.pid, sample_interval_usec=1000, all_threads=False) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + self.assertTrue(profiler._is_process_running()) + self.assertIsNotNone(profiler.unwinder.get_stack_trace()) + proc.kill() + proc.wait() + # ValueError on MacOS (yeah I know), ProcessLookupError on Linux and Windows + self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace) + + # Exit the context manager to ensure the process is terminated + self.assertFalse(profiler._is_process_running()) + self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace) + + @unittest.skipUnless(sys.platform == "linux", "Only valid on Linux") + def test_esrch_signal_handling(self): + with test_subprocess("import time; time.sleep(1000)") as proc: + try: + unwinder = _remote_debugging.RemoteUnwinder(proc.pid) + except PermissionError: + self.skipTest( + "Insufficient permissions to read the stack trace" + ) + initial_trace = unwinder.get_stack_trace() + self.assertIsNotNone(initial_trace) + + proc.kill() + + # Wait for the process to die and try to get another trace + proc.wait() + + with self.assertRaises(ProcessLookupError): + unwinder.get_stack_trace() + + + +class TestSampleProfilerCLI(unittest.TestCase): + def test_cli_collapsed_format_validation(self): + """Test that CLI properly validates incompatible options with collapsed format.""" + test_cases = [ + # Test sort options are invalid with collapsed + ( + ["profile.sample", "--collapsed", "--sort-nsamples", "12345"], + "sort", + ), + ( + ["profile.sample", "--collapsed", "--sort-tottime", "12345"], + "sort", + ), + ( + [ + "profile.sample", + "--collapsed", + "--sort-cumtime", + "12345", + ], + "sort", + ), + ( + [ + "profile.sample", + "--collapsed", + "--sort-sample-pct", + "12345", + ], + "sort", + ), + ( + [ + "profile.sample", + "--collapsed", + "--sort-cumul-pct", + "12345", + ], + "sort", + ), + ( + ["profile.sample", "--collapsed", "--sort-name", "12345"], + "sort", + ), + # Test limit option is invalid with collapsed + (["profile.sample", "--collapsed", "-l", "20", "12345"], "limit"), + ( + ["profile.sample", "--collapsed", "--limit", "20", "12345"], + "limit", + ), + # Test no-summary option is invalid with collapsed + ( + ["profile.sample", "--collapsed", "--no-summary", "12345"], + "summary", + ), + ] + + for test_args, expected_error_keyword in test_cases: + with ( + mock.patch("sys.argv", test_args), + mock.patch("sys.stderr", io.StringIO()) as mock_stderr, + self.assertRaises(SystemExit) as cm, + ): + profile.sample.main() + + self.assertEqual(cm.exception.code, 2) # argparse error code + error_msg = mock_stderr.getvalue() + self.assertIn("error:", error_msg) + self.assertIn("--pstats format", error_msg) + + def test_cli_default_collapsed_filename(self): + """Test that collapsed format gets a default filename when not specified.""" + test_args = ["profile.sample", "--collapsed", "12345"] + + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + # Check that filename was set to default collapsed format + mock_sample.assert_called_once() + call_args = mock_sample.call_args[1] + self.assertEqual(call_args["output_format"], "collapsed") + self.assertEqual(call_args["filename"], "collapsed.12345.txt") + + def test_cli_custom_output_filenames(self): + """Test custom output filenames for both formats.""" + test_cases = [ + ( + ["profile.sample", "--pstats", "-o", "custom.pstats", "12345"], + "custom.pstats", + "pstats", + ), + ( + ["profile.sample", "--collapsed", "-o", "custom.txt", "12345"], + "custom.txt", + "collapsed", + ), + ] + + for test_args, expected_filename, expected_format in test_cases: + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + mock_sample.assert_called_once() + call_args = mock_sample.call_args[1] + self.assertEqual(call_args["filename"], expected_filename) + self.assertEqual(call_args["output_format"], expected_format) + + def test_cli_missing_required_arguments(self): + """Test that CLI requires PID argument.""" + with ( + mock.patch("sys.argv", ["profile.sample"]), + mock.patch("sys.stderr", io.StringIO()), + ): + with self.assertRaises(SystemExit): + profile.sample.main() + + def test_cli_mutually_exclusive_format_options(self): + """Test that pstats and collapsed options are mutually exclusive.""" + with ( + mock.patch( + "sys.argv", + ["profile.sample", "--pstats", "--collapsed", "12345"], + ), + mock.patch("sys.stderr", io.StringIO()), + ): + with self.assertRaises(SystemExit): + profile.sample.main() + + def test_argument_parsing_basic(self): + test_args = ["profile.sample", "12345"] + + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + mock_sample.assert_called_once_with( + 12345, + sample_interval_usec=100, + duration_sec=10, + filename=None, + all_threads=False, + limit=15, + sort=2, + show_summary=True, + output_format="pstats", + realtime_stats=False, + ) + + def test_sort_options(self): + sort_options = [ + ("--sort-nsamples", 0), + ("--sort-tottime", 1), + ("--sort-cumtime", 2), + ("--sort-sample-pct", 3), + ("--sort-cumul-pct", 4), + ("--sort-name", -1), + ] + + for option, expected_sort_value in sort_options: + test_args = ["profile.sample", option, "12345"] + + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + mock_sample.assert_called_once() + call_args = mock_sample.call_args[1] + self.assertEqual( + call_args["sort"], + expected_sort_value, + ) + mock_sample.reset_mock() + + +if __name__ == "__main__": + unittest.main() |