Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/test_memory_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sys
from argparse import Namespace
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "tools"))

from estimate_flash_mla_memory import estimate_bytes # noqa: E402


def test_memory_estimator_counts_k_cache_blocks():
args = Namespace(
dtype="bf16",
batch_size=2,
s_q=1,
mean_sk=17,
h_q=4,
h_kv=1,
d=8,
dv=4,
block_size=16,
)

estimates = estimate_bytes(args)

assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["total"] >= estimates["k_cache"]
Comment on lines +25 to +26

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

建议在测试中增加对 out 估算值的断言,以确保其正确使用了 dtype_bytes(在 bf16 下为 2 字节),从而避免后续引入类似的计算错误。

Suggested change
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["total"] >= estimates["k_cache"]
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["out"] == 2 * 1 * 4 * 4 * 2
assert estimates["total"] >= estimates["k_cache"]



if __name__ == "__main__":
test_memory_estimator_counts_k_cache_blocks()
69 changes: 69 additions & 0 deletions tools/estimate_flash_mla_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import argparse
import json
import math


DTYPE_BYTES = {
"bf16": 2,
"fp16": 2,
"fp32": 4,
}


def estimate_bytes(args: argparse.Namespace) -> dict[str, int]:
dtype_bytes = DTYPE_BYTES[args.dtype]
max_seqlen_pad = math.ceil(args.mean_sk / 256) * 256
num_blocks = args.batch_size * math.ceil(max_seqlen_pad / args.block_size)

q = args.batch_size * args.s_q * args.h_q * args.d * dtype_bytes
k_cache = num_blocks * args.block_size * args.h_kv * args.d * dtype_bytes
out = args.batch_size * args.s_q * args.h_q * args.dv * 4

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

estimate_bytes 函数中,输出张量 out 的显存大小计算硬编码了 4 字节(即 float32)。然而,注意力机制的输出张量 out 的数据类型通常与输入查询张量 q 保持一致(例如 bf16fp16,它们占用 2 字节)。使用硬编码的 4 会导致在估算 16 位精度(如 bf16/fp16)时的 out 显存偏大一倍。建议将其修改为使用 dtype_bytes

Suggested change
out = args.batch_size * args.s_q * args.h_q * args.dv * 4
out = args.batch_size * args.s_q * args.h_q * args.dv * dtype_bytes

lse = args.batch_size * args.h_q * args.s_q * 4
block_table = args.batch_size * math.ceil(max_seqlen_pad / args.block_size) * 4
cache_seqlens = args.batch_size * 4

total = q + k_cache + out + lse + block_table + cache_seqlens
return {
"q": q,
"k_cache": k_cache,
"out": out,
"lse": lse,
"block_table": block_table,
"cache_seqlens": cache_seqlens,
"total": total,
}


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Estimate FlashMLA test tensor memory.")
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--s-q", type=int, default=1)
parser.add_argument("--mean-sk", type=int, default=4096)
parser.add_argument("--h-q", type=int, default=16)
parser.add_argument("--h-kv", type=int, default=1)
parser.add_argument("--d", type=int, default=576)
parser.add_argument("--dv", type=int, default=512)
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--dtype", choices=sorted(DTYPE_BYTES), default="bf16")
parser.add_argument("--json", action="store_true", help="Print JSON instead of text.")
return parser.parse_args()


def main() -> int:
args = parse_args()
estimates = estimate_bytes(args)
gib = estimates["total"] / 1024**3
if args.json:
payload = dict(estimates)
payload["total_gib"] = gib
print(json.dumps(payload, indent=2, sort_keys=True))
else:
for name, value in estimates.items():
print(f"{name}: {value / 1024**2:.2f} MiB")
print(f"total_gib: {gib:.3f}")
return 0


if __name__ == "__main__":
raise SystemExit(main())