Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx_lm/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main() -> None:

if args.export_gguf:
model_type = config["model_type"]
if model_type not in ["llama", "mixtral", "mistral"]:
if model_type not in ["llama", "mixtral", "mistral", "qwen3_5_moe"]:
raise ValueError(
f"Model type {model_type} not supported for GGUF conversion."
)
Expand Down
45 changes: 45 additions & 0 deletions mlx_lm/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def load(path: Path) -> "HfVocab":


def translate_weight_names(name):
# Strip language_model. prefix (Qwen3.6 ConditionalGeneration wrapper)
name = name.replace("language_model.", "")

name = name.replace("model.layers.", "blk.")
# for mixtral gate
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
Expand All @@ -115,6 +118,27 @@ def translate_weight_names(name):
replacement = r"ffn_up.\1.weight"
name = re.sub(pattern, replacement, name)

# for Qwen3.6 MoE (switch_mlp merged expert tensors)
name = name.replace("mlp.switch_mlp.gate_up_proj", "ffn_gate_up_exps")
name = name.replace("mlp.switch_mlp.down_proj", "ffn_down_exps")
# Qwen3.6 shared experts
name = name.replace("mlp.shared_expert.gate_proj", "ffn_gate_shexp")
name = name.replace("mlp.shared_expert.down_proj", "ffn_down_shexp")
name = name.replace("mlp.shared_expert.up_proj", "ffn_up_shexp")
name = name.replace("mlp.shared_expert_gate", "ffn_gate_inp_shexp")
# Qwen3.6 MoE router
name = name.replace("mlp.gate", "ffn_gate_inp")
# Qwen3.6 linear attention (Mamba-style SSM)
name = name.replace("linear_attn.A_log", "ssm_a")
name = name.replace("linear_attn.conv1d", "ssm_conv1d")
name = name.replace("linear_attn.dt_bias", "ssm_dt.bias")
name = name.replace("linear_attn.in_proj_a", "ssm_alpha")
name = name.replace("linear_attn.in_proj_b", "ssm_beta")
name = name.replace("linear_attn.in_proj_qkv", "attn_qkv")
name = name.replace("linear_attn.in_proj_z", "attn_gate")
name = name.replace("linear_attn.norm", "ssm_norm")
name = name.replace("linear_attn.out_proj", "ssm_out")

name = name.replace("mlp.gate_proj", "ffn_gate")
name = name.replace("mlp.down_proj", "ffn_down")
name = name.replace("mlp.up_proj", "ffn_up")
Expand Down Expand Up @@ -291,6 +315,27 @@ def convert_to_gguf(
for k, v in weights.items()
}

# Pre-process Qwen3.6 MoE: fuse gate_proj + up_proj → gate_up_proj
# switch_mlp stores gate and up projections as separate tensors,
# but GGUF expects them concatenated as gate_up_proj
fused_weights = {}
skip_keys = set()
for k, v in weights.items():
if "switch_mlp.gate_proj" in k:
up_key = k.replace("gate_proj", "up_proj")
if up_key in weights:
cat_dim = 1 if v.ndim == 3 else 0
fused = mx.concatenate([v, weights[up_key]], axis=cat_dim)
fused_key = k.replace("gate_proj", "gate_up_proj")
fused_weights[fused_key] = fused
skip_keys.add(k)
skip_keys.add(up_key)
if fused_weights:
weights = {
**(fused_weights),
**{k: v for k, v in weights.items() if k not in skip_keys},
}

# rename weights for gguf format
weights = {translate_weight_names(k): v for k, v in weights.items()}

