Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tools/op_source_inventory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3
"""Build a JSON inventory of mcoplib operator source groups."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

GROUPS = {
"vllm": "op/vllm",
"sglang": "op/sglang",
"lmdeploy": "op/lmdeploy",
"cv": "op/cv",
"native": "op",
}

SOURCE_SUFFIXES = {".cu", ".cuh", ".cpp", ".cc", ".h", ".hpp", ".py"}


def _sources(root: Path, relative_dir: str) -> list[str]:
base = root / relative_dir
if not base.exists():
return []
return sorted(
path.relative_to(root).as_posix()
for path in base.rglob("*")
if path.is_file() and path.suffix in SOURCE_SUFFIXES
)


def build_inventory(root: Path) -> dict[str, object]:
raw_files: dict[str, list[str]] = {}
for name, relative_dir in GROUPS.items():
raw_files[name] = _sources(root, relative_dir)

other_files: set[str] = set()
for name in GROUPS:
if name != "native":
other_files.update(raw_files[name])
if "native" in raw_files:
raw_files["native"] = [path for path in raw_files["native"] if path not in other_files]

groups: dict[str, object] = {}
for name, relative_dir in GROUPS.items():
files = raw_files[name]
groups[name] = {"root": relative_dir, "count": len(files), "files": files}
return {"root": str(root), "groups": groups}
Comment on lines +32 to +48

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在当前的实现中,GROUPS 中的 "native" 对应的路径是 "op"。由于 _sources 函数内部使用了 rglob("*") 进行递归搜索,这会导致 "op/vllm""op/sglang" 等子目录下的所有源文件也被重复统计到 "native" 分组中。这不仅导致数据冗余,也使得 "native" 分组的计数和文件列表不准确。

建议在构建清单时,从 "native" 分组中排除已被其他更具体的分组(如 vllm, sglang, lmdeploy, cv)包含的文件。

def build_inventory(root: Path) -> dict[str, object]:
    raw_files: dict[str, list[str]] = {}
    for name, relative_dir in GROUPS.items():
        raw_files[name] = _sources(root, relative_dir)

    # 排除其他特定分组中已包含的文件,避免在 "native" 中重复统计
    other_files = set()
    for name in GROUPS:
        if name != "native":
            other_files.update(raw_files[name])

    if "native" in raw_files:
        raw_files["native"] = [f for f in raw_files["native"] if f not in other_files]

    groups: dict[str, object] = {}
    for name, relative_dir in GROUPS.items():
        files = raw_files[name]
        groups[name] = {"root": relative_dir, "count": len(files), "files": files}
    return {"root": str(root), "groups": groups}



def main() -> int:
parser = argparse.ArgumentParser(description="Create mcoplib operator source inventory.")
parser.add_argument("--root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--output", type=Path)
args = parser.parse_args()

payload = build_inventory(args.root.resolve())
text = json.dumps(payload, indent=2, sort_keys=True)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(text + "\n", encoding="utf-8")
else:
print(text)
return 0


if __name__ == "__main__":
raise SystemExit(main())
30 changes: 30 additions & 0 deletions unit_test/test_op_source_inventory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tempfile
import unittest
from pathlib import Path

import sys

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from tools.op_source_inventory import build_inventory


class OpSourceInventoryTest(unittest.TestCase):
def test_inventory_counts_group_sources(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "op" / "vllm").mkdir(parents=True)
(root / "op" / "vllm" / "kernel.cu").write_text("", encoding="utf-8")
(root / "op" / "vllm" / "README.md").write_text("", encoding="utf-8")
(root / "op" / "native_kernel.cu").write_text("", encoding="utf-8")

inventory = build_inventory(root)

self.assertEqual(inventory["groups"]["vllm"]["count"], 1)
self.assertEqual(inventory["groups"]["vllm"]["files"], ["op/vllm/kernel.cu"])
Comment on lines +13 to +24

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了配合对 native 分组排重逻辑的修改,建议在单元测试中增加对 native 分组的断言,确保其不会错误地包含其他子分组(如 vllm)的文件,并且能正确统计属于 native 自身的源文件。

Suggested change
def test_inventory_counts_group_sources(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "op" / "vllm").mkdir(parents=True)
(root / "op" / "vllm" / "kernel.cu").write_text("", encoding="utf-8")
(root / "op" / "vllm" / "README.md").write_text("", encoding="utf-8")
inventory = build_inventory(root)
self.assertEqual(inventory["groups"]["vllm"]["count"], 1)
self.assertEqual(inventory["groups"]["vllm"]["files"], ["op/vllm/kernel.cu"])
def test_inventory_counts_group_sources(self):
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
(root / "op" / "vllm").mkdir(parents=True)
(root / "op" / "vllm" / "kernel.cu").write_text("", encoding="utf-8")
(root / "op" / "vllm" / "README.md").write_text("", encoding="utf-8")
(root / "op" / "native_kernel.cu").write_text("", encoding="utf-8")
inventory = build_inventory(root)
self.assertEqual(inventory["groups"]["vllm"]["count"], 1)
self.assertEqual(inventory["groups"]["vllm"]["files"], ["op/vllm/kernel.cu"])
self.assertEqual(inventory["groups"]["native"]["count"], 1)
self.assertEqual(inventory["groups"]["native"]["files"], ["op/native_kernel.cu"])

self.assertEqual(inventory["groups"]["native"]["count"], 1)
self.assertEqual(inventory["groups"]["native"]["files"], ["op/native_kernel.cu"])


if __name__ == "__main__":
unittest.main()