perf(deepseek-v4): vectorize read_deepseek_v4_indexer_fp8_cache#238
perf(deepseek-v4): vectorize read_deepseek_v4_indexer_fp8_cache#238yuanqingz wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: efa0d11208
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| slots = slot_mapping.to(torch.int64) | ||
| valid = slots >= 0 | ||
| safe_slots = torch.where(valid, slots, torch.zeros_like(slots)) | ||
| pages = torch.div(safe_slots, block_size, rounding_mode="floor") |
There was a problem hiding this comment.
Move slot_mapping to cache device before offset math
This change keeps slots on whatever device slot_mapping already uses (slot_mapping.to(torch.int64)), but later combines value_base/scale_base with torch.arange(..., device=cache_2d.device). When cache_2d is on CUDA and slot_mapping is on CPU, this now raises a cross-device tensor error during offset construction, whereas the previous .tolist() loop accepted CPU mappings. Converting slot_mapping to cache_2d.device before computing pages/pos would preserve prior behavior.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch — fixed in 6bdc106. The original .tolist() loop was implicitly device-agnostic (Python ints moved between CPU/GPU via subsequent slicing). The vectorized version composes offsets with torch.arange(device=cache_2d.device), so I now explicitly slot_mapping.to(device=cache_2d.device, dtype=torch.int64) to preserve the prior contract.
The original implementation iterated over `slot_mapping.tolist()` in
Python and performed GPU slicing + dtype-view + multiply per token. For
a 16-req x 1024-token prefill batch (~14338 tokens) across ~30 sparse
attention layers this is ~430K Python iterations per forward pass, each
with several GPU ops. The CPU sync from `.tolist()` also blocks any
hope of CUDA graph capture for the indexer path.
Replace with a batched torch-op implementation following the same
pattern already used by `read_deepseek_v4_indexer_mxfp4_cache` (same
file): one `gather` per dimension, dequantize on device. Output is
bit-identical to the reference loop for valid slots, zero for invalid
slots (slot < 0).
Measured impact on DeepSeek-V4-Flash with H20-3e TP=4, FP8 KV cache,
random ISL=1024 OSL=4 c=16:
TTFT (ms): 823,467 -> 18,197 (45x)
TPOT (ms): 2,067 -> 298 (7x)
16/16 bench duration: 1350s -> 19s (70x)
The vectorized implementation is also CUDA-graph-safe (no Python
branches, no `.tolist()` CPU sync), unblocking `--enforce-eager`
removal for V4-Flash's sparse indexer path.
Existing test `test_csa_indexer_cache_insert_fp8_path` continues to
pass; numerical equivalence with the original reference loop was
verified against the DeepSeek-V4-Flash bring-up smoke ('The capital of
France is Paris.') and a successful 16/16 random-prompt bench run.
Signed-off-by: Yuanqing Zhao <yuanqingz@nvidia.com>
efa0d11 to
6bdc106
Compare
|
@codex review |
|
Codex Review: Didn't find any major issues. Can't wait for the next one! ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
Summary
read_deepseek_v4_indexer_fp8_cachecurrently iterates overslot_mapping.tolist()in Python and performs GPU slicing + dtype-view.tolist()is also a CPU sync that blocksCUDA graph capture of the entire indexer path.
Replace with a batched torch-op implementation following the exact
pattern already used by
read_deepseek_v4_indexer_mxfp4_cachein thesame file (lines 773–833) — one
gatherper dimension, dequantize ondevice.
Measured impact
DeepSeek-V4-Flash, H20-3e TP=4, FP8 KV cache, random ISL=1024 OSL=4 c=16:
For a 16-req × 1024-token prefill batch (~14338 tokens) × ~30 sparse
attention layers, the original loop ran ~430K Python iterations per
forward pass.
Correctness
for invalid slots (
slot < 0).is Paris.") and successful 16/16 random prompt bench run.
test_csa_indexer_cache_insert_fp8_pathtest(
test/runtime/test_deepseek_v4_attention_ops.py) continues to pass.Side benefit
The vectorized implementation is CUDA-graph-safe (no Python branches,
no CPU sync), which unblocks
--enforce-eagerremoval for V4-Flash'ssparse indexer path in follow-up work.