Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions benchmark/bench_flash_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,27 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
"flash_mla_triton",
]

shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
]
def build_shape_configs(device="cuda"):
return [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor(
[seqlen + 2 * i for i in range(batch)],
dtype=torch.int32,
device=device,
),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.bfloat16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4]
for head in [128]
]


def get_args():
Expand All @@ -504,6 +521,7 @@ def get_args():
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
shape_configs = build_shape_configs()
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
Expand All @@ -517,4 +535,4 @@ def get_args():
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
57 changes: 57 additions & 0 deletions tests/test_benchmark_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import ast
import unittest
from pathlib import Path


BENCHMARK = Path(__file__).parents[1] / "benchmark" / "bench_flash_mla.py"


class BenchmarkShapeConfigTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tree = ast.parse(BENCHMARK.read_text(encoding="utf-8"))

def test_shape_configs_are_not_created_at_import_time(self):
top_level_assigns = [
node
for node in self.tree.body
if isinstance(node, ast.Assign)
and any(
isinstance(target, ast.Name) and target.id == "shape_configs"
for target in node.targets
)
]

self.assertEqual(top_level_assigns, [])

def test_shape_builder_accepts_device_argument(self):
builder = next(
node
for node in self.tree.body
if isinstance(node, ast.FunctionDef) and node.name == "build_shape_configs"
)

self.assertEqual(builder.args.args[0].arg, "device")
self.assertEqual(builder.args.defaults[0].value, "cuda")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在测试中,直接使用 builder.args.defaults[0].value 来获取默认值虽然在 Python 3.8+ 中可行,但使用 ast.literal_eval 是一种更通用、更安全且兼容性更好的方式来解析 AST 中的字面量值。建议将其替换为 ast.literal_eval(builder.args.defaults[0])

Suggested change
self.assertEqual(builder.args.defaults[0].value, "cuda")
self.assertEqual(ast.literal_eval(builder.args.defaults[0]), "cuda")


def test_main_creates_shape_configs_lazily(self):
main_block = next(
node
for node in self.tree.body
if isinstance(node, ast.If)
and isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Name)
and node.test.left.id == "__name__"
)
calls = [node for node in ast.walk(main_block) if isinstance(node, ast.Call)]

self.assertTrue(
any(
isinstance(call.func, ast.Name) and call.func.id == "build_shape_configs"
for call in calls
)
)


if __name__ == "__main__":
unittest.main()