Skip to content

校验 Python API 输入参数#22

Open
ghangz wants to merge 1 commit into
MetaX-MACA:mainfrom
ghangz:mengz/validate-python-api-inputs
Open

校验 Python API 输入参数#22
ghangz wants to merge 1 commit into
MetaX-MACA:mainfrom
ghangz:mengz/validate-python-api-inputs

Conversation

@ghangz

@ghangz ghangz commented Jun 8, 2026

Copy link
Copy Markdown

该 PR 为 Python API 增加输入参数校验,对非法 dtype、shape 或设备输入给出明确错误,避免底层算子报错难以定位。

这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。

已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:mengz/validate-python-api-inputs,目标仓库:MetaX-MACA/FlashMLA

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces input validation for the flash_mla_with_kvcache function, checking dimensions, shapes, and data types of various input tensors, and adds corresponding unit tests using a mocked CUDA module. The reviewer suggested several improvements: validating that num_splits has a size of batch_size + 1 to prevent potential out-of-bounds memory access in CUDA, ensuring all input tensors are actual torch.Tensor instances on the same device, and adding a unit test to verify the num_splits size validation.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +36 to +39
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]:
raise ValueError(
"batch size must match across q, block_table, and cache_seqlens"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

根据 flash_mla_with_kvcache 的文档,num_splits 的形状应为 (batch_size + 1,)。如果用户传入的 num_splits 长度不匹配(例如小于 batch_size + 1),底层 CUDA 算子在执行时可能会发生越界内存访问(Out-of-bounds access),导致 CUDA 报错或非法内存访问。

建议在此处同时校验 num_splits 的长度是否等于 batch_size + 1

Suggested change
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]:
raise ValueError(
"batch size must match across q, block_table, and cache_seqlens"
)
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0] or num_splits.shape[0] != q.shape[0] + 1:
raise ValueError(
f"batch size mismatch: q batch_size is {q.shape[0]}, but block_table has {block_table.shape[0]}, "
f"cache_seqlens has {cache_seqlens.shape[0]}, and num_splits must have size {q.shape[0] + 1}, got {num_splits.shape[0]}"
)

tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> None:
if q.dim() != 4:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

PR 描述中提到“对非法 dtype、shape 或设备输入给出明确错误”,但目前的实现中并没有对设备(Device)一致性进行校验。如果用户传入了位于不同设备(例如 CPU 和 CUDA,或者不同的 GPU 卡)上的 Tensor,可能会导致底层算子运行时报错或产生未定义行为。此外,如果传入的参数不是 torch.Tensor 类型(例如普通的 list 或 numpy 数组),直接访问 .dim().dtype 会抛出混淆的 AttributeError

建议在函数开头增加对所有输入 Tensor 的类型和设备一致性校验。

Suggested change
if q.dim() != 4:
for name, t in [("q", q), ("k_cache", k_cache), ("block_table", block_table), ("cache_seqlens", cache_seqlens), ("tile_scheduler_metadata", tile_scheduler_metadata), ("num_splits", num_splits)]:
if not isinstance(t, torch.Tensor):
raise TypeError(f"{name} must be a torch.Tensor, got {type(t)}")
if t.device != q.device:
raise ValueError(f"All tensors must be on the same device, but {name} is on {t.device} while q is on {q.device}")
if q.dim() != 4:

Comment on lines +51 to +57
def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()

with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了确保新增的 num_splits 长度校验逻辑正确工作,建议在测试文件中增加对应的单元测试。

Suggested change
def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)
def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)
def test_flash_mla_rejects_mismatched_num_splits():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
num_splits = torch.zeros(4, dtype=torch.int32)
with pytest.raises(ValueError, match="batch size mismatch"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant