Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ python setup.py install
python tests/test_flash_mla.py
```

### Smoke test

After building the extension, run a small correctness case before launching the
full benchmark suite:

```bash
python tools/run_flash_mla_smoke.py
```

The command prints the detected torch version and MACA device name, then runs a
single bf16 FlashMLA case against the PyTorch reference implementation. Use
`--dtype fp16` or the shape flags in `--help` to cover additional cases.

### Usage

```python
Expand Down
77 changes: 77 additions & 0 deletions tools/run_flash_mla_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
import argparse
import random
from pathlib import Path
import sys

import torch


REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from tests.test_flash_mla import test_flash_mla # noqa: E402


def _dtype(value: str) -> torch.dtype:
if value == "bf16":
return torch.bfloat16
if value == "fp16":
return torch.float16
raise argparse.ArgumentTypeError("dtype must be bf16 or fp16")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run a small FlashMLA correctness smoke test on one MACA device."
)
parser.add_argument("--device", default="cuda:0", help="Torch device to run on.")
parser.add_argument("--dtype", type=_dtype, default=torch.bfloat16, help="bf16 or fp16.")
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--s-q", type=int, default=1)
parser.add_argument("--mean-sk", type=int, default=4096)
parser.add_argument("--h-q", type=int, default=16)
parser.add_argument("--h-kv", type=int, default=1)
parser.add_argument("--d", type=int, default=576)
parser.add_argument("--dv", type=int, default=512)
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--varlen", action="store_true")
parser.add_argument("--non-causal", action="store_true")
return parser.parse_args()


def main() -> int:
args = parse_args()
device = torch.device(args.device)
if device.type != "cuda":
raise ValueError("FlashMLA smoke test requires a CUDA-compatible MACA device.")
Comment on lines +46 to +48

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在初始化 CUDA 设备之前,建议先检查 torch.cuda.is_available(),并验证指定的设备索引是否在可用设备范围内(device.index < torch.cuda.device_count())。如果环境未正确配置(例如 MACA 驱动未加载或 PyTorch 未编译 CUDA 支持),直接调用 torch.cuda.set_device 会抛出难以理解的底层错误。增加这些防御性检查可以提供更友好的错误提示。

    if not torch.cuda.is_available():\n        raise RuntimeError("CUDA is not available. Please check your MACA driver and PyTorch installation.")\n    device = torch.device(args.device)\n    if device.type != "cuda":\n        raise ValueError("FlashMLA smoke test requires a CUDA-compatible MACA device.")\n    if device.index is not None and device.index >= torch.cuda.device_count():\n        raise ValueError(f"Device index {device.index} is out of range. Total available devices: {torch.cuda.device_count()}")


torch.set_default_dtype(args.dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)

print(f"torch={torch.__version__}")
print(f"device={torch.cuda.get_device_name(device)}")
print(f"dtype={args.dtype}")

test_flash_mla(
b=args.batch_size,
s_q=args.s_q,
mean_sk=args.mean_sk,
h_q=args.h_q,
h_kv=args.h_kv,
d=args.d,
dv=args.dv,
causal=not args.non_causal,
varlen=args.varlen,
block_size=args.block_size,
)
print("flash_mla_smoke_ok")
return 0


if __name__ == "__main__":
raise SystemExit(main())