From f661f267d7254e05fc98cb8ab613e08fe9bf2f05 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Sun, 7 Jun 2026 16:28:01 +0800 Subject: [PATCH] benchmark: make FlashMLA shapes configurable --- benchmark/bench_flash_mla.py | 72 +++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 2b59f8ed..0562f89a 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -1,5 +1,6 @@ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a import argparse +import csv import math import random @@ -484,10 +485,47 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [ - {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} - for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] -] +DEFAULT_BATCHES = [128] +DEFAULT_SEQLENS = [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4] +DEFAULT_HEADS = [128] + + +def _parse_int_list(value): + try: + return [int(item) for item in value.split(",") if item.strip()] + except ValueError as err: + raise argparse.ArgumentTypeError("expected comma-separated integers") from err + + +def _parse_dtype(value): + if value == "bf16": + return torch.bfloat16 + if value == "fp16": + return torch.float16 + raise argparse.ArgumentTypeError("dtype must be bf16 or fp16") + + +def make_shape_configs(args): + return [ + { + "b": batch, + "s_q": args.s_q, + "cache_seqlens": torch.tensor( + [seqlen + args.varlen_step * i for i in range(batch)], + dtype=torch.int32, + device="cuda", + ), + "h_q": head, + "h_kv": args.h_kv, + "d": args.d, + "dv": args.dv, + "causal": not args.non_causal, + "dtype": args.dtype, + } + for batch in args.batches + for seqlen in args.seqlens + for head in args.heads + ] def get_args(): @@ -497,6 +535,17 @@ def get_args(): parser.add_argument("--all", action="store_true") parser.add_argument("--one", action="store_true") parser.add_argument("--compare", action="store_true") + parser.add_argument("--output", type=str, default=None, help="CSV output path.") + parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES) + parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS) + parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS) + parser.add_argument("--s-q", type=int, default=1) + parser.add_argument("--h-kv", type=int, default=1) + parser.add_argument("--d", type=int, default=512 + 64) + parser.add_argument("--dv", type=int, default=512) + parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16) + parser.add_argument("--varlen-step", type=int, default=2) + parser.add_argument("--non-causal", action="store_true") args = parser.parse_args() return args @@ -504,17 +553,20 @@ def get_args(): if __name__ == "__main__": args = get_args() benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target - with open(f"{benchmark_type}_perf.csv", "w") as fout: - fout.write("name,batch,seqlen,head,bw\n") + output_path = args.output or f"{benchmark_type}_perf.csv" + shape_configs = make_shape_configs(args) + with open(output_path, "w", newline="") as fout: + writer = csv.writer(fout) + writer.writerow(["name", "batch", "seqlen", "head", "bw"]) for shape in shape_configs: if args.all: for target in available_targets: perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') + writer.writerow([target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"]) elif args.compare: perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') - fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') + writer.writerow([args.baseline, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perfa:.0f}"]) + writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{prefb:.0f}"]) elif args.one: perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file + writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"])