Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions omlx/oq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,104 @@ def _should_quantize_tensor(name: str, shape: tuple) -> bool:
return True


_PER_EXPERT_WEIGHT_RE = re.compile(
r"^(.+\.mlp)\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$"
)


def _prefuse_moe_experts_for_vlm(weights, config: dict) -> int:
"""Fuse per-expert MoE weights into the stacked layout mlx_vlm expects.

FP8 MoE VLM checkpoints (e.g. Qwen3.5-MoE variants) store experts as
``...mlp.experts.{n}.(gate|up|down)_proj.weight`` after FP8 dequant, but
mlx_vlm's ``Model.sanitize`` pops a single fused ``experts.gate_up_proj``
and crashes with ``KeyError`` when it is not present. This shim rebuilds
the fused tensors in chunks so peak memory stays bounded for models with
hundreds of experts. No-op when no per-expert keys are found (so dense
models and already-fused checkpoints pay nothing).

Returns the number of layers whose experts were fused.
"""
tc = config.get("text_config", {}) or {}
num_experts = (
config.get("num_local_experts")
or tc.get("num_local_experts")
or config.get("num_experts")
or tc.get("num_experts")
or 0
)
if not num_experts:
return 0

groups: dict[tuple[str, str], dict[int, str]] = {}
for key in list(weights.keys()):
m = _PER_EXPERT_WEIGHT_RE.match(key)
if not m:
continue
prefix, idx, proj = m.group(1), int(m.group(2)), m.group(3)
groups.setdefault((prefix, proj), {})[idx] = key

if not groups:
return 0

_STACK_CHUNK = 16

def _stack_chunked(src_keys: list[str]):
partials = []
for base in range(0, len(src_keys), _STACK_CHUNK):
piece = [weights.pop(k) for k in src_keys[base:base + _STACK_CHUNK]]
stk = mx.stack(piece, axis=0)
mx.eval(stk)
del piece
mx.clear_cache()
partials.append(stk)
if len(partials) == 1:
return partials[0]
out = mx.concatenate(partials, axis=0)
mx.eval(out)
del partials
mx.clear_cache()
return out

prefixes = {p for p, _ in groups}
fused_layers = 0
for prefix in prefixes:
fused_gu = f"{prefix}.experts.gate_up_proj"
fused_dn = f"{prefix}.experts.down_proj"
if fused_gu in weights and fused_dn in weights:
continue

gate_map = groups.get((prefix, "gate_proj"), {})
up_map = groups.get((prefix, "up_proj"), {})
down_map = groups.get((prefix, "down_proj"), {})
if not (len(gate_map) == len(up_map) == len(down_map) == num_experts):
# Partial or mismatched layout — let the downstream sanitize raise
# a clear error instead of silently stacking a broken tensor.
continue

gate_keys = [gate_map[e] for e in range(num_experts)]
up_keys = [up_map[e] for e in range(num_experts)]
down_keys = [down_map[e] for e in range(num_experts)]

gate_stacked = _stack_chunked(gate_keys)
up_stacked = _stack_chunked(up_keys)
fused_gate_up = mx.concatenate([gate_stacked, up_stacked], axis=-2)
mx.eval(fused_gate_up)
del gate_stacked, up_stacked
mx.clear_cache()

weights[fused_gu] = fused_gate_up
weights[fused_dn] = _stack_chunked(down_keys)
fused_layers += 1

if fused_layers:
logger.info(
f"oQ: pre-fused per-expert MoE weights for {fused_layers} layer(s) "
f"(E={num_experts}) before mlx_vlm sanitize"
)
return fused_layers


def _build_model_sanitizer(config: dict):
"""Build a sanitize function from the model class.

Expand Down Expand Up @@ -1481,6 +1579,7 @@ def _build_model_sanitizer(config: dict):
model_config.text_config = text_config

def _vlm_sanitize(weights):
_prefuse_moe_experts_for_vlm(weights, config)
class _Proxy:
audio_tower = None
proxy = _Proxy()
Expand Down
131 changes: 131 additions & 0 deletions tests/test_oq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_is_moe_router,
_LazyTensorIndex,
_normalize_quant_path,
_prefuse_moe_experts_for_vlm,
_quantize_chunked,
_should_quantize_tensor,
estimate_memory,
Expand Down Expand Up @@ -1092,5 +1093,135 @@ def drop_sanitize(weights):
assert len(plan) == len(tensors) - 1


# =============================================================================
# Test _prefuse_moe_experts_for_vlm
# =============================================================================


@pytest.mark.skipif(not HAS_MLX, reason="mlx not available")
class TestPrefuseMoeExpertsForVlm:
"""Prefuse shim for per-expert MoE layouts (FP8 Qwen3.5-MoE VLM etc.).

