Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion benchmark/bench_flash_mla.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import json
import math
import os
import random
import subprocess
import sys
from pathlib import Path

import flashinfer
import torch
Expand Down Expand Up @@ -497,13 +502,65 @@ def get_args():
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
parser.add_argument(
"--metadata-json",
type=str,
default=None,
help="Write benchmark environment metadata to this JSON file.",
)
args = parser.parse_args()
return args


def _run_command(command):
try:
return subprocess.check_output(
command, stderr=subprocess.STDOUT, text=True
).strip()
except (OSError, subprocess.CalledProcessError) as err:
return str(err)
Comment on lines +515 to +521

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

_run_command 中,当命令执行失败(例如未安装 git,或者当前目录不是 git 仓库)时,返回原始的异常字符串(如 Command '['git', 'rev-parse', 'HEAD']' returned non-zero exit status 128.)会导致输出的 JSON 元数据显得杂乱且不易解析。建议在发生异常时直接返回 None,这样在 JSON 中会呈现为 null,更加整洁规范。

Suggested change
def _run_command(command):
try:
return subprocess.check_output(
command, stderr=subprocess.STDOUT, text=True
).strip()
except (OSError, subprocess.CalledProcessError) as err:
return str(err)
def _run_command(command):
try:
return subprocess.check_output(
command, stderr=subprocess.STDOUT, text=True
).strip()
except (OSError, subprocess.CalledProcessError):
return None



def collect_benchmark_metadata():
env_keys = [
"MACA_PATH",
"CUDA_PATH",
"MACA_CLANG_PATH",
"LD_LIBRARY_PATH",
"PYTORCH_CUDA_ALLOC_CONF",
]
metadata = {
"python": sys.version.replace("\n", " "),
"torch": torch.__version__,
"torch_cuda": torch.version.cuda,
"triton": getattr(triton, "__version__", "unknown"),
"cuda_available": torch.cuda.is_available(),
"env": {key: os.environ.get(key) for key in env_keys if os.environ.get(key)},
"git_commit": _run_command(["git", "rev-parse", "HEAD"]),
}
if torch.cuda.is_available():
device = torch.device("cuda:0")
metadata["device_name"] = torch.cuda.get_device_name(device)
metadata["device_count"] = torch.cuda.device_count()
maca_path = os.environ.get("MACA_PATH")
if maca_path:
version_file = Path(maca_path) / "Version.txt"
if version_file.is_file():
metadata["maca_version_txt"] = version_file.read_text(encoding="utf-8").strip()
return metadata


def write_benchmark_metadata(path):
with open(path, "w", encoding="utf-8") as fout:
json.dump(collect_benchmark_metadata(), fout, indent=2, sort_keys=True)
fout.write("\n")
Comment on lines +553 to +556

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

write_benchmark_metadata 中,如果用户为 --metadata-json 指定了一个包含多级目录的路径(例如 output/metadata.json),且该父目录不存在,open(path, "w") 将会抛出 FileNotFoundError 导致基准测试中断。建议在写入前使用 path.parent.mkdir(parents=True, exist_ok=True) 自动创建父目录,以提高脚本的健壮性。

Suggested change
def write_benchmark_metadata(path):
with open(path, "w", encoding="utf-8") as fout:
json.dump(collect_benchmark_metadata(), fout, indent=2, sort_keys=True)
fout.write("\n")
def write_benchmark_metadata(path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as fout:
json.dump(collect_benchmark_metadata(), fout, indent=2, sort_keys=True)
fout.write("\n")



if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
if args.metadata_json:
write_benchmark_metadata(args.metadata_json)
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
Expand All @@ -517,4 +574,4 @@ def get_args():
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')