Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 190 additions & 1 deletion mlx_lm/models/gated_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,192 @@ def gated_delta_ops(
return y, state


CHUNK_SIZE = 64

# Solve block size; bounds fp32 error growth when keys repeat in a chunk.
SUB_BLOCK = 16


def _solve_strict_lower(A: mx.array, b: mx.array, sb: int = SUB_BLOCK) -> mx.array:
"""Solve (I - A) x = b for strictly lower-triangular (nilpotent) A.

Blocked forward substitution; a global Neumann expansion can
overflow fp32 when keys repeat within a chunk.
"""
C = A.shape[-1]

def doubling(Aii, rhs, n):
x = rhs
if n <= 1:
return x
P = Aii
steps = (n - 1).bit_length()
for s in range(steps):
x = x + P @ x
if s != steps - 1:
P = P @ P
return x

if C <= sb:
return doubling(A, b, C)

nb = (C + sb - 1) // sb
blocks = []
for i in range(nb):
lo, hi = i * sb, min((i + 1) * sb, C)
rhs = b[..., lo:hi, :]
if i > 0:
prev = blocks[0] if i == 1 else mx.concatenate(blocks, axis=-2)
rhs = rhs + A[..., lo:hi, :lo] @ prev
blocks.append(doubling(A[..., lo:hi, lo:hi], rhs, hi - lo))
return mx.concatenate(blocks, axis=-2)


def _gated_delta_chunk(
state: mx.array, # [B, H, Dk, Dv]
q: mx.array, # [B, H, C, Dk]
k: mx.array, # [B, H, C, Dk]
v: mx.array, # [B, H, C, Dv]
g: mx.array, # [B, H, C], gating in (0, 1)
beta: mx.array, # [B, H, C]
) -> Tuple[mx.array, mx.array]:
"""Run C timesteps as one triangular solve (gated UT/WY transform).

Exact reformulation of the sequential recurrence; runs in fp32.
"""
C = q.shape[2]
orig_dtype = q.dtype

q = q.astype(mx.float32)
k = k.astype(mx.float32)
v = v.astype(mx.float32)
g = g.astype(mx.float32)
beta = beta.astype(mx.float32)
state = state.astype(mx.float32)

# Log-domain cumulative decay; the clamp keeps -inf out of the cumsum.
g_log = mx.log(mx.maximum(g, 1e-12)) # [B, H, C]
g_cumlog = mx.cumsum(g_log, axis=-1)
g_last = g_cumlog[..., -1:]

# Zero the upper triangle before exp: it overflows, and inf * 0 = NaN.
tril_ones = mx.tril(mx.ones((C, C), dtype=mx.float32))
L_diff = (g_cumlog[..., :, None] - g_cumlog[..., None, :]) * tril_ones
L_mask = mx.exp(L_diff) * tril_ones

v_beta = v * beta[..., None] # [B, H, C, Dv]
k_beta = k * beta[..., None] # [B, H, C, Dk]

strict_lower = mx.tril(mx.ones((C, C), dtype=mx.float32), k=-1)
A = -(k_beta @ mx.swapaxes(k, -1, -2)) * L_mask * strict_lower

decay_exp = mx.exp(g_cumlog)[..., None] # [B, H, C, 1]
rhs = mx.concatenate([v_beta, k_beta * decay_exp], axis=-1)
sol = _solve_strict_lower(A, rhs)
v_corrected, k_cumdecay = mx.split(sol, [v.shape[-1]], axis=-1)

v_new = v_corrected - k_cumdecay @ state # [B, H, C, Dv]
y_inter = (q * decay_exp) @ state # [B, H, C, Dv]

attn = (q @ mx.swapaxes(k, -1, -2)) * L_mask
y = y_inter + attn @ v_new

state_decay = mx.exp(g_last)[..., None]
decay_to_end = mx.exp(g_last - g_cumlog)[..., None]
new_state = state * state_decay + mx.swapaxes(k * decay_to_end, -1, -2) @ v_new

# The state stays fp32 across chunks; casting at boundaries drifts.
return y.astype(orig_dtype), new_state


_gated_delta_chunk_checkpointed = mx.checkpoint(_gated_delta_chunk)


def gated_delta_ops_chunked(
q: mx.array,
k: mx.array,
v: mx.array,
g: mx.array,
beta: mx.array,
state: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
chunk_size: Optional[int] = None,
) -> Tuple[mx.array, mx.array]:
"""
Chunk-parallel implementation of prompt prefill for scalar gating.

Equivalent to gated_delta_ops but processes chunk_size timesteps at a
time with dense matmuls, each chunk wrapped in mx.checkpoint, so the
autodiff graph is O(T / chunk_size) instead of O(T).

Shapes:
- q, k: [B, T, Hk, Dk]
- v: [B, T, Hv, Dv]
- g: [B, T, Hv] (scalar gating only, values in (0, 1))
- beta: [B, T, Hv]
- state: [B, Hv, Dv, Dk]
Returns:
- y: [B, T, Hv, Dv]
- state: [B, Hv, Dv, Dk]
"""
B, T, Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]
C = chunk_size or CHUNK_SIZE

