校验 Python API 输入参数#22
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces input validation for the flash_mla_with_kvcache function, checking dimensions, shapes, and data types of various input tensors, and adds corresponding unit tests using a mocked CUDA module. The reviewer suggested several improvements: validating that num_splits has a size of batch_size + 1 to prevent potential out-of-bounds memory access in CUDA, ensuring all input tensors are actual torch.Tensor instances on the same device, and adding a unit test to verify the num_splits size validation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]: | ||
| raise ValueError( | ||
| "batch size must match across q, block_table, and cache_seqlens" | ||
| ) |
There was a problem hiding this comment.
根据 flash_mla_with_kvcache 的文档,num_splits 的形状应为 (batch_size + 1,)。如果用户传入的 num_splits 长度不匹配(例如小于 batch_size + 1),底层 CUDA 算子在执行时可能会发生越界内存访问(Out-of-bounds access),导致 CUDA 报错或非法内存访问。
建议在此处同时校验 num_splits 的长度是否等于 batch_size + 1。
| if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]: | |
| raise ValueError( | |
| "batch size must match across q, block_table, and cache_seqlens" | |
| ) | |
| if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0] or num_splits.shape[0] != q.shape[0] + 1: | |
| raise ValueError( | |
| f"batch size mismatch: q batch_size is {q.shape[0]}, but block_table has {block_table.shape[0]}, " | |
| f"cache_seqlens has {cache_seqlens.shape[0]}, and num_splits must have size {q.shape[0] + 1}, got {num_splits.shape[0]}" | |
| ) |
| tile_scheduler_metadata: torch.Tensor, | ||
| num_splits: torch.Tensor, | ||
| ) -> None: | ||
| if q.dim() != 4: |
There was a problem hiding this comment.
PR 描述中提到“对非法 dtype、shape 或设备输入给出明确错误”,但目前的实现中并没有对设备(Device)一致性进行校验。如果用户传入了位于不同设备(例如 CPU 和 CUDA,或者不同的 GPU 卡)上的 Tensor,可能会导致底层算子运行时报错或产生未定义行为。此外,如果传入的参数不是 torch.Tensor 类型(例如普通的 list 或 numpy 数组),直接访问 .dim() 或 .dtype 会抛出混淆的 AttributeError。
建议在函数开头增加对所有输入 Tensor 的类型和设备一致性校验。
| if q.dim() != 4: | |
| for name, t in [("q", q), ("k_cache", k_cache), ("block_table", block_table), ("cache_seqlens", cache_seqlens), ("tile_scheduler_metadata", tile_scheduler_metadata), ("num_splits", num_splits)]: | |
| if not isinstance(t, torch.Tensor): | |
| raise TypeError(f"{name} must be a torch.Tensor, got {type(t)}") | |
| if t.device != q.device: | |
| raise ValueError(f"All tensors must be on the same device, but {name} is on {t.device} while q is on {q.device}") | |
| if q.dim() != 4: |
| def test_flash_mla_rejects_invalid_value_head_dim(): | ||
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | ||
|
|
||
| with pytest.raises(ValueError, match="head_dim_v"): | ||
| flash_mla_with_kvcache( | ||
| q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits | ||
| ) |
There was a problem hiding this comment.
为了确保新增的 num_splits 长度校验逻辑正确工作,建议在测试文件中增加对应的单元测试。
| def test_flash_mla_rejects_invalid_value_head_dim(): | |
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | |
| with pytest.raises(ValueError, match="head_dim_v"): | |
| flash_mla_with_kvcache( | |
| q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits | |
| ) | |
| def test_flash_mla_rejects_invalid_value_head_dim(): | |
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | |
| with pytest.raises(ValueError, match="head_dim_v"): | |
| flash_mla_with_kvcache( | |
| q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits | |
| ) | |
| def test_flash_mla_rejects_mismatched_num_splits(): | |
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | |
| num_splits = torch.zeros(4, dtype=torch.int32) | |
| with pytest.raises(ValueError, match="batch size mismatch"): | |
| flash_mla_with_kvcache( | |
| q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits | |
| ) |
该 PR 为 Python API 增加输入参数校验,对非法 dtype、shape 或设备输入给出明确错误,避免底层算子报错难以定位。
这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。
已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:
mengz/validate-python-api-inputs,目标仓库:MetaX-MACA/FlashMLA。