feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attn, KV cache + softmax scan; bugs fixed → 0% → ~30% accept rate#174
Merged
Conversation
PR #172 shipped end-to-end MTP scaffolding (load + reduced FC-only forward + engine API + CLI). Three open work items remain for "MTP fully": Phase 2.2 — full transformer block in mtp_forward.cu (currently a no-op passthrough at line 186-190). Design fork documented: Path A (TransformerLayer view-adapter, reuse existing run_attention + run_moe_ffn) vs Path B (from-scratch fused kernels). Path A recommended. Phase 3.5 — auto-invoke mtp_draft_one + verify forward + accept-prefix from the decode loop. Currently mtp_draft_one exists but nothing in step_decode calls it. Phase 5.5 — A/B matrix to decide default-on/off. Task-by-task breakdown for each phase. Cross-references the memory entry mtp_phase2_open_2026_05_14 capturing what's shipped vs open. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…igmoid gate
Replaces the no-op Step 5 placeholder in mtp_forward.cu:186-190 with the
full MoE branch of the MTP transformer block:
Step 5.B.1 post_attention_layernorm(fc_out) → d_post_norm
Step 5.B.2 moe_gate_topk_fused: router @ post_norm, softmax, top-k=8
Step 5.B.3 D2H sync of routing indices+weights for host-side dispatch
Step 5.B.4 Per chosen expert (k ∈ [0, 8)):
gate_up = experts_gate_up_packed[idx] @ post_norm → [1024]
act = silu(gate) * up → [512]
down = experts_down_packed[idx] @ act → [2048]
store into d_expert_outputs[k * hidden]
Step 5.B.5 moe_weighted_sum_residual: fc_out += Σ w[k] * out[k]
Step 5.B.6 shared expert: silu(gate_proj·x) * (up_proj·x) → down_proj
scaled by sigmoid(shared_expert_gate_inp · x), added to fc_out
All compute reuses existing imp primitives:
- imp::rmsnorm
- imp::moe_gate_topk_fused (fused gate-GEMV + softmax + top-k for M=1)
- imp::gemm (M=1 GEMV for per-expert weights and shared expert projections)
- imp::swiglu (silu(gate) * up)
- imp::moe_weighted_sum_residual (Σ + residual)
- imp::shared_expert_gate_scale (sigmoid scalar gate in-place)
+ one tiny new kernel: mtp_add_shared_kernel to fold shared_out into fc_out
Per-expert weight handling: experts_gate_up_packed is [256, 1024, 2048] and
experts_down_packed is [256, 2048, 512] FP16. For each chosen expert, we
build a 2D Tensor view at the expert's slice offset (no extra copies). The
3D packed layout sticks with the shipped MtpHead design.
Workspace gains MoE scratch buffers (post_norm, gate_up scratch, act,
per-expert outputs, moe_out, shared_*) plus a MoeRoutingBuffers pool and
pinned host buffers for the routing D2H. mtp_workspace_allocate gains
n_experts / top_k / expert_d_ff / shared_d_ff params so the Engine sizes
correctly. The 2-arg form is retained for back-compat.
Engine threads model config (256 / 8 / 512 / 512 for Qwen3.6) into the
workspace allocator.
Also fixes hf_config_loader to read Qwen3.5/3.6's shared_expert_intermediate_size
(previously only read DeepSeek's moe_shared_expert_intermediate_size) so
expert_shared_d_ff = 512 lands on the config for Qwen3.6-NVFP4. Without this,
the MTP shared expert block silently disabled itself.
Attention block remains a passthrough (Step 5.A) — Qwen3.6 MTP has unusual
attention shapes (q_proj [8192,2048] but o_proj input is 4096) that need
upstream-reference investigation. Documented in the header.
Smoke test on Qwen3.6-NVFP4 with --mtp-spec-decode 2: workspace allocates
cleanly (d_ff_shared=512), main-model decode produces coherent output
("The capital of France is Paris"), verify-fast green (decode +3.23%,
prefill +2.31%, graphs 1.72×).
The MoE block only RUNS when mtp_draft_one is invoked, which is still
manual (Phase 3.5 auto-invoke not yet wired).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MtpForwardTest.DraftStepProducesValidToken: - Loads Qwen3.6-NVFP4 + MTP sidecar end-to-end - Allocates MTP workspace with full MoE config (256 experts / top-8 / expert_d_ff=512 / shared_d_ff=512) - Calls mtp_draft_step with a random FP16 hidden state + arbitrary token id - Asserts out_token_id ∈ [0, vocab_size) PASSES on RTX 5090 (14.4s including 1.57 GiB MTP upload), exercising: - router GEMV + top-8 selection - per-expert gate_up + swiglu + down (8 experts dispatched) - moe_weighted_sum_residual - shared expert gate_proj/up_proj/down_proj - sigmoid scalar gate This is the first test that actually invokes the MoE block; existing E2E paths don't auto-call mtp_draft_one (Phase 3.5 deferred). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Completes the MTP transformer block (Phase 2.2 = MoE + Attn-with-KV-cache + gated-V-output), wires auto-invoke telemetry (Phase 3.5), and validates end-to-end (Phase 5.5). TWO root-cause bugs found and fixed — accept rate jumped from 0/4-at-0% to 4/4-with-signal (avg 29.5%).
Phase 5.5 validation (Qwen3.6-NVFP4, --mtp-spec-decode 1, max_tokens=128)
All 4 classes show real signal. Below the ≥ 60%-on-3/4 default-on threshold but far above noise. RoPE on Q/K is the next quality lever.
Root-cause bugs (latest commit
166f3fa)Bug 1: RMSNorm 1D-shape early-return
imp::rmsnormreadsx.shape[0]as rows andx.shape[1]as d_model. mtp_forward.cu (since PR #172 Phase 2.1) passed 1D tensors[hidden_dim]→ kernel sawrows=hidden_dim, d_model=0and early-returned without writing output. The MTP forward's RMSNorm outputs were uninitialized FP16 buffer contents (often saturated ~22000). LM-head argmax locked deterministically to token 6178 ('awn') regardless of input.Fix: 4 sites in mtp_forward.cu changed from
[hidden_dim]→[1, hidden_dim].Bug 2: Missing arch_norm_offset on MTP norms
Qwen3.5/3.6 SafeTensors stores RMSNorm gammas as deltas
Wwhere actualgamma = 1 + W. Main-model loader applies the +1 viactx.arch_norm_offset.upload_mtp_weights()usedupload_unquantized_weight()which doesn't expose the offset → MTP norms ran with scale ≈ 0 (raw W with mean near zero).Fix: dispatch the 7 norm tensors (pre_fc_norm_{embedding,hidden}, input_layernorm, post_attention_layernorm, q_norm, k_norm, final_norm) through
upload_weight(..., weight_offset=ctx.arch_norm_offset).Per-phase shipped
attn_output_gate=Truegated V-broadcast (no KV)Phase 2.2.Attn+KV details
MtpDraftWorkspace::d_k_cache+d_v_cache) up to 16K context — 16 MiB each for Qwen3.6 dims (max_seq_len × num_kv_heads × head_dim × 2 bytes).mtp_kv_append_kernel: appends current step's k/v at positionmtp_pos.mtp_attn_kv_scan_kernel: one CTA per Q-head, softmax over [0, mtp_pos+1) with shared-mem max-reduce, GQA broadcast from kv_h, scaled by 1/√head_dim.mtp_gate_attn_out_kernel: appliessilu(gate)elementwise to attention output.Engine::mtp_accuracy_reset()resetsmtp_posonimp_context_resetfor clean new-session state.imp::qknorm_rope_fusedalready exists for main model with partial-rope + mrope) should close roughly half the gap to DeepSeek-V3 paper's ~85% expectations.What's still placeholder
Validation
MtpForwardTest.DraftStepProducesValidToken: PASS (full MoE + Attn+KV path engaged).make verify-fast: green (decode +1.89× graph speedup, smoke 'Paris' check passes).imp-cli --mtp-spec-decode 1produces identical tokens with/without MTP (telemetry remains non-behavioral).scripts/mtp_accuracy_bench.sh: 4/4 classes with signal (above).Files changed
Memory:
mtp_phase5_validation_2026_05_14— re-run recipe + next-step recommendations.🤖 Generated with Claude Code