if state is None:
state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

if (repeat_factor := Hv // Hk) > 1:
q = mx.repeat(q, repeat_factor, -2)
k = mx.repeat(k, repeat_factor, -2)

# Masked steps are identities on the state (g = 1, beta = 0).
if mask is not None:
m = mask[..., None]
g = mx.where(m, g, mx.ones_like(g))
beta = beta * m

# Pad T to a multiple of C with identity steps (g = 1, beta = 0).
pad_len = (C - (T % C)) % C
if pad_len > 0:
q = mx.pad(q, [(0, 0), (0, pad_len), (0, 0), (0, 0)])
k = mx.pad(k, [(0, 0), (0, pad_len), (0, 0), (0, 0)])
v = mx.pad(v, [(0, 0), (0, pad_len), (0, 0), (0, 0)])
g = mx.concatenate([g, mx.ones((B, pad_len, Hv), dtype=g.dtype)], axis=1)
beta = mx.pad(beta, [(0, 0), (0, pad_len), (0, 0)])

num_chunks = (T + pad_len) // C

# [B, T, H, D] -> [B, H, Nc, C, D]
q = mx.swapaxes(q, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
k = mx.swapaxes(k, 1, 2).reshape(B, Hv, num_chunks, C, Dk)
v = mx.swapaxes(v, 1, 2).reshape(B, Hv, num_chunks, C, Dv)
g = mx.swapaxes(g, 1, 2).reshape(B, Hv, num_chunks, C)
beta = mx.swapaxes(beta, 1, 2).reshape(B, Hv, num_chunks, C)

# [B, Hv, Dv, Dk] -> [B, Hv, Dk, Dv]
state = mx.swapaxes(state.astype(mx.float32), -1, -2)

ys = []
for ci in range(num_chunks):
y_c, state = _gated_delta_chunk_checkpointed(
state,
q[:, :, ci],
k[:, :, ci],
v[:, :, ci],
g[:, :, ci],
beta[:, :, ci],
)
ys.append(y_c)

y = mx.concatenate(ys, axis=2)
if pad_len > 0:
y = y[:, :, :T, :]
y = mx.swapaxes(y, 1, 2)

return y, mx.swapaxes(state, -1, -2)


def gated_delta_update(
q: mx.array,
k: mx.array,
Expand All @@ -279,5 +465,8 @@ def gated_delta_update(
state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available():
return gated_delta_ops(q, k, v, g, beta, state, mask)
if g.ndim == 4 or q.shape[1] == 1:
# Vectorized gating and single-token steps use the sequential path.
return gated_delta_ops(q, k, v, g, beta, state, mask)
return gated_delta_ops_chunked(q, k, v, g, beta, state, mask)
return gated_delta_kernel(q, k, v, g, beta, state, mask)
136 changes: 136 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
from mlx_lm.models.gated_delta import (
compute_g,
gated_delta_kernel,
gated_delta_ops,
gated_delta_ops_chunked,
gated_delta_update,
)
from mlx_lm.models.ssm import ssm_attn, ssm_update
Expand Down Expand Up @@ -3220,6 +3222,140 @@ def test_gated_delta_masked(self):
self.assertTrue(mx.allclose(y, y_gt, rtol=1e-4, atol=1e-4))
self.assertTrue(mx.allclose(st, st_gt, rtol=1e-4, atol=1e-3))

def test_gated_delta_chunked(self):
B, Hk, Hv, Dk, Dv = 1, 2, 4, 32, 32

# (T, chunk_size, repeat_keys, with_state, g_mode, beta_high)
cases = [
(96, 16, False, False, "random", False), # multi-chunk
(96, 16, True, False, "random", False), # repeated keys (adversarial)
(70, 16, False, False, "random", False), # T not multiple of chunk_size
(128, 64, False, False, "random", False), # blocked solve (C > SUB_BLOCK)
(128, 64, True, False, "random", False), # blocked solve, repeated keys
(96, 16, False, True, "random", False), # carried-in state
(96, 16, False, False, "zero", False), # g -> 0 (log-domain hazard)
(96, 16, False, False, "one", False), # g -> 1 (no decay)
(128, 64, True, False, "one", True), # collinear, no decay, beta -> 1
]
for T, chunk_size, repeat_keys, with_state, g_mode, beta_high in cases:
mx.random.seed(0)
q = mx.random.normal(shape=(B, T, Hk, Dk)) * 0.5
k = mx.random.normal(shape=(B, T, Hk, Dk))
k = k / mx.linalg.norm(k, axis=-1, keepdims=True)
if repeat_keys:
k = mx.broadcast_to(k[:, :1], k.shape) * 1.0
v = mx.random.normal(shape=(B, T, Hv, Dv)) * 0.5
if g_mode == "zero":
g = mx.full((B, T, Hv), 1e-30)
elif g_mode == "one":
g = mx.full((B, T, Hv), 1.0 - 1e-7)
else:
g = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)))
if beta_high:
beta = mx.full((B, T, Hv), 0.999)
else:
beta = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)) + 2.0)
state = (
mx.random.normal(shape=(B, Hv, Dv, Dk)) * 0.3 if with_state else None
)

