diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 74f672984..061477ec1 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -29,11 +29,13 @@ def __init__( max_position_embeddings=2048, base=10000, mrope_section=[11, 11, 0], + rope_parameters=None, ): super().__init__( dim, max_position_embeddings=max_position_embeddings, base=base, + rope_parameters=rope_parameters, mrope_section=mrope_section, style="interleaved", ) @@ -258,11 +260,7 @@ def _target_verify_qlinear_header(bits: int, group_size: int) -> str: } return scale * accum + sum * bias; } -""".replace( - "__BITS__", str(bits) - ).replace( - "__GS__", str(group_size) - ) +""".replace("__BITS__", str(bits)).replace("__GS__", str(group_size)) _TARGET_VERIFY_QMV_SOURCE = r""" @@ -1391,6 +1389,7 @@ def __init__(self, args: TextConfig): max_position_embeddings=args.max_position_embeddings, base=args.rope_parameters["rope_theta"], mrope_section=args.rope_parameters["mrope_section"], + rope_parameters=args.rope_parameters, ) def __call__( diff --git a/mlx_vlm/models/rope_utils.py b/mlx_vlm/models/rope_utils.py index 6fb292ae6..af7cbee73 100644 --- a/mlx_vlm/models/rope_utils.py +++ b/mlx_vlm/models/rope_utils.py @@ -1,3 +1,4 @@ +import math from functools import lru_cache from typing import Optional, Sequence @@ -397,6 +398,41 @@ def compute_inv_freq(dim: int, base: float): return 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) +def apply_cope_clip( + inv_freq, + original_max_position_embeddings: Optional[int] = None, + clip_n: Optional[int] = None, +): + """Soft-clip out-of-distribution low frequencies (CoPE, arXiv:2602.05258). + + Frequency components whose rotation period exceeds the pre-training + context window never complete a full rotation during training and go + out-of-distribution when extrapolating past it. CoPE attenuates them + with a raised-cosine (Hann) taper: the boundary component is left + untouched and the lowest-frequency component is frozen (inv_freq=0), + avoiding the spectral leakage of a hard cutoff. + + When ``clip_n`` is not given it is derived from + ``original_max_position_embeddings``: every component with period + ``2*pi/inv_freq`` greater than the window is clipped. + """ + n = inv_freq.shape[0] + if clip_n is None: + if original_max_position_embeddings is None: + raise ValueError( + "CoPE requires either clip_n or " + "original_max_position_embeddings to size the clip." + ) + periods = 2 * math.pi / inv_freq + clip_n = int((periods > original_max_position_embeddings).sum()) + clip_n = min(int(clip_n), n) + if clip_n == 0: + return inv_freq + theta = mx.linspace(0, math.pi, num=clip_n) + mask = 0.5 * (1.0 + mx.cos(theta)) + return mx.concatenate([inv_freq[: n - clip_n], inv_freq[n - clip_n :] * mask]) + + @mx.compile def _apply_selected_mrope_frequency_layout(freqs, position_selector): indices = mx.broadcast_to( @@ -535,6 +571,15 @@ def __init__( self.base = base self.style = style self._inv_freq = compute_inv_freq(dim, base) + rope_params = rope_parameters or rope_scaling or {} + if (rope_params.get("rope_type") or rope_params.get("type")) == "cope": + self._inv_freq = apply_cope_clip( + self._inv_freq, + rope_params.get( + "original_max_position_embeddings", max_position_embeddings + ), + rope_params.get("clip_n"), + ) self.attention_scaling = attention_scaling self.cast_output = cast_output self._mrope_section = list( diff --git a/mlx_vlm/tests/test_cope_rope.py b/mlx_vlm/tests/test_cope_rope.py new file mode 100644 index 000000000..2a2780ad7 --- /dev/null +++ b/mlx_vlm/tests/test_cope_rope.py @@ -0,0 +1,104 @@ +"""Tests for CoPE (Clipped RoPE) soft-clipping in rope_utils. + +CoPE: Clipped RoPE as A Scalable Free Lunch for Long Context LLMs +(arXiv:2602.05258). +""" + +import math + +import mlx.core as mx +import pytest + +from mlx_vlm.models.rope_utils import ( + MRoPERotaryEmbedding, + apply_cope_clip, + compute_inv_freq, +) + +# Qwen3.5/3.6-family rope geometry: theta=10M, 64 rotary dims, native 262144. +DIM, BASE, NATIVE = 64, 10_000_000, 262_144 + + +def test_auto_clip_sizing(): + inv_freq = compute_inv_freq(DIM, BASE) + clipped = apply_cope_clip(inv_freq, original_max_position_embeddings=NATIVE) + + # Components with period 2*pi/inv_freq > 262144: i >= 22, i.e. 10 of 32. + periods = 2 * math.pi / inv_freq + expected_n = int((periods > NATIVE).sum()) + assert expected_n == 10 + + # Unclipped head is bit-identical; boundary component preserved (mask=1). + assert mx.allclose(clipped[:22], inv_freq[:22]) + assert mx.allclose(clipped[22], inv_freq[22], rtol=1e-5) + # Lowest-frequency component is frozen entirely. + assert clipped[-1] == 0.0 + # Taper is monotone non-increasing across the clipped range. + tail = clipped[22:] + assert mx.all(tail[:-1] >= tail[1:]) + + +def test_explicit_clip_n_and_noop(): + inv_freq = compute_inv_freq(DIM, BASE) + + clipped = apply_cope_clip(inv_freq, clip_n=4) + assert mx.allclose(clipped[:28], inv_freq[:28]) + assert clipped[-1] == 0.0 + + # clip_n=0 is a no-op. + assert mx.allclose(apply_cope_clip(inv_freq, clip_n=0), inv_freq) + + # clip_n is bounded by the number of components. + clipped = apply_cope_clip(inv_freq, clip_n=999) + assert clipped.shape == inv_freq.shape + + +def test_requires_sizing_information(): + inv_freq = compute_inv_freq(DIM, BASE) + with pytest.raises(ValueError): + apply_cope_clip(inv_freq) + + +def test_mrope_embedding_integration(): + rope = MRoPERotaryEmbedding( + DIM, + max_position_embeddings=NATIVE, + base=BASE, + rope_parameters={ + "rope_type": "cope", + "original_max_position_embeddings": NATIVE, + "mrope_section": [11, 11, 10], + }, + ) + raw = compute_inv_freq(DIM, BASE) + assert rope.inv_freq[-1] == 0.0 + assert mx.allclose(rope.inv_freq[:22], raw[:22]) + + # Without rope_type=cope the frequencies are untouched. + rope = MRoPERotaryEmbedding( + DIM, + max_position_embeddings=NATIVE, + base=BASE, + rope_parameters={"mrope_section": [11, 11, 10]}, + ) + assert mx.allclose(rope.inv_freq, raw) + + +def test_qwen3_5_rotary_embedding_passthrough(): + from mlx_vlm.models.qwen3_5.language import Qwen3_5RotaryEmbedding + + rope = Qwen3_5RotaryEmbedding( + DIM, + max_position_embeddings=NATIVE, + base=BASE, + mrope_section=[11, 11, 10], + rope_parameters={ + "rope_type": "cope", + "original_max_position_embeddings": NATIVE, + "rope_theta": BASE, + "partial_rotary_factor": 0.25, + "mrope_section": [11, 11, 10], + }, + ) + assert rope.inv_freq[-1] == 0.0 + assert rope.inv_freq.shape[0] == DIM // 2