diff --git a/mcoplib/profiler.py b/mcoplib/profiler.py index adfba41..aca076a 100644 --- a/mcoplib/profiler.py +++ b/mcoplib/profiler.py @@ -16,7 +16,16 @@ def _is_profiler_enabled() -> bool: def _timestamp() -> str: - return datetime.now().strftime("%Y%m%dT%H%M%S") + return datetime.now().strftime("%Y%m%dT%H%M%S%f") + + +def _trace_file_path(output_dir, func_name, rank): + safe_name = "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in func_name)[:128] + filename = ( + f"{safe_name}_trace_rank_{rank}_" + f"{_timestamp()}_pid_{os.getpid()}_tid_{threading.get_ident()}.json" + ) + return os.path.join(output_dir, filename) def _track_handler(prof, output_dir, func_name): @@ -48,8 +57,8 @@ def _track_handler(prof, output_dir, func_name): # If distributed environment is not initialized, use default value 0 rank = 0 - # Export trace to local directory - trace_path = os.path.join(output_dir, f"{func_name}_trace_rank_{rank}.json") + # Export trace to a unique file so repeated benchmark runs do not overwrite evidence. + trace_path = _trace_file_path(output_dir, func_name, rank) prof.export_chrome_trace(trace_path) print(f"Chrome trace exported to: {trace_path}") diff --git a/unit_test/test_profiler_trace_path.py b/unit_test/test_profiler_trace_path.py new file mode 100644 index 0000000..3c6518a --- /dev/null +++ b/unit_test/test_profiler_trace_path.py @@ -0,0 +1,44 @@ +import os +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from mcoplib.profiler import _trace_file_path + + +class ProfilerTracePathTest(unittest.TestCase): + @patch("mcoplib.profiler._timestamp") + def test_trace_file_path_is_unique(self, mock_timestamp): + mock_timestamp.side_effect = ["20260610T120000000001", "20260610T120000000002"] + with tempfile.TemporaryDirectory() as tmp_path: + first = _trace_file_path(tmp_path, "fused_mla", 0) + second = _trace_file_path(tmp_path, "fused_mla", 0) + + self.assertNotEqual(first, second) + self.assertTrue(first.endswith(".json")) + self.assertEqual(os.path.dirname(first), tmp_path) + + def test_trace_file_path_sanitizes_function_name(self): + with tempfile.TemporaryDirectory() as tmp_path: + path = _trace_file_path(tmp_path, "op/name with spaces", 1) + + filename = os.path.basename(path) + self.assertTrue(filename.startswith("op_name_with_spaces_trace_rank_1_")) + self.assertNotIn("/", filename) + self.assertNotIn(" ", filename) + + def test_trace_file_path_truncates_long_function_name(self): + with tempfile.TemporaryDirectory() as tmp_path: + path = _trace_file_path(tmp_path, "x" * 300, 0) + + filename = os.path.basename(path) + prefix = filename.split("_trace_rank_", 1)[0] + self.assertEqual(len(prefix), 128) + + +if __name__ == "__main__": + unittest.main()