Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions mlx_lm/models/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,92 @@ def __call__(self, x, offset=0):
)


class CoPERoPE(nn.Module):
"""Clipped RoPE (CoPE) with a raised-cosine taper on low frequencies.

Implements the soft-clipping strategy from `CoPE: Clipped RoPE as A
Scalable Free Lunch for Long Context LLMs <https://arxiv.org/abs/2602.05258>`_.
The lowest-frequency components — whose rotation periods exceed the
pre-training context window and therefore go out-of-distribution when
extrapolating — are attenuated by a raised-cosine (Hann) mask. The
boundary component is left untouched and the lowest-frequency component
is frozen entirely, avoiding the spectral leakage of a hard cutoff.

Args:
dims (int): The feature dimensions to be rotated.
base (float): Base for the exponential scaling.
original_max_position_embeddings (int, optional): The context window
the model was pre-trained with. When ``clip_n`` is not given, the
number of clipped components is derived from it: every component
whose period exceeds this length is clipped.
traditional (bool, optional): Unused legacy rotation order flag,
kept for parity with the other RoPE classes. Default: ``False``.
clip_n (int, optional): Explicit number of low-frequency components
to clip. Overrides the derivation from
``original_max_position_embeddings``.
"""

def __init__(
self,
dims: int,
base: float = 10000.0,
original_max_position_embeddings: Optional[int] = None,
traditional: bool = False,
clip_n: Optional[int] = None,
):
super().__init__()
self.dims = dims
self.traditional = traditional

n = dims // 2
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)

if clip_n is None:
if original_max_position_embeddings is None:
raise ValueError(
"CoPE requires either clip_n or "
"original_max_position_embeddings to size the clip."
)
# A component is out-of-distribution if its period (the number
# of positions per full rotation, 2*pi*freqs) exceeds the
# pre-training context window.
periods = 2 * math.pi * freqs
clip_n = int((periods > original_max_position_embeddings).sum())
clip_n = min(clip_n, n)
self.clip_n = clip_n

if clip_n > 0:
# Raised-cosine mask going 1 -> 0 across the clip_n lowest
# frequencies. In the freqs convention of mx.fast.rope the
# rotation angle is position / freqs, so attenuating a
# frequency means dividing freqs by the mask; a fully masked
# component becomes inf (identity rotation).
theta = mx.linspace(0, math.pi, num=clip_n)
mask = 0.5 * (1.0 + mx.cos(theta))
tail = mx.where(
mask > 1e-6,
freqs[n - clip_n :] / mx.maximum(mask, 1e-6),
mx.inf,
)
freqs = mx.concatenate([freqs[: n - clip_n], tail])

self._freqs = freqs

def extra_repr(self):
return f"{self.dims}, clip_n={self.clip_n}/{self.dims // 2}"

def __call__(self, x, offset=0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)


def initialize_rope(
dims,
base,
Expand Down Expand Up @@ -298,6 +384,16 @@ def initialize_rope(
base=base,
factor=scaling_config.get("factor", 1.0),
)
elif rope_type == "cope":
return CoPERoPE(
dims=dims,
base=base,
traditional=traditional,
original_max_position_embeddings=scaling_config.get(
"original_max_position_embeddings", max_position_embeddings
),
clip_n=scaling_config.get("clip_n"),
)
elif rope_type == "mrope":
mrope_section = scaling_config.get("mrope_section", [])
assert (
Expand Down
71 changes: 71 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,77 @@ def test_rope(self):
mx.eval(y, expected)
self.assertTrue(mx.allclose(y, expected))

def test_cope_rope(self):
# Auto-sized clip: components whose period exceeds the original
# context window are clipped. base=10M, dims=64, context=262144
# gives 10 of 32 clipped components.
rope = rope_utils.initialize_rope(
64,
base=10_000_000,
traditional=False,
scaling_config={
"rope_type": "cope",
"original_max_position_embeddings": 262144,
},
)
self.assertTrue(isinstance(rope, rope_utils.CoPERoPE))
self.assertEqual(rope.clip_n, 10)

raw = 10_000_000 ** (mx.arange(0, 64, 2, dtype=mx.float32) / 64)
# Unclipped head is untouched, boundary component is preserved
# (mask=1), and the lowest-frequency component is frozen (inf).
self.assertTrue(mx.allclose(rope._freqs[:22], raw[:22]))
self.assertTrue(mx.allclose(rope._freqs[22], raw[22], rtol=1e-5))
self.assertTrue(mx.isinf(rope._freqs[-1]))
# Effective rotation speed decreases monotonically across the taper
inv = 1.0 / rope._freqs[22:]
self.assertTrue(mx.all(inv[:-1] >= inv[1:]))

# Explicit clip_n overrides the derivation
rope = rope_utils.initialize_rope(
64,
base=10_000_000,
traditional=False,
scaling_config={"rope_type": "cope", "clip_n": 4},
)
self.assertEqual(rope.clip_n, 4)

# original_max_position_embeddings falls back to the model's
# max_position_embeddings when not set in the scaling config.
rope = rope_utils.initialize_rope(
64,
base=10_000_000,
traditional=False,
scaling_config={"rope_type": "cope"},
max_position_embeddings=262144,
)
self.assertEqual(rope.clip_n, 10)

# clip_n=0 (nothing out-of-distribution) matches default RoPE
rope = rope_utils.initialize_rope(
8,
base=100.0,
traditional=False,
scaling_config={"rope_type": "cope", "clip_n": 0},
)
x = mx.arange(16, dtype=mx.float32).reshape(1, 1, 2, 8)
expected = mx.fast.rope(
x, 8, traditional=False, base=100.0, scale=1.0, offset=3
)
mx.eval(rope(x, offset=3), expected)
self.assertTrue(mx.allclose(rope(x, offset=3), expected))

def test_cope_rope_no_mutation(self):
rope = rope_utils.CoPERoPE(
dims=8,
base=10000.0,
original_max_position_embeddings=128,
)
x = mx.ones((1, 2, 4, 8))
rope(x)
mx.eval(x)
self.assertTrue((x == 1).all())

def test_su_scaled_rope_no_mutation(self):
rope = rope_utils.SuScaledRoPE(
dims=8,
Expand Down