-
Notifications
You must be signed in to change notification settings - Fork 3
校验 Python API 输入参数 #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,50 @@ | |||||||||||||||||||
| import flash_mla_cuda as flash_mla | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _check_int32_tensor(name: str, tensor: torch.Tensor) -> None: | ||||||||||||||||||||
| if tensor.dtype != torch.int32: | ||||||||||||||||||||
| raise TypeError(f"{name} must use torch.int32, got {tensor.dtype}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _validate_flash_mla_inputs( | ||||||||||||||||||||
| q: torch.Tensor, | ||||||||||||||||||||
| k_cache: torch.Tensor, | ||||||||||||||||||||
| block_table: torch.Tensor, | ||||||||||||||||||||
| cache_seqlens: torch.Tensor, | ||||||||||||||||||||
| head_dim_v: int, | ||||||||||||||||||||
| tile_scheduler_metadata: torch.Tensor, | ||||||||||||||||||||
| num_splits: torch.Tensor, | ||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||
| if q.dim() != 4: | ||||||||||||||||||||
| raise ValueError(f"q must be 4D, got shape {tuple(q.shape)}") | ||||||||||||||||||||
| if k_cache.dim() != 4: | ||||||||||||||||||||
| raise ValueError(f"k_cache must be 4D, got shape {tuple(k_cache.shape)}") | ||||||||||||||||||||
| if block_table.dim() != 2: | ||||||||||||||||||||
| raise ValueError(f"block_table must be 2D, got shape {tuple(block_table.shape)}") | ||||||||||||||||||||
| if cache_seqlens.dim() != 1: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"cache_seqlens must be 1D, got shape {tuple(cache_seqlens.shape)}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if num_splits.dim() != 1: | ||||||||||||||||||||
| raise ValueError(f"num_splits must be 1D, got shape {tuple(num_splits.shape)}") | ||||||||||||||||||||
| if q.shape[0] != block_table.shape[0] or q.shape[0] != cache_seqlens.shape[0]: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| "batch size must match across q, block_table, and cache_seqlens" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+36
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据 建议在此处同时校验
Suggested change
|
||||||||||||||||||||
| if q.shape[-1] != k_cache.shape[-1]: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"q head_dim ({q.shape[-1]}) must match k_cache head_dim ({k_cache.shape[-1]})" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if head_dim_v <= 0 or head_dim_v > k_cache.shape[-1]: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"head_dim_v must be in (0, {k_cache.shape[-1]}], got {head_dim_v}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| _check_int32_tensor("block_table", block_table) | ||||||||||||||||||||
| _check_int32_tensor("cache_seqlens", cache_seqlens) | ||||||||||||||||||||
| _check_int32_tensor("tile_scheduler_metadata", tile_scheduler_metadata) | ||||||||||||||||||||
| _check_int32_tensor("num_splits", num_splits) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def get_mla_metadata( | ||||||||||||||||||||
| cache_seqlens: torch.Tensor, | ||||||||||||||||||||
| num_heads_per_head_k: int, | ||||||||||||||||||||
|
|
@@ -52,6 +96,15 @@ def flash_mla_with_kvcache( | |||||||||||||||||||
| out: (batch_size, seq_len_q, num_heads_q, head_dim_v). | ||||||||||||||||||||
| softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| _validate_flash_mla_inputs( | ||||||||||||||||||||
| q, | ||||||||||||||||||||
| k_cache, | ||||||||||||||||||||
| block_table, | ||||||||||||||||||||
| cache_seqlens, | ||||||||||||||||||||
| head_dim_v, | ||||||||||||||||||||
| tile_scheduler_metadata, | ||||||||||||||||||||
| num_splits, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if softmax_scale is None: | ||||||||||||||||||||
| softmax_scale = q.shape[-1] ** (-0.5) | ||||||||||||||||||||
| out, softmax_lse = flash_mla.fwd_kvcache_mla( | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,57 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import types | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class _FakeFlashMla(types.SimpleNamespace): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_mla_metadata(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def fwd_kvcache_mla(self, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.empty(1), torch.empty(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| sys.modules.setdefault("flash_mla_cuda", _FakeFlashMla()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from flash_mla.flash_mla_interface import flash_mla_with_kvcache # noqa: E402 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def _valid_inputs(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q = torch.empty(2, 1, 4, 8) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| k_cache = torch.empty(8, 16, 1, 8) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| block_table = torch.zeros(2, 4, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_seqlens = torch.ones(2, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata = torch.zeros(1, 8, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| num_splits = torch.zeros(3, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return q, k_cache, block_table, cache_seqlens, metadata, num_splits | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_flash_mla_rejects_mismatched_batch_size(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_seqlens = torch.ones(3, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| with pytest.raises(ValueError, match="batch size"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| flash_mla_with_kvcache( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_flash_mla_rejects_non_int32_cache_lengths(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_seqlens = cache_seqlens.to(torch.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| with pytest.raises(TypeError, match="cache_seqlens"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| flash_mla_with_kvcache( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, 4, metadata, num_splits | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_flash_mla_rejects_invalid_value_head_dim(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, metadata, num_splits = _valid_inputs() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| with pytest.raises(ValueError, match="head_dim_v"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| flash_mla_with_kvcache( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k_cache, block_table, cache_seqlens, 16, metadata, num_splits | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+51
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了确保新增的
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR 描述中提到“对非法 dtype、shape 或设备输入给出明确错误”,但目前的实现中并没有对设备(Device)一致性进行校验。如果用户传入了位于不同设备(例如 CPU 和 CUDA,或者不同的 GPU 卡)上的 Tensor,可能会导致底层算子运行时报错或产生未定义行为。此外,如果传入的参数不是
torch.Tensor类型(例如普通的 list 或 numpy 数组),直接访问.dim()或.dtype会抛出混淆的AttributeError。建议在函数开头增加对所有输入 Tensor 的类型和设备一致性校验。