Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions example/qwen3_5/test_mtp_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,33 @@
# Test result on Qwen3.5-35B-A3B model
# ================================================================
# MTP HEAD 0 — logits[i] predicts token[i+2]
# Top-1 accuracy : 43.08% (28/65 valid positions)
# Top-1 accuracy : 66.15% (43/65 valid positions)
# ================================================================

# --- Spot-check: first 12 positions ---
# pos= 0 | pred= 314 ' of' | gt= 314 ' of' |
# --- Spot-check: first 12 positions ---
# pos= 0 | pred= 11 ',' | gt= 314 ' of' |
# pos= 1 | pred= 279 ' the' | gt= 9338 ' France' | ✗
# pos= 2 | pred= 369 ' is' | gt= 369 ' is' | ✓
# pos= 3 | pred= 11751 ' Paris' | gt= 11751 ' Paris' | ✓
# pos= 4 | pred= 13 '.' | gt= 13 '.' | ✓
# pos= 5 | pred= 198 '\n' | gt= 561 ' The' | ✗
# pos= 6 | pred= 6511 ' capital' | gt= 242476 ' Eiff' | ✗
# pos= 7 | pred= 684 'so' | gt= 300 'el' |
# pos= 7 | pred= 300 'el' | gt= 300 'el' |
# pos= 8 | pred= 21262 ' Tower' | gt= 21262 ' Tower' | ✓
# pos= 9 | pred= 369 ' is' | gt= 557 ' was' | ✗
# pos= 10 | pred= 5617 ' built' | gt= 5617 ' built' | ✓
# pos= 11 | pred= 303 ' in' | gt= 303 ' in' | ✓

# ================================================================
# MAIN HEAD — logits[i] predicts token[i+1]
# Top-1 accuracy : 68.18% (45/66 valid positions)
# Top-1 accuracy : 69.70% (46/66 valid positions)
# ================================================================

# ================================================================
# SUMMARY
# ================================================================
# [PASS] MTP head 0 (predicts token[i+2]): top-1 acc = 43.08% (28/65)
# [PASS] Main head (predicts token[i+1]): top-1 acc = 68.18% (45/66)
# [PASS] MTP head 0 (predicts token[i+2]): top-1 acc = 66.15% (43/65)
# [PASS] Main head (predicts token[i+1]): top-1 acc = 69.70% (46/66)
# ================================================================


Expand Down
34 changes: 28 additions & 6 deletions mbridge/models/qwen3_5/base_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

class Qwen3_5VlBaseBridge(VLMBridge):

mtp_fused_experts: bool = False

def _handle_hf_config(self):
self.hf_text_config = getattr(self.hf_config, "text_config", self.hf_config)
self.hf_vision_config = getattr(self.hf_config, "vision_config", self.hf_config)
Expand Down Expand Up @@ -134,10 +136,13 @@ def _get_mcore_config_by_name(self, mcore_weights_name: str):
return self.config

def _get_safetensor_io(self, weights_path: str):
# TODO: MTP layers are not handled yet
return Qwen3_5SafeTensorIO(
self._get_actual_hf_path(weights_path), ignore_mtp=False
mtp_num_layers = getattr(self.config, "mtp_num_layers", None)

sio = Qwen3_5SafeTensorIO(
self._get_actual_hf_path(weights_path), ignore_mtp=(mtp_num_layers is None)
)
self.mtp_fused_experts = sio.mtp_fused_experts
return sio

def _weight_name_mapping_mcore_local_to_global(
self, model: torch.nn.Module, consider_ep: bool = True
Expand Down Expand Up @@ -305,6 +310,15 @@ def _weight_name_mapping_visual(self, name: str) -> list[str]:
],
}

MTP_FUSED_EXPERTS_MAPPING = {
"language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc1.weight{expert_index}": [
"mtp.layers.0.mlp.experts.gate_up_proj",
],
"language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc2.weight{expert_index}": [
"mtp.layers.0.mlp.experts.down_proj",
],
}

def _convert_mtp_param(self, name: str) -> list[str]:
assert self.config.mtp_num_layers == 1, "only support one mtp layer for now"

Expand All @@ -319,10 +333,13 @@ def _convert_mtp_param(self, name: str) -> list[str]:
# e.g. language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc1.weight3
# -> key = "...linear_fc1.weight{expert_index}", expert_index = 3
if ".mlp.experts.linear_fc" in name:
# split off the numeric expert_index suffix after ".weight"
prefix, expert_index_str = name.split(".weight", 1)
expert_index = int(expert_index_str)
key = prefix + ".weight{expert_index}"

if self.mtp_fused_experts:
return self.MTP_FUSED_EXPERTS_MAPPING[key]

mapping_names = self._MTP_MAPPING[key]
return [x.format(expert_index=expert_index) for x in mapping_names]

Expand Down Expand Up @@ -405,7 +422,12 @@ def _weight_to_hf_format(

return [hf_names[0]], [mcore_weights[: self.vocab_size]]

if "mtp" in mcore_weights_name:
is_mtp_fused_expert = (
"mtp" in mcore_weights_name
and ".mlp.experts.linear_fc" in mcore_weights_name
and self.mtp_fused_experts
)
if "mtp" in mcore_weights_name and not is_mtp_fused_expert:
return [hf_names[0]], [mcore_weights]

# moe
Expand Down Expand Up @@ -591,7 +613,7 @@ def _weight_to_mcore_format(

# moe
if ".mlp.experts.linear_fc" in mcore_weights_name:
if "mtp" in mcore_weights_name:
if "mtp" in mcore_weights_name and not self.mtp_fused_experts:
return hf_weights[0]
# get export index
local_experts_idx = int(mcore_weights_name.split(".weight")[-1])
Expand Down
5 changes: 5 additions & 0 deletions mbridge/models/qwen3_5/qwen3_5_safetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ def __init__(self, hf_dir: str, ignore_mtp: bool = False):
self.index[key] = filename

self.hf_dir = hf_dir

has_mtp = any(k.startswith("mtp.") for k in self.index)
self.mtp_fused_experts = (
has_mtp and "mtp.layers.0.mlp.experts.gate_up_proj" in self.index
)
Loading