Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions benchmark/bench_flash_mla.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import csv
import math
import random

Expand Down Expand Up @@ -484,10 +485,47 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
"flash_mla_triton",
]

shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
]
DEFAULT_BATCHES = [128]
DEFAULT_SEQLENS = [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4]
DEFAULT_HEADS = [128]


def _parse_int_list(value):
try:
return [int(item) for item in value.split(",") if item.strip()]
except ValueError as err:
raise argparse.ArgumentTypeError("expected comma-separated integers") from err
Comment on lines +493 to +497

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

如果用户传入了空字符串(例如 --batches ""),_parse_int_list 会返回一个空列表 []。这会导致后续的 make_shape_configs 生成空的配置列表,脚本将静默结束且不执行任何评测。建议在解析结果为空时抛出 argparse.ArgumentTypeError 异常,以提供更友好的错误提示。

Suggested change
def _parse_int_list(value):
try:
return [int(item) for item in value.split(",") if item.strip()]
except ValueError as err:
raise argparse.ArgumentTypeError("expected comma-separated integers") from err
def _parse_int_list(value):
try:
result = [int(item) for item in value.split(",") if item.strip()]
if not result:
raise argparse.ArgumentTypeError("expected at least one integer")
return result
except ValueError as err:
raise argparse.ArgumentTypeError("expected comma-separated integers") from err



def _parse_dtype(value):
if value == "bf16":
return torch.bfloat16
if value == "fp16":
return torch.float16
raise argparse.ArgumentTypeError("dtype must be bf16 or fp16")
Comment on lines +500 to +505

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

目前的 _parse_dtype 仅支持精确匹配小写的 "bf16""fp16"。如果用户输入了大写(如 "BF16")或完整名称(如 "bfloat16"),会直接报错。建议将其转换为小写并支持常见的别名,以提高参数解析的容错性和易用性。

Suggested change
def _parse_dtype(value):
if value == "bf16":
return torch.bfloat16
if value == "fp16":
return torch.float16
raise argparse.ArgumentTypeError("dtype must be bf16 or fp16")
def _parse_dtype(value):
val = value.lower()
if val in ("bf16", "bfloat16"):
return torch.bfloat16
if val in ("fp16", "float16"):
return torch.float16
raise argparse.ArgumentTypeError("dtype must be bf16 or fp16")



def make_shape_configs(args):
return [
{
"b": batch,
"s_q": args.s_q,
"cache_seqlens": torch.tensor(
[seqlen + args.varlen_step * i for i in range(batch)],
dtype=torch.int32,
device="cuda",
),
"h_q": head,
"h_kv": args.h_kv,
"d": args.d,
"dv": args.dv,
"causal": not args.non_causal,
"dtype": args.dtype,
}
for batch in args.batches
for seqlen in args.seqlens
for head in args.heads
]


def get_args():
Expand All @@ -497,24 +535,38 @@ def get_args():
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
parser.add_argument("--output", type=str, default=None, help="CSV output path.")
parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES)
parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS)
parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS)
parser.add_argument("--s-q", type=int, default=1)
parser.add_argument("--h-kv", type=int, default=1)
parser.add_argument("--d", type=int, default=512 + 64)
parser.add_argument("--dv", type=int, default=512)
parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16)
parser.add_argument("--varlen-step", type=int, default=2)
parser.add_argument("--non-causal", action="store_true")
Comment on lines +539 to +548

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

新增的命令行参数(如 --batches--seqlens--heads 等)没有提供 help 描述。为了提高脚本的可维护性和易用性,建议为这些参数添加详细的 help 说明,以便用户通过 python bench_flash_mla.py --help 查看参数含义和默认值。

Suggested change
parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES)
parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS)
parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS)
parser.add_argument("--s-q", type=int, default=1)
parser.add_argument("--h-kv", type=int, default=1)
parser.add_argument("--d", type=int, default=512 + 64)
parser.add_argument("--dv", type=int, default=512)
parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16)
parser.add_argument("--varlen-step", type=int, default=2)
parser.add_argument("--non-causal", action="store_true")
parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES, help="Comma-separated list of batch sizes.")
parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS, help="Comma-separated list of sequence lengths.")
parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS, help="Comma-separated list of head numbers.")
parser.add_argument("--s-q", type=int, default=1, help="Query sequence length.")
parser.add_argument("--h-kv", type=int, default=1, help="Number of KV heads.")
parser.add_argument("--d", type=int, default=512 + 64, help="Head dimension d.")
parser.add_argument("--dv", type=int, default=512, help="Head dimension dv.")
parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16, help="Data type (bf16 or fp16).")
parser.add_argument("--varlen-step", type=int, default=2, help="Step size for variable sequence lengths.")
parser.add_argument("--non-causal", action="store_true", help="Disable causal attention mask.")

args = parser.parse_args()
return args


if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
Comment on lines 554 to 555

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

如果用户在运行脚本时没有指定 --all--compare--one 中的任何一个参数(例如直接运行 python bench_flash_mla.py),脚本会默认静默运行,只生成一个仅包含表头的空 CSV 文件。这会给用户带来困惑。建议在未指定任何模式时,默认启用 --one 模式,以运行默认的 target 评测。

Suggested change
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
args = get_args()
if not (args.all or args.compare or args.one):
args.one = True
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target

with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
output_path = args.output or f"{benchmark_type}_perf.csv"
shape_configs = make_shape_configs(args)
with open(output_path, "w", newline="") as fout:
writer = csv.writer(fout)
writer.writerow(["name", "batch", "seqlen", "head", "bw"])
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
writer.writerow([target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"])
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
writer.writerow([args.baseline, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perfa:.0f}"])
writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{prefb:.0f}"])
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"])