Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions mcoplib/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def _timestamp() -> str:
return datetime.now().strftime("%Y%m%dT%H%M%S")


def _normalize_profile_count(name, value, minimum):
try:
count = int(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be an integer, got {value!r}") from exc
if count < minimum:
raise ValueError(f"{name} must be >= {minimum}, got {count}")
return count


def _track_handler(prof, output_dir, func_name):
"""
Track handler implementation that matches test.py track_handler format.
Expand Down Expand Up @@ -81,6 +91,9 @@ def f(...): ...
warmup: number of warmup calls (not profiled)
repeat: number of times to call function inside a profiler run
"""
warmup_count = _normalize_profile_count("warmup", warmup, 0)
repeat_count = _normalize_profile_count("repeat", repeat, 1)

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -98,7 +111,7 @@ def wrapper(*args, **kwargs):

# Warm-up (no profiling) to stabilize JIT / caches
try:
for _ in range(max(0, int(warmup))):
for _ in range(warmup_count):
func(*args, **kwargs)
except Exception:
# keep raising actual function exceptions
Expand All @@ -116,9 +129,9 @@ def wrapper(*args, **kwargs):
activities=activities,
schedule=torch.profiler.schedule(
wait=0,
warmup=warmup,
warmup=0,
active=1,
repeat=repeat
repeat=repeat_count
),
Comment on lines 130 to 135

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

问题分析

这里存在一个关于 PyTorch Profiler 调度器(schedule)与实际循环次数不匹配的严重逻辑问题。

在代码的前半部分(第 113-115 行),已经通过手动循环完成了 warmup_count 次的预热(Warm-up):

for _ in range(warmup_count):
    func(*args, **kwargs)

而在 torch.profiler.profile 中,设置了 warmup=warmup_count。这意味着 PyTorch Profiler 内部也期望在开始记录(ACTIVE)之前,先经历 warmup_countprof.step()

然而,在实际执行的 profiling 循环中(第 144-146 行),循环只执行了 repeat_count 次:

for _ in range(repeat_count):
    result = func(*args, **kwargs)
    prof.step()

这会导致以下问题:

  1. 无法记录数据:如果 warmup_count >= repeat_count,由于 prof.step() 只被调用了 repeat_count 次,Profiler 将永远无法到达 ACTIVE 状态,从而导致导出的 trace 没有任何性能数据。
  2. 记录次数不符:如果 repeat_count > warmup_count,前 warmup_count 次迭代在 Profiler 内部仍被视为 WARMUP 阶段,只有剩余 of repeat_count - warmup_count 次迭代才会被真正记录,这与预期的 repeat_count 次记录不符。

解决方案

既然已经在外部手动执行了预热,Profiler 内部的 schedule 应该将 warmup 设为 0。这样,Profiler 启动后的每一步都是 ACTIVE 状态,正好与 repeat_count 次循环完美匹配。

Suggested change
schedule=torch.profiler.schedule(
wait=0,
warmup=warmup,
warmup=warmup_count,
active=1,
repeat=repeat
repeat=repeat_count
),
schedule=torch.profiler.schedule(
wait=0,
warmup=0,
active=1,
repeat=repeat_count
),

on_trace_ready=lambda prof: _track_handler(prof, output_dir, func.__name__),
with_modules=True,
Expand All @@ -128,7 +141,7 @@ def wrapper(*args, **kwargs):

# Run the function the specified number of times
result = None
for _ in range(max(1, int(repeat))):
for _ in range(repeat_count):
result = func(*args, **kwargs)
prof.step() # Step the profiler

Expand Down
87 changes: 87 additions & 0 deletions unit_test/test_profiler_count_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import types
import sys
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from mcoplib.profiler import _normalize_profile_count, profiler


class ProfilerCountValidationTest(unittest.TestCase):
def test_normalize_profile_count_accepts_numeric_strings(self):
self.assertEqual(_normalize_profile_count("repeat", "3", 1), 3)

def test_normalize_profile_count_rejects_invalid_values(self):
with self.assertRaisesRegex(ValueError, "warmup must be an integer"):
_normalize_profile_count("warmup", "bad", 0)

def test_normalize_profile_count_rejects_values_below_minimum(self):
with self.assertRaisesRegex(ValueError, "repeat must be >= 1"):
_normalize_profile_count("repeat", 0, 1)

def test_profiler_validates_counts_when_decorator_is_created(self):
with self.assertRaisesRegex(ValueError, "repeat must be >= 1"):
profiler(repeat=0)

def test_profiler_schedule_uses_zero_internal_warmup(self):
schedule_kwargs = {}
fake_torch = types.ModuleType("torch")
fake_torch.cuda = types.SimpleNamespace(is_available=lambda: False)

profiler_module = types.ModuleType("torch.profiler")

def fake_schedule(**kwargs):
schedule_kwargs.update(kwargs)
return kwargs

class FakeProfile:
def __init__(self, **kwargs):
self.kwargs = kwargs

def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
return False

def step(self):
return None

profiler_module.profile = lambda **kwargs: FakeProfile(**kwargs)
profiler_module.schedule = fake_schedule
profiler_module.ProfilerActivity = types.SimpleNamespace(CPU="cpu", CUDA="cuda")
fake_torch.profiler = profiler_module

previous_torch = sys.modules.get("torch")
previous_profiler = sys.modules.get("torch.profiler")
sys.modules["torch"] = fake_torch
sys.modules["torch.profiler"] = profiler_module
try:
calls = []

@profiler(warmup=2, repeat=3)
def sample():
calls.append("run")
return "ok"

self.assertEqual(sample(), "ok")
finally:
if previous_torch is None:
sys.modules.pop("torch", None)
else:
sys.modules["torch"] = previous_torch
if previous_profiler is None:
sys.modules.pop("torch.profiler", None)
else:
sys.modules["torch.profiler"] = previous_profiler

self.assertEqual(schedule_kwargs["wait"], 0)
self.assertEqual(schedule_kwargs["warmup"], 0)
self.assertEqual(schedule_kwargs["active"], 1)
self.assertEqual(schedule_kwargs["repeat"], 3)
self.assertEqual(len(calls), 5)


if __name__ == "__main__":
unittest.main()