Skip to content

[model] fix: qwen3_moe router double-softmax and idempotent _init_weights wrap#715

Open
TimYangst wants to merge 2 commits intomainfrom
tingyang/fix/qwen3_moe_router_softmax
Open

[model] fix: qwen3_moe router double-softmax and idempotent _init_weights wrap#715
TimYangst wants to merge 2 commits intomainfrom
tingyang/fix/qwen3_moe_router_softmax

Conversation

@TimYangst
Copy link
Copy Markdown
Collaborator

@TimYangst TimYangst commented Apr 30, 2026

What does this PR do?

Two related fixes to the qwen3_moe patches.

  1. Router double-softmax (v4 + v5). Qwen3MoeTopKRouter.forward now returns raw pre-softmax logits as router_logits. HF's load_balancing_loss_func applies softmax internally, so returning the post-softmax tensor (as both upstream HF v4.57.3 / v5.2.0 and the prior VeOmni patch did) put the aux-loss path on softmax(softmax(x)) whenever output_router_logits=True. This flattens the routing distribution and weakens the load-balancing gradient signal. The post-softmax tensor is now kept locally as routing_weights for top-k selection only. Same root cause as HF PR #45131 (issue #45120); we mirror the fix on both VeOmni paths: modeling_qwen3_moe.py (v4 monkey patch) and qwen3_moe_gpu_patch_gen_config.py (v5; generated/patched_modeling_qwen3_moe_gpu.py regenerated via python -m veomni.patchgen.check_patchgen --fix).

  2. Idempotent _init_weights wrap (v4). apply_veomni_qwen3_moe_patch() now wraps an import-time snapshot of HF's _init_weights (_HF_QWEN3_MOE_INIT_WEIGHTS) so repeated patch application stays idempotent.

Checklist Before Starting

  • PR title follows [{modules}] {type}: {description} format

Test

  • Local e2e parallel test passes on 8x A100 in both envs: test_text_parallel_align[qwen3_moe-...] (transformers 4.57.3) and test_text_parallel_align[qwen3_moe_v5] (transformers 5.2.0).

  • The CI toy config has output_router_logits: false, so the router fix is not directly exercised by the existing e2e test. A follow-up could add an aux-loss-enabled subcase to lock this in.

  • End-to-end behavioral check (tingyang/qwen3_moe/verify_router_softmax_fix_e2e.py): builds the toy qwen3_moe model, forces output_router_logits=True, hooks every router and prints stats of forward()[0] — the value HF's load_balancing_loss_func consumes. Run twice (FIXED current code vs. monkey-patched buggy router):

    Run A — FIXED router (current code)
      layer 0: range [-6.76, 6.95]   sum(dim=-1) [13.89, 12.57, 6.97]   → LOGITS (raw)
      layer 1: range [-7.43, 9.13]   sum(dim=-1) [7.70, 11.03, 17.38]   → LOGITS (raw)
      layer 2: range [-5.85, 9.43]   sum(dim=-1) [7.38, 5.38, 2.67]     → LOGITS (raw)
      layer 3: range [-7.15, 6.81]   sum(dim=-1) [2.12, -3.66, 0.94]    → LOGITS (raw)
    
    Run B — BUGGY router (pre-fix double-softmax behavior)
      layer 0: range [0.00, 0.94]    sum(dim=-1) [1.0, 1.0, 1.0]        → PROB (BUG)
      layer 1: range [0.00, 0.99]    sum(dim=-1) [1.0, 1.0, 1.0]        → PROB (BUG)
      layer 2: range [0.00, 1.00]    sum(dim=-1) [1.0, 1.0, 1.0]        → PROB (BUG)
      layer 3: range [0.00, 0.96]    sum(dim=-1) [1.0, 1.0, 1.0]        → PROB (BUG)
    

    Decision rule: values ≥ 0 and sum(dim=-1) == 1 → already-softmax (bug); otherwise raw logits (fixed). All 4 MoE layers agree across both runs.

  • Idempotency verified by a sanity script: 3 repeated apply_veomni_qwen3_moe_patch() calls leave _init_weights consuming exactly 3 RNG samples (one per Patch* expert tensor), not 9.

API and Usage Example

N/A — internal patch refactor, no public API change.

Design & Code Changes

  • v4 PatchQwen3MoeTopKRouter.forward returns raw router_logits; rename internal post-softmax tensor to routing_weights.
  • v5 patchgen adds Qwen3MoeTopKRouter.forward override mirroring the v4 fix; generated/patched_modeling_qwen3_moe_gpu.py regenerated.
  • Capture _HF_QWEN3_MOE_INIT_WEIGHTS at import time and always wrap that snapshot.

Checklist Before Submitting

  • Read the Contribute Guide
  • Applied pre-commit checks (make quality clean)
  • Added/updated documentation
  • Added tests to CI workflow (or explained why not feasible)

@github-actions github-actions Bot added the fix label Apr 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the Qwen3 MoE router's forward method to return raw logits, avoiding double-softmax operations in the auxiliary loss calculation and matching Hugging Face's implementation. It also introduces a wrapper for the weight initialization process to prevent nested patching, which ensures consistent and reproducible initialization across multiple calls. I have no feedback to provide.

@TimYangst TimYangst marked this pull request as ready for review April 30, 2026 20:11
@TimYangst TimYangst marked this pull request as draft April 30, 2026 20:40
Copy link
Copy Markdown
Collaborator

@Luosuu Luosuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this bug only is triggered when load balance loss func is set?

@TimYangst TimYangst marked this pull request as ready for review May 1, 2026 21:18
@TimYangst TimYangst force-pushed the tingyang/fix/qwen3_moe_router_softmax branch from 2938dd1 to a2af78c Compare May 1, 2026 21:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants