增强 Python 输入校验#23
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Python-side input validation for the flash_mla interface, checking tensor dimensions, data types, and device consistency before calling the underlying C++ extension. It also adds a comprehensive test suite to verify these validation rules. The review feedback suggests adding data type validation for q and k_cache to ensure they are floating-point and match each other, improving test isolation by backing up and restoring sys.modules in the unit tests, and adding a test case for the new data type validation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| _check_dtype("block_table", block_table, torch.int32) | ||
| _check_dtype("cache_seqlens", cache_seqlens, torch.int32) | ||
| _check_dtype("tile_scheduler_metadata", tile_scheduler_metadata, torch.int32) | ||
| _check_dtype("num_splits", num_splits, torch.int32) |
There was a problem hiding this comment.
建议增加对 q 和 k_cache 的数据类型(dtype)校验。由于 FlashMLA 核心算子要求输入的 Query 和 Key/Value 缓存具有相同的数据类型(通常为 bfloat16 或 float16),并且必须是浮点类型,在 Python 侧提前校验可以避免底层 C++ 算子因类型不匹配或非浮点类型而导致未定义行为或崩溃。
_check_dtype("block_table", block_table, torch.int32)
_check_dtype("cache_seqlens", cache_seqlens, torch.int32)
_check_dtype("tile_scheduler_metadata", tile_scheduler_metadata, torch.int32)
_check_dtype("num_splits", num_splits, torch.int32)
if not q.is_floating_point():
raise TypeError(f"q must be a floating point tensor, got {q.dtype}")
if q.dtype != k_cache.dtype:
raise TypeError(
f"q and k_cache must have the same dtype, got {q.dtype} and {k_cache.dtype}"
)| def setUp(self): | ||
| self.fake_extension = FakeFlashMla() | ||
| sys.modules["flash_mla_cuda"] = self.fake_extension | ||
| sys.modules.pop("flash_mla", None) | ||
| sys.modules.pop("flash_mla.flash_mla_interface", None) | ||
| self.interface = importlib.import_module("flash_mla.flash_mla_interface") | ||
|
|
||
| def tearDown(self): | ||
| sys.modules.pop("flash_mla", None) | ||
| sys.modules.pop("flash_mla.flash_mla_interface", None) | ||
| sys.modules.pop("flash_mla_cuda", None) |
There was a problem hiding this comment.
在 setUp 和 tearDown 中直接 pop 掉 sys.modules 中的模块可能会破坏测试套件中其他测试的隔离性。如果真实的 flash_mla_cuda 或 flash_mla 模块在运行此测试之前已经被加载,直接 pop 会导致后续需要真实模块的测试因找不到模块而失败。建议在 setUp 中备份这些模块,并在 tearDown 中进行恢复。
| def setUp(self): | |
| self.fake_extension = FakeFlashMla() | |
| sys.modules["flash_mla_cuda"] = self.fake_extension | |
| sys.modules.pop("flash_mla", None) | |
| sys.modules.pop("flash_mla.flash_mla_interface", None) | |
| self.interface = importlib.import_module("flash_mla.flash_mla_interface") | |
| def tearDown(self): | |
| sys.modules.pop("flash_mla", None) | |
| sys.modules.pop("flash_mla.flash_mla_interface", None) | |
| sys.modules.pop("flash_mla_cuda", None) | |
| def setUp(self): | |
| self.fake_extension = FakeFlashMla() | |
| self._saved_modules = {} | |
| for mod in ["flash_mla_cuda", "flash_mla", "flash_mla.flash_mla_interface"]: | |
| if mod in sys.modules: | |
| self._saved_modules[mod] = sys.modules[mod] | |
| del sys.modules[mod] | |
| sys.modules["flash_mla_cuda"] = self.fake_extension | |
| self.interface = importlib.import_module("flash_mla.flash_mla_interface") | |
| def tearDown(self): | |
| for mod in ["flash_mla", "flash_mla.flash_mla_interface", "flash_mla_cuda"]: | |
| sys.modules.pop(mod, None) | |
| for mod, val in self._saved_modules.items(): | |
| sys.modules[mod] = val |
| self.assertEqual(self.fake_extension.kvcache_calls, 0) | ||
|
|
There was a problem hiding this comment.
在添加了 q 和 k_cache 的 dtype 一致性校验后,建议在测试文件中补充对应的单元测试,以确保该校验逻辑正确生效且未来不会被意外破坏。
| self.assertEqual(self.fake_extension.kvcache_calls, 0) | |
| self.assertEqual(self.fake_extension.kvcache_calls, 0) | |
| def test_kvcache_rejects_dtype_mismatch_before_extension(self): | |
| q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs() | |
| k_cache = k_cache.to(torch.float16) | |
| with self.assertRaisesRegex(TypeError, "same dtype"): | |
| self.interface.flash_mla_with_kvcache( | |
| q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits | |
| ) | |
| self.assertEqual(self.fake_extension.kvcache_calls, 0) |
该 PR 补充 Python 侧输入检查,覆盖常见张量布局和参数边界问题,让 FlashMLA 的错误信息更贴近用户调用位置。
这个修改面向沐曦 GPU 适配场景中比较容易影响开发、构建或验证稳定性的环节,把原来需要人工排查的问题前移到工具链、运行前检查或基准脚本中处理。实现上保持对现有默认行为的兼容,只在检测到明确配置、输入或环境异常时给出更直接的诊断,避免引入额外运行依赖,也方便维护者独立审阅该分支。
已在沐曦算力环境中完成对应分支验证,验证记录包含真实运行日志、命令输出和失败路径检查,本地归档目录为:E:/Documents/muxi/测试报告/FlashMLA_real_maca_validation_20260608。提交分支:
mengz/validate-python-inputs,目标仓库:MetaX-MACA/FlashMLA。