Expand Down
105 changes: 105 additions & 0 deletions tests/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,108 @@ def test_convert_to_gguf(

if __name__ == "__main__":
unittest.main()


class TestQwen36MoETensorMapping(unittest.TestCase):
"""Tests for Qwen3.6 MoE (qwen3_5_moe) GGUF conversion support."""

def test_translate_weight_names_strips_language_model_prefix(self):
from mlx_lm.gguf import translate_weight_names

name = "language_model.model.layers.0.mlp.gate_proj.weight"
result = translate_weight_names(name)
self.assertNotIn("language_model", result)
self.assertIn("blk.0", result)

def test_translate_weight_names_maps_switch_mlp(self):
from mlx_lm.gguf import translate_weight_names

name = "model.layers.0.mlp.switch_mlp.down_proj.weight"
result = translate_weight_names(name)
self.assertIn("ffn_down_exps", result)
self.assertNotIn("switch_mlp", result)

def test_translate_weight_names_maps_switch_mlp_gate_up(self):
from mlx_lm.gguf import translate_weight_names

name = "model.layers.0.mlp.switch_mlp.gate_up_proj.weight"
result = translate_weight_names(name)
self.assertIn("ffn_gate_up_exps", result)

def test_translate_weight_names_maps_shared_expert(self):
from mlx_lm.gguf import translate_weight_names

gate = translate_weight_names(
"model.layers.0.mlp.shared_expert.gate_proj.weight"
)
down = translate_weight_names(
"model.layers.0.mlp.shared_expert.down_proj.weight"
)
up = translate_weight_names(
"model.layers.0.mlp.shared_expert.up_proj.weight"
)
self.assertIn("ffn_gate_shexp", gate)
self.assertIn("ffn_down_shexp", down)
self.assertIn("ffn_up_shexp", up)

def test_translate_weight_names_maps_moe_router(self):
from mlx_lm.gguf import translate_weight_names

name = "model.layers.0.mlp.gate.weight"
result = translate_weight_names(name)
self.assertIn("ffn_gate_inp", result)

def test_translate_weight_names_maps_linear_attn(self):
from mlx_lm.gguf import translate_weight_names

a_log = translate_weight_names("model.layers.0.linear_attn.A_log")
conv1d = translate_weight_names(
"model.layers.0.linear_attn.conv1d.weight"
)
self.assertIn("ssm_a", a_log)
self.assertIn("ssm_conv1d", conv1d)

def test_gate_up_proj_fusion_in_convert(self):
"""Test that switch_mlp.gate_proj + up_proj are fused before name translation."""
# Simulate 3D expert tensors [n_experts, intermediate, hidden]
gate = mx.random.uniform(shape=[4, 512, 2048])
up = mx.random.uniform(shape=[4, 512, 2048])

# Simulate the fusion logic from convert_to_gguf
weights = {
"model.layers.0.mlp.switch_mlp.gate_proj.weight": gate,
"model.layers.0.mlp.switch_mlp.up_proj.weight": up,
"model.layers.0.mlp.switch_mlp.down_proj.weight": mx.random.uniform(
shape=[4, 2048, 512]
),
}

# Apply fusion (same logic as in convert_to_gguf)
fused_weights = {}
skip_keys = set()
for k, v in weights.items():
if "switch_mlp.gate_proj" in k:
up_key = k.replace("gate_proj", "up_proj")
if up_key in weights:
cat_dim = 1 if v.ndim == 3 else 0
fused = mx.concatenate([v, weights[up_key]], axis=cat_dim)
fused_key = k.replace("gate_proj", "gate_up_proj")
fused_weights[fused_key] = fused
skip_keys.add(k)
skip_keys.add(up_key)
if fused_weights:
weights = {
**(fused_weights),
**{k: v for k, v in weights.items() if k not in skip_keys},
}

# Verify fusion result
fused_key = "model.layers.0.mlp.switch_mlp.gate_up_proj.weight"
self.assertIn(fused_key, weights)
self.assertEqual(weights[fused_key].shape, [4, 1024, 2048])
self.assertNotIn(
"model.layers.0.mlp.switch_mlp.gate_proj.weight", weights
)
self.assertNotIn(
"model.layers.0.mlp.switch_mlp.up_proj.weight", weights
)