The fused layout expected by mlx_vlm's ``qwen3_5_moe.sanitize``:
* ``...experts.gate_up_proj`` shape [E, 2*I, H]
* ``...experts.down_proj`` shape [E, H, I]
The per-expert layout found in FP8-dequanted checkpoints:
* ``...experts.{e}.gate_proj.weight`` shape [I, H]
* ``...experts.{e}.up_proj.weight`` shape [I, H]
* ``...experts.{e}.down_proj.weight`` shape [H, I]
"""

def _build_per_expert_weights(self, num_layers: int, num_experts: int,
hidden: int = 16, intermediate: int = 32):
weights = {}
for l in range(num_layers):
prefix = f"model.language_model.layers.{l}.mlp"
for e in range(num_experts):
weights[f"{prefix}.experts.{e}.gate_proj.weight"] = (
mx.ones((intermediate, hidden)) * (e + 1)
)
weights[f"{prefix}.experts.{e}.up_proj.weight"] = (
mx.ones((intermediate, hidden)) * (e + 100)
)
weights[f"{prefix}.experts.{e}.down_proj.weight"] = (
mx.ones((hidden, intermediate)) * (e + 200)
)
return weights

def test_fuses_per_expert_layout(self):
config = {"text_config": {"num_experts": 4}}
weights = self._build_per_expert_weights(num_layers=2, num_experts=4)

fused = _prefuse_moe_experts_for_vlm(weights, config)

assert fused == 2
for l in range(2):
prefix = f"model.language_model.layers.{l}.mlp"
assert f"{prefix}.experts.gate_up_proj" in weights
assert f"{prefix}.experts.down_proj" in weights
# All per-expert source keys should be popped out.
for e in range(4):
assert f"{prefix}.experts.{e}.gate_proj.weight" not in weights
assert f"{prefix}.experts.{e}.up_proj.weight" not in weights
assert f"{prefix}.experts.{e}.down_proj.weight" not in weights
# Shape invariants: fused gate_up_proj is [E, 2*I, H].
assert weights[f"{prefix}.experts.gate_up_proj"].shape == (4, 64, 16)
assert weights[f"{prefix}.experts.down_proj"].shape == (4, 16, 32)

def test_splits_back_to_original_values(self):
"""mlx_vlm's sanitize splits gate_up_proj at axis=-2 to recover
gate / up. Verify we reconstruct the fused tensor in an order that
survives that split."""
config = {"num_local_experts": 3}
weights = self._build_per_expert_weights(num_layers=1, num_experts=3)

_prefuse_moe_experts_for_vlm(weights, config)

prefix = "model.language_model.layers.0.mlp"
fused = weights[f"{prefix}.experts.gate_up_proj"]
gate, up = mx.split(fused, 2, axis=-2)

# gate_proj for expert e was filled with (e + 1); up_proj with (e + 100).
expected_gate = mx.stack([mx.ones((32, 16)) * (e + 1) for e in range(3)], axis=0)
expected_up = mx.stack([mx.ones((32, 16)) * (e + 100) for e in range(3)], axis=0)
assert mx.allclose(gate, expected_gate).item()
assert mx.allclose(up, expected_up).item()

def test_noop_when_already_fused(self):
config = {"num_local_experts": 2}
prefix = "model.language_model.layers.0.mlp"
already_fused = {
f"{prefix}.experts.gate_up_proj": mx.ones((2, 8, 4)),
f"{prefix}.experts.down_proj": mx.ones((2, 4, 4)),
}
before = dict(already_fused)
fused = _prefuse_moe_experts_for_vlm(already_fused, config)
assert fused == 0
assert set(already_fused.keys()) == set(before.keys())

def test_noop_for_dense_model(self):
"""Dense (non-MoE) configs have no num_experts — skip without scanning."""
config = {"hidden_size": 4096, "num_hidden_layers": 32}
weights = {"model.embed_tokens.weight": mx.zeros((100, 16))}
assert _prefuse_moe_experts_for_vlm(weights, config) == 0

def test_skips_partial_layer(self):
"""Partial per-expert layouts (missing experts) are left untouched so
downstream sanitize raises a clear error rather than corrupt shapes."""
config = {"num_local_experts": 4}
weights = self._build_per_expert_weights(num_layers=1, num_experts=4)
# Drop one expert's down_proj — now the layer is incomplete.
del weights["model.language_model.layers.0.mlp.experts.2.down_proj.weight"]

fused = _prefuse_moe_experts_for_vlm(weights, config)

assert fused == 0
# Source keys preserved verbatim for the sanitize caller to inspect.
assert "model.language_model.layers.0.mlp.experts.0.gate_proj.weight" in weights
assert "model.language_model.layers.0.mlp.experts.gate_up_proj" not in weights

def test_config_num_experts_name_fallbacks(self):
"""Config can declare experts via any of four key names."""
for cfg in (
{"num_local_experts": 2},
{"num_experts": 2},
{"text_config": {"num_local_experts": 2}},
{"text_config": {"num_experts": 2}},
):
weights = self._build_per_expert_weights(num_layers=1, num_experts=2)
assert _prefuse_moe_experts_for_vlm(weights, cfg) == 1

def test_chunked_stack_for_many_experts(self):
"""Stack chunk size is 16; verify >16 experts still fuse cleanly."""
config = {"num_local_experts": 20}
weights = self._build_per_expert_weights(num_layers=1, num_experts=20)

_prefuse_moe_experts_for_vlm(weights, config)

prefix = "model.language_model.layers.0.mlp"
assert weights[f"{prefix}.experts.gate_up_proj"].shape == (20, 64, 16)
assert weights[f"{prefix}.experts.down_proj"].shape == (20, 16, 32)


# =============================================================================
# Test GPTQ quantization