Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,106 @@
import flash_mla_cuda as flash_mla


def _check_dim(name: str, tensor: torch.Tensor, dim: int) -> None:
if tensor.dim() != dim:
raise ValueError(f"{name} must be a {dim}D tensor, got {tensor.dim()}D")


def _check_dtype(name: str, tensor: torch.Tensor, dtype: torch.dtype) -> None:
if tensor.dtype is not dtype:
raise TypeError(f"{name} must use dtype {dtype}, got {tensor.dtype}")


def _check_same_device(reference_name: str, reference: torch.Tensor, name: str, tensor: torch.Tensor) -> None:
if tensor.device != reference.device:
raise ValueError(
f"{name} must be on the same device as {reference_name}: "
f"got {tensor.device} and {reference.device}"
)


def _validate_metadata_inputs(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> None:
_check_dim("cache_seqlens", cache_seqlens, 1)
_check_dtype("cache_seqlens", cache_seqlens, torch.int32)
if num_heads_per_head_k <= 0:
raise ValueError("num_heads_per_head_k must be positive")
if num_heads_k <= 0:
raise ValueError("num_heads_k must be positive")


def _validate_kvcache_inputs(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> None:
_check_dim("q", q, 4)
_check_dim("k_cache", k_cache, 4)
_check_dim("block_table", block_table, 2)
_check_dim("cache_seqlens", cache_seqlens, 1)
_check_dim("tile_scheduler_metadata", tile_scheduler_metadata, 2)
_check_dim("num_splits", num_splits, 1)

_check_dtype("block_table", block_table, torch.int32)
_check_dtype("cache_seqlens", cache_seqlens, torch.int32)
_check_dtype("tile_scheduler_metadata", tile_scheduler_metadata, torch.int32)
_check_dtype("num_splits", num_splits, torch.int32)
Comment on lines +57 to +60

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

建议增加对 qk_cache 的数据类型(dtype)校验。由于 FlashMLA 核心算子要求输入的 Query 和 Key/Value 缓存具有相同的数据类型(通常为 bfloat16float16),并且必须是浮点类型,在 Python 侧提前校验可以避免底层 C++ 算子因类型不匹配或非浮点类型而导致未定义行为或崩溃。

    _check_dtype("block_table", block_table, torch.int32)
    _check_dtype("cache_seqlens", cache_seqlens, torch.int32)
    _check_dtype("tile_scheduler_metadata", tile_scheduler_metadata, torch.int32)
    _check_dtype("num_splits", num_splits, torch.int32)

    if not q.is_floating_point():
        raise TypeError(f"q must be a floating point tensor, got {q.dtype}")
    if q.dtype != k_cache.dtype:
        raise TypeError(
            f"q and k_cache must have the same dtype, got {q.dtype} and {k_cache.dtype}"
        )


batch_size, _, num_heads_q, head_dim = q.shape
_, _, num_heads_k, cache_head_dim = k_cache.shape

if batch_size != block_table.shape[0]:
raise ValueError(
"block_table batch dimension must match q: "
f"got {block_table.shape[0]} and {batch_size}"
)
if batch_size != cache_seqlens.shape[0]:
raise ValueError(
"cache_seqlens length must match q batch dimension: "
f"got {cache_seqlens.shape[0]} and {batch_size}"
)
if num_splits.shape[0] != batch_size + 1:
raise ValueError(
"num_splits length must be batch_size + 1: "
f"got {num_splits.shape[0]} and {batch_size + 1}"
)
if num_heads_k <= 0:
raise ValueError("k_cache must contain at least one KV head")
if num_heads_q % num_heads_k != 0:
raise ValueError(
"q num_heads must be divisible by k_cache num_heads: "
f"got {num_heads_q} and {num_heads_k}"
)
if head_dim != cache_head_dim:
raise ValueError(
"q head_dim must match k_cache head_dim: "
f"got {head_dim} and {cache_head_dim}"
)
if head_dim_v <= 0:
raise ValueError("head_dim_v must be positive")
if head_dim_v > cache_head_dim:
raise ValueError(
"head_dim_v must not exceed k_cache head_dim: "
f"got {head_dim_v} and {cache_head_dim}"
)

for name, tensor in (
("k_cache", k_cache),
("block_table", block_table),
("cache_seqlens", cache_seqlens),
("tile_scheduler_metadata", tile_scheduler_metadata),
("num_splits", num_splits),
):
_check_same_device("q", q, name, tensor)


def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
Expand All @@ -22,6 +122,7 @@ def get_mla_metadata(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
_validate_metadata_inputs(cache_seqlens, num_heads_per_head_k, num_heads_k)
return flash_mla.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)


Expand Down Expand Up @@ -52,6 +153,15 @@ def flash_mla_with_kvcache(
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
_validate_kvcache_inputs(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla.fwd_kvcache_mla(
Expand Down
120 changes: 120 additions & 0 deletions tests/test_python_input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import importlib
import sys
import types
import unittest

import torch


class FakeFlashMla(types.SimpleNamespace):
def __init__(self):
super().__init__(
get_mla_metadata=self.get_mla_metadata,
fwd_kvcache_mla=self.fwd_kvcache_mla,
)
self.metadata_calls = 0
self.kvcache_calls = 0

def get_mla_metadata(self, cache_seqlens, num_heads_per_head_k, num_heads_k):
self.metadata_calls += 1
metadata = torch.empty((1, 16), dtype=torch.int32, device=cache_seqlens.device)
num_splits = torch.empty((cache_seqlens.shape[0] + 1,), dtype=torch.int32, device=cache_seqlens.device)
return metadata, num_splits

def fwd_kvcache_mla(
self,
q,
k_cache,
_v_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
):
self.kvcache_calls += 1
out = torch.empty((*q.shape[:-1], head_dim_v), dtype=q.dtype, device=q.device)
lse = torch.empty((q.shape[0], q.shape[2], q.shape[1]), dtype=torch.float32, device=q.device)
return out, lse


class PythonInputValidationTest(unittest.TestCase):
def setUp(self):
self.fake_extension = FakeFlashMla()
sys.modules["flash_mla_cuda"] = self.fake_extension
sys.modules.pop("flash_mla", None)
sys.modules.pop("flash_mla.flash_mla_interface", None)
self.interface = importlib.import_module("flash_mla.flash_mla_interface")

def tearDown(self):
sys.modules.pop("flash_mla", None)
sys.modules.pop("flash_mla.flash_mla_interface", None)
sys.modules.pop("flash_mla_cuda", None)
Comment on lines +44 to +54

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

setUptearDown 中直接 popsys.modules 中的模块可能会破坏测试套件中其他测试的隔离性。如果真实的 flash_mla_cudaflash_mla 模块在运行此测试之前已经被加载,直接 pop 会导致后续需要真实模块的测试因找不到模块而失败。建议在 setUp 中备份这些模块,并在 tearDown 中进行恢复。

Suggested change
def setUp(self):
self.fake_extension = FakeFlashMla()
sys.modules["flash_mla_cuda"] = self.fake_extension
sys.modules.pop("flash_mla", None)
sys.modules.pop("flash_mla.flash_mla_interface", None)
self.interface = importlib.import_module("flash_mla.flash_mla_interface")
def tearDown(self):
sys.modules.pop("flash_mla", None)
sys.modules.pop("flash_mla.flash_mla_interface", None)
sys.modules.pop("flash_mla_cuda", None)
def setUp(self):
self.fake_extension = FakeFlashMla()
self._saved_modules = {}
for mod in ["flash_mla_cuda", "flash_mla", "flash_mla.flash_mla_interface"]:
if mod in sys.modules:
self._saved_modules[mod] = sys.modules[mod]
del sys.modules[mod]
sys.modules["flash_mla_cuda"] = self.fake_extension
self.interface = importlib.import_module("flash_mla.flash_mla_interface")
def tearDown(self):
for mod in ["flash_mla", "flash_mla.flash_mla_interface", "flash_mla_cuda"]:
sys.modules.pop(mod, None)
for mod, val in self._saved_modules.items():
sys.modules[mod] = val


def _valid_kvcache_inputs(self):
batch_size = 2
q = torch.randn(batch_size, 1, 4, 8)
k_cache = torch.randn(3, 16, 2, 8)
block_table = torch.zeros((batch_size, 1), dtype=torch.int32)
cache_seqlens = torch.full((batch_size,), 8, dtype=torch.int32)
tile_scheduler_metadata = torch.zeros((1, 16), dtype=torch.int32)
num_splits = torch.zeros((batch_size + 1,), dtype=torch.int32)
return q, k_cache, block_table, cache_seqlens, 4, tile_scheduler_metadata, num_splits

def test_metadata_rejects_wrong_cache_seqlens_dtype_before_extension(self):
cache_seqlens = torch.ones((2,), dtype=torch.int64)

with self.assertRaisesRegex(TypeError, "cache_seqlens"):
self.interface.get_mla_metadata(cache_seqlens, 4, 2)

self.assertEqual(self.fake_extension.metadata_calls, 0)

def test_metadata_rejects_non_positive_head_counts_before_extension(self):
cache_seqlens = torch.ones((2,), dtype=torch.int32)

with self.assertRaisesRegex(ValueError, "num_heads_per_head_k"):
self.interface.get_mla_metadata(cache_seqlens, 0, 2)

with self.assertRaisesRegex(ValueError, "num_heads_k"):
self.interface.get_mla_metadata(cache_seqlens, 4, 0)

self.assertEqual(self.fake_extension.metadata_calls, 0)

def test_kvcache_rejects_incompatible_heads_before_extension(self):
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs()
q = torch.randn(q.shape[0], q.shape[1], 3, q.shape[3])

with self.assertRaisesRegex(ValueError, "divisible"):
self.interface.flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits
)

self.assertEqual(self.fake_extension.kvcache_calls, 0)

Comment on lines +94 to +95

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在添加了 qk_cache 的 dtype 一致性校验后,建议在测试文件中补充对应的单元测试,以确保该校验逻辑正确生效且未来不会被意外破坏。

Suggested change
self.assertEqual(self.fake_extension.kvcache_calls, 0)
self.assertEqual(self.fake_extension.kvcache_calls, 0)
def test_kvcache_rejects_dtype_mismatch_before_extension(self):
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs()
k_cache = k_cache.to(torch.float16)
with self.assertRaisesRegex(TypeError, "same dtype"):
self.interface.flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits
)
self.assertEqual(self.fake_extension.kvcache_calls, 0)

def test_kvcache_rejects_bad_num_splits_length_before_extension(self):
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, _num_splits = self._valid_kvcache_inputs()
num_splits = torch.zeros((q.shape[0],), dtype=torch.int32)

with self.assertRaisesRegex(ValueError, "batch_size \\+ 1"):
self.interface.flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits
)

self.assertEqual(self.fake_extension.kvcache_calls, 0)

def test_kvcache_accepts_valid_inputs_and_uses_default_scale(self):
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits = self._valid_kvcache_inputs()

out, lse = self.interface.flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, head_dim_v, metadata, num_splits
)

self.assertEqual(out.shape, (2, 1, 4, head_dim_v))
self.assertEqual(lse.shape, (2, 4, 1))
self.assertEqual(self.fake_extension.kvcache_calls, 1)


if __name__ == "__main__":
unittest.main()