diff --git a/backends/cuda/tests/test_tq4_sdpa.py b/backends/cuda/tests/test_tq4_sdpa.py index 9cf1e9e2d57..f4cc1d770ef 100644 --- a/backends/cuda/tests/test_tq4_sdpa.py +++ b/backends/cuda/tests/test_tq4_sdpa.py @@ -438,6 +438,242 @@ def test_output_shape_and_dtype(self): self.assertEqual(out.shape, (1, H_q, Lq, D)) self.assertEqual(out.dtype, torch.bfloat16) + # ------------------------------------------------------------------ + # 128k code path: kv_len clamp (decode) + mask_is_causal (prefill) + # + # Every test above calls tq4_sdpa WITHOUT kv_len and WITHOUT + # mask_is_causal, so they only exercise the kv_len=None fallback + # (full-Lk loop) at short KV. The cases below drive the actual + # long-context paths used in production by the Gemma-4 31B global + # layers (head_dim=512, GQA 8:4) and Qwen 3.5 MoE (head_dim=256, + # GQA 16:2): + # * the on-device kv_len scalar that bounds the KV loop to the + # filled context (decode), and + # * the mask_is_causal per-tile causal block-skip (prefill). + # + # "GARBAGE TAIL": in production the KV cache is a fixed buffer + # pre-allocated to max_seq_len (e.g. 131072). At any step only the + # first kv_len positions hold real K/V; the rest is stale / + # uninitialized memory that attention must ignore. We simulate that + # tail by writing large-magnitude (x1000) values into [kv_len:]. If + # the clamp / block-skip works the kernel never reads the tail and + # the output matches a reference built from [0, kv_len) only; if it + # is broken the huge tail values dominate the softmax and the cosine + # collapses to ~0. So the garbage tail is a built-in negative control + # (verified: dropping kv_len drives the cosine to ~-0.01 and fails). + # + # CAUSAL ALIGNMENT (top-left vs bottom-right): when L_q < L_kv (a + # chunked prefill / decode, where the Lq new queries sit at the END + # of a kv_len-long context) there are two ways to place the causal + # triangle. PyTorch F.sdpa(is_causal=True) uses TOP-LEFT alignment + # (query row i attends to keys [0, i]) -- wrong for a KV cache. This + # kernel and gemma4_31b/model.py::_build_masks use BOTTOM-RIGHT + # alignment: query row i is absolute position (kv_len - Lq + i) and + # attends to keys [0, kv_len - Lq + i]. So the reference below builds + # an explicit bottom-right mask (q_pos >= cache_pos) rather than + # passing is_causal=True, which would otherwise mismatch the kernel. + # ------------------------------------------------------------------ + + def _run_long_kv_test( + self, + *, + H_q, + H_kv, + D, + Lq, + kv_len, + buffer_len, + causal=False, + garbage=True, + pass_kv_len=True, + min_cosine=0.99, + seed=42, + ): + """Drive tq4_sdpa over a buffer whose first ``kv_len`` positions are + real and whose ``[kv_len:]`` tail is large-magnitude garbage, then + compare against an fp32 reference built from the first ``kv_len`` + positions only. + + The kernel sees the full (garbage-tailed) compressed buffer; the + on-device ``kv_len`` scalar (and, for prefill, the bottom-right + causal mask) must confine attention to ``[0, kv_len)``. + + ``causal=True`` builds a bottom-right-aligned mask (the Lq queries + are the last Lq positions of a kv_len-long context), mirroring the + production ``q_pos >= cache_pos`` mask in gemma4_31b/model.py + ``_build_masks`` and the kernel's ``(kv_len - Lq) + seq_pos`` block + bound. We deliberately do NOT use ``F.sdpa(is_causal=True)`` for the + reference: PyTorch aligns is_causal top-left when L_q < L_kv, while + this kernel (and the model) align bottom-right. + """ + torch.manual_seed(seed) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids = centroids.cuda() + boundaries = boundaries.cuda() + rotation = rotation.cuda() + + B = 1 + k = torch.randn(B, H_kv, buffer_len, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, buffer_len, D, dtype=torch.bfloat16, device="cuda") + if garbage and buffer_len > kv_len: + g = buffer_len - kv_len + k[:, :, kv_len:, :] = ( + torch.randn(B, H_kv, g, D, dtype=torch.bfloat16, device="cuda") * 1000.0 + ) + v[:, :, kv_len:, :] = ( + torch.randn(B, H_kv, g, D, dtype=torch.bfloat16, device="cuda") * 1000.0 + ) + + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + + k_packed, k_norms = _compress(k, boundaries, rotation) + v_packed, v_norms = _compress(v, boundaries, rotation) + + attn_mask = None + if causal: + cache_pos = torch.arange(buffer_len, device="cuda") + q_pos = torch.arange(kv_len - Lq, kv_len, device="cuda").unsqueeze(1) + attn_mask = (q_pos >= cache_pos.unsqueeze(0)).view(1, 1, Lq, buffer_len) + + kv_len_t = ( + torch.tensor([kv_len], dtype=torch.int32, device="cuda") + if pass_kv_len + else None + ) + + out = self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + attn_mask=attn_mask, + is_causal=False, + scale=None, + kv_len=kv_len_t, + mask_is_causal=causal, + ) + + # Reference: the same decompress-then-fp32-SDPA path the other tests + # use (_reference_tq4_sdpa), but over ONLY the first kv_len positions + # so the garbage tail can never influence it. _compress is per-row, + # so compressing the sliced K/V here is bit-identical to the kernel's + # view of the full buffer sliced to [:, :, :kv_len]; the helper also + # handles the GQA repeat_interleave and mask broadcast internally. + ref_mask = attn_mask[:, :, :, :kv_len] if attn_mask is not None else None + ref, *_ = _reference_tq4_sdpa( + q, + k[:, :, :kv_len], + v[:, :, :kv_len], + centroids, + boundaries, + rotation, + attn_mask=ref_mask, + ) + + self.assertFalse(torch.isnan(out).any(), "NaN in output") + cos = _cosine_sim(out, ref) + self.assertGreater( + cos, + min_cosine, + f"Cosine {cos:.5f} < {min_cosine} " + f"(H_q={H_q} H_kv={H_kv} D={D} Lq={Lq} kv_len={kv_len} " + f"buffer={buffer_len} causal={causal} kv_len_passed={pass_kv_len})", + ) + return cos + + def test_kv_len_clamp_decode_gemma_global(self): + """Decode (Lq=1) kv_len clamp at Gemma-4 31B global-layer shape + (head_dim=512, GQA 8:4). N=8192 leaves a 24k garbage tail in a 32k + buffer (clamp guard); N=32768 fills the buffer (full 32k loop).""" + for N in (8192, 32768): + with self.subTest(N=N): + self._run_long_kv_test( + H_q=8, H_kv=4, D=512, Lq=1, kv_len=N, buffer_len=32768 + ) + + def test_kv_len_clamp_decode_qwen(self): + """Decode (Lq=1) kv_len clamp at Qwen 3.5 MoE shape + (head_dim=256, GQA 16:2).""" + for N in (8192, 32768): + with self.subTest(N=N): + self._run_long_kv_test( + H_q=16, H_kv=2, D=256, Lq=1, kv_len=N, buffer_len=32768 + ) + + def test_mask_is_causal_prefill_gemma_global(self): + """Chunked prefill (Lq>1) with mask_is_causal at Gemma global shape. + The Lq queries are the last Lq of a kv_len-long context; the + per-tile causal block-skip plus bottom-right mask must match the + fp32 causal reference over the first kv_len positions. A garbage + tail beyond kv_len also exercises the clamp.""" + for Lq, kv_len, buf in ((256, 4096, 8192), (2048, 8192, 16384)): + with self.subTest(Lq=Lq, kv_len=kv_len): + self._run_long_kv_test( + H_q=8, + H_kv=4, + D=512, + Lq=Lq, + kv_len=kv_len, + buffer_len=buf, + causal=True, + ) + + def test_mask_is_causal_prefill_qwen(self): + """Chunked prefill (Lq>1) with mask_is_causal at Qwen shape.""" + for Lq, kv_len, buf in ((256, 4096, 8192), (2048, 8192, 16384)): + with self.subTest(Lq=Lq, kv_len=kv_len): + self._run_long_kv_test( + H_q=16, + H_kv=2, + D=256, + Lq=Lq, + kv_len=kv_len, + buffer_len=buf, + causal=True, + ) + + def test_kv_len_none_fallback_qwen(self): + """Regression: the kv_len=None fallback (HAS_KV_LEN False, full-Lk + loop) that the Qwen path relies on still matches the fp32 reference. + This guards the original behavior the kv_len feature must preserve + for callers that pass neither kv_len nor mask_is_causal.""" + self._run_long_kv_test( + H_q=16, + H_kv=2, + D=256, + Lq=1, + kv_len=256, + buffer_len=256, + garbage=False, + pass_kv_len=False, + ) + + @unittest.skipUnless( + os.environ.get("TQ4_RUN_128K") == "1", + "128k case is heavy for the 24GB CI runner; set TQ4_RUN_128K=1 to run", + ) + def test_kv_len_clamp_128k(self): + """Full 131072-entry buffer (Qwen shape). (a) kv_len=8192 with a + ~123k garbage tail — the clamp keeps decode O(context) and never + touches the tail; (b) kv_len=131072 — correctness at true 128k + scale. Gated behind TQ4_RUN_128K because the fp32 reference for (b) + needs >~6GB and CI runs on a 24GB A10G.""" + self._run_long_kv_test( + H_q=16, H_kv=2, D=256, Lq=1, kv_len=8192, buffer_len=131072 + ) + self._run_long_kv_test( + H_q=16, + H_kv=2, + D=256, + Lq=1, + kv_len=131072, + buffer_len=131072, + garbage=False, + ) + # ------------------------------------------------------------------ # Validation errors # ------------------------------------------------------------------ diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index c68ea086940..e1576f7e446 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -77,6 +77,7 @@ def _tq4_sdpa_fwd_kernel_body( LUT_lo_ptr, Mask_ptr, O_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -109,6 +110,8 @@ def _tq4_sdpa_fwd_kernel_body( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, + MASK_IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, @@ -167,9 +170,37 @@ def _tq4_sdpa_fwd_kernel_body( offs_n_init = tl.arange(0, BLOCK_N) - for start_n in tl.range(0, Lk, BLOCK_N): + # Bound the KV loop to the number of valid (filled) positions instead of the + # full pre-allocated buffer Lk. For decode this is input_pos+1; for a prefill + # chunk it is chunk_end. This makes the global-layer attention O(context) + # rather than O(max_seq_len) — the empty tail of the cache is never touched. + # kv_len is read from a GPU scalar so the bound updates across CUDA-graph + # replays (decode is graph-captured). When not provided (HAS_KV_LEN False, + # e.g. qwen) it falls back to Lk, preserving the original behavior exactly. + if HAS_KV_LEN: + kv_len = tl.load(KV_LEN_ptr) + else: + kv_len = Lk + + # Per-tile causal upper bound (prefill). With a causal attn_mask, the rows in + # this tile attend only up to their own absolute position; the largest such + # position is (kv_len - Lq) + max(seq_pos) — seq_pos is the query-row index, + # and the (kv_len - Lq) offset converts it to an absolute KV position (so it + # is correct for chunked prefill, not just the first chunk). KV blocks that + # start beyond it are fully masked, so we stop the loop there. This is the + # prefill analogue of the kv_len decode clamp and ~halves the causal-triangle + # work. For decode (Lq=1, max(seq_pos)=0) this evaluates to kv_len, so decode + # is byte-identical. Applied only when MASK_IS_CAUSAL (the caller guarantees a + # causal mask, e.g. Gemma's full-attention layers); otherwise the full kv_len + # bound is kept, which is safe for an arbitrary mask. + loop_end = kv_len + if MASK_IS_CAUSAL: + max_q_pos = (kv_len - Lq) + tl.max(seq_pos) + loop_end = tl.minimum(kv_len, max_q_pos + 1) + + for start_n in tl.range(0, loop_end, BLOCK_N): offs_n = start_n + offs_n_init - kv_valid = offs_n < Lk + kv_valid = offs_n < kv_len # -- K decompression (LUT, no norm multiply on [N,D] tile) -- kp_ptrs = ( @@ -277,6 +308,7 @@ def _tq4_sdpa_fwd_kernel_m64( LUT_lo_ptr, Mask_ptr, O_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -309,6 +341,8 @@ def _tq4_sdpa_fwd_kernel_m64( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, + MASK_IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, HALF_D: tl.constexpr, NUM_GROUPS: tl.constexpr, @@ -326,6 +360,7 @@ def _tq4_sdpa_fwd_kernel_m64( LUT_lo_ptr, Mask_ptr, O_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -358,6 +393,8 @@ def _tq4_sdpa_fwd_kernel_m64( sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=IS_CAUSAL, + HAS_KV_LEN=HAS_KV_LEN, + MASK_IS_CAUSAL=MASK_IS_CAUSAL, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, @@ -387,6 +424,7 @@ def _tq4_sdpa_fwd_kernel_m32( LUT_lo_ptr, Mask_ptr, O_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -419,6 +457,8 @@ def _tq4_sdpa_fwd_kernel_m32( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, + MASK_IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, HALF_D: tl.constexpr, NUM_GROUPS: tl.constexpr, @@ -436,6 +476,7 @@ def _tq4_sdpa_fwd_kernel_m32( LUT_lo_ptr, Mask_ptr, O_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -468,6 +509,8 @@ def _tq4_sdpa_fwd_kernel_m32( sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=IS_CAUSAL, + HAS_KV_LEN=HAS_KV_LEN, + MASK_IS_CAUSAL=MASK_IS_CAUSAL, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, @@ -492,6 +535,7 @@ def _launch_tq4_kernel( lut_lo, mask_ptr, out_rot, + kv_len_ptr, B, H_Q, H_KV, @@ -500,6 +544,8 @@ def _launch_tq4_kernel( D, sm_scale, HAS_MASK, + HAS_KV_LEN, + MASK_IS_CAUSAL, stride_mb, stride_mq, stride_mk, @@ -537,6 +583,7 @@ def grid(meta): lut_lo, mask_ptr if HAS_MASK else 0, out_rot, + kv_len_ptr if HAS_KV_LEN else 0, B, H_grid, L_Q, @@ -553,6 +600,8 @@ def grid(meta): sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + HAS_KV_LEN=HAS_KV_LEN, + MASK_IS_CAUSAL=MASK_IS_CAUSAL, HEAD_DIM=D, HALF_D=HALF_D, NUM_GROUPS=num_groups, @@ -641,6 +690,8 @@ def tq4_sdpa( attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, + kv_len: Optional[torch.Tensor] = None, + mask_is_causal: bool = False, ) -> torch.Tensor: """Fused TQ4 SDPA over nibble-packed compressed K/V cache. @@ -665,6 +716,22 @@ def tq4_sdpa( ``1/sqrt(HEAD_DIM)`` when ``None``. Models that handle their own normalization (e.g. Gemma 4 with QK-norm uses ``1.0``) should pass an explicit value. + kv_len: Optional GPU int scalar = number of valid (filled) KV + positions. When provided, the inner KV loop is bounded to + ``kv_len`` instead of the full pre-allocated ``L_KV``, making + attention O(context) instead of O(max_seq_len). It is read on + the device (no host sync) so the bound updates correctly under + CUDA-graph replay (decode). For decode pass ``input_pos + 1``; + for a prefill chunk pass ``chunk_end``. When ``None`` the loop + runs over the full ``L_KV`` (original behavior). + mask_is_causal: Set True only when ``attn_mask`` is a standard + causal mask (row at absolute position p attends to [0, p]). + Enables a per-tile causal upper bound that skips KV blocks past + the tile's last query position — the prefill analogue of the + ``kv_len`` clamp, ~halving prefill (causal-triangle) work. It is + a no-op for decode (L_Q=1) and byte-identical there. Leave False + (default) for any non-causal mask; the kernel keeps the full + ``kv_len`` bound, which is correct for an arbitrary mask. Returns: [B, H_Q, L_Q, D] bf16 attention output @@ -679,6 +746,21 @@ def tq4_sdpa( sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale) num_groups = H_Q // H_KV + HAS_KV_LEN = kv_len is not None + if HAS_KV_LEN: + # Device int32 scalar, clamped to the buffer size for OOB safety. + # Reshaped to [1] so the kernel can ``tl.load`` element 0. No + # ``.item()`` — keeps it CUDA-graph-safe (value updates on replay). + kv_len_t = torch.clamp( + kv_len.reshape(1).to(torch.int32), max=int(N_KV) + ).contiguous() + else: + kv_len_t = None + + # The per-tile causal upper bound needs kv_len to convert query-row indices + # to absolute KV positions, so it is only meaningful when kv_len is supplied. + MASK_IS_CAUSAL = bool(mask_is_causal) and HAS_KV_LEN + # Build [256] bf16 lookup tables from [16] centroids. # In the export path, inductor fuses this into the compiled graph. all_bytes = torch.arange(256, device=centroids.device) @@ -726,6 +808,7 @@ def tq4_sdpa( lut_lo, mask_ptr, out_rot, + kv_len_t, B, H_Q, H_KV, @@ -734,6 +817,8 @@ def tq4_sdpa( D, sm_scale, HAS_MASK, + HAS_KV_LEN, + MASK_IS_CAUSAL, stride_mb, stride_mq, stride_mk, @@ -758,5 +843,7 @@ def _tq4_sdpa_fake( attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, + kv_len: Optional[torch.Tensor] = None, + mask_is_causal: bool = False, ) -> torch.Tensor: return torch.empty_like(query) diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index ae3bcb24c19..482f64083a0 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -93,14 +93,31 @@ method with dynamic sequence length and host-side sampling. Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`. -#### TurboQuant KV cache (long context, MLX only) +#### TurboQuant KV cache (long context, CUDA + MLX) For long-context inference, add `--turboquant` to swap the full-attention layers' KV cache for a TurboQuant TQ4 cache (4-bit codebook + nibble pack). This gives ~3.8× cache memory savings on the full-attention layers and lets -you fit context lengths that wouldn't fit in bf16. Sliding-window layers are unaffected. +you fit context lengths that wouldn't fit in bf16. Sliding-window layers are +unaffected. Supported on both the CUDA and MLX backends. + +**Long context requires BOTH flags**: `--turboquant` *and* a larger +`--max-seq-len`. Raising `--max-seq-len` alone keeps a bf16 KV cache, which does +not fit at long context. On CUDA, `--turboquant` is what enables 128k: Gemma4-31B +at `--max-seq-len 131072` runs within ~27 GiB at runtime (fits a 32 GB card). + +```bash +# CUDA — 128k context (TQ4 KV) +python examples/models/gemma4_31b/export.py \ + --gguf ./gemma-4-31B-it-Q4_K_M.gguf \ + --output-dir ./gemma4_31b_exports_128k \ + --max-seq-len 131072 \ + --backend cuda \ + --turboquant +``` ```bash +# MLX (Apple Silicon) python examples/models/gemma4_31b/export.py \ --prequantized ./gemma4_31b_int4 \ --output-dir ./gemma4_31b_exports_mlx_tq \ diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index aeafd97f74e..c16498d02b7 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -77,6 +77,17 @@ def _turboquant_attention_forward( # uncompressed K/V is never materialized. k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v) + # Number of valid (filled) KV positions = input_pos[0] + T. Passing this to + # tq4_sdpa bounds its KV loop to the actual context instead of the full + # pre-allocated buffer (max_seq_len for global layers), making attention + # O(context) instead of O(max_seq_len). Kept as a GPU scalar (no ``.item()``) + # so the bound is captured correctly by the decode CUDA graph. Decode: T=1 -> + # input_pos+1; prefill chunk: T -> chunk_end. + # NOTE: this call-site argument was dropped during a rebase, which silently + # disabled the O(context) bound and forced a full max_seq_len sweep every + # step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded). + kv_len = input_pos[0] + input_pos.shape[0] + # ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's # default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the # 1/sqrt(d) factor into trained weights. @@ -91,6 +102,7 @@ def _turboquant_attention_forward( attn_mask, False, # is_causal — attn_mask already encodes causal masking self.scaling, + kv_len, ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) diff --git a/extension/llm/modules/turboquant/kv_cache.py b/extension/llm/modules/turboquant/kv_cache.py index 12c01721a15..684f763b44e 100644 --- a/extension/llm/modules/turboquant/kv_cache.py +++ b/extension/llm/modules/turboquant/kv_cache.py @@ -158,9 +158,13 @@ def update(self, input_pos, k_val, v_val): k_packed, k_norms = self._compress(k_val) v_packed, v_norms = self._compress(v_val) - self.k_packed[:, :, input_pos] = k_packed - self.k_norms[:, :, input_pos] = k_norms - self.v_packed[:, :, input_pos] = v_packed - self.v_norms[:, :, input_pos] = v_norms + # index_copy_ (not self.x[:, :, input_pos] = ...) keeps the decode + # write CUDA-graph-capturable: a static scatter along the position + # dim, matching the model's flat global KV cache. Plain index + # assignment lowers to index_put_, which breaks cuda_graph capture. + self.k_packed.index_copy_(2, input_pos, k_packed) + self.k_norms.index_copy_(2, input_pos, k_norms) + self.v_packed.index_copy_(2, input_pos, v_packed) + self.v_norms.index_copy_(2, input_pos, v_norms) return self.k_packed, self.k_norms, self.v_packed, self.v_norms