增加 MLA 显存估算工具#13
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a tool (tools/estimate_flash_mla_memory.py) to estimate FlashMLA test tensor memory, along with a corresponding test suite (tests/test_memory_estimator.py). The reviewer identified that the output tensor (out) memory calculation is hardcoded to 4 bytes, which overestimates memory for 16-bit precision types like bf16 or fp16. It is recommended to use dtype_bytes for this calculation and to add a corresponding assertion in the unit tests to verify the output tensor's estimated size.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
|
||
| q = args.batch_size * args.s_q * args.h_q * args.d * dtype_bytes | ||
| k_cache = num_blocks * args.block_size * args.h_kv * args.d * dtype_bytes | ||
| out = args.batch_size * args.s_q * args.h_q * args.dv * 4 |
There was a problem hiding this comment.
在 estimate_bytes 函数中,输出张量 out 的显存大小计算硬编码了 4 字节(即 float32)。然而,注意力机制的输出张量 out 的数据类型通常与输入查询张量 q 保持一致(例如 bf16 或 fp16,它们占用 2 字节)。使用硬编码的 4 会导致在估算 16 位精度(如 bf16/fp16)时的 out 显存偏大一倍。建议将其修改为使用 dtype_bytes。
| out = args.batch_size * args.s_q * args.h_q * args.dv * 4 | |
| out = args.batch_size * args.s_q * args.h_q * args.dv * dtype_bytes |
| assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2 | ||
| assert estimates["total"] >= estimates["k_cache"] |
There was a problem hiding this comment.
建议在测试中增加对 out 估算值的断言,以确保其正确使用了 dtype_bytes(在 bf16 下为 2 字节),从而避免后续引入类似的计算错误。
| assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2 | |
| assert estimates["total"] >= estimates["k_cache"] | |
| assert estimates["k_cache"] == 2 * 16 * 16 * 1 * 8 * 2 | |
| assert estimates["out"] == 2 * 1 * 4 * 4 * 2 | |
| assert estimates["total"] >= estimates["k_cache"] |
该 PR 增加面向 MLA 典型输入形状的显存估算能力,帮助用户在沐曦显存资源有限的情况下提前选择可运行的 batch、head 和 sequence 配置。
这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。
已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:
mengz/add-memory-estimator,目标仓库:MetaX-MACA/FlashMLA。