支持配置化 benchmark 形状#16
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the benchmarking script benchmark/bench_flash_mla.py to support dynamic shape configurations and robust CSV output writing. It introduces several command-line arguments (such as --batches, --seqlens, --heads, and --dtype) along with custom parsing helpers, and replaces manual CSV string formatting with Python's standard csv module. The review feedback suggests several improvements to enhance robustness and usability: defaulting to --one mode if no mode is specified to avoid empty outputs, raising errors for empty list inputs, supporting case-insensitive and alias-based dtype parsing, and adding descriptive help messages to the new command-line arguments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| args = get_args() | ||
| benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target |
There was a problem hiding this comment.
如果用户在运行脚本时没有指定 --all、--compare 或 --one 中的任何一个参数(例如直接运行 python bench_flash_mla.py),脚本会默认静默运行,只生成一个仅包含表头的空 CSV 文件。这会给用户带来困惑。建议在未指定任何模式时,默认启用 --one 模式,以运行默认的 target 评测。
| args = get_args() | |
| benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target | |
| args = get_args() | |
| if not (args.all or args.compare or args.one): | |
| args.one = True | |
| benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target |
| def _parse_int_list(value): | ||
| try: | ||
| return [int(item) for item in value.split(",") if item.strip()] | ||
| except ValueError as err: | ||
| raise argparse.ArgumentTypeError("expected comma-separated integers") from err |
There was a problem hiding this comment.
如果用户传入了空字符串(例如 --batches ""),_parse_int_list 会返回一个空列表 []。这会导致后续的 make_shape_configs 生成空的配置列表,脚本将静默结束且不执行任何评测。建议在解析结果为空时抛出 argparse.ArgumentTypeError 异常,以提供更友好的错误提示。
| def _parse_int_list(value): | |
| try: | |
| return [int(item) for item in value.split(",") if item.strip()] | |
| except ValueError as err: | |
| raise argparse.ArgumentTypeError("expected comma-separated integers") from err | |
| def _parse_int_list(value): | |
| try: | |
| result = [int(item) for item in value.split(",") if item.strip()] | |
| if not result: | |
| raise argparse.ArgumentTypeError("expected at least one integer") | |
| return result | |
| except ValueError as err: | |
| raise argparse.ArgumentTypeError("expected comma-separated integers") from err |
| def _parse_dtype(value): | ||
| if value == "bf16": | ||
| return torch.bfloat16 | ||
| if value == "fp16": | ||
| return torch.float16 | ||
| raise argparse.ArgumentTypeError("dtype must be bf16 or fp16") |
There was a problem hiding this comment.
目前的 _parse_dtype 仅支持精确匹配小写的 "bf16" 和 "fp16"。如果用户输入了大写(如 "BF16")或完整名称(如 "bfloat16"),会直接报错。建议将其转换为小写并支持常见的别名,以提高参数解析的容错性和易用性。
| def _parse_dtype(value): | |
| if value == "bf16": | |
| return torch.bfloat16 | |
| if value == "fp16": | |
| return torch.float16 | |
| raise argparse.ArgumentTypeError("dtype must be bf16 or fp16") | |
| def _parse_dtype(value): | |
| val = value.lower() | |
| if val in ("bf16", "bfloat16"): | |
| return torch.bfloat16 | |
| if val in ("fp16", "float16"): | |
| return torch.float16 | |
| raise argparse.ArgumentTypeError("dtype must be bf16 or fp16") |
| parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES) | ||
| parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS) | ||
| parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS) | ||
| parser.add_argument("--s-q", type=int, default=1) | ||
| parser.add_argument("--h-kv", type=int, default=1) | ||
| parser.add_argument("--d", type=int, default=512 + 64) | ||
| parser.add_argument("--dv", type=int, default=512) | ||
| parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16) | ||
| parser.add_argument("--varlen-step", type=int, default=2) | ||
| parser.add_argument("--non-causal", action="store_true") |
There was a problem hiding this comment.
新增的命令行参数(如 --batches、--seqlens、--heads 等)没有提供 help 描述。为了提高脚本的可维护性和易用性,建议为这些参数添加详细的 help 说明,以便用户通过 python bench_flash_mla.py --help 查看参数含义和默认值。
| parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES) | |
| parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS) | |
| parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS) | |
| parser.add_argument("--s-q", type=int, default=1) | |
| parser.add_argument("--h-kv", type=int, default=1) | |
| parser.add_argument("--d", type=int, default=512 + 64) | |
| parser.add_argument("--dv", type=int, default=512) | |
| parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16) | |
| parser.add_argument("--varlen-step", type=int, default=2) | |
| parser.add_argument("--non-causal", action="store_true") | |
| parser.add_argument("--batches", type=_parse_int_list, default=DEFAULT_BATCHES, help="Comma-separated list of batch sizes.") | |
| parser.add_argument("--seqlens", type=_parse_int_list, default=DEFAULT_SEQLENS, help="Comma-separated list of sequence lengths.") | |
| parser.add_argument("--heads", type=_parse_int_list, default=DEFAULT_HEADS, help="Comma-separated list of head numbers.") | |
| parser.add_argument("--s-q", type=int, default=1, help="Query sequence length.") | |
| parser.add_argument("--h-kv", type=int, default=1, help="Number of KV heads.") | |
| parser.add_argument("--d", type=int, default=512 + 64, help="Head dimension d.") | |
| parser.add_argument("--dv", type=int, default=512, help="Head dimension dv.") | |
| parser.add_argument("--dtype", type=_parse_dtype, default=torch.bfloat16, help="Data type (bf16 or fp16).") | |
| parser.add_argument("--varlen-step", type=int, default=2, help="Step size for variable sequence lengths.") | |
| parser.add_argument("--non-causal", action="store_true", help="Disable causal attention mask.") |
该 PR 将 benchmark 的输入形状从固定组合扩展为可配置参数,便于针对不同显存规格和模型规模做覆盖测试。
这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。
已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:
mengz/configurable-benchmark-shapes,目标仓库:MetaX-MACA/FlashMLA。