Skip to content

Enable 128k context for Gemma4-31B CUDA #20316

Merged
Gasoonjia merged 7 commits into
mainfrom
g4-128k-context
Jun 18, 2026
Merged

Enable 128k context for Gemma4-31B CUDA #20316
Gasoonjia merged 7 commits into
mainfrom
g4-128k-context

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Enable 128k context for Gemma4-31B (CUDA, TurboQuant TQ4 KV)

What & why

This PR enables 128k context end-to-end (export + C++ CUDA runtime) by using TurboQuant TQ4 (4-bit) format and fixing the fused TQ4 attention kernel so decode/prefill scale with the actual context length and are CUDA-graph capturable. The 50 sliding-window layers are unchanged (2,048-entry ring cache).

With these fixes, enabling long context is just --max-seq-len 131072 --turboquant.

Changes

  • TurboQuantKVCache.updateindex_copy_: write the compressed cache via a static
    scatter so the decode step is CUDA-graph-capturable (the previous slice-assignment
    lowered to index_put_, which breaks graph capture).
  • tq4_sdpa kv_len clamp: an optional on-device int32 kv_len scalar bounds the
    kernel's KV loop to the filled context instead of the full pre-allocated 131,072 buffer
    → decode/prefill become O(context) instead of O(max_seq_len). It is read on-device
    (no host sync) so it stays correct across CUDA-graph replays. kv_len=None falls back
    to the original full-loop behavior (the shared Qwen path is byte-identical).
  • tq4_sdpa prefill causal block-skip: skip fully-masked causal upper-triangle KV
    blocks during chunked prefill.
  • Tests: extend test_tq4_sdpa.py to cover the actual 128k paths — long-KV kv_len
    clamp (with a large-magnitude "garbage tail" beyond kv_len as a built-in negative
    control) and bottom-right mask_is_causal chunked-prefill, at Gemma-global (D=512) and
    Qwen (D=256) shapes, plus a gated 131,072 case. Runs in the existing unittest-cuda CI.

Results

Model: Gemma4-31B (GGUF int6 weights) + TQ4 KV @128k. Hardware: A100 80GB. C++ CUDA
runner, --cuda_graph, temperature 0.

  • Works e2e: 128k export + C++ CUDA runtime produce coherent output.
  • Memory: runtime peak ~27.0 GiB (export peak 32.05 GiB) — runtime fits a 32 GB
    card, and is context-independent (KV buffers are pre-allocated at load).
  • Quality: needle-in-haystack retrieval is exact; per-position logit cosine of
    TQ4-KV vs bf16-KV is ~0.9997 and holds to the full 128k.

Throughput (decode is the instantaneous, windowed rate measured at each KV depth):

Phase Context length Throughput (tok/s)
Prefill 2,048 2,247
Decode 128 45.6
Decode 512 43.4
Decode 2,048 36.5
Decode 8,192 22.3
Decode 32,768 8.7
Decode 131,072 N/A (too long)

Known limitation / follow-up

  1. Decode throughput degrades with context depth: the global-layer TQ4 attention is
    O(context) and currently launches few CTAs (no split-K), so deep-context decode is
    under-parallelized. Adding split-K / flash-decoding to tq4_sdpa is the natural
    follow-up to speed up decode at long context.

  2. Exportation memory consumption is too large to fit in consumer-based GPU like 5090, make it impossible to export on user server. Should optimize the gguf exportation path for better support.

@pytorch-bot

pytorch-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20316

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 3 Unrelated Failures, 4 Unclassified Failures

As of commit 3f42650 with merge base 48ff29e (image):

NEW FAILURE - The following job has failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 17, 2026
@Gasoonjia Gasoonjia changed the title G4 128k context Enable 128k context for Gemma4-31B CUDA Jun 17, 2026
Base automatically changed from g4-int6-gguf to main June 17, 2026 05:39
@Gasoonjia Gasoonjia requested a review from kirklandsign as a code owner June 17, 2026 05:39
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

TurboQuantKVCache.update wrote the compressed cache via slice index-assignment (self.k_packed[:, :, input_pos] = ...), which lowers to index_put_ and breaks CUDA-graph capture of the decode step. Switch the 4 cache writes to index_copy_(2, input_pos, ...): a static scatter along the position dim, matching the model's flat global KV cache (Gemma4KVCache.update). Nibble packing is on the last dim (D/2); index_copy_ along dim 2 (positions) leaves the packed layout untouched. MLX overrides update(); qwen shares this path (test_turboquant 5/5).
Builds on the kv_len decode clamp. After that clamp, prefill still scanned ALL KV blocks in [0, chunk_end] for every query row, paying the full causal upper-triangle. Add a per-tile causal upper bound to tq4_sdpa's loop, gated behind a new mask_is_causal op arg (threaded through register_fake, both autotune wrappers, launcher, and body): loop_end = min(kv_len, (kv_len - Lq) + max(seq_pos) + 1). seq_pos is the query-row index; the (kv_len - Lq) offset converts it to an absolute KV position, so it is correct for chunked prefill. gemma's full-attention (global) layers pass mask_is_causal=True; sliding layers are bf16 (untouched) and qwen does not pass the flag (defaults False). Decode-safe: at Lq=1 loop_end == kv_len, so decode is byte-identical. Skipped blocks are fully masked, so prefill is bit-identical to no-skip, just faster.
…ontext) decode/prefill)

The fused TQ4 attention kernel iterated the full pre-allocated KV buffer (max_seq_len) every decode/prefill step, masking empty/future blocks only after computing them, making the 10 global layers O(max_seq_len) regardless of actual context. Add an optional kv_len GPU int32 scalar to triton::tq4_sdpa that bounds the inner loop to the valid/filled length. gemma's call site passes kv_len = input_pos[0] + T (decode: pos+1; prefill chunk: chunk_end). kv_len is read on-device via tl.load (no .item(), no host sync) so the bound updates across CUDA-graph replays. When kv_len is None, HAS_KV_LEN is False and the loop runs over full Lk, so the shared qwen path is byte-identical.
…only')

TQ4 KV cache long-context is now supported on the CUDA backend, not just MLX. Update the README section: retitle to CUDA + MLX, add a CUDA 128k example, and state that long context requires BOTH --turboquant and a larger --max-seq-len (raising --max-seq-len alone keeps a bf16 KV cache that does not fit at 128k). Note the ~27 GiB runtime peak fits a 32 GB card.
@mergennachin

Copy link
Copy Markdown
Contributor

Nice job! Shall we add test this in our CI?

…t) decode/prefill)

The kv_len argument to torch.ops.triton.tq4_sdpa was dropped from
cuda_source_transformations.py during a rebase, leaving HAS_KV_LEN=False so the
TQ4 attention kernel swept the full pre-allocated max_seq_len every step instead
of the valid context. At 128k this made decode ~2.7 tok/s (vs ~37) and prefill
~282 (vs ~2230). Restore kv_len = input_pos[0] + input_pos.shape[0] so the
kernel's O(context) bound is active again.

Verified (128k + turboquant, CUDA graph): prefill 2231 (was 282), decode 46.1
@~128ctx / 36.7 @~2048ctx (was flat 2.67) — matches the pre-rebase reference.
@Gasoonjia Gasoonjia merged commit 63b4c4d into main Jun 18, 2026
270 of 279 checks passed
@Gasoonjia Gasoonjia deleted the g4-128k-context branch June 18, 2026 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants