Enable 128k context for Gemma4-31B CUDA #20316
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20316
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 3 Unrelated Failures, 4 Unclassified FailuresAs of commit 3f42650 with merge base 48ff29e ( NEW FAILURE - The following job has failed:
UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f6bdb2c to
9390209
Compare
This PR needs a
|
TurboQuantKVCache.update wrote the compressed cache via slice index-assignment (self.k_packed[:, :, input_pos] = ...), which lowers to index_put_ and breaks CUDA-graph capture of the decode step. Switch the 4 cache writes to index_copy_(2, input_pos, ...): a static scatter along the position dim, matching the model's flat global KV cache (Gemma4KVCache.update). Nibble packing is on the last dim (D/2); index_copy_ along dim 2 (positions) leaves the packed layout untouched. MLX overrides update(); qwen shares this path (test_turboquant 5/5).
Builds on the kv_len decode clamp. After that clamp, prefill still scanned ALL KV blocks in [0, chunk_end] for every query row, paying the full causal upper-triangle. Add a per-tile causal upper bound to tq4_sdpa's loop, gated behind a new mask_is_causal op arg (threaded through register_fake, both autotune wrappers, launcher, and body): loop_end = min(kv_len, (kv_len - Lq) + max(seq_pos) + 1). seq_pos is the query-row index; the (kv_len - Lq) offset converts it to an absolute KV position, so it is correct for chunked prefill. gemma's full-attention (global) layers pass mask_is_causal=True; sliding layers are bf16 (untouched) and qwen does not pass the flag (defaults False). Decode-safe: at Lq=1 loop_end == kv_len, so decode is byte-identical. Skipped blocks are fully masked, so prefill is bit-identical to no-skip, just faster.
…ontext) decode/prefill) The fused TQ4 attention kernel iterated the full pre-allocated KV buffer (max_seq_len) every decode/prefill step, masking empty/future blocks only after computing them, making the 10 global layers O(max_seq_len) regardless of actual context. Add an optional kv_len GPU int32 scalar to triton::tq4_sdpa that bounds the inner loop to the valid/filled length. gemma's call site passes kv_len = input_pos[0] + T (decode: pos+1; prefill chunk: chunk_end). kv_len is read on-device via tl.load (no .item(), no host sync) so the bound updates across CUDA-graph replays. When kv_len is None, HAS_KV_LEN is False and the loop runs over full Lk, so the shared qwen path is byte-identical.
…V (cover the 128k decode/prefill path)
…only') TQ4 KV cache long-context is now supported on the CUDA backend, not just MLX. Update the README section: retitle to CUDA + MLX, add a CUDA 128k example, and state that long context requires BOTH --turboquant and a larger --max-seq-len (raising --max-seq-len alone keeps a bf16 KV cache that does not fit at 128k). Note the ~27 GiB runtime peak fits a 32 GB card.
9390209 to
42582d5
Compare
|
Nice job! Shall we add test this in our CI? |
…t) decode/prefill) The kv_len argument to torch.ops.triton.tq4_sdpa was dropped from cuda_source_transformations.py during a rebase, leaving HAS_KV_LEN=False so the TQ4 attention kernel swept the full pre-allocated max_seq_len every step instead of the valid context. At 128k this made decode ~2.7 tok/s (vs ~37) and prefill ~282 (vs ~2230). Restore kv_len = input_pos[0] + input_pos.shape[0] so the kernel's O(context) bound is active again. Verified (128k + turboquant, CUDA graph): prefill 2231 (was 282), decode 46.1 @~128ctx / 36.7 @~2048ctx (was flat 2.67) — matches the pre-rebase reference.
Enable 128k context for Gemma4-31B (CUDA, TurboQuant TQ4 KV)
What & why
This PR enables 128k context end-to-end (export + C++ CUDA runtime) by using TurboQuant TQ4 (4-bit) format and fixing the fused TQ4 attention kernel so decode/prefill scale with the actual context length and are CUDA-graph capturable. The 50 sliding-window layers are unchanged (2,048-entry ring cache).
With these fixes, enabling long context is just
--max-seq-len 131072 --turboquant.Changes
TurboQuantKVCache.update→index_copy_: write the compressed cache via a staticscatter so the decode step is CUDA-graph-capturable (the previous slice-assignment
lowered to
index_put_, which breaks graph capture).tq4_sdpakv_len clamp: an optional on-device int32kv_lenscalar bounds thekernel's KV loop to the filled context instead of the full pre-allocated 131,072 buffer
→ decode/prefill become O(context) instead of O(max_seq_len). It is read on-device
(no host sync) so it stays correct across CUDA-graph replays.
kv_len=Nonefalls backto the original full-loop behavior (the shared Qwen path is byte-identical).
tq4_sdpaprefill causal block-skip: skip fully-masked causal upper-triangle KVblocks during chunked prefill.
test_tq4_sdpa.pyto cover the actual 128k paths — long-KV kv_lenclamp (with a large-magnitude "garbage tail" beyond
kv_lenas a built-in negativecontrol) and bottom-right
mask_is_causalchunked-prefill, at Gemma-global (D=512) andQwen (D=256) shapes, plus a gated 131,072 case. Runs in the existing
unittest-cudaCI.Results
Model: Gemma4-31B (GGUF int6 weights) + TQ4 KV @128k. Hardware: A100 80GB. C++ CUDA
runner,
--cuda_graph, temperature 0.card, and is context-independent (KV buffers are pre-allocated at load).
TQ4-KV vs bf16-KV is ~0.9997 and holds to the full 128k.
Throughput (decode is the instantaneous, windowed rate measured at each KV depth):
Known limitation / follow-up
Decode throughput degrades with context depth: the global-layer TQ4 attention is
O(context) and currently launches few CTAs (no split-K), so deep-context decode is
under-parallelized. Adding split-K / flash-decoding to
tq4_sdpais the naturalfollow-up to speed up decode at long context.
Exportation memory consumption is too large to fit in consumer-based GPU like 5090, make it impossible to export on user server. Should optimize the gguf exportation path for better support.