y_ref, st_ref = gated_delta_ops(q, k, v, g, beta, state)
y, st = gated_delta_ops_chunked(
q, k, v, g, beta, state, chunk_size=chunk_size
)
# The collinear beta -> 1 corner is a blow-up guard: the blocked
# solve degrades to ~1e-3 there instead of fp32 noise.
rtol, atol = (1e-2, 1e-2) if beta_high else (1e-3, 1e-4)
self.assertTrue(mx.allclose(y, y_ref, rtol=rtol, atol=atol))
self.assertTrue(mx.allclose(st, st_ref, rtol=rtol, atol=atol))

def test_gated_delta_chunked_masked(self):
B, T, Hk, Hv, Dk, Dv = 1, 8, 2, 4, 32, 32

mx.random.seed(0)
q = mx.random.normal(shape=(B, T, Hk, Dk))
k = mx.random.normal(shape=(B, T, Hk, Dk))
v = mx.random.normal(shape=(B, T, Hv, Dv))
g = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)))
beta = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)))
state = mx.random.normal(shape=(B, Hv, Dv, Dk)) * 0.3

for s, e, mask in [
(3, 8, mx.array([[False] * 3 + [True] * 5])),
(0, 5, mx.array([[True] * 5 + [False] * 3])),
]:
y_gt, st_gt = gated_delta_ops(
q[:, s:e],
k[:, s:e],
v[:, s:e],
g[:, s:e],
beta[:, s:e],
state,
)
y, st = gated_delta_ops_chunked(q, k, v, g, beta, state, mask, chunk_size=4)
self.assertTrue(mx.allclose(y[:, s:e], y_gt, rtol=1e-3, atol=1e-4))
self.assertTrue(mx.allclose(st, st_gt, rtol=1e-3, atol=1e-4))

def test_gated_delta_chunked_grads(self):
B, T, Hk, Hv, Dk, Dv = 1, 64, 2, 4, 32, 32

def loss_ref(q, k, v, g, beta):
y, s = gated_delta_ops(q, k, v, g, beta)
return (y**2).sum() + (s**2).sum()

def loss_chunked(q, k, v, g, beta):
y, s = gated_delta_ops_chunked(q, k, v, g, beta, chunk_size=16)
return (y**2).sum() + (s**2).sum()

for repeat_keys in [False, True]:
mx.random.seed(0)
q = mx.random.normal(shape=(B, T, Hk, Dk)) * 0.5
k = mx.random.normal(shape=(B, T, Hk, Dk))
k = k / mx.linalg.norm(k, axis=-1, keepdims=True)
if repeat_keys:
k = mx.broadcast_to(k[:, :1], k.shape) * 1.0
v = mx.random.normal(shape=(B, T, Hv, Dv)) * 0.5
g = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)))
beta = mx.sigmoid(mx.random.normal(shape=(B, T, Hv)) + 2.0)

grads_ref = mx.grad(loss_ref, argnums=(0, 1, 2, 3, 4))(q, k, v, g, beta)
grads = mx.grad(loss_chunked, argnums=(0, 1, 2, 3, 4))(q, k, v, g, beta)
for g_ref, g_chunked in zip(grads_ref, grads):
self.assertTrue(mx.allclose(g_chunked, g_ref, rtol=5e-3, atol=5e-3))

def test_gated_delta_chunked_update(self):
B, T, Hk, Hv, Dk, Dv = 1, 40, 2, 4, 32, 32

mx.random.seed(0)
q = mx.random.normal(shape=(B, T, Hk, Dk)) * 0.1
k = mx.random.normal(shape=(B, T, Hk, Dk)) * 0.1
v = mx.random.normal(shape=(B, T, Hv, Dv)) * 0.1
a = -5.0 + mx.random.normal(shape=(B, T, Hv)) * 0.1
b = mx.random.normal(shape=(B, T, Hv))
A_log = mx.zeros((Hv,))
dt_bias = mx.ones((Hv,))

# use_kernel=False routes multi-token calls through the chunked path.
y, st = gated_delta_update(q, k, v, a, b, A_log, dt_bias, use_kernel=False)
g = compute_g(A_log, a, dt_bias)
beta = mx.sigmoid(b)
y_ref, st_ref = gated_delta_ops(q, k, v, g, beta)
self.assertTrue(mx.allclose(y, y_ref, rtol=1e-3, atol=1e-4))
self.assertTrue(mx.allclose(st, st_ref, rtol=1e-3, atol=1e-4))

# bf16 training inputs must produce finite gradients.
qb, kb, vb, ab, bb = (t.astype(mx.bfloat16) for t in (q, k, v, a, b))

def loss(q_, k_, v_, a_, b_):
y, s = gated_delta_update(
q_, k_, v_, a_, b_, A_log, dt_bias, use_kernel=False
)
return (y.astype(mx.float32) ** 2).sum() + (s**2).sum()

grads = mx.grad(loss, argnums=(0, 1, 2, 3, 4))(qb, kb, vb, ab, bb)
for grad in grads:
self.assertTrue(bool(mx.isfinite(grad).all()))


if __name__ == "__main__":
unittest.main()