Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mlx_vlm/models/qwen3_5/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def __init__(
max_position_embeddings=2048,
base=10000,
mrope_section=[11, 11, 0],
rope_parameters=None,
):
super().__init__(
dim,
max_position_embeddings=max_position_embeddings,
base=base,
rope_parameters=rope_parameters,
mrope_section=mrope_section,
style="interleaved",
)
Expand Down Expand Up @@ -258,11 +260,7 @@ def _target_verify_qlinear_header(bits: int, group_size: int) -> str:
}
return scale * accum + sum * bias;
}
""".replace(
"__BITS__", str(bits)
).replace(
"__GS__", str(group_size)
)
""".replace("__BITS__", str(bits)).replace("__GS__", str(group_size))


_TARGET_VERIFY_QMV_SOURCE = r"""
Expand Down Expand Up @@ -1391,6 +1389,7 @@ def __init__(self, args: TextConfig):
max_position_embeddings=args.max_position_embeddings,
base=args.rope_parameters["rope_theta"],
mrope_section=args.rope_parameters["mrope_section"],
rope_parameters=args.rope_parameters,
)

def __call__(
Expand Down
45 changes: 45 additions & 0 deletions mlx_vlm/models/rope_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from functools import lru_cache
from typing import Optional, Sequence

Expand Down Expand Up @@ -397,6 +398,41 @@ def compute_inv_freq(dim: int, base: float):
return 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))


def apply_cope_clip(
inv_freq,
original_max_position_embeddings: Optional[int] = None,
clip_n: Optional[int] = None,
):
"""Soft-clip out-of-distribution low frequencies (CoPE, arXiv:2602.05258).

Frequency components whose rotation period exceeds the pre-training
context window never complete a full rotation during training and go
out-of-distribution when extrapolating past it. CoPE attenuates them
with a raised-cosine (Hann) taper: the boundary component is left
untouched and the lowest-frequency component is frozen (inv_freq=0),
avoiding the spectral leakage of a hard cutoff.

When ``clip_n`` is not given it is derived from
``original_max_position_embeddings``: every component with period
``2*pi/inv_freq`` greater than the window is clipped.
"""
n = inv_freq.shape[0]
if clip_n is None:
if original_max_position_embeddings is None:
raise ValueError(
"CoPE requires either clip_n or "
"original_max_position_embeddings to size the clip."
)
periods = 2 * math.pi / inv_freq
clip_n = int((periods > original_max_position_embeddings).sum())
clip_n = min(int(clip_n), n)
if clip_n == 0:
return inv_freq
theta = mx.linspace(0, math.pi, num=clip_n)
mask = 0.5 * (1.0 + mx.cos(theta))
return mx.concatenate([inv_freq[: n - clip_n], inv_freq[n - clip_n :] * mask])


@mx.compile
def _apply_selected_mrope_frequency_layout(freqs, position_selector):
indices = mx.broadcast_to(
Expand Down Expand Up @@ -535,6 +571,15 @@ def __init__(
self.base = base
self.style = style
self._inv_freq = compute_inv_freq(dim, base)
rope_params = rope_parameters or rope_scaling or {}
if (rope_params.get("rope_type") or rope_params.get("type")) == "cope":
self._inv_freq = apply_cope_clip(
self._inv_freq,
rope_params.get(
"original_max_position_embeddings", max_position_embeddings
),
rope_params.get("clip_n"),
)
self.attention_scaling = attention_scaling
self.cast_output = cast_output
self._mrope_section = list(
Expand Down
104 changes: 104 additions & 0 deletions mlx_vlm/tests/test_cope_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Tests for CoPE (Clipped RoPE) soft-clipping in rope_utils.

CoPE: Clipped RoPE as A Scalable Free Lunch for Long Context LLMs
(arXiv:2602.05258).
"""

import math

import mlx.core as mx
import pytest

from mlx_vlm.models.rope_utils import (
MRoPERotaryEmbedding,
apply_cope_clip,
compute_inv_freq,
)

# Qwen3.5/3.6-family rope geometry: theta=10M, 64 rotary dims, native 262144.
DIM, BASE, NATIVE = 64, 10_000_000, 262_144


def test_auto_clip_sizing():
inv_freq = compute_inv_freq(DIM, BASE)
clipped = apply_cope_clip(inv_freq, original_max_position_embeddings=NATIVE)

# Components with period 2*pi/inv_freq > 262144: i >= 22, i.e. 10 of 32.
periods = 2 * math.pi / inv_freq
expected_n = int((periods > NATIVE).sum())
assert expected_n == 10

# Unclipped head is bit-identical; boundary component preserved (mask=1).
assert mx.allclose(clipped[:22], inv_freq[:22])
assert mx.allclose(clipped[22], inv_freq[22], rtol=1e-5)
# Lowest-frequency component is frozen entirely.
assert clipped[-1] == 0.0
# Taper is monotone non-increasing across the clipped range.
tail = clipped[22:]
assert mx.all(tail[:-1] >= tail[1:])


def test_explicit_clip_n_and_noop():
inv_freq = compute_inv_freq(DIM, BASE)

clipped = apply_cope_clip(inv_freq, clip_n=4)
assert mx.allclose(clipped[:28], inv_freq[:28])
assert clipped[-1] == 0.0

# clip_n=0 is a no-op.
assert mx.allclose(apply_cope_clip(inv_freq, clip_n=0), inv_freq)

# clip_n is bounded by the number of components.
clipped = apply_cope_clip(inv_freq, clip_n=999)
assert clipped.shape == inv_freq.shape


def test_requires_sizing_information():
inv_freq = compute_inv_freq(DIM, BASE)
with pytest.raises(ValueError):
apply_cope_clip(inv_freq)


def test_mrope_embedding_integration():
rope = MRoPERotaryEmbedding(
DIM,
max_position_embeddings=NATIVE,
base=BASE,
rope_parameters={
"rope_type": "cope",
"original_max_position_embeddings": NATIVE,
"mrope_section": [11, 11, 10],
},
)
raw = compute_inv_freq(DIM, BASE)
assert rope.inv_freq[-1] == 0.0
assert mx.allclose(rope.inv_freq[:22], raw[:22])

# Without rope_type=cope the frequencies are untouched.
rope = MRoPERotaryEmbedding(
DIM,
max_position_embeddings=NATIVE,
base=BASE,
rope_parameters={"mrope_section": [11, 11, 10]},
)
assert mx.allclose(rope.inv_freq, raw)


def test_qwen3_5_rotary_embedding_passthrough():
from mlx_vlm.models.qwen3_5.language import Qwen3_5RotaryEmbedding

rope = Qwen3_5RotaryEmbedding(
DIM,
max_position_embeddings=NATIVE,
base=BASE,
mrope_section=[11, 11, 10],
rope_parameters={
"rope_type": "cope",
"original_max_position_embeddings": NATIVE,
"rope_theta": BASE,
"partial_rotary_factor": 0.25,
"mrope_section": [11, 11, 10],
},
)
assert rope.inv_freq[-1] == 0.0
assert rope.inv_freq.shape[0] == DIM // 2