Add chunk-parallel gated delta ops for training#1389
Conversation
|
Here's the GQA patch against current #1389 (1edaa42) — a single-file change to The idea: the On an M4 Max (Qwen3.5-9B linear-attn shape, fwd+bwd, bf16): +20% at T=2048, +34% at T=4096, and peak memory ~37% lower at T=4096 since the repeated q/k no longer sit in the autodiff graph. Stays numerically identical to the current path (rel ~1e-7 across rf=4/2/1, padding, masks, batch, carried state) and all Patch below — happy to open it as a PR against your branch instead if that's easier to fold. diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py
index 4e268cc..06c0802 100644
--- a/mlx_lm/models/gated_delta.py
+++ b/mlx_lm/models/gated_delta.py
@@ -301,18 +301,27 @@ def _solve_strict_lower(A: mx.array, b: mx.array, sb: int = SUB_BLOCK) -> mx.arr
def _gated_delta_chunk(
- state: mx.array, # [B, H, Dk, Dv]
- q: mx.array, # [B, H, C, Dk]
- k: mx.array, # [B, H, C, Dk]
- v: mx.array, # [B, H, C, Dv]
- g: mx.array, # [B, H, C], gating in (0, 1)
- beta: mx.array, # [B, H, C]
+ state: mx.array, # [B, Hv, Dk, Dv]
+ q: mx.array, # [B, Hk, C, Dk]
+ k: mx.array, # [B, Hk, C, Dk]
+ v: mx.array, # [B, Hv, C, Dv]
+ g: mx.array, # [B, Hv, C], gating in (0, 1)
+ beta: mx.array, # [B, Hv, C]
+ repeat_factor: int = 1,
) -> Tuple[mx.array, mx.array]:
"""Run C timesteps as one triangular solve (gated UT/WY transform).
Exact reformulation of the sequential recurrence; runs in fp32.
+
+ Under grouped-query attention (repeat_factor > 1) the C x C key Gram
+ and q . k^T products depend only on the Hk query/key heads, so they
+ are formed before the GQA broadcast to Hv; the per-Hv gating is
+ applied afterwards. This avoids materializing the repeated q/k for the
+ whole sequence and shrinks those two matmuls by repeat_factor.
"""
C = q.shape[2]
+ Hk = q.shape[1]
+ Hv = v.shape[1]
orig_dtype = q.dtype
q = q.astype(mx.float32)
@@ -323,7 +332,7 @@ def _gated_delta_chunk(
state = state.astype(mx.float32)
# Log-domain cumulative decay; the clamp keeps -inf out of the cumsum.
- g_log = mx.log(mx.maximum(g, 1e-12)) # [B, H, C]
+ g_log = mx.log(mx.maximum(g, 1e-12)) # [B, Hv, C]
g_cumlog = mx.cumsum(g_log, axis=-1)
g_last = g_cumlog[..., -1:]
@@ -332,21 +341,40 @@ def _gated_delta_chunk(
L_diff = (g_cumlog[..., :, None] - g_cumlog[..., None, :]) * tril_ones
L_mask = mx.exp(L_diff) * tril_ones
- v_beta = v * beta[..., None] # [B, H, C, Dv]
- k_beta = k * beta[..., None] # [B, H, C, Dk]
+ # GQA sharing: the C x C products only need the Hk heads. Form them
+ # first, then broadcast q/k and the products to Hv.
+ kkT = k @ mx.swapaxes(k, -1, -2) # [B, Hk, C, C], kkT[i, j] = <k_i, k_j>
+ qkT = q @ mx.swapaxes(k, -1, -2) # [B, Hk, C, C]
+ if repeat_factor > 1:
+ B = q.shape[0]
+
+ def _to_hv(x): # [B, Hk, C, D] -> [B, Hv, C, D]
+ D = x.shape[-1]
+ return mx.broadcast_to(
+ x[:, :, None], (B, Hk, repeat_factor, C, D)
+ ).reshape(B, Hv, C, D)
+
+ kkT = _to_hv(kkT)
+ qkT = _to_hv(qkT)
+ q = _to_hv(q)
+ k = _to_hv(k)
+
+ v_beta = v * beta[..., None] # [B, Hv, C, Dv]
+ k_beta = k * beta[..., None] # [B, Hv, C, Dk]
+ # (k_beta @ k^T)[i, j] = beta_i * <k_i, k_j> = beta_i * kkT[i, j]
strict_lower = mx.tril(mx.ones((C, C), dtype=mx.float32), k=-1)
- A = -(k_beta @ mx.swapaxes(k, -1, -2)) * L_mask * strict_lower
+ A = -(beta[..., :, None] * kkT) * L_mask * strict_lower
- decay_exp = mx.exp(g_cumlog)[..., None] # [B, H, C, 1]
+ decay_exp = mx.exp(g_cumlog)[..., None] # [B, Hv, C, 1]
rhs = mx.concatenate([v_beta, k_beta * decay_exp], axis=-1)
sol = _solve_strict_lower(A, rhs)
v_corrected, k_cumdecay = mx.split(sol, [v.shape[-1]], axis=-1)
- v_new = v_corrected - k_cumdecay @ state # [B, H, C, Dv]
- y_inter = (q * decay_exp) @ state # [B, H, C, Dv]
+ v_new = v_corrected - k_cumdecay @ state # [B, Hv, C, Dv]
+ y_inter = (q * decay_exp) @ state # [B, Hv, C, Dv]
- attn = (q @ mx.swapaxes(k, -1, -2)) * L_mask
+ attn = qkT * L_mask
y = y_inter + attn @ v_new
state_decay = mx.exp(g_last)[..., None]
@@ -390,13 +418,13 @@ def gated_delta_ops_chunked(
B, T, Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]
C = chunk_size or CHUNK_SIZE
+ repeat_factor = Hv // Hk
if state is None:
state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)
- if (repeat_factor := Hv // Hk) > 1:
- q = mx.repeat(q, repeat_factor, -2)
- k = mx.repeat(k, repeat_factor, -2)
+ # q/k stay on the Hk query/key heads; the chunk shares their C x C
+ # products across the GQA group and broadcasts to Hv internally.
# Masked steps are identities on the state (g = 1, beta = 0).
if mask is not None:
@@ -415,9 +443,9 @@ def gated_delta_ops_chunked(
num_chunks = (T + pad_len) // C
- # [B, T, H, D] -> [B, H, Nc, C, D]
- q = mx.swapaxes(q, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
- k = mx.swapaxes(k, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
+ # [B, T, H, D] -> [B, H, Nc, C, D]; q/k keep their Hk heads.
+ q = mx.swapaxes(q, 1, 2).reshape(B, Hk, num_chunks, C, Dk)
+ k = mx.swapaxes(k, 1, 2).reshape(B, Hk, num_chunks, C, Dk)
v = mx.swapaxes(v, 1, 2).reshape(B, Hv, num_chunks, C, Dv)
g = mx.swapaxes(g, 1, 2).reshape(B, Hv, num_chunks, C)
beta = mx.swapaxes(beta, 1, 2).reshape(B, Hv, num_chunks, C)
@@ -434,6 +462,7 @@ def gated_delta_ops_chunked(
v[:, :, ci],
g[:, :, ci],
beta[:, :, ci],
+ repeat_factor,
)
ys.append(y_c) |
Problem
In training mode (
use_kernel=False),gated_delta_updatefalls back to a sequential per-timestep loop. Autodiff keeps allTintermediate states alive, so Qwen3.5 / Qwen3-Next fine-tuning aborts with a Metal command-buffer OOM on the first backward pass (#1206, earlier #482). #997 noted the impact on finetuning.What this does
Replaces that fallback with the standard chunked formulation of the gated delta rule, the same algorithm the CUDA implementations use (flash-linear-attention; Yang et al.; Yang & Kautz):
mx.checkpoint, shrinking the autodiff graph fromO(T)toO(T / 64)nodes.gated_delta_update's signature, all call sites, and the inference kernel are unchanged. Vectorized gating (Kimi Linear) and single-token steps keep the sequential path.Why this approach
Dk % 32 == 0) at ~2.8x higher peak memory; its own routing falls back to a sequential path everywhere else, and this PR is a strictly better fallback beneath it if it lands. This PR fixes the default path on its own either way.Correctness
beta -> 1case that guards the solve block size (independent verification in the Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon) #1217 thread showedSUB_BLOCK=32blows up while 16 degrades gracefully).gated_delta_opsacross multi-chunk, padding, GQA, carried-in state, masks, decay extremes, and repeated keys. Final-state max diff vs the fp32 sequential reference is ~1e-6 at T=256.1e-12decay clamp, the gradient w.r.t.gis zeroed where the sequential reference returns a nonzero value. Harmless in practice (the model's parameterization scales this gradient bygitself); flagged during independent verification in the Add Metal VJP kernel for gated_delta_update (trainable Qwen3.5 / Qwen3-Next LoRA on Apple Silicon) #1217 thread by @SudarkinV.Performance
Single linear-attention layer, fwd+bwd at T=2048, bf16 inputs / fp32 state, median of 10, M3 Ultra:
End-to-end, a 512-step DWQ distillation of Qwen3.5-35B-A3B (seq 2048, 30 GDN layers) ran at 12.6 s/step on a workload where
mainhits the OOM from #1206.Limitations
T == 1) and the inference Metal kernel are untouched.