Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tools/op_complexity_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""Summarize CUDA/MACA operator source complexity for validation planning."""

from __future__ import annotations

import argparse
import json
from pathlib import Path


SOURCE_SUFFIXES = {".cu", ".cuh", ".cpp", ".h"}


def analyze_file(path: Path, root: Path) -> dict[str, object]:
text = path.read_text(encoding="utf-8", errors="replace")
return {
"path": path.relative_to(root).as_posix(),
"lines": len(text.splitlines()),
"kernel_launches": text.count("<<<"),
"templates": text.count("template"),
"torch_bindings": text.count("PYBIND11_MODULE"),
}


def build_report(root: Path) -> dict[str, object]:
op_dir = root / "op"
if not op_dir.is_dir():
return {"file_count": 0, "total_lines": 0, "top_by_lines": []}

files = [
analyze_file(path, root)
for path in sorted(op_dir.rglob("*"))
if path.is_file() and path.suffix in SOURCE_SUFFIXES
]
return {
"file_count": len(files),
"total_lines": sum(item["lines"] for item in files),
"top_by_lines": sorted(files, key=lambda item: item["lines"], reverse=True)[:20],
}
Comment on lines +25 to +39

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

当指定的 root 目录下不存在 op 子目录时,对 rglob 结果进行排序和迭代会抛出 FileNotFoundError 异常。为了提高脚本的健壮性,建议在构建报告前先检查 op 目录是否存在。如果不存在,可以直接返回一个空的报告结构,避免程序崩溃。此外,item["lines"] 本身已经是整型,无需在 sumsorted 中重复调用 int() 进行类型转换。

Suggested change
def build_report(root: Path) -> dict[str, object]:
files = [
analyze_file(path, root)
for path in sorted((root / "op").rglob("*"))
if path.is_file() and path.suffix in SOURCE_SUFFIXES
]
return {
"file_count": len(files),
"total_lines": sum(int(item["lines"]) for item in files),
"top_by_lines": sorted(files, key=lambda item: int(item["lines"]), reverse=True)[:20],
}
def build_report(root: Path) -> dict[str, object]:
op_dir = root / "op"
if not op_dir.is_dir():
return {
"file_count": 0,
"total_lines": 0,
"top_by_lines": [],
}
files = [
analyze_file(path, root)
for path in sorted(op_dir.rglob("*"))
if path.is_file() and path.suffix in SOURCE_SUFFIXES
]
return {
"file_count": len(files),
"total_lines": sum(item["lines"] for item in files),
"top_by_lines": sorted(files, key=lambda item: item["lines"], reverse=True)[:20],
}



def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--root", type=Path, default=Path.cwd(), help="repository root")
parser.add_argument("--output", type=Path, help="write JSON report to this path")
args = parser.parse_args()

text = json.dumps(build_report(args.root), indent=2, ensure_ascii=False)
if args.output:
args.output.write_text(text + "\n", encoding="utf-8")
else:
print(text)
return 0


if __name__ == "__main__":
raise SystemExit(main())
29 changes: 29 additions & 0 deletions unit_test/test_op_complexity_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import tempfile
import unittest
from pathlib import Path

from tools.op_complexity_report import build_report


class OpComplexityReportTest(unittest.TestCase):
def test_counts_operator_sources(self):
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
op = root / "op"
op.mkdir()
(op / "kernel.cu").write_text("template <typename T>\nvoid f(){ k<<<1,1>>>(); }\n", encoding="utf-8")

report = build_report(root)

self.assertEqual(report["file_count"], 1)
self.assertEqual(report["top_by_lines"][0]["kernel_launches"], 1)

def test_returns_empty_report_when_op_directory_is_missing(self):
with tempfile.TemporaryDirectory() as tmpdir:
report = build_report(Path(tmpdir))

self.assertEqual(report, {"file_count": 0, "total_lines": 0, "top_by_lines": []})


if __name__ == "__main__":
unittest.main()