From 6c759363658fa1e1464ee0a3573ad6177b22ba94 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Fri, 5 Jun 2026 11:30:16 +0800 Subject: [PATCH] Add Python input validation for FlashMLA wrappers --- flash_mla/flash_mla_interface.py | 110 +++++++++++++++++++++++ tests/test_python_input_validation.py | 120 ++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 tests/test_python_input_validation.py diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 616469ac..c7a4650c 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -7,6 +7,106 @@ import flash_mla_cuda as flash_mla +def _check_dim(name: str, tensor: torch.Tensor, dim: int) -> None: + if tensor.dim() != dim: + raise ValueError(f"{name} must be a {dim}D tensor, got {tensor.dim()}D") + + +def _check_dtype(name: str, tensor: torch.Tensor, dtype: torch.dtype) -> None: + if tensor.dtype is not dtype: + raise TypeError(f"{name} must use dtype {dtype}, got {tensor.dtype}") + + +def _check_same_device(reference_name: str, reference: torch.Tensor, name: str, tensor: torch.Tensor) -> None: + if tensor.device != reference.device: + raise ValueError( + f"{name} must be on the same device as {reference_name}: " + f"got {tensor.device} and {reference.device}" + ) + + +def _validate_metadata_inputs( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> None: + _check_dim("cache_seqlens", cache_seqlens, 1) + _check_dtype("cache_seqlens", cache_seqlens, torch.int32) + if num_heads_per_head_k <= 0: + raise ValueError("num_heads_per_head_k must be positive") + if num_heads_k <= 0: + raise ValueError("num_heads_k must be positive") + + +def _validate_kvcache_inputs( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> None: + _check_dim("q", q, 4) + _check_dim("k_cache", k_cache, 4) + _check_dim("block_table", block_table, 2) + _check_dim("cache_seqlens", cache_seqlens, 1) + _check_dim("tile_scheduler_metadata", tile_scheduler_metadata, 2) + _check_dim("num_splits", num_splits, 1) + + _check_dtype("block_table", block_table, torch.int32) + _check_dtype("cache_seqlens", cache_seqlens, torch.int32) + _check_dtype("tile_scheduler_metadata", tile_scheduler_metadata, torch.int32) + _check_dtype("num_splits", num_splits, torch.int32) + + batch_size, _, num_heads_q, head_dim = q.shape + _, _, num_heads_k, cache_head_dim = k_cache.shape + + if batch_size != block_table.shape[0]: + raise ValueError( + "block_table batch dimension must match q: " + f"got {block_table.shape[0]} and {batch_size}" + ) + if batch_size != cache_seqlens.shape[0]: + raise ValueError( + "cache_seqlens length must match q batch dimension: " + f"got {cache_seqlens.shape[0]} and {batch_size}" + ) + if num_splits.shape[0] != batch_size + 1: + raise ValueError( + "num_splits length must be batch_size + 1: " + f"got {num_splits.shape[0]} and {batch_size + 1}" + ) + if num_heads_k <= 0: + raise ValueError("k_cache must contain at least one KV head") + if num_heads_q % num_heads_k != 0: + raise ValueError( + "q num_heads must be divisible by k_cache num_heads: " + f"got {num_heads_q} and {num_heads_k}" + ) + if head_dim != cache_head_dim: + raise ValueError( + "q head_dim must match k_cache head_dim: " + f"got {head_dim} and {cache_head_dim}" + ) + if head_dim_v <= 0: + raise ValueError("head_dim_v must be positive") + if head_dim_v > cache_head_dim: + raise ValueError( + "head_dim_v must not exceed k_cache head_dim: " + f"got {head_dim_v} and {cache_head_dim}" + ) + + for name, tensor in ( + ("k_cache", k_cache), + ("block_table", block_table), + ("cache_seqlens", cache_seqlens), + ("tile_scheduler_metadata", tile_scheduler_metadata), + ("num_splits", num_splits), + ): + _check_same_device("q", q, name, tensor) + + def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, @@ -22,6 +122,7 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ + _validate_metadata_inputs(cache_seqlens, num_heads_per_head_k, num_heads_k) return flash_mla.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) @@ -52,6 +153,15 @@ def flash_mla_with_kvcache( out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ + _validate_kvcache_inputs( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = flash_mla.fwd_kvcache_mla( diff --git a/tests/test_python_input_validation.py b/tests/test_python_input_validation.py new file mode 100644 index 00000000..28b2de1d --- /dev/null +++ b/tests/test_python_input_validation.py @@ -0,0 +1,120 @@ +import importlib +import sys +import types +import unittest + +import torch + + +class FakeFlashMla(types.SimpleNamespace): + def __init__(self): + super().__init__( + get_mla_metadata=self.get_mla_metadata, + fwd_kvcache_mla=self.fwd_kvcache_mla, + ) + self.metadata_calls = 0 + self.kvcache_calls = 0 + + def get_mla_metadata(self, cache_seqlens, num_heads_per_head_k, num_heads_k): + self.metadata_calls += 1 + metadata = torch.empty((1, 16), dtype=torch.int32, device=cache_seqlens.device) + num_splits = torch.empty((cache_seqlens.shape[0] + 1,), dtype=torch.int32, device=cache_seqlens.device) + return metadata, num_splits + + def fwd_kvcache_mla( + self, + q, + k_cache, + _v_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ): + self.kvcache_calls += 1 + out = torch.empty((*q.shape[:-1], head_dim_v), dtype=q.dtype, device=q.device) + lse = torch.empty((q.shape[0], q.shape[2], q.shape[1]), dtype=torch.float32, device=q.device) + return out, lse + + +class PythonInputValidationTest(unittest.TestCase): + def setUp(self): + self.fake_extension = FakeFlashMla() + sys.modules["flash_mla_cuda"] = self.fake_extension + sys.modules.pop("flash_mla", None) + sys.modules.pop("flash_mla.flash_mla_interface", None) + self.interface = importlib.import_module("flash_mla.flash_mla_interface") + + def tearDown(self): + sys.modules.pop("flash_mla", None) + sys.modules.pop("flash_mla.flash_mla_interface", None) + sys.modules.pop("flash_mla_cuda", None) + + def _valid_kvcache_inputs(self): + batch_size = 2 + q = torch.randn(batch_size, 1, 4, 8) + k_cache = torch.randn(3, 16, 2, 8) + block_table = torch.zeros((batch_size, 1), dtype=torch.int32) + cache_seqlens = torch.full((batch_size,), 8, dtype=torch.int32) + tile_scheduler_metadata = torch.zeros((1, 16), dtype=torch.int32) + num_splits = torch.zeros((batch_size + 1,), dtype=torch.int32) + return q, k_cache, block_table, cache_seqlens, 4, tile_scheduler_metadata, num_splits + + def test_metadata_rejects_wrong_cache_seqlens_dtype_before_extension(self): + cache_seqlens = torch.ones((2,), dtype=torch.int64) + + with self.assertRaisesRegex(TypeError, "cache_seqlens"): + self.interface.get_mla_metadata(cache_seqlens, 4, 2) + + self.assertEqual(self.fake_extension.metadata_calls, 0) + + def test_metadata_rejects_non_positive_head_counts_before_extension(self): + cache_seqlens = torch.ones((2,), dtype=torch.int32) + + with self.assertRaisesRegex(ValueError, "num_heads_per_head_k"): + self.interface.get_mla_metadata(cache_seqlens, 0, 2) + + with self.assertRaisesRegex(ValueError, "num_heads_k"): + self.interface.get_mla_metadata(cache_seqlens, 4, 0) + + self.assertEqual(self.fake_extension.metadata_calls, 0) + + def test_kvcache_rejects_incompatible_heads_before_extension(self): + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs() + q = torch.randn(q.shape[0], q.shape[1], 3, q.shape[3]) + + with self.assertRaisesRegex(ValueError, "divisible"): + self.interface.flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits + ) + + self.assertEqual(self.fake_extension.kvcache_calls, 0) + + def test_kvcache_rejects_bad_num_splits_length_before_extension(self): + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, _num_splits = self._valid_kvcache_inputs() + num_splits = torch.zeros((q.shape[0],), dtype=torch.int32) + + with self.assertRaisesRegex(ValueError, "batch_size \\+ 1"): + self.interface.flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits + ) + + self.assertEqual(self.fake_extension.kvcache_calls, 0) + + def test_kvcache_accepts_valid_inputs_and_uses_default_scale(self): + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs() + + out, lse = self.interface.flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits + ) + + self.assertEqual(out.shape, (2, 1, 4, head_dim_v)) + self.assertEqual(lse.shape, (2, 4, 1)) + self.assertEqual(self.fake_extension.kvcache_calls, 1) + + +if __name__ == "__main__": + unittest.main()