From c11f0fde7469b34a88da8e702b9d4e362efc5509 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Tue, 9 Jun 2026 11:12:37 +0800 Subject: [PATCH 1/2] Validate profiler warmup and repeat counts --- mcoplib/profiler.py | 21 +++++++++++++--- unit_test/test_profiler_count_validation.py | 28 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 unit_test/test_profiler_count_validation.py diff --git a/mcoplib/profiler.py b/mcoplib/profiler.py index adfba41..44d0482 100644 --- a/mcoplib/profiler.py +++ b/mcoplib/profiler.py @@ -19,6 +19,16 @@ def _timestamp() -> str: return datetime.now().strftime("%Y%m%dT%H%M%S") +def _normalize_profile_count(name, value, minimum): + try: + count = int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"{name} must be an integer, got {value!r}") from exc + if count < minimum: + raise ValueError(f"{name} must be >= {minimum}, got {count}") + return count + + def _track_handler(prof, output_dir, func_name): """ Track handler implementation that matches test.py track_handler format. @@ -81,6 +91,9 @@ def f(...): ... warmup: number of warmup calls (not profiled) repeat: number of times to call function inside a profiler run """ + warmup_count = _normalize_profile_count("warmup", warmup, 0) + repeat_count = _normalize_profile_count("repeat", repeat, 1) + def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -98,7 +111,7 @@ def wrapper(*args, **kwargs): # Warm-up (no profiling) to stabilize JIT / caches try: - for _ in range(max(0, int(warmup))): + for _ in range(warmup_count): func(*args, **kwargs) except Exception: # keep raising actual function exceptions @@ -116,9 +129,9 @@ def wrapper(*args, **kwargs): activities=activities, schedule=torch.profiler.schedule( wait=0, - warmup=warmup, + warmup=warmup_count, active=1, - repeat=repeat + repeat=repeat_count ), on_trace_ready=lambda prof: _track_handler(prof, output_dir, func.__name__), with_modules=True, @@ -128,7 +141,7 @@ def wrapper(*args, **kwargs): # Run the function the specified number of times result = None - for _ in range(max(1, int(repeat))): + for _ in range(repeat_count): result = func(*args, **kwargs) prof.step() # Step the profiler diff --git a/unit_test/test_profiler_count_validation.py b/unit_test/test_profiler_count_validation.py new file mode 100644 index 0000000..62f9546 --- /dev/null +++ b/unit_test/test_profiler_count_validation.py @@ -0,0 +1,28 @@ +import sys +import unittest +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from mcoplib.profiler import _normalize_profile_count, profiler + + +class ProfilerCountValidationTest(unittest.TestCase): + def test_normalize_profile_count_accepts_numeric_strings(self): + self.assertEqual(_normalize_profile_count("repeat", "3", 1), 3) + + def test_normalize_profile_count_rejects_invalid_values(self): + with self.assertRaisesRegex(ValueError, "warmup must be an integer"): + _normalize_profile_count("warmup", "bad", 0) + + def test_normalize_profile_count_rejects_values_below_minimum(self): + with self.assertRaisesRegex(ValueError, "repeat must be >= 1"): + _normalize_profile_count("repeat", 0, 1) + + def test_profiler_validates_counts_when_decorator_is_created(self): + with self.assertRaisesRegex(ValueError, "repeat must be >= 1"): + profiler(repeat=0) + + +if __name__ == "__main__": + unittest.main() From 0e2e01c816d2f7f989736f84aed466a5a5aff0b4 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Thu, 11 Jun 2026 00:32:50 +0800 Subject: [PATCH 2/2] Align profiler schedule with manual warmup --- mcoplib/profiler.py | 2 +- unit_test/test_profiler_count_validation.py | 59 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mcoplib/profiler.py b/mcoplib/profiler.py index 44d0482..62649d9 100644 --- a/mcoplib/profiler.py +++ b/mcoplib/profiler.py @@ -129,7 +129,7 @@ def wrapper(*args, **kwargs): activities=activities, schedule=torch.profiler.schedule( wait=0, - warmup=warmup_count, + warmup=0, active=1, repeat=repeat_count ), diff --git a/unit_test/test_profiler_count_validation.py b/unit_test/test_profiler_count_validation.py index 62f9546..d38d17a 100644 --- a/unit_test/test_profiler_count_validation.py +++ b/unit_test/test_profiler_count_validation.py @@ -1,3 +1,4 @@ +import types import sys import unittest from pathlib import Path @@ -23,6 +24,64 @@ def test_profiler_validates_counts_when_decorator_is_created(self): with self.assertRaisesRegex(ValueError, "repeat must be >= 1"): profiler(repeat=0) + def test_profiler_schedule_uses_zero_internal_warmup(self): + schedule_kwargs = {} + fake_torch = types.ModuleType("torch") + fake_torch.cuda = types.SimpleNamespace(is_available=lambda: False) + + profiler_module = types.ModuleType("torch.profiler") + + def fake_schedule(**kwargs): + schedule_kwargs.update(kwargs) + return kwargs + + class FakeProfile: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def step(self): + return None + + profiler_module.profile = lambda **kwargs: FakeProfile(**kwargs) + profiler_module.schedule = fake_schedule + profiler_module.ProfilerActivity = types.SimpleNamespace(CPU="cpu", CUDA="cuda") + fake_torch.profiler = profiler_module + + previous_torch = sys.modules.get("torch") + previous_profiler = sys.modules.get("torch.profiler") + sys.modules["torch"] = fake_torch + sys.modules["torch.profiler"] = profiler_module + try: + calls = [] + + @profiler(warmup=2, repeat=3) + def sample(): + calls.append("run") + return "ok" + + self.assertEqual(sample(), "ok") + finally: + if previous_torch is None: + sys.modules.pop("torch", None) + else: + sys.modules["torch"] = previous_torch + if previous_profiler is None: + sys.modules.pop("torch.profiler", None) + else: + sys.modules["torch.profiler"] = previous_profiler + + self.assertEqual(schedule_kwargs["wait"], 0) + self.assertEqual(schedule_kwargs["warmup"], 0) + self.assertEqual(schedule_kwargs["active"], 1) + self.assertEqual(schedule_kwargs["repeat"], 3) + self.assertEqual(len(calls), 5) + if __name__ == "__main__": unittest.main()