feat(mtp): Phase 2.2.Attn+KV + mrope + prompt-prefill + K-chain; defer 3.5 batched-verify#175
Merged
Conversation
…=True
Replaces the Step 5.A passthrough with a real attention block.
Architectural discovery (from vllm Qwen3NextAttention upstream): Qwen3.6
MTP uses attn_output_gate=True, which doubles q_proj output to
[num_heads, 2*head_dim] per token. The output is split per-head into
(q, gate) — q drives Q@K attention, gate is silu'd and elementwise
multiplied with the attention output before o_proj. This is why the
shapes are q_proj[8192,2048] / o_proj[2048,4096] on a model with main
config num_heads=16 head_dim=256: 2*16*256=8192 for Q+gate, 16*256=4096
for the post-gate attention output.
MVP scope: M=1 first-draft step with NO MTP KV history.
- softmax over a single token reduces to identity
- attn_out_per_head[h] = V[h / GQA_group] (broadcast via GQA)
- silu(gate) elementwise multiplied on top
- o_proj reduces 4096 → 2048
- Result added to fc_out as the attention residual
This gives the FIRST draft step a real attention contribution (gate × V).
Subsequent draft steps (K>=1) would need an actual MTP KV cache to
attend over prior drafts — that's Phase 2.2.Attn+KV future work.
New kernels in mtp_forward.cu:
- mtp_gated_v_broadcast_kernel: computes silu(gate[h]) * V_broadcast[h/gqa]
- mtp_add_kernel: fc_out += attn_residual
Reused: imp::rmsnorm (input_layernorm), imp::gemm (q/v/o projections).
Workspace gains 6 attention scratch buffers (input_norm, q_full, k_proj,
v_proj, attn_out, attn_residual) plus 3 dim fields (num_heads, num_kv_heads,
head_dim). mtp_workspace_allocate gains the 3-tuple of attention dims as
default-0 params (back-compat preserved).
Engine derives attention dims from MTP head's q_proj/v_proj shapes:
num_heads = q_proj.shape[0] / (2 * head_dim)
num_kv_heads = v_proj.shape[0] / head_dim
head_dim = main model's head_dim (256 for Qwen3.6)
Log line gains "num_heads=N/NKV, head_dim=HD" suffix.
Validation:
- MtpForwardTest.DraftStepProducesValidToken now also asserts mtp_head_dim=256,
mtp_num_heads=16, mtp_num_kv=2 from the shipped Qwen3.6-NVFP4 weights, and
allocates the workspace with attention engaged. Test PASS in 14.9s.
- make verify-fast green: decode +2.77%, prefill +3.22%, graphs 1.58×.
The K (k_proj) tensor and q_norm/k_norm are computed-for-shape but unused
in the no-history MVP. They become live when the MTP KV cache + softmax
attention kernel land.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Hooks mtp_draft_one into step_decode_forward to record per-step prediction
accuracy WITHOUT changing generation. Single-sequence only (batch=1).
After each main forward emits the next token, the hook:
1. If a prediction was made on the prior step, compare to the actual
next_token → increment mtp_accuracy_.{matches,total}.
2. Capture executor_->view_hidden(1) (the just-produced hidden state for
the new token) and call mtp_draft_one to predict next-step's token.
Engine gains:
- struct MtpAccuracy { int matches, total; float rate(); }
- mtp_accuracy() const accessor
- mtp_accuracy_reset() — clears for fresh measurement
- private mtp_pending_prediction_ + mtp_accuracy_ state
CLI prints accuracy at end of generation when MTP enabled:
$ imp-cli --model Qwen3.6-NVFP4 --mtp-spec-decode 1 --prompt "..." --max-tokens 16
...
mtp 0 / 14 drafts matched (0.0% accept rate)
Validation:
- WITH and WITHOUT --mtp-spec-decode: generated tokens are IDENTICAL
on Qwen3.6-NVFP4 ("The user is asking for the capital of France.").
Confirms telemetry is non-behavioral.
- 0% accept rate at K=1 confirms what the architectural analysis
predicted: the current Phase 2.2.Attn MVP (M=1, no MTP KV history)
can't keep up with the main model's predictions. Phase 2.2.Attn+KV
(real attention with MTP-side cache) is required before Phase 3.5
batched-verify yields any speedup.
- make verify-fast green (decode +2.77% vs baseline, prefill within
threshold, graphs 1.48× speedup, smoke 'Paris' check passes).
This satisfies Phase 5.5's measurement infrastructure: we can now run
arbitrary prompts and read off MTP accuracy to drive the decision on
whether to implement batched verify (Phase 3.5 proper) and KV-aware
attention (Phase 2.2.Attn+KV). The session-end log line is the
acceptance-rate signal the design spec calls for.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds scripts/mtp_accuracy_bench.sh that runs imp-cli with --mtp-spec-decode 1 across 4 prompt classes (factual / verbose-think / code / instruction) on Qwen3.6-NVFP4 and parses the per-step accept-rate telemetry line. Also gates the async-graph-loop initiation on !mtp_spec_decode_enabled() so the Phase 3.5 telemetry hook (which lives in step_decode_forward) sees EVERY decode step. Before this gate, ~3 of 4 prompts had their telemetry silently dropped because long generations transition to the async path mid-stream. Decision thresholds (per design spec): ≥ 60% on ≥ 3/4 classes → Phase 3.5 batched-verify is ROI-justified < 30% across the board → Phase 2.2.Attn+KV is the load-bearing blocker Result on Qwen3.6-NVFP4 with the current Phase 2.2 implementation (MoE complete + gated-attention MVP without MTP KV history): | class | matches | total | rate | |----------------|---------|-------|-------| | factual | 0 | 67 | 0.0% | | verbose-think | 0 | 15 | 0.0% | | code | 0 | 126 | 0.0% | | instruction | 0 | 126 | 0.0% | Definitive: 0/4 at 0%. The draft signal without MTP-side KV attention is indistinguishable from noise — DO NOT implement Phase 3.5 batched-verify yet, it would be measurable throughput regression. Next prerequisite is Phase 2.2.Attn+KV (proper attention with MTP-side cache, ~16 MiB extra VRAM at 16K context). Memory entry mtp_phase5_validation_2026_05_14 captures the finding + recipe for re-running once Phase 2.2.Attn+KV lands. make verify-fast green: decode within 3% threshold, graphs 1.63×, smoke 'Paris' check passes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Phase 5.5 re-bench on Qwen3.6-NVFP4 with --mtp-spec-decode 1:
| class | matches | total | rate |
|----------------|---------|-------|-------|
| factual | 22 | 67 | 32.8% |
| verbose-think | 23 | 126 | 18.3% |
| code | 47 | 126 | 37.3% |
| instruction | 37 | 126 | 29.4% |
avg 29.5%
Up from 0/4 at 0% (previous session). The MTP head was fundamentally
producing zero-information output. TWO compounding bugs:
(1) RMSNorm shape misuse in mtp_forward.cu (~30 of 30% of the win)
imp::rmsnorm reads x.shape[0] as rows and x.shape[1] as d_model.
My Phase 2.1+2.2 code passed 1D tensors [hidden_dim] → kernel saw
rows=hidden_dim, d_model=0 and EARLY-RETURNED. Output buffers kept
their uninitialized contents, locking the LM-head argmax to a
deterministic noise token (6178 = 'awn' for Qwen3.6). Fixed all
four mtp_forward.cu call sites to use shape [1, hidden_dim] (2D).
(2) Missing arch_norm_offset for Qwen3.5/3.6 MTP norms
Qwen3.5/3.6 SafeTensors stores RMSNorm gammas as deltas W where
actual gamma = 1 + W. Main-model code adds the +1 during BF16→FP16
via ctx.arch_norm_offset. upload_mtp_weights() was using
upload_unquantized_weight() which doesn't expose the offset param.
With W ≈ 0.0 stored on disk, MTP norms ran with scale ≈ 0 instead of
≈ 1. Fixed: dispatch norm tensors (pre_fc_norm_{embedding,hidden},
input_layernorm, post_attention_layernorm, q_norm, k_norm,
final_norm) through upload_weight() with weight_offset=arch_norm_offset.
Also includes Phase 2.2.Attn+KV scaffolding from this session:
- Per-session MTP K/V cache (max 16K context → 32 MiB), reset on
imp_context_reset via Engine::mtp_accuracy_reset()
- mtp_attn_kv_scan_kernel: one CTA per Q-head, softmax over cached
positions, scaled by 1/sqrt(head_dim), GQA broadcast from kv_h
- mtp_kv_append_kernel + mtp_gate_attn_out_kernel for the composed pipeline
- Engine sizes the cache from min(model.max_seq_len, 16384)
Bench script now reports avg rate + 3-way verdict (default-on /
batched-verify-worth-trying / still-blocked).
RoPE on Q/K is still missing — would close some of the ~30% → ~85%
gap to DeepSeek-V3 paper expectations (Qwen3.6 uses partial-rope 0.25
+ mrope sections [11, 11, 10] for multimodal). With RoPE, accept rate
should clear the 60% default-on threshold on factual/instruction; at
that point batched-verify (Phase 3.5 proper) becomes a real perf win.
verify-fast green: decode within threshold, graphs 1.89× speedup,
smoke 'Paris' check passes. Generated tokens still identical with/
without MTP — telemetry remains non-behavioral.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ndard RoPE
Three changes to the MTP attention forward:
1. Q-extraction from q_full into d_q_attn (cudaMemcpy2DAsync, per-head
pitch: first head_dim of each head's 2*head_dim slice). Gate stays in
q_full for the post-attention silu(gate)*out multiplication.
2. Per-head QK-norm via imp::rmsnorm (the non-fused path the main model
uses when qknorm_rope_fused is disabled). Reshape Q to [num_heads,
head_dim] and K to [num_kv_heads, head_dim], apply rmsnorm with
weight=mtp.q_norm/k_norm, offset=norm_weight_offset (0.0 for
Qwen3.5/3.6 — the +1 was baked in at upload).
3. Opt-in standard partial-RoPE via imp::rope_forward, gated on env
IMP_MTP_USE_ROPE=1. Default OFF because Qwen3.6 uses mrope with
section split [11, 11, 10] on the 64 rope-dims; standard partial-rope
is a frequency-pattern mismatch versus training and empirically
REDUCES accept rate (~30% → ~20%). Proper mrope support is a
future enhancement.
CRITICAL bug fix on the way: the runtime weight_offset for q_norm/k_norm
must match cfg.norm_weight_offset (0.0 for Qwen3.5/3.6) NOT a duplicate
arch-derived +1. The upload path already adds +1 during BF16→FP16; an
additional +1 at runtime would compute gamma = W + 2 instead of W + 1.
Phase 5.5 bench on Qwen3.6-NVFP4 (with these changes, RoPE off):
| class | matches | total | rate |
|----------------|---------|-------|-------|
| factual | 21 | 67 | 31.3% |
| verbose-think | 31 | 126 | 24.6% |
| code | 43 | 126 | 34.1% |
| instruction | 36 | 126 | 28.6% |
avg 29.6%
Within noise of the pre-qk-norm 29.5% baseline — qk-norm doesn't move
the needle alone. The next step that should is mrope-aware Q/K rotation.
Both qk-norm and mrope are present in the main model's forward; the MTP
head was trained with both, so applying one without the other
under-utilizes the trained weights.
Workspace gains:
- d_q_attn: [num_heads * head_dim] FP16 — Q extracted post-gate
- d_mtp_position: [1] int — for RoPE position
- rope_theta / rope_dim / rope_neox / rms_norm_eps / arch_norm_offset hyperparams
Engine threads these from model->config. The opt-in env knob lets us
revisit RoPE quickly once mrope lands.
verify-fast green: decode within threshold, graphs 1.69× speedup,
smoke 'Paris' check passes. Generated tokens still identical with/without
MTP enabled — telemetry remains non-behavioral.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the opt-in standard partial-rope (via imp::rope_forward) with a
dedicated mrope kernel that supports Qwen3-VL's section split. For
Qwen3.6 the mrope_section is [11, 11, 10] (half-counts; full rope_dim
= 64 = 2·(11+11+10)).
For text-only generation all 3 position components (T, H, W) equal
mtp_pos, so mrope mathematically reduces to standard partial-rope —
matching the bench evidence (avg 27.3% with mrope on vs 29.5% without,
within run-to-run variance). The kernel infrastructure exists so that
multimodal token handling (different T/H/W per token) drops in cleanly
when needed.
`mtp_mrope_kernel<IsKv>` design:
- One CTA per head (Q: nh CTAs, K: nkv CTAs)
- Threads handle the rope_dim/2 frequency pairs in stride
- Per pair k: cumulative-section lookup → position[s] from {T,H,W}
- Frequency: inv_freq[k] = theta^(-2k/rope_dim) (standard partial-rope)
- NeoX rotation: (x[k], x[k+rope_dim/2]) → rotated by `pos * inv_freq[k]`
Workspace gains:
- mrope_sec0/sec1/sec2 (validated to sum to rope_dim/2)
Engine sets the section split from hardcoded Qwen3.6 values when rope_dim
== 64 (since imp doesn't load mrope_section from config yet). Falls back
to all-temporal split for other rope_dim values.
RoPE is now ON by default (rope_dim from config); IMP_MTP_NO_ROPE=1
disables. The standard-rope IMP_MTP_USE_ROPE knob is removed.
Phase 5.5 bench unchanged at avg ~27% (variance), confirming the 30%
plateau is NOT a rotation-correctness issue. The next quality levers:
- Multi-step MTP chaining (K>1) — feed previous draft back to next forward
- Investigate whether main-model absolute positions vs MTP-local 0..N
matters (theory says no — relative attention is preserved — but
worth empirical confirmation)
- Per-class analysis of which drafts match vs miss to understand what
the MTP head learned to predict well
verify-fast: decode noise from container rebuild (cuBLAS algo selection
variance ~3-5% on Qwen3-8B Q8_0; affects baseline-vs-current numbers
but not MTP accept-rate metric). Pre-stash test confirmed regression
is independent of these changes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds Engine::mtp_prefill_prompt() that runs MTP forward at each prompt
position to populate the MTP-side KV cache. Called automatically from
step_prefill after the main forward completes (when MTP is enabled +
single-chunk prefill).
Without this, MTP starts decode with an empty KV cache while the main
model has the entire prompt context — a fundamental asymmetry that
caps achievable accept rate.
Hook conditions:
- MTP spec-decode enabled
- offset == 0 AND req->prefill_offset == 0 (single-chunk path only)
- Uses executor's hidden_ buffer (holds all chunk_len positions post-forward)
The LAST prompt position's MTP prediction is stored in mtp_pending_prediction_
so the first decode-step's accuracy comparison is meaningful (instead of
being skipped due to the prior -1 sentinel).
Cost: O(prompt_len) extra MTP forwards, ~2-3 ms each on Qwen3.6-NVFP4.
50-token prompt → ~125 ms one-time. Acceptable.
Phase 5.5 bench result (with prefill enabled):
| class | matches | total | rate |
|----------------|---------|-------|-------|
| factual | 19 | 94 | 20.2% |
| verbose-think | 41 | 127 | 32.3% |
| code | 38 | 127 | 29.9% |
| instruction | 30 | 127 | 23.6% |
avg 26.5%
Marginal change from no-prefill bench (avg 27.3%). The 30% plateau is
NOT a context-asymmetry issue either — prompt context is now in cache
and accept rate is still bounded. Two contributing factors:
- First decode-step's accept comparison is now counted (was skipped
when pending=-1 at session start), and predicting from prompt-tail
is harder than from mid-decode.
- Class verbose-think did improve (25→32%), suggesting prompt context
DOES help on some prompt types — but only modestly.
The fundamental gap to DeepSeek-V3's ~85% paper number likely needs:
- Multi-step (K>1) MTP forward chaining
- Possibly different scaling / hyperparameter tuning
- Or per-element validation against HF reference forward
Chunked prefill (long prompts that hit max_seq_len chunk boundary) is
NOT supported by this hook yet — only the single-chunk case. The MTP
KV cache would need per-chunk capture for chunked. Documented as a
TODO; not a blocker since typical prompts are 50-200 tokens.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
cuBLAS heuristic re-selection after the previous imp:test image was pruned: tg128 151.16 → 148.15 (-2.0%), pp512 14547 → 14260 (-2.0%). This is the documented "GEMM algo selection variance across container restarts" issue, unrelated to any MTP changes. Pre-stash test on the merged state confirmed identical regression independent of recent commits. verify-fast green after refresh: decode -1.86% within 3% threshold, graphs 1.69× speedup, smoke 'Paris' check passes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two diagnostic toggles added for analyzing the 30% accept-rate plateau:
IMP_MTP_PATTERN_LOG=1
Logs every (predicted, actual) pair from the telemetry hook with
decoded token strings, so accept patterns can be analyzed offline.
Format: "MTP-PAT [+/-] pred=N 'str' actual=N 'str'"
IMP_MTP_PRENORM_H=1
Applies the main model's output_norm RMSNorm to h_prev before
passing to MTP. Tests whether MTP expects post-final-norm hidden
state (some upstream MTP variants do).
Findings from accept-pattern analysis on Qwen3.6-NVFP4:
- Accepts skew heavily toward common English tokens: "the", "is", "a",
"asking", "provide", "answer", "factual", "straightforward", "\n".
- Wrong predictions skew toward HIGH-vocab-ID tokens: Thai script
(~233k), Cyrillic (~163k), special tokens (<|im_start|>, </think>,
<|box_end|>), emojis. Almost no garbled Latin-script gibberish.
- Suggests MTP's h_final lands in the lm_head input space well enough
that COMMON-WORD signal survives, but for content-specific positions
the distribution spreads into the rare-token tail of the vocab.
IMP_MTP_PRENORM_H=1 A/B result: WORSE (factual 20.6% → 10.3%). Refutes
the "MTP expects post-out-norm h_prev" hypothesis. The default (raw
pre-out-norm hidden state from executor's hidden_ buffer) is correct.
These knobs stay as opt-in diagnostics — they don't affect production
default behavior. The pattern log is useful for any future MTP debug
session ("why did accept rate drop on prompt X?").
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds K>1 chained-draft instrumentation: at each main decode step, draft
K predictions in a chain (each feeding back the previous draft + MTP's
own h_final as the next h_prev). Track per-lookahead accept rate via
mtp_chain_accept_[k]. Cache roll-back keeps only the first chain step's
KV-cache write so the real MTP context doesn't drift.
CLI prints "mtp[k=N]" lines for each lookahead distance ≥ 1.
Bench finding on Qwen3.6-NVFP4 (K=2 across 4 prompt classes):
| prompt | k=0 (next) | k=1 (next-next) |
|-----------------|------------|-----------------|
| factual | 29.9% | 0.8% |
| verbose-think | 17.3% | 0.0% |
| code | 19.7% | 2.4% |
| instruction | 22.8% | 3.2% |
| avg | 22.4% | 1.6% |
DEFINITIVE: The MTP head COLLAPSES at chained lookahead 1. K=2 drafts
have ~1.6% accept on the second token — essentially noise. Chained MTP
forward (feeding MTP's own h_final back as h_prev) diverges from the
main-model trajectory the model was trained against.
Implication for Phase 3.5 batched-verify ROI:
- K=1 batched verify: cost ≈ 2× decode, accept ≈ 22%, tokens/cycle ≈ 1.22.
Effective throughput: 1.22 / 2 = 0.61× baseline → 40% SLOWER.
- K=2 batched verify: tokens/cycle ≈ 1 + 0.22 + 0.22*0.016 = 1.23.
Same 40% slowdown.
Break-even requires accept rate ≥ 50% on k=0, which we cannot reach
without architectural changes (multi-token-predictor needs the MAIN
model's intermediate hidden states for each chained step, not MTP's
projected hidden state — which is exactly what spec-decode tries to
avoid recomputing).
Conclusion: Phase 3.5 batched-verify implementation should be DEFERRED
until MTP single-step accept rate clears 50%. That likely requires
either:
- Different MTP head architecture (multi-head with main-model
intermediate-layer projection, like DeepSeek-V3's actual design)
- Or accepting MTP as content-tail "easy token" predictor only and
using a different draft model for spec-decode.
K-chain measurement infrastructure stays as a useful diagnostic for
future MTP iterations.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-on to PR #174 (which shipped Phase 2.2.MoE + Attn MVP + Phase 3.5 telemetry + Phase 5.5 harness + the two root-cause bug fixes). This PR completes the diagnostic investigation of the 30% accept-rate plateau and ships definitive findings.
What's in
1edf1837c09e22cf4dcd91ee1af74503d43c1b3e57Phase 5.5 final findings (Qwen3.6-NVFP4)
K=1 accept rate (per prompt class): 20-37% across factual/verbose-think/code/instruction, avg ~22-30% with run-to-run variance.
K=2 chained drafts (4-prompt sample):
The MTP head collapses at chained lookahead 1. When MTP feeds its own h_final back as h_prev for the second step, it diverges from the main-model trajectory the head was trained against.
ROI math for Phase 3.5 batched-verify
Phase 3.5 batched-verify implementation is DEFERRED. Shipping it now would be a measurable throughput regression.
Refuted hypotheses (all bench-confirmed)
Likely real bottleneck (not addressed)
DeepSeek-V3-style MTP architecture conditions each chained step on the main model's intermediate-layer activations, not on MTP's own projected hidden state. Implementing this would require running the main model up to layer L for each chain step → no spec-decode speedup. Future architectural change required.
Diagnostic infrastructure shipped (re-runnable)
CLI prints per-lookahead accept rates:
Useful side-effects
Validation
make verify-fastgreen (after baseline refresh for post-Docker-prune cuBLAS variance).imp-cli --mtp-spec-decode 1/2/3 --prompt "..."shows per-lookahead accept rates inline.Memory
mtp_phase5_validation_2026_05_14— final state, re-eval triggers, what to try next.🤖 Generated with Claude Code