From 7e67447979ab799aea0913bcb77ef26dcd21fbac Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Sun, 7 Jun 2026 16:29:35 +0800 Subject: [PATCH] benchmark: record environment metadata --- benchmark/bench_flash_mla.py | 59 +++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 2b59f8ed..cb407ae6 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -1,7 +1,12 @@ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a import argparse +import json import math +import os import random +import subprocess +import sys +from pathlib import Path import flashinfer import torch @@ -497,13 +502,65 @@ def get_args(): parser.add_argument("--all", action="store_true") parser.add_argument("--one", action="store_true") parser.add_argument("--compare", action="store_true") + parser.add_argument( + "--metadata-json", + type=str, + default=None, + help="Write benchmark environment metadata to this JSON file.", + ) args = parser.parse_args() return args +def _run_command(command): + try: + return subprocess.check_output( + command, stderr=subprocess.STDOUT, text=True + ).strip() + except (OSError, subprocess.CalledProcessError) as err: + return str(err) + + +def collect_benchmark_metadata(): + env_keys = [ + "MACA_PATH", + "CUDA_PATH", + "MACA_CLANG_PATH", + "LD_LIBRARY_PATH", + "PYTORCH_CUDA_ALLOC_CONF", + ] + metadata = { + "python": sys.version.replace("\n", " "), + "torch": torch.__version__, + "torch_cuda": torch.version.cuda, + "triton": getattr(triton, "__version__", "unknown"), + "cuda_available": torch.cuda.is_available(), + "env": {key: os.environ.get(key) for key in env_keys if os.environ.get(key)}, + "git_commit": _run_command(["git", "rev-parse", "HEAD"]), + } + if torch.cuda.is_available(): + device = torch.device("cuda:0") + metadata["device_name"] = torch.cuda.get_device_name(device) + metadata["device_count"] = torch.cuda.device_count() + maca_path = os.environ.get("MACA_PATH") + if maca_path: + version_file = Path(maca_path) / "Version.txt" + if version_file.is_file(): + metadata["maca_version_txt"] = version_file.read_text(encoding="utf-8").strip() + return metadata + + +def write_benchmark_metadata(path): + with open(path, "w", encoding="utf-8") as fout: + json.dump(collect_benchmark_metadata(), fout, indent=2, sort_keys=True) + fout.write("\n") + + if __name__ == "__main__": args = get_args() benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + if args.metadata_json: + write_benchmark_metadata(args.metadata_json) with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: @@ -517,4 +574,4 @@ def get_args(): fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')