-
Notifications
You must be signed in to change notification settings - Fork 3
支持配置化 benchmark 形状 #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||
| # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a | ||||||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||||||
| import csv | ||||||||||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||||||||||
| import random | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -484,10 +485,47 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): | |||||||||||||||||||||||||||||||||||||||||
| "flash_mla_triton", | ||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shape_configs = [ | ||||||||||||||||||||||||||||||||||||||||||
| {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} | ||||||||||||||||||||||||||||||||||||||||||
| for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] | ||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_BATCHES = [128] | ||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_SEQLENS = [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4] | ||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_HEADS = [128] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _parse_int_list(value): | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| return [int(item) for item in value.split(",") if item.strip()] | ||||||||||||||||||||||||||||||||||||||||||
| except ValueError as err: | ||||||||||||||||||||||||||||||||||||||||||
| raise argparse.ArgumentTypeError("expected comma-separated integers") from err | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _parse_dtype(value): | ||||||||||||||||||||||||||||||||||||||||||
| if value == "bf16": | ||||||||||||||||||||||||||||||||||||||||||
| return torch.bfloat16 | ||||||||||||||||||||||||||||||||||||||||||
| if value == "fp16": | ||||||||||||||||||||||||||||||||||||||||||
| return torch.float16 | ||||||||||||||||||||||||||||||||||||||||||
| raise argparse.ArgumentTypeError("dtype must be bf16 or fp16") | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+500
to
+505
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前的
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def make_shape_configs(args): | ||||||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||
| "b": batch, | ||||||||||||||||||||||||||||||||||||||||||
| "s_q": args.s_q, | ||||||||||||||||||||||||||||||||||||||||||
| "cache_seqlens": torch.tensor( | ||||||||||||||||||||||||||||||||||||||||||
| [seqlen + args.varlen_step * i for i in range(batch)], | ||||||||||||||||||||||||||||||||||||||||||
| dtype=torch.int32, | ||||||||||||||||||||||||||||||||||||||||||
| device="cuda", | ||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||
| "h_q": head, | ||||||||||||||||||||||||||||||||||||||||||
| "h_kv": args.h_kv, | ||||||||||||||||||||||||||||||||||||||||||
| "d": args.d, | ||||||||||||||||||||||||||||||||||||||||||
| "dv": args.dv, | ||||||||||||||||||||||||||||||||||||||||||
| "causal": not args.non_causal, | ||||||||||||||||||||||||||||||||||||||||||
| "dtype": args.dtype, | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| for batch in args.batches | ||||||||||||||||||||||||||||||||||||||||||
| for seqlen in args.seqlens | ||||||||||||||||||||||||||||||||||||||||||
| for head in args.heads | ||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def get_args(): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -497,24 +535,38 @@ def get_args(): | |||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--all", action="store_true") | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--one", action="store_true") | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--compare", action="store_true") | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--output", type=str, default=None, help="CSV output path.") | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--s-q", type=int, default=1) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--h-kv", type=int, default=1) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--d", type=int, default=512 + 64) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--dv", type=int, default=512) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--varlen-step", type=int, default=2) | ||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--non-causal", action="store_true") | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+539
to
+548
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新增的命令行参数(如
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||
| return args | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||
| args = get_args() | ||||||||||||||||||||||||||||||||||||||||||
| benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
554
to
555
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果用户在运行脚本时没有指定
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| with open(f"{benchmark_type}_perf.csv", "w") as fout: | ||||||||||||||||||||||||||||||||||||||||||
| fout.write("name,batch,seqlen,head,bw\n") | ||||||||||||||||||||||||||||||||||||||||||
| output_path = args.output or f"{benchmark_type}_perf.csv" | ||||||||||||||||||||||||||||||||||||||||||
| shape_configs = make_shape_configs(args) | ||||||||||||||||||||||||||||||||||||||||||
| with open(output_path, "w", newline="") as fout: | ||||||||||||||||||||||||||||||||||||||||||
| writer = csv.writer(fout) | ||||||||||||||||||||||||||||||||||||||||||
| writer.writerow(["name", "batch", "seqlen", "head", "bw"]) | ||||||||||||||||||||||||||||||||||||||||||
| for shape in shape_configs: | ||||||||||||||||||||||||||||||||||||||||||
| if args.all: | ||||||||||||||||||||||||||||||||||||||||||
| for target in available_targets: | ||||||||||||||||||||||||||||||||||||||||||
| perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) | ||||||||||||||||||||||||||||||||||||||||||
| fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') | ||||||||||||||||||||||||||||||||||||||||||
| writer.writerow([target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"]) | ||||||||||||||||||||||||||||||||||||||||||
| elif args.compare: | ||||||||||||||||||||||||||||||||||||||||||
| perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) | ||||||||||||||||||||||||||||||||||||||||||
| fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') | ||||||||||||||||||||||||||||||||||||||||||
| fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') | ||||||||||||||||||||||||||||||||||||||||||
| writer.writerow([args.baseline, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perfa:.0f}"]) | ||||||||||||||||||||||||||||||||||||||||||
| writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{prefb:.0f}"]) | ||||||||||||||||||||||||||||||||||||||||||
| elif args.one: | ||||||||||||||||||||||||||||||||||||||||||
| perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) | ||||||||||||||||||||||||||||||||||||||||||
| fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') | ||||||||||||||||||||||||||||||||||||||||||
| writer.writerow([args.target, shape["b"], f'{shape["cache_seqlens"].float().mean().cpu().item():.0f}', shape["h_q"], f"{perf:.0f}"]) | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果用户传入了空字符串(例如
--batches ""),_parse_int_list会返回一个空列表[]。这会导致后续的make_shape_configs生成空的配置列表,脚本将静默结束且不执行任何评测。建议在解析结果为空时抛出argparse.ArgumentTypeError异常,以提供更友好的错误提示。