Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,50 @@
import flash_mla_cuda as flash_mla


def _check_int32_tensor(name: str, tensor: torch.Tensor) -> None:
if tensor.dtype != torch.int32:
raise TypeError(f"{name} must use torch.int32, got {tensor.dtype}")


def _validate_flash_mla_inputs(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
) -> None:
if q.dim() != 4:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

PR 描述中提到“对非法 dtype、shape 或设备输入给出明确错误”,但目前的实现中并没有对设备(Device)一致性进行校验。如果用户传入了位于不同设备(例如 CPU 和 CUDA,或者不同的 GPU 卡)上的 Tensor,可能会导致底层算子运行时报错或产生未定义行为。此外,如果传入的参数不是 torch.Tensor 类型(例如普通的 list 或 numpy 数组),直接访问 .dim().dtype 会抛出混淆的 AttributeError

建议在函数开头增加对所有输入 Tensor 的类型和设备一致性校验。

Suggested change
if q.dim() != 4:
for name, t in [("q", q), ("k_cache", k_cache), ("block_table", block_table), ("cache_seqlens", cache_seqlens), ("tile_scheduler_metadata", tile_scheduler_metadata), ("num_splits", num_splits)]:
if not isinstance(t, torch.Tensor):
raise TypeError(f"{name} must be a torch.Tensor, got {type(t)}")
if t.device != q.device:
raise ValueError(f"All tensors must be on the same device, but {name} is on {t.device} while q is on {q.device}")
if q.dim() != 4:

raise ValueError(f"q must be 4D, got shape {tuple(q.shape)}")
if k_cache.dim() != 4:
raise ValueError(f"k_cache must be 4D, got shape {tuple(k_cache.shape)}")
if block_table.dim() != 2:
raise ValueError(f"block_table must be 2D, got shape {tuple(block_table.shape)}")
if cache_seqlens.dim() != 1:
raise ValueError(
f"cache_seqlens must be 1D, got shape {tuple(cache_seqlens.shape)}"
)
if num_splits.dim() != 1:
raise ValueError(f"num_splits must be 1D, got shape {tuple(num_splits.shape)}")
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]:
raise ValueError(
"batch size must match across q, block_table, and cache_seqlens"
)
Comment on lines +36 to +39

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

根据 flash_mla_with_kvcache 的文档,num_splits 的形状应为 (batch_size + 1,)。如果用户传入的 num_splits 长度不匹配(例如小于 batch_size + 1),底层 CUDA 算子在执行时可能会发生越界内存访问(Out-of-bounds access),导致 CUDA 报错或非法内存访问。

建议在此处同时校验 num_splits 的长度是否等于 batch_size + 1

Suggested change
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]:
raise ValueError(
"batch size must match across q, block_table, and cache_seqlens"
)
if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0] or num_splits.shape[0] != q.shape[0] + 1:
raise ValueError(
f"batch size mismatch: q batch_size is {q.shape[0]}, but block_table has {block_table.shape[0]}, "
f"cache_seqlens has {cache_seqlens.shape[0]}, and num_splits must have size {q.shape[0] + 1}, got {num_splits.shape[0]}"
)

if q.shape[-1] != k_cache.shape[-1]:
raise ValueError(
f"q head_dim ({q.shape[-1]}) must match k_cache head_dim ({k_cache.shape[-1]})"
)
if head_dim_v <= 0 or head_dim_v > k_cache.shape[-1]:
raise ValueError(
f"head_dim_v must be in (0, {k_cache.shape[-1]}], got {head_dim_v}"
)
_check_int32_tensor("block_table", block_table)
_check_int32_tensor("cache_seqlens", cache_seqlens)
_check_int32_tensor("tile_scheduler_metadata", tile_scheduler_metadata)
_check_int32_tensor("num_splits", num_splits)


def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
Expand Down Expand Up @@ -52,6 +96,15 @@ def flash_mla_with_kvcache(
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
_validate_flash_mla_inputs(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla.fwd_kvcache_mla(
Expand Down
57 changes: 57 additions & 0 deletions tests/test_flash_mla_interface_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import sys
import types

import pytest
import torch


class _FakeFlashMla(types.SimpleNamespace):
def get_mla_metadata(self, *args, **kwargs):
return None

def fwd_kvcache_mla(self, *args, **kwargs):
return torch.empty(1), torch.empty(1)


sys.modules.setdefault("flash_mla_cuda", _FakeFlashMla())

from flash_mla.flash_mla_interface import flash_mla_with_kvcache # noqa: E402


def _valid_inputs():
q = torch.empty(2, 1, 4, 8)
k_cache = torch.empty(8, 16, 1, 8)
block_table = torch.zeros(2, 4, dtype=torch.int32)
cache_seqlens = torch.ones(2, dtype=torch.int32)
metadata = torch.zeros(1, 8, dtype=torch.int32)
num_splits = torch.zeros(3, dtype=torch.int32)
return q, k_cache, block_table, cache_seqlens, metadata, num_splits


def test_flash_mla_rejects_mismatched_batch_size():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
cache_seqlens = torch.ones(3, dtype=torch.int32)

with pytest.raises(ValueError, match="batch size"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits
)


def test_flash_mla_rejects_non_int32_cache_lengths():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
cache_seqlens = cache_seqlens.to(torch.int64)

with pytest.raises(TypeError, match="cache_seqlens"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits
)


def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()

with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)
Comment on lines +51 to +57

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了确保新增的 num_splits 长度校验逻辑正确工作,建议在测试文件中增加对应的单元测试。

Suggested change
def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)
def test_flash_mla_rejects_invalid_value_head_dim():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
with pytest.raises(ValueError, match="head_dim_v"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits
)
def test_flash_mla_rejects_mismatched_num_splits():
q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs()
num_splits = torch.zeros(4, dtype=torch.int32)
with pytest.raises(ValueError, match="batch size mismatch"):
flash_mla_with_kvcache(
q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits
)