Skip to content

增加 MLA 显存估算工具#13

Open
ghangz wants to merge 1 commit into
MetaX-MACA:mainfrom
ghangz:mengz/add-memory-estimator
Open

增加 MLA 显存估算工具#13
ghangz wants to merge 1 commit into
MetaX-MACA:mainfrom
ghangz:mengz/add-memory-estimator

Conversation

@ghangz

@ghangz ghangz commented Jun 8, 2026

Copy link
Copy Markdown

该 PR 增加面向 MLA 典型输入形状的显存估算能力,帮助用户在沐曦显存资源有限的情况下提前选择可运行的 batch、head 和 sequence 配置。

这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。

已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:mengz/add-memory-estimator,目标仓库:MetaX-MACA/FlashMLA

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a tool (tools/estimate_flash_mla_memory.py) to estimate FlashMLA test tensor memory, along with a corresponding test suite (tests/test_memory_estimator.py). The reviewer identified that the output tensor (out) memory calculation is hardcoded to 4 bytes, which overestimates memory for 16-bit precision types like bf16 or fp16. It is recommended to use dtype_bytes for this calculation and to add a corresponding assertion in the unit tests to verify the output tensor's estimated size.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.


q = args.batch_size * args.s_q * args.h_q * args.d * dtype_bytes
k_cache = num_blocks * args.block_size * args.h_kv * args.d * dtype_bytes
out = args.batch_size * args.s_q * args.h_q * args.dv * 4

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

estimate_bytes 函数中,输出张量 out 的显存大小计算硬编码了 4 字节(即 float32)。然而,注意力机制的输出张量 out 的数据类型通常与输入查询张量 q 保持一致(例如 bf16fp16,它们占用 2 字节)。使用硬编码的 4 会导致在估算 16 位精度(如 bf16/fp16)时的 out 显存偏大一倍。建议将其修改为使用 dtype_bytes

Suggested change
out = args.batch_size * args.s_q * args.h_q * args.dv * 4
out = args.batch_size * args.s_q * args.h_q * args.dv * dtype_bytes

Comment on lines +25 to +26
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["total"] >= estimates["k_cache"]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

建议在测试中增加对 out 估算值的断言,以确保其正确使用了 dtype_bytes(在 bf16 下为 2 字节),从而避免后续引入类似的计算错误。

Suggested change
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["total"] >= estimates["k_cache"]
assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2
assert estimates["out"] == 2 * 1 * 4 * 4 * 2
assert estimates["total"] >= estimates["k_cache"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant