Skip to content

Commit 73f61cf

Browse files
authored
[Cuda] enable turboquant on gemma4 (#19891)
1 parent 2b41021 commit 73f61cf

3 files changed

Lines changed: 160 additions & 14 deletions

File tree

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ def tq4_sdpa(
640640
rotation: torch.Tensor,
641641
attn_mask: Optional[torch.Tensor] = None,
642642
is_causal: bool = False,
643+
scale: Optional[float] = None,
643644
) -> torch.Tensor:
644645
"""Fused TQ4 SDPA over nibble-packed compressed K/V cache.
645646
@@ -660,6 +661,10 @@ def tq4_sdpa(
660661
rotation: [D, D] orthogonal rotation matrix
661662
attn_mask: Optional [B, 1, L_Q, L_KV] bool mask
662663
is_causal: apply causal masking (requires L_Q == L_KV)
664+
scale: softmax scale applied to ``Q @ K^T``. Defaults to
665+
``1/sqrt(HEAD_DIM)`` when ``None``. Models that handle their
666+
own normalization (e.g. Gemma 4 with QK-norm uses ``1.0``)
667+
should pass an explicit value.
663668
664669
Returns:
665670
[B, H_Q, L_Q, D] bf16 attention output
@@ -671,7 +676,7 @@ def tq4_sdpa(
671676

672677
_validate_tq4_mask(attn_mask, B, N_Q, N_KV)
673678

674-
sm_scale = 1.0 / math.sqrt(D)
679+
sm_scale = float(1.0 / math.sqrt(D)) if scale is None else float(scale)
675680
num_groups = H_Q // H_KV
676681

677682
# Build [256] bf16 lookup tables from [16] centroids.
@@ -752,5 +757,6 @@ def _tq4_sdpa_fake(
752757
rotation: torch.Tensor,
753758
attn_mask: Optional[torch.Tensor] = None,
754759
is_causal: bool = False,
760+
scale: Optional[float] = None,
755761
) -> torch.Tensor:
756762
return torch.empty_like(query)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""CUDA source transformations for Gemma 4 31B-IT.
8+
9+
Currently only adds optional TurboQuant TQ4 KV cache compression for
10+
full-attention layers, leaving sliding-window layers untouched. When
11+
``use_turboquant=True`` is passed:
12+
13+
- ``Gemma4Attention.kv_cache`` is replaced with
14+
``extension.llm.modules.turboquant.TurboQuantKVCache`` on every
15+
full-attention layer (sliding layers keep their ``RingKVCache``).
16+
- The attention forward is monkey-patched to call
17+
``torch.ops.triton.tq4_sdpa`` (the fused TQ4 attention kernel) instead
18+
of ``F.scaled_dot_product_attention``.
19+
20+
The model file (``model.py``) stays backend-agnostic — all CUDA
21+
TurboQuant specifics live here.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import types
27+
28+
# Importing this module registers ``torch.ops.triton.tq4_sdpa``.
29+
import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401
30+
31+
import torch
32+
import torch.nn as nn
33+
34+
from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb
35+
from executorch.extension.llm.modules.turboquant import TurboQuantKVCache
36+
37+
38+
def _turboquant_attention_forward(
39+
self,
40+
x: torch.Tensor,
41+
input_pos: torch.Tensor,
42+
attn_mask: torch.Tensor,
43+
) -> torch.Tensor:
44+
"""Drop-in replacement for ``Gemma4Attention.forward`` that uses
45+
``torch.ops.triton.tq4_sdpa`` over a ``TurboQuantKVCache``.
46+
47+
Mirrors the default forward up to (and including) RoPE; only the
48+
cache update and SDPA call differ.
49+
"""
50+
B, T, _ = x.shape
51+
52+
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
53+
raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim)
54+
if self.k_eq_v:
55+
raw_v = raw_k
56+
else:
57+
raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim)
58+
59+
q = self.q_norm(q)
60+
k = self.k_norm(raw_k)
61+
v = self.v_norm(raw_v)
62+
63+
# (B, H, T, D) for SDPA / KV cache.
64+
q = q.transpose(1, 2)
65+
k = k.transpose(1, 2)
66+
v = v.transpose(1, 2)
67+
68+
# RoPE: same code path as default forward.
69+
freqs = torch.outer(input_pos.float(), self.inv_freq)
70+
emb = torch.cat((freqs, freqs), dim=-1)
71+
cos = torch.cos(emb)
72+
sin = torch.sin(emb)
73+
q, k = apply_rotary_emb(q, k, cos, sin)
74+
75+
# Compress + write. Returns the full compressed cache tensors —
76+
# tq4_sdpa decompresses per tile in its inner loop, so the full
77+
# uncompressed K/V is never materialized.
78+
k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v)
79+
80+
# ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's
81+
# default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the
82+
# 1/sqrt(d) factor into trained weights.
83+
y = torch.ops.triton.tq4_sdpa(
84+
q,
85+
k_packed,
86+
k_norms,
87+
v_packed,
88+
v_norms,
89+
self.kv_cache.centroids,
90+
self.kv_cache.rotation,
91+
attn_mask,
92+
False, # is_causal — attn_mask already encodes causal masking
93+
self.scaling,
94+
)
95+
96+
y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
97+
return self.o_proj(y)
98+
99+
100+
def cuda_source_transformations(
101+
model: nn.Module,
102+
*,
103+
use_turboquant: bool = False,
104+
) -> None:
105+
"""Apply CUDA source transformations to a Gemma 4 31B model in place.
106+
107+
Args:
108+
model: ``Gemma4_31B`` instance to transform.
109+
use_turboquant: When True, swap full-attention layers' KV caches
110+
for the backend-agnostic ``TurboQuantKVCache`` (~3.8× cache
111+
memory savings) and route their SDPA through
112+
``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
113+
unaffected.
114+
"""
115+
if not use_turboquant:
116+
return
117+
118+
config = model.config
119+
n_swapped = 0
120+
for layer in model.layers:
121+
attn = layer.self_attn
122+
if attn.is_sliding:
123+
continue
124+
attn.kv_cache = TurboQuantKVCache(
125+
n_heads=attn.n_kv_heads,
126+
head_dim=attn.head_dim,
127+
max_seq_len=config.max_seq_len,
128+
)
129+
attn.forward = types.MethodType(_turboquant_attention_forward, attn)
130+
n_swapped += 1
131+
132+
print(
133+
f"[gemma4_31b cuda] TurboQuant: swapped {n_swapped} full-attention "
134+
f"KV caches with TurboQuantKVCache (TQ4)"
135+
)

examples/models/gemma4_31b/export.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,7 @@ def export_and_lower(
144144
) -> None:
145145
"""Export and lower the model to ExecuTorch for the given backend."""
146146
if backend == "cuda":
147-
if use_turboquant:
148-
raise ValueError(
149-
"--turboquant is only supported with --backend mlx "
150-
"(the CUDA path here uses a different TurboQuant integration; "
151-
"see examples/models/qwen3_5_moe/export.py)."
152-
)
153-
_export_cuda(model, config, output_dir)
147+
_export_cuda(model, config, output_dir, use_turboquant=use_turboquant)
154148
elif backend == "mlx":
155149
_export_mlx(model, config, output_dir, use_turboquant=use_turboquant)
156150
else:
@@ -159,7 +153,12 @@ def export_and_lower(
159153
)
160154

161155

162-
def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None:
156+
def _export_cuda(
157+
model: Gemma4_31B,
158+
config: Gemma4_31BConfig,
159+
output_dir: str,
160+
use_turboquant: bool = False,
161+
) -> None:
163162
import gc
164163

165164
import torch._inductor.config as inductor_config
@@ -182,6 +181,13 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
182181

183182
materialize_runtime_buffers(model, dtype=torch.bfloat16)
184183

184+
if use_turboquant:
185+
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
186+
cuda_source_transformations,
187+
)
188+
189+
cuda_source_transformations(model, use_turboquant=True)
190+
185191
# Int4Tensor weights are used directly — no format conversion.
186192
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).
187193
# Both decode and prefill share the same nibble-packed weights.
@@ -443,14 +449,13 @@ def main() -> None:
443449
parser.add_argument(
444450
"--turboquant",
445451
action="store_true",
446-
help="Use TurboQuant TQ4 KV cache compression (MLX backend only). "
447-
"~3.8× cache memory savings; applies only to full-attention "
448-
"(non-sliding) layers — sliding layers keep RingBufferKVCache.",
452+
help="Use TurboQuant TQ4 KV cache compression. ~3.8× cache memory "
453+
"savings; applies only to full-attention (non-sliding) layers — "
454+
"sliding layers keep their default cache. Supported on both "
455+
"--backend mlx and --backend cuda.",
449456
)
450457
args = parser.parse_args()
451458

452-
if args.turboquant and args.backend != "mlx":
453-
parser.error("--turboquant requires --backend mlx.")
454459
if args.backend == "cuda" and not torch.cuda.is_available():
455460
parser.error("CUDA is required for the cuda backend.")
456461

0 commit comments

Comments
 (0)