From cb037cf58cc1154bbcefa02ee5bc1bd013fb6187 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Fri, 5 Jun 2026 11:39:54 +0800 Subject: [PATCH] Create benchmark shapes lazily --- benchmark/bench_flash_mla.py | 28 ++++++++++++++--- tests/test_benchmark_shapes.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 tests/test_benchmark_shapes.py diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 2b59f8ed..300087f7 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -484,10 +484,27 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "flash_mla_triton", ] -shape_configs = [ - {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} - for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] -] +def build_shape_configs(device="cuda"): + return [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor( + [seqlen + 2 * i for i in range(batch)], + dtype=torch.int32, + device=device, + ), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.bfloat16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4] + for head in [128] + ] def get_args(): @@ -504,6 +521,7 @@ def get_args(): if __name__ == "__main__": args = get_args() benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + shape_configs = build_shape_configs() with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: @@ -517,4 +535,4 @@ def get_args(): fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) - fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') diff --git a/tests/test_benchmark_shapes.py b/tests/test_benchmark_shapes.py new file mode 100644 index 00000000..6a7a3667 --- /dev/null +++ b/tests/test_benchmark_shapes.py @@ -0,0 +1,57 @@ +import ast +import unittest +from pathlib import Path + + +BENCHMARK = Path(__file__).parents[1] / "benchmark" / "bench_flash_mla.py" + + +class BenchmarkShapeConfigTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tree = ast.parse(BENCHMARK.read_text(encoding="utf-8")) + + def test_shape_configs_are_not_created_at_import_time(self): + top_level_assigns = [ + node + for node in self.tree.body + if isinstance(node, ast.Assign) + and any( + isinstance(target, ast.Name) and target.id == "shape_configs" + for target in node.targets + ) + ] + + self.assertEqual(top_level_assigns, []) + + def test_shape_builder_accepts_device_argument(self): + builder = next( + node + for node in self.tree.body + if isinstance(node, ast.FunctionDef) and node.name == "build_shape_configs" + ) + + self.assertEqual(builder.args.args[0].arg, "device") + self.assertEqual(builder.args.defaults[0].value, "cuda") + + def test_main_creates_shape_configs_lazily(self): + main_block = next( + node + for node in self.tree.body + if isinstance(node, ast.If) + and isinstance(node.test, ast.Compare) + and isinstance(node.test.left, ast.Name) + and node.test.left.id == "__name__" + ) + calls = [node for node in ast.walk(main_block) if isinstance(node, ast.Call)] + + self.assertTrue( + any( + isinstance(call.func, ast.Name) and call.func.id == "build_shape_configs" + for call in calls + ) + ) + + +if __name__ == "__main__": + unittest.main()