diff --git a/mlx_lm/fuse.py b/mlx_lm/fuse.py index 87f667752..75fdabadf 100644 --- a/mlx_lm/fuse.py +++ b/mlx_lm/fuse.py @@ -92,7 +92,7 @@ def main() -> None: if args.export_gguf: model_type = config["model_type"] - if model_type not in ["llama", "mixtral", "mistral"]: + if model_type not in ["llama", "mixtral", "mistral", "qwen3_5_moe"]: raise ValueError( f"Model type {model_type} not supported for GGUF conversion." ) diff --git a/mlx_lm/gguf.py b/mlx_lm/gguf.py index 241ac35a1..37e477e7b 100644 --- a/mlx_lm/gguf.py +++ b/mlx_lm/gguf.py @@ -101,6 +101,9 @@ def load(path: Path) -> "HfVocab": def translate_weight_names(name): + # Strip language_model. prefix (Qwen3.6 ConditionalGeneration wrapper) + name = name.replace("language_model.", "") + name = name.replace("model.layers.", "blk.") # for mixtral gate name = name.replace("block_sparse_moe.gate", "ffn_gate_inp") @@ -115,6 +118,27 @@ def translate_weight_names(name): replacement = r"ffn_up.\1.weight" name = re.sub(pattern, replacement, name) + # for Qwen3.6 MoE (switch_mlp merged expert tensors) + name = name.replace("mlp.switch_mlp.gate_up_proj", "ffn_gate_up_exps") + name = name.replace("mlp.switch_mlp.down_proj", "ffn_down_exps") + # Qwen3.6 shared experts + name = name.replace("mlp.shared_expert.gate_proj", "ffn_gate_shexp") + name = name.replace("mlp.shared_expert.down_proj", "ffn_down_shexp") + name = name.replace("mlp.shared_expert.up_proj", "ffn_up_shexp") + name = name.replace("mlp.shared_expert_gate", "ffn_gate_inp_shexp") + # Qwen3.6 MoE router + name = name.replace("mlp.gate", "ffn_gate_inp") + # Qwen3.6 linear attention (Mamba-style SSM) + name = name.replace("linear_attn.A_log", "ssm_a") + name = name.replace("linear_attn.conv1d", "ssm_conv1d") + name = name.replace("linear_attn.dt_bias", "ssm_dt.bias") + name = name.replace("linear_attn.in_proj_a", "ssm_alpha") + name = name.replace("linear_attn.in_proj_b", "ssm_beta") + name = name.replace("linear_attn.in_proj_qkv", "attn_qkv") + name = name.replace("linear_attn.in_proj_z", "attn_gate") + name = name.replace("linear_attn.norm", "ssm_norm") + name = name.replace("linear_attn.out_proj", "ssm_out") + name = name.replace("mlp.gate_proj", "ffn_gate") name = name.replace("mlp.down_proj", "ffn_down") name = name.replace("mlp.up_proj", "ffn_up") @@ -291,6 +315,27 @@ def convert_to_gguf( for k, v in weights.items() } + # Pre-process Qwen3.6 MoE: fuse gate_proj + up_proj → gate_up_proj + # switch_mlp stores gate and up projections as separate tensors, + # but GGUF expects them concatenated as gate_up_proj + fused_weights = {} + skip_keys = set() + for k, v in weights.items(): + if "switch_mlp.gate_proj" in k: + up_key = k.replace("gate_proj", "up_proj") + if up_key in weights: + cat_dim = 1 if v.ndim == 3 else 0 + fused = mx.concatenate([v, weights[up_key]], axis=cat_dim) + fused_key = k.replace("gate_proj", "gate_up_proj") + fused_weights[fused_key] = fused + skip_keys.add(k) + skip_keys.add(up_key) + if fused_weights: + weights = { + **(fused_weights), + **{k: v for k, v in weights.items() if k not in skip_keys}, + } + # rename weights for gguf format weights = {translate_weight_names(k): v for k, v in weights.items()} diff --git a/tests/test_gguf.py b/tests/test_gguf.py index f7e789a00..21d220996 100644 --- a/tests/test_gguf.py +++ b/tests/test_gguf.py @@ -60,3 +60,108 @@ def test_convert_to_gguf( if __name__ == "__main__": unittest.main() + + +class TestQwen36MoETensorMapping(unittest.TestCase): + """Tests for Qwen3.6 MoE (qwen3_5_moe) GGUF conversion support.""" + + def test_translate_weight_names_strips_language_model_prefix(self): + from mlx_lm.gguf import translate_weight_names + + name = "language_model.model.layers.0.mlp.gate_proj.weight" + result = translate_weight_names(name) + self.assertNotIn("language_model", result) + self.assertIn("blk.0", result) + + def test_translate_weight_names_maps_switch_mlp(self): + from mlx_lm.gguf import translate_weight_names + + name = "model.layers.0.mlp.switch_mlp.down_proj.weight" + result = translate_weight_names(name) + self.assertIn("ffn_down_exps", result) + self.assertNotIn("switch_mlp", result) + + def test_translate_weight_names_maps_switch_mlp_gate_up(self): + from mlx_lm.gguf import translate_weight_names + + name = "model.layers.0.mlp.switch_mlp.gate_up_proj.weight" + result = translate_weight_names(name) + self.assertIn("ffn_gate_up_exps", result) + + def test_translate_weight_names_maps_shared_expert(self): + from mlx_lm.gguf import translate_weight_names + + gate = translate_weight_names( + "model.layers.0.mlp.shared_expert.gate_proj.weight" + ) + down = translate_weight_names( + "model.layers.0.mlp.shared_expert.down_proj.weight" + ) + up = translate_weight_names( + "model.layers.0.mlp.shared_expert.up_proj.weight" + ) + self.assertIn("ffn_gate_shexp", gate) + self.assertIn("ffn_down_shexp", down) + self.assertIn("ffn_up_shexp", up) + + def test_translate_weight_names_maps_moe_router(self): + from mlx_lm.gguf import translate_weight_names + + name = "model.layers.0.mlp.gate.weight" + result = translate_weight_names(name) + self.assertIn("ffn_gate_inp", result) + + def test_translate_weight_names_maps_linear_attn(self): + from mlx_lm.gguf import translate_weight_names + + a_log = translate_weight_names("model.layers.0.linear_attn.A_log") + conv1d = translate_weight_names( + "model.layers.0.linear_attn.conv1d.weight" + ) + self.assertIn("ssm_a", a_log) + self.assertIn("ssm_conv1d", conv1d) + + def test_gate_up_proj_fusion_in_convert(self): + """Test that switch_mlp.gate_proj + up_proj are fused before name translation.""" + # Simulate 3D expert tensors [n_experts, intermediate, hidden] + gate = mx.random.uniform(shape=[4, 512, 2048]) + up = mx.random.uniform(shape=[4, 512, 2048]) + + # Simulate the fusion logic from convert_to_gguf + weights = { + "model.layers.0.mlp.switch_mlp.gate_proj.weight": gate, + "model.layers.0.mlp.switch_mlp.up_proj.weight": up, + "model.layers.0.mlp.switch_mlp.down_proj.weight": mx.random.uniform( + shape=[4, 2048, 512] + ), + } + + # Apply fusion (same logic as in convert_to_gguf) + fused_weights = {} + skip_keys = set() + for k, v in weights.items(): + if "switch_mlp.gate_proj" in k: + up_key = k.replace("gate_proj", "up_proj") + if up_key in weights: + cat_dim = 1 if v.ndim == 3 else 0 + fused = mx.concatenate([v, weights[up_key]], axis=cat_dim) + fused_key = k.replace("gate_proj", "gate_up_proj") + fused_weights[fused_key] = fused + skip_keys.add(k) + skip_keys.add(up_key) + if fused_weights: + weights = { + **(fused_weights), + **{k: v for k, v in weights.items() if k not in skip_keys}, + } + + # Verify fusion result + fused_key = "model.layers.0.mlp.switch_mlp.gate_up_proj.weight" + self.assertIn(fused_key, weights) + self.assertEqual(weights[fused_key].shape, [4, 1024, 2048]) + self.assertNotIn( + "model.layers.0.mlp.switch_mlp.gate_proj.weight", weights + ) + self.assertNotIn( + "model.layers.0.mlp.switch_mlp.up_proj.weight", weights + )