From 9adc1ea3d66014ca7a0082ebf1ab94f9e663f255 Mon Sep 17 00:00:00 2001 From: papertager <2567587994@qq.com> Date: Sun, 7 Jun 2026 16:32:21 +0800 Subject: [PATCH] python: validate FlashMLA wrapper inputs --- flash_mla/flash_mla_interface.py | 53 ++++++++++++++++++ tests/test_flash_mla_interface_validation.py | 57 ++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 tests/test_flash_mla_interface_validation.py diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 616469ac..af458a4b 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -7,6 +7,50 @@ import flash_mla_cuda as flash_mla +def _check_int32_tensor(name: str, tensor: torch.Tensor) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must use torch.int32, got {tensor.dtype}") + + +def _validate_flash_mla_inputs( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> None: + if q.dim() != 4: + raise ValueError(f"q must be 4D, got shape {tuple(q.shape)}") + if k_cache.dim() != 4: + raise ValueError(f"k_cache must be 4D, got shape {tuple(k_cache.shape)}") + if block_table.dim() != 2: + raise ValueError(f"block_table must be 2D, got shape {tuple(block_table.shape)}") + if cache_seqlens.dim() != 1: + raise ValueError( + f"cache_seqlens must be 1D, got shape {tuple(cache_seqlens.shape)}" + ) + if num_splits.dim() != 1: + raise ValueError(f"num_splits must be 1D, got shape {tuple(num_splits.shape)}") + if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]: + raise ValueError( + "batch size must match across q, block_table, and cache_seqlens" + ) + if q.shape[-1] != k_cache.shape[-1]: + raise ValueError( + f"q head_dim ({q.shape[-1]}) must match k_cache head_dim ({k_cache.shape[-1]})" + ) + if head_dim_v <= 0 or head_dim_v > k_cache.shape[-1]: + raise ValueError( + f"head_dim_v must be in (0, {k_cache.shape[-1]}], got {head_dim_v}" + ) + _check_int32_tensor("block_table", block_table) + _check_int32_tensor("cache_seqlens", cache_seqlens) + _check_int32_tensor("tile_scheduler_metadata", tile_scheduler_metadata) + _check_int32_tensor("num_splits", num_splits) + + def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, @@ -52,6 +96,15 @@ def flash_mla_with_kvcache( out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ + _validate_flash_mla_inputs( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_scheduler_metadata, + num_splits, + ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = flash_mla.fwd_kvcache_mla( diff --git a/tests/test_flash_mla_interface_validation.py b/tests/test_flash_mla_interface_validation.py new file mode 100644 index 00000000..90553118 --- /dev/null +++ b/tests/test_flash_mla_interface_validation.py @@ -0,0 +1,57 @@ +import sys +import types + +import pytest +import torch + + +class _FakeFlashMla(types.SimpleNamespace): + def get_mla_metadata(self, *args, **kwargs): + return None + + def fwd_kvcache_mla(self, *args, **kwargs): + return torch.empty(1), torch.empty(1) + + +sys.modules.setdefault("flash_mla_cuda", _FakeFlashMla()) + +from flash_mla.flash_mla_interface import flash_mla_with_kvcache # noqa: E402 + + +def _valid_inputs(): + q = torch.empty(2, 1, 4, 8) + k_cache = torch.empty(8, 16, 1, 8) + block_table = torch.zeros(2, 4, dtype=torch.int32) + cache_seqlens = torch.ones(2, dtype=torch.int32) + metadata = torch.zeros(1, 8, dtype=torch.int32) + num_splits = torch.zeros(3, dtype=torch.int32) + return q, k_cache, block_table, cache_seqlens, metadata, num_splits + + +def test_flash_mla_rejects_mismatched_batch_size(): + q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() + cache_seqlens = torch.ones(3, dtype=torch.int32) + + with pytest.raises(ValueError, match="batch size"): + flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits + ) + + +def test_flash_mla_rejects_non_int32_cache_lengths(): + q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() + cache_seqlens = cache_seqlens.to(torch.int64) + + with pytest.raises(TypeError, match="cache_seqlens"): + flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits + ) + + +def test_flash_mla_rejects_invalid_value_head_dim(): + q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() + + with pytest.raises(ValueError, match="head_dim_v"): + flash_mla_with_kvcache( + q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits + )