Skip to content

Add chunk-parallel gated delta ops for training#1389

Open
tsato081 wants to merge 2 commits into
ml-explore:mainfrom
tsato081:chunked-gdn-training
Open

Add chunk-parallel gated delta ops for training#1389
tsato081 wants to merge 2 commits into
ml-explore:mainfrom
tsato081:chunked-gdn-training

Conversation

@tsato081

@tsato081 tsato081 commented Jun 10, 2026

Copy link
Copy Markdown

Problem

In training mode (use_kernel=False), gated_delta_update falls back to a sequential per-timestep loop. Autodiff keeps all T intermediate states alive, so Qwen3.5 / Qwen3-Next fine-tuning aborts with a Metal command-buffer OOM on the first backward pass (#1206, earlier #482). #997 noted the impact on finetuning.

What this does

Replaces that fallback with the standard chunked formulation of the gated delta rule, the same algorithm the CUDA implementations use (flash-linear-attention; Yang et al.; Yang & Kautz):

  • Each 64-step chunk becomes one unit-lower-triangular solve computed with dense matmuls.
  • Each chunk is wrapped in mx.checkpoint, shrinking the autodiff graph from O(T) to O(T / 64) nodes.
  • gated_delta_update's signature, all call sites, and the inference kernel are unchanged. Vectorized gating (Kimi Linear) and single-token steps keep the sequential path.

Why this approach

Correctness

  • The chunked form is exact, not an approximation: the per-chunk system matrix is nilpotent, so the solve is algebraically identical to the recurrence.
  • The solve runs in fp32 with blocked forward substitution; a global Neumann expansion overflows fp32 on repeated keys. Adversarial tests cover this, including a collinear beta -> 1 case that guards the solve block size (independent verification in the Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon) #1217 thread showed SUB_BLOCK=32 blows up while 16 degrades gracefully).
  • New tests compare outputs, final state, and gradients against gated_delta_ops across multi-chunk, padding, GQA, carried-in state, masks, decay extremes, and repeated keys. Final-state max diff vs the fp32 sequential reference is ~1e-6 at T=256.
  • Known benign deviation: below the 1e-12 decay clamp, the gradient w.r.t. g is zeroed where the sequential reference returns a nonzero value. Harmless in practice (the model's parameterization scales this gradient by g itself); flagged during independent verification in the Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon) #1217 thread by @SudarkinV.

Performance

Single linear-attention layer, fwd+bwd at T=2048, bf16 inputs / fp32 state, median of 10, M3 Ultra:

Shape ops (main) this PR
Qwen3.5-9B class (Hk16 Hv64 Dk192 Dv128) 7.043 s / 64.0 GB 0.117 s / 3.0 GB
Qwen3.5-35B-A3B (Hk16 Hv32 Dk128 Dv128) 1.149 s / 24.6 GB 0.071 s / 1.2 GB

End-to-end, a 512-step DWQ distillation of Qwen3.5-35B-A3B (seq 2048, 30 GDN layers) ran at 12.6 s/step on a workload where main hits the OOM from #1206.

Limitations

  • Vectorized (per-channel) gating still uses the sequential loop; the chunked math extends to it but is left as a follow-up.
  • Decode (T == 1) and the inference Metal kernel are untouched.

@SudarkinV

Copy link
Copy Markdown

Here's the GQA patch against current #1389 (1edaa42) — a single-file change to gated_delta_ops_chunked / _gated_delta_chunk.

The idea: the k @ k^T Gram and q @ k^T only depend on the Hk query/key heads, so instead of mx.repeat-ing q/k up to Hv up front, the chunk computes those two C×C matmuls on Hk and broadcasts to Hv internally; the per-Hv gating is applied after. Public signature is unchanged — only the private _gated_delta_chunk gains a repeat_factor arg.

On an M4 Max (Qwen3.5-9B linear-attn shape, fwd+bwd, bf16): +20% at T=2048, +34% at T=4096, and peak memory ~37% lower at T=4096 since the repeated q/k no longer sit in the autodiff graph. Stays numerically identical to the current path (rel ~1e-7 across rf=4/2/1, padding, masks, batch, carried state) and all gated_delta_* tests pass.

Patch below — happy to open it as a PR against your branch instead if that's easier to fold.

diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py
index 4e268cc..06c0802 100644
--- a/mlx_lm/models/gated_delta.py
+++ b/mlx_lm/models/gated_delta.py
@@ -301,18 +301,27 @@ def _solve_strict_lower(A: mx.array, b: mx.array, sb: int = SUB_BLOCK) -> mx.arr
 
 
 def _gated_delta_chunk(
-    state: mx.array,  # [B, H, Dk, Dv]
-    q: mx.array,  # [B, H, C, Dk]
-    k: mx.array,  # [B, H, C, Dk]
-    v: mx.array,  # [B, H, C, Dv]
-    g: mx.array,  # [B, H, C], gating in (0, 1)
-    beta: mx.array,  # [B, H, C]
+    state: mx.array,  # [B, Hv, Dk, Dv]
+    q: mx.array,  # [B, Hk, C, Dk]
+    k: mx.array,  # [B, Hk, C, Dk]
+    v: mx.array,  # [B, Hv, C, Dv]
+    g: mx.array,  # [B, Hv, C], gating in (0, 1)
+    beta: mx.array,  # [B, Hv, C]
+    repeat_factor: int = 1,
 ) -> Tuple[mx.array, mx.array]:
     """Run C timesteps as one triangular solve (gated UT/WY transform).
 
     Exact reformulation of the sequential recurrence; runs in fp32.
+
+    Under grouped-query attention (repeat_factor > 1) the C x C key Gram
+    and q . k^T products depend only on the Hk query/key heads, so they
+    are formed before the GQA broadcast to Hv; the per-Hv gating is
+    applied afterwards. This avoids materializing the repeated q/k for the
+    whole sequence and shrinks those two matmuls by repeat_factor.
     """
     C = q.shape[2]
+    Hk = q.shape[1]
+    Hv = v.shape[1]
     orig_dtype = q.dtype
 
     q = q.astype(mx.float32)
@@ -323,7 +332,7 @@ def _gated_delta_chunk(
     state = state.astype(mx.float32)
 
     # Log-domain cumulative decay; the clamp keeps -inf out of the cumsum.
-    g_log = mx.log(mx.maximum(g, 1e-12))  # [B, H, C]
+    g_log = mx.log(mx.maximum(g, 1e-12))  # [B, Hv, C]
     g_cumlog = mx.cumsum(g_log, axis=-1)
     g_last = g_cumlog[..., -1:]
 
@@ -332,21 +341,40 @@ def _gated_delta_chunk(
     L_diff = (g_cumlog[..., :, None] - g_cumlog[..., None, :]) * tril_ones
     L_mask = mx.exp(L_diff) * tril_ones
 
-    v_beta = v * beta[..., None]  # [B, H, C, Dv]
-    k_beta = k * beta[..., None]  # [B, H, C, Dk]
+    # GQA sharing: the C x C products only need the Hk heads. Form them
+    # first, then broadcast q/k and the products to Hv.
+    kkT = k @ mx.swapaxes(k, -1, -2)  # [B, Hk, C, C], kkT[i, j] = <k_i, k_j>
+    qkT = q @ mx.swapaxes(k, -1, -2)  # [B, Hk, C, C]
+    if repeat_factor > 1:
+        B = q.shape[0]
+
+        def _to_hv(x):  # [B, Hk, C, D] -> [B, Hv, C, D]
+            D = x.shape[-1]
+            return mx.broadcast_to(
+                x[:, :, None], (B, Hk, repeat_factor, C, D)
+            ).reshape(B, Hv, C, D)
+
+        kkT = _to_hv(kkT)
+        qkT = _to_hv(qkT)
+        q = _to_hv(q)
+        k = _to_hv(k)
+
+    v_beta = v * beta[..., None]  # [B, Hv, C, Dv]
+    k_beta = k * beta[..., None]  # [B, Hv, C, Dk]
 
+    # (k_beta @ k^T)[i, j] = beta_i * <k_i, k_j> = beta_i * kkT[i, j]
     strict_lower = mx.tril(mx.ones((C, C), dtype=mx.float32), k=-1)
-    A = -(k_beta @ mx.swapaxes(k, -1, -2)) * L_mask * strict_lower
+    A = -(beta[..., :, None] * kkT) * L_mask * strict_lower
 
-    decay_exp = mx.exp(g_cumlog)[..., None]  # [B, H, C, 1]
+    decay_exp = mx.exp(g_cumlog)[..., None]  # [B, Hv, C, 1]
     rhs = mx.concatenate([v_beta, k_beta * decay_exp], axis=-1)
     sol = _solve_strict_lower(A, rhs)
     v_corrected, k_cumdecay = mx.split(sol, [v.shape[-1]], axis=-1)
 
-    v_new = v_corrected - k_cumdecay @ state  # [B, H, C, Dv]
-    y_inter = (q * decay_exp) @ state  # [B, H, C, Dv]
+    v_new = v_corrected - k_cumdecay @ state  # [B, Hv, C, Dv]
+    y_inter = (q * decay_exp) @ state  # [B, Hv, C, Dv]
 
-    attn = (q @ mx.swapaxes(k, -1, -2)) * L_mask
+    attn = qkT * L_mask
     y = y_inter + attn @ v_new
 
     state_decay = mx.exp(g_last)[..., None]
@@ -390,13 +418,13 @@ def gated_delta_ops_chunked(
     B, T, Hk, Dk = q.shape
     Hv, Dv = v.shape[-2:]
     C = chunk_size or CHUNK_SIZE
+    repeat_factor = Hv // Hk
 
     if state is None:
         state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)
 
-    if (repeat_factor := Hv // Hk) > 1:
-        q = mx.repeat(q, repeat_factor, -2)
-        k = mx.repeat(k, repeat_factor, -2)
+    # q/k stay on the Hk query/key heads; the chunk shares their C x C
+    # products across the GQA group and broadcasts to Hv internally.
 
     # Masked steps are identities on the state (g = 1, beta = 0).
     if mask is not None:
@@ -415,9 +443,9 @@ def gated_delta_ops_chunked(
 
     num_chunks = (T + pad_len) // C
 
-    # [B, T, H, D] -> [B, H, Nc, C, D]
-    q = mx.swapaxes(q, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
-    k = mx.swapaxes(k, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
+    # [B, T, H, D] -> [B, H, Nc, C, D]; q/k keep their Hk heads.
+    q = mx.swapaxes(q, 1, 2).reshape(B, Hk, num_chunks, C, Dk)
+    k = mx.swapaxes(k, 1, 2).reshape(B, Hk, num_chunks, C, Dk)
     v = mx.swapaxes(v, 1, 2).reshape(B, Hv, num_chunks, C, Dv)
     g = mx.swapaxes(g, 1, 2).reshape(B, Hv, num_chunks, C)
     beta = mx.swapaxes(beta, 1, 2).reshape(B, Hv, num_chunks, C)
@@ -434,6 +462,7 @@ def gated_delta_ops_chunked(
             v[:, :, ci],
             g[:, :, ci],
             beta[:, :, ci],
+            repeat_factor,
         )
         ys.append(y_c)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants