From 546741a44aeb624f936bd76c6cff936ce54c52d1 Mon Sep 17 00:00:00 2001 From: Raphael Friedmann Date: Thu, 14 May 2026 15:52:33 +0200 Subject: [PATCH 1/3] docs(mtp): plan for remaining Phase 2.2 + 3.5 + 5.5 work after PR #172 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #172 shipped end-to-end MTP scaffolding (load + reduced FC-only forward + engine API + CLI). Three open work items remain for "MTP fully": Phase 2.2 — full transformer block in mtp_forward.cu (currently a no-op passthrough at line 186-190). Design fork documented: Path A (TransformerLayer view-adapter, reuse existing run_attention + run_moe_ffn) vs Path B (from-scratch fused kernels). Path A recommended. Phase 3.5 — auto-invoke mtp_draft_one + verify forward + accept-prefix from the decode loop. Currently mtp_draft_one exists but nothing in step_decode calls it. Phase 5.5 — A/B matrix to decide default-on/off. Task-by-task breakdown for each phase. Cross-references the memory entry mtp_phase2_open_2026_05_14 capturing what's shipped vs open. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../plans/2026-05-14-mtp-phase2-onwards.md | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 docs/superpowers/plans/2026-05-14-mtp-phase2-onwards.md diff --git a/docs/superpowers/plans/2026-05-14-mtp-phase2-onwards.md b/docs/superpowers/plans/2026-05-14-mtp-phase2-onwards.md new file mode 100644 index 0000000..73352a4 --- /dev/null +++ b/docs/superpowers/plans/2026-05-14-mtp-phase2-onwards.md @@ -0,0 +1,241 @@ +# MTP Phase 2.2 + 3.5 + 5.5 — Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Take MTP from "end-to-end scaffolding works" (post PR #172) to "production-grade speculative decoding with measured net win" — implement the full MTP transformer block, auto-invoke draft+verify from the decode loop, validate acceptance rate. + +**Architecture:** PR #172 shipped weight loading + reduced FC-only forward + engine API + CLI flag + smoke. Three concrete open work items remain. Phase 2.2 fills in the transformer block in `mtp_forward.cu`. Phase 3.5 wires drafts into the decode loop with verify-accept logic. Phase 5.5 measures. + +**Tech Stack:** Same as before — C++20 + CUDA sm_120a, reuses existing attention + MoE kernels where feasible, no new third-party deps. + +**Spec:** `docs/superpowers/specs/2026-05-14-mtp-wiring-design.md` + +**Predecessor work (shipped on main):** +- PR #171 (`536af79`) — Phase 1.A: detection + scaffolding +- PR #172 (`b1f74e0`) — Phases 1.B + 1.C + 2.1 + 3 + 4 + 5: full scaffolding, reduced forward + +**Memory:** [`mtp_phase2_open_2026_05_14`](../../../memory/mtp_phase2_open_2026_05_14.md) + +--- + +## Phase 2.2 — Full Transformer Block (multi-week) + +**Outcome:** `src/runtime/mtp_forward.cu:186-190` no-op placeholder replaced with full attention + 256-expert MoE block. Draft logits become close-to-optimal (matching DeepSeek-V3 paper §4.5 expectations). + +### Design fork + +The shipped `MtpHead` uses **flat individual tensors** (q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, router, experts_gate_up_packed [256,1024,2048], experts_down_packed [256,2048,512], shared_expert_*). The main model's `TransformerLayer` uses per-expert vectors `expert_w_gate[256]` + `expert_w_up[256]` + `expert_w_down[256]`. Two paths: + +**Path A: TransformerLayer-view adapter** +- Add `TransformerLayer mtp_layer_view_` field to `MtpHead`, populated as Tensor views of the same device buffers (no extra allocation). +- Slice `experts_gate_up_packed[256, 1024, 2048]` into 256 × `expert_w_gate[512, 2048]` + 256 × `expert_w_up[512, 2048]` views. +- Refactor `GraphExecutor::run_attention(int layer, ...)` and `run_moe_ffn(int layer, ...)` to a `_on(const TransformerLayer&, ...)` overload that takes the layer by reference. +- MTP forward calls `executor_->run_attention_on(mtp.layer_view, state, stream)` then `run_moe_ffn_on(...)`. +- Cost: medium refactor of two existing functions + their internal layer-index-driven callers. + +**Path B: From-scratch fused kernels** +- Write new attention kernel that reads from `mtp.q_proj` / `mtp.k_proj` / `mtp.v_proj` / `mtp.o_proj` directly. +- Write new 256-way MoE forward against the packed 3D `experts_gate_up_packed` / `experts_down_packed`. +- Cost: high — the 256-expert MoE forward is the bulk of multi-week budget. + +**Recommended path:** A (adapter). It reuses existing battle-tested kernels (Qwen3-Coder NVFP4 MoE 266 tok/s with graphs is the reference) and limits new code to a view-population helper. + +### Task list (Path A) + +#### Task 2.2.A.1 — Add `TransformerLayer mtp_layer_view_` to MtpHead + +**Files:** +- Modify: `src/model/mtp_head.h` + +- [ ] **Step 1: Add the field** + +```cpp +// In MtpHead struct, after final_norm: + +// TransformerLayer view: populated by populate_layer_view() so the existing +// run_attention/run_moe_ffn kernels can be invoked against MTP weights. +// All Tensor handles inside are device-side views of the flat fields above — +// no extra allocations. +TransformerLayer mtp_layer_view; +``` + +- [ ] **Step 2: Add `populate_layer_view()` declaration** + +```cpp +// Free function (not method — TransformerLayer lives elsewhere). Builds +// view Tensors that point into the flat MtpHead fields, including slicing +// the 256-expert packed tensors. Idempotent. Returns false if any expected +// flat field is unloaded. +bool mtp_populate_layer_view(MtpHead& head); +``` + +- [ ] **Step 3: Implement in mtp_head.cpp (new file)** + +(Implementation details TBD by executor — reuse the slicing pattern from `src/model/mtp_loader.cpp` in the abandoned fix/open-bugs branch, lines 226-249, or the existing slicing in `src/model/weight_upload.cu:1590-1647` for Gemma-4 fused gate_up.) + +- [ ] **Step 4: Call from safetensors_loader after MTP load succeeds** + +- [ ] **Step 5: Unit test** — populate, then verify `head.mtp_layer_view.wq.data == head.q_proj.data`, expert views point at expected offsets within `head.experts_gate_up_packed.data`. + +- [ ] **Step 6: Commit + verify-fast green** + +#### Task 2.2.A.2 — Refactor `run_attention` to take `const TransformerLayer&` + +**Files:** +- Modify: `src/graph/executor.h` — add `run_attention_on(const TransformerLayer&, const InferenceState&, cudaStream_t)` +- Modify: `src/graph/executor_attention.cu` — extract body of `run_attention(int layer, ...)` into the new overload; the old form becomes `run_attention_on(model_->layer(layer), ...)` + +- [ ] Refactor signature, run all attention tests green, verify-fast green. + +#### Task 2.2.A.3 — Refactor `run_moe_ffn` to take `const TransformerLayer&` + +Mirror of 2.2.A.2 for the MoE function. `src/graph/executor_forward_moe.cu`. + +- [ ] Refactor, all MoE tests + Qwen3-Coder NVFP4 / Gemma-4 NVFP4 / Qwen3.6-NVFP4 smoke green, verify-fast green. + +#### Task 2.2.A.4 — Wire MTP forward to use the layer view + +**Files:** +- Modify: `src/runtime/mtp_forward.cu:186-190` + +Replace the no-op Step 5 with: + +```cpp +// Step 5: full transformer block (Phase 2.2). +// fc_out is the layer input. We need an InferenceState-like context for +// the attention call — but MTP runs on a single token in a single-step +// batched form. Build a minimal state and reuse the existing executor. +{ + // Construct or borrow an InferenceState for n_tokens=1. + InferenceState mtp_state; + mtp_state.n_tokens = 1; + // ... populate position_ids, kv_block_table_, etc. from a dedicated + // MTP KV slot (allocated separately to avoid polluting the main + // model's KV cache). + executor_->run_attention_on(mtp.mtp_layer_view, mtp_state, stream); + executor_->run_moe_ffn_on(mtp.mtp_layer_view, stream); +} +``` + +This task requires designing the **MTP KV cache** — see Task 2.2.A.5. + +#### Task 2.2.A.5 — Allocate MTP KV cache slot + +**Files:** +- Modify: `src/memory/kv_cache_manager.{h,cpp}` — extend with an auxiliary single-layer pool sized for `max_seq_len × n_kv_heads × head_dim` + +Rationale: the MTP layer has its own attention. If we reuse the main model's KV cache, drafted tokens pollute the main cache and require rollback on rejection. A separate MTP KV slot keeps draft + main state cleanly isolated. + +#### Task 2.2.A.6 — Golden-value test + +**Files:** +- Create: `scripts/mtp_reference.py` — HF transformers reference that runs the Qwen3.6 MTP head against 32 random hidden states + tokens and dumps expected output tensors. +- Create: `tests/test_mtp_forward_golden.cu` — load reference dump, run imp's `mtp_draft_step`, assert cosine-sim ≥ 0.99 on output logits. + +Tests Path A is correctly mirroring the upstream model. + +#### Task 2.2.A.7 — verify-fast + Qwen3.6-NVFP4 smoke + +`make verify-fast` green; `imp-cli --model /home/kekz/models/Qwen3.6-35B-A3B-NVFP4 --mtp-spec-decode 2 --bench` produces coherent output; no main-model perf regression. + +--- + +## Phase 3.5 — Auto-invoke Draft + Verify from Decode Loop (~3-5 days) + +**Outcome:** When `--mtp-spec-decode K` is set, the decode loop automatically calls `mtp_draft_one` K times then runs a batched verify forward and accepts the longest prefix matching argmax. + +**Files:** +- Modify: `src/runtime/engine.cpp` — `step_decode_continuous()` branches when `mtp_spec_decode_enabled()` +- Possibly modify: `src/graph/executor.h` — add `forward_batch_for_verify(...)` variant that returns per-position logits +- Modify: `src/memory/kv_cache_manager.{h,cpp}` — accept-prefix-only KV commit primitive (see Task 2.2.A.5) + +### Task list + +#### Task 3.5.1 — Draft loop helper + +`std::vector Engine::mtp_draft_k(int prev_token, const void* d_h_prev)` — calls `mtp_draft_one` K times, feeding back the previous drafted token. Returns the K draft tokens. + +#### Task 3.5.2 — Batched verify forward + +The main model needs to run on `[prev_token, draft_0, ..., draft_{K-1}]` (K+1 positions) in one batched call. The current `forward_logits` does this; we just need to capture per-position logits. + +#### Task 3.5.3 — Accept-prefix selection + +For greedy: accept draft_i iff `argmax(verify_logits[i]) == draft_i`. Stop at first mismatch. Sample the verify-position-N+1 token from `verify_logits[accepted_count]` as the bonus. + +#### Task 3.5.4 — KV state machine + +Only commit KV entries for accepted positions (`prev_token` + accepted drafts). Reject KV entries for unaccepted drafts. Probably easiest: write all K+1 KV entries during verify forward to a scratch slot, then memcpy only the accepted prefix to the real cache. + +#### Task 3.5.5 — Sampling support (non-greedy) + +DeepSeek-V3 paper §4.5 describes the probabilistic accept rule for sampling. Initial ship targets greedy only; extend later. + +#### Task 3.5.6 — Multi-sequence support + +Continuous-batching path. Initial ship can gate spec-decode to batch=1; multi-seq is a follow-up. + +--- + +## Phase 5.5 — Acceptance-Rate Validation (~3-5 days) + +**Outcome:** Measured acceptance rate, tok/s vs baseline, coherence vs baseline. Default-on/off decision. + +### Task list + +#### Task 5.5.1 — A/B matrix harness + +`scripts/mtp_bench.sh` wrapping `imp-bench` over the matrix: +- K ∈ {1, 2, 3} +- prompts ∈ {factual, verbose-think, code, instruction} +- length ∈ {64, 256, 1024} + +Captures tok/s + accept-rate per cell. Output CSV. + +#### Task 5.5.2 — Coherence smoke + +`check-degeneration` skill against MTP-on and MTP-off; outputs must be coherent in both modes. + +#### Task 5.5.3 — Decide default + +Threshold: net throughput ≥ 1.3× baseline AND accept rate ≥ 60% on ≥ 3/4 prompt classes → default-on. Else: keep opt-in via `--mtp-spec-decode K`. + +#### Task 5.5.4 — Document findings + +`memory/mtp_phase5_validation_.md` regardless of outcome. Update `MEMORY.md` index. + +--- + +## Cross-cutting concerns + +### Performance gates +- All phases gated by `make verify-fast` (no Qwen3-4B Q8 baseline regression) +- Phase 5.5 also runs Qwen3.6-NVFP4 tg256 baseline (current 117–142 tok/s post-PR #95) + +### Memory updates +- After Phase 2.2 → `mtp_phase2_2_forward_.md`: which kernels reused, golden-value RMS, perf +- After Phase 3.5 → `mtp_phase3_5_verify_.md`: accept-rule, KV state machine +- After Phase 5.5 → `mtp_phase5_validation_.md`: A/B numbers, default-on/off decision + +### Out-of-scope +- DeepSeek V3 / V3.1 specifically (different attention architecture) +- Multimodal MTP variants (image/audio conditioning) +- Distributed/multi-GPU +- Sampling-mode accept rule (greedy first; sampling is 3.5 follow-up) + +--- + +## Self-review + +**Spec coverage:** +- ✅ Phase 2.2 full transformer block → Tasks 2.2.A.1 – 2.2.A.7 +- ✅ Phase 3.5 auto-invoke → Tasks 3.5.1 – 3.5.6 +- ✅ Phase 5.5 validation → Tasks 5.5.1 – 5.5.4 + +**Placeholder scan:** +- Task 2.2.A.1 Step 3 (`populate_layer_view` implementation) and Task 3.5.4 (KV state machine) are intentionally outlined-not-detailed because the right shapes depend on existing code surfaces that should be re-inspected when the task is picked up. Path A vs B is the load-bearing decision; once that's made, the detailed steps follow mechanically. + +**Risk markers:** +- Path A's `run_attention_on` refactor touches a hot function with many internal callers. Mitigation: keep the int-layer overload as a thin shim around the by-ref form. +- MTP KV cache (Task 2.2.A.5) adds VRAM cost: `max_seq_len × n_kv_heads × head_dim × 2 bytes`. For Qwen3.6 16K context: 16384 × 2 × 256 × 2 = 16 MiB. Negligible. +- Phase 3.5 batched verify may need graph capture work — currently graphs run on n=1 fast-path; n=K+1 may hit a different code path with different perf characteristics. From 0502c9e7d500a587cd8d95eb0680e5984f33abc5 Mon Sep 17 00:00:00 2001 From: Raphael Friedmann Date: Thu, 14 May 2026 16:10:50 +0200 Subject: [PATCH 2/3] =?UTF-8?q?feat(mtp):=20Phase=202.2=20MoE=20block=20?= =?UTF-8?q?=E2=80=94=20256-expert=20top-8=20+=20shared=20expert=20+=20sigm?= =?UTF-8?q?oid=20gate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the no-op Step 5 placeholder in mtp_forward.cu:186-190 with the full MoE branch of the MTP transformer block: Step 5.B.1 post_attention_layernorm(fc_out) → d_post_norm Step 5.B.2 moe_gate_topk_fused: router @ post_norm, softmax, top-k=8 Step 5.B.3 D2H sync of routing indices+weights for host-side dispatch Step 5.B.4 Per chosen expert (k ∈ [0, 8)): gate_up = experts_gate_up_packed[idx] @ post_norm → [1024] act = silu(gate) * up → [512] down = experts_down_packed[idx] @ act → [2048] store into d_expert_outputs[k * hidden] Step 5.B.5 moe_weighted_sum_residual: fc_out += Σ w[k] * out[k] Step 5.B.6 shared expert: silu(gate_proj·x) * (up_proj·x) → down_proj scaled by sigmoid(shared_expert_gate_inp · x), added to fc_out All compute reuses existing imp primitives: - imp::rmsnorm - imp::moe_gate_topk_fused (fused gate-GEMV + softmax + top-k for M=1) - imp::gemm (M=1 GEMV for per-expert weights and shared expert projections) - imp::swiglu (silu(gate) * up) - imp::moe_weighted_sum_residual (Σ + residual) - imp::shared_expert_gate_scale (sigmoid scalar gate in-place) + one tiny new kernel: mtp_add_shared_kernel to fold shared_out into fc_out Per-expert weight handling: experts_gate_up_packed is [256, 1024, 2048] and experts_down_packed is [256, 2048, 512] FP16. For each chosen expert, we build a 2D Tensor view at the expert's slice offset (no extra copies). The 3D packed layout sticks with the shipped MtpHead design. Workspace gains MoE scratch buffers (post_norm, gate_up scratch, act, per-expert outputs, moe_out, shared_*) plus a MoeRoutingBuffers pool and pinned host buffers for the routing D2H. mtp_workspace_allocate gains n_experts / top_k / expert_d_ff / shared_d_ff params so the Engine sizes correctly. The 2-arg form is retained for back-compat. Engine threads model config (256 / 8 / 512 / 512 for Qwen3.6) into the workspace allocator. Also fixes hf_config_loader to read Qwen3.5/3.6's shared_expert_intermediate_size (previously only read DeepSeek's moe_shared_expert_intermediate_size) so expert_shared_d_ff = 512 lands on the config for Qwen3.6-NVFP4. Without this, the MTP shared expert block silently disabled itself. Attention block remains a passthrough (Step 5.A) — Qwen3.6 MTP has unusual attention shapes (q_proj [8192,2048] but o_proj input is 4096) that need upstream-reference investigation. Documented in the header. Smoke test on Qwen3.6-NVFP4 with --mtp-spec-decode 2: workspace allocates cleanly (d_ff_shared=512), main-model decode produces coherent output ("The capital of France is Paris"), verify-fast green (decode +3.23%, prefill +2.31%, graphs 1.72×). The MoE block only RUNS when mtp_draft_one is invoked, which is still manual (Phase 3.5 auto-invoke not yet wired). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/model/hf_config_loader.cpp | 15 ++ src/runtime/engine.cpp | 17 ++- src/runtime/mtp_forward.cu | 257 ++++++++++++++++++++++++++++++++- src/runtime/mtp_forward.h | 74 ++++++++-- 4 files changed, 337 insertions(+), 26 deletions(-) diff --git a/src/model/hf_config_loader.cpp b/src/model/hf_config_loader.cpp index 53801f8..1dd841d 100644 --- a/src/model/hf_config_loader.cpp +++ b/src/model/hf_config_loader.cpp @@ -388,6 +388,16 @@ bool HFConfigLoader::load_config(const std::string& model_dir, ModelConfig& cfg) IMP_LOG_INFO(" GDN config: inner=%d state=%d groups=%d n_heads=%d conv_kernel=%d", cfg.ssm_inner_size, cfg.ssm_state_size, cfg.ssm_group_count, cfg.ssm_dt_rank, cfg.ssm_conv_kernel); + + // Qwen3.5 / 3.6 MoE shared-expert intermediate size. Used by the MTP + // forward to size the shared-expert FFN scratch and by diagnostic logs + // for the main model. Key name differs from DeepSeek-style configs. + int qwen_shared_d_ff = 0; + jobj_get_int(eff, "shared_expert_intermediate_size", qwen_shared_d_ff); + if (qwen_shared_d_ff == 0) + jobj_get_int(eff, "moe_shared_expert_intermediate_size", qwen_shared_d_ff); + if (qwen_shared_d_ff > 0) + cfg.expert_shared_d_ff = qwen_shared_d_ff; } // Nemotron-H MoE: hybrid Mamba2 + MoE-Expert + Attention. Read the @@ -427,7 +437,12 @@ bool HFConfigLoader::load_config(const std::string& model_dir, ModelConfig& cfg) int n_routed = 0, n_shared = 0, shared_d_ff = 0; jobj_get_int(eff, "n_routed_experts", n_routed); jobj_get_int(eff, "n_shared_experts", n_shared); + // moe_shared_expert_intermediate_size: DeepSeek / Qwen2-MoE naming. + // shared_expert_intermediate_size: Qwen3.5 / 3.6 naming (used by their + // shared-expert variant of GroupQuery MoE). jobj_get_int(eff, "moe_shared_expert_intermediate_size", shared_d_ff); + if (shared_d_ff == 0) + jobj_get_int(eff, "shared_expert_intermediate_size", shared_d_ff); if (n_routed > 0 && cfg.n_experts == 0) cfg.n_experts = n_routed; if (n_shared > 0) diff --git a/src/runtime/engine.cpp b/src/runtime/engine.cpp index 46080b2..df2f2fa 100644 --- a/src/runtime/engine.cpp +++ b/src/runtime/engine.cpp @@ -148,18 +148,23 @@ bool Engine::enable_mtp_spec_decode(int k) { mtp_spec_k_ = k; return true; } - const int hidden_dim = model_->config_.d_model; - const int vocab_size = model_->config_.vocab_size; + const int hidden_dim = model_->config_.d_model; + const int vocab_size = model_->config_.vocab_size; + const int n_experts = model_->config_.n_experts; + const int top_k = model_->config_.n_experts_active; + const int expert_d_ff = model_->config_.expert_d_ff; + const int shared_d_ff = model_->config_.expert_shared_d_ff; auto* ws = new imp::MtpDraftWorkspace(); - if (!imp::mtp_workspace_allocate(*ws, hidden_dim, vocab_size)) { + if (!imp::mtp_workspace_allocate(*ws, hidden_dim, vocab_size, + n_experts, top_k, expert_d_ff, shared_d_ff)) { delete ws; IMP_LOG_ERROR("enable_mtp_spec_decode: workspace alloc failed"); return false; } mtp_ws_storage_ = ws; mtp_spec_k_ = k; - IMP_LOG_INFO("MTP spec-decode enabled (k=%d, hidden=%d, vocab=%d, workspace allocated)", - k, hidden_dim, vocab_size); + IMP_LOG_INFO("MTP spec-decode enabled (k=%d, hidden=%d, vocab=%d, experts=%d/top%d, d_ff_e=%d, " + "d_ff_shared=%d)", k, hidden_dim, vocab_size, n_experts, top_k, expert_d_ff, shared_d_ff); return true; } @@ -173,7 +178,7 @@ bool Engine::mtp_draft_one(int prev_token_id, const void* d_h_prev, IMP_LOG_ERROR("mtp_draft_one: MTP head not loaded"); return false; } - const auto* ws = static_cast(mtp_ws_storage_); + auto* ws = static_cast(mtp_ws_storage_); return imp::mtp_draft_step(prev_token_id, d_h_prev, *model_->mtp_, model_->tok_emb_, model_->out_proj_, *ws, hidden_dim, vocab_size, out_token_id, diff --git a/src/runtime/mtp_forward.cu b/src/runtime/mtp_forward.cu index c766137..c3f634a 100644 --- a/src/runtime/mtp_forward.cu +++ b/src/runtime/mtp_forward.cu @@ -13,12 +13,15 @@ // ============================================================================= #include "runtime/mtp_forward.h" -#include "compute/layernorm.h" +#include "compute/activation.h" // swiglu, shared_expert_gate_scale #include "compute/gemm.h" +#include "compute/layernorm.h" +#include "compute/moe_routing.h" #include "core/logging.h" #include #include #include +#include namespace imp { @@ -79,21 +82,74 @@ __global__ void mtp_argmax_kernel(const __half* __restrict__ logits, int vocab_s if (tid == 0) *out_idx = s_idx[0]; } +// --------------------------------------------------------------------------- +// MoE residual + shared-expert combine kernel +// --------------------------------------------------------------------------- +// fc_out[i] += moe_out[i] + shared_out[i] for i in [0, hidden_dim) +// moe_out already contains the residual-added MoE output (residual was fed +// through moe_weighted_sum_residual). We need to add shared_out which has +// already been scaled by the sigmoid gate (via shared_expert_gate_scale). +__global__ void mtp_add_shared_kernel(__half* __restrict__ fc_out, + const __half* __restrict__ shared_out, + int hidden_dim) { + int t = blockIdx.x * blockDim.x + threadIdx.x; + if (t >= hidden_dim) return; + float v = __half2float(fc_out[t]) + __half2float(shared_out[t]); + fc_out[t] = __float2half(v); +} + // --------------------------------------------------------------------------- // Workspace alloc/free // --------------------------------------------------------------------------- -bool mtp_workspace_allocate(MtpDraftWorkspace& ws, int hidden_dim, int vocab_size) { +bool mtp_workspace_allocate(MtpDraftWorkspace& ws, int hidden_dim, int vocab_size, + int n_experts, int top_k, int expert_d_ff, int shared_d_ff) { if (hidden_dim <= 0 || vocab_size <= 0) return false; auto alloc = [](void** p, size_t bytes) { return cudaMalloc(p, bytes) == cudaSuccess; }; bool ok = true; + // Phase 2.1 buffers (always allocated) ok &= alloc(&ws.d_emb_norm, hidden_dim * sizeof(__half)); ok &= alloc(&ws.d_h_norm, hidden_dim * sizeof(__half)); ok &= alloc(&ws.d_fc_in, 2 * hidden_dim * sizeof(__half)); ok &= alloc(&ws.d_fc_out, hidden_dim * sizeof(__half)); ok &= alloc(&ws.d_h_final, hidden_dim * sizeof(__half)); ok &= alloc(&ws.d_logits, vocab_size * sizeof(__half)); + + ws.hidden_dim = hidden_dim; + ws.n_experts = n_experts; + ws.top_k = top_k; + ws.expert_d_ff = expert_d_ff; + ws.shared_d_ff = shared_d_ff; + + // Phase 2.2 MoE buffers (only if n_experts > 0) + if (ok && n_experts > 0 && top_k > 0 && expert_d_ff > 0) { + ok &= alloc(&ws.d_post_norm, hidden_dim * sizeof(__half)); + ok &= alloc(&ws.d_router_logits, n_experts * sizeof(__half)); + ok &= alloc(&ws.d_expert_gate_up, 2 * expert_d_ff * sizeof(__half)); + ok &= alloc(&ws.d_expert_act, expert_d_ff * sizeof(__half)); + ok &= alloc(&ws.d_expert_outputs, top_k * hidden_dim * sizeof(__half)); + ok &= alloc(&ws.d_moe_out, hidden_dim * sizeof(__half)); + + if (shared_d_ff > 0) { + ok &= alloc(&ws.d_shared_gate, shared_d_ff * sizeof(__half)); + ok &= alloc(&ws.d_shared_up, shared_d_ff * sizeof(__half)); + ok &= alloc(&ws.d_shared_act, shared_d_ff * sizeof(__half)); + ok &= alloc(&ws.d_shared_out, hidden_dim * sizeof(__half)); + } + + // Routing pool (max 1 token for M=1 decode). + ws.routing_buf.allocate(/*max_tokens=*/1, /*max_experts=*/n_experts, /*top_k=*/top_k); + + // Pinned host buffers for D2H of routing decision. + if (ok) { + ok &= (cudaHostAlloc(reinterpret_cast(&ws.h_expert_indices), + top_k * sizeof(int), cudaHostAllocDefault) == cudaSuccess); + ok &= (cudaHostAlloc(reinterpret_cast(&ws.h_expert_weights), + top_k * sizeof(float), cudaHostAllocDefault) == cudaSuccess); + } + } + if (!ok) mtp_workspace_free(ws); return ok; } @@ -108,6 +164,20 @@ void mtp_workspace_free(MtpDraftWorkspace& ws) { frfn(ws.d_fc_out); frfn(ws.d_h_final); frfn(ws.d_logits); + frfn(ws.d_post_norm); + frfn(ws.d_router_logits); + frfn(ws.d_expert_gate_up); + frfn(ws.d_expert_act); + frfn(ws.d_expert_outputs); + frfn(ws.d_moe_out); + frfn(ws.d_shared_gate); + frfn(ws.d_shared_up); + frfn(ws.d_shared_act); + frfn(ws.d_shared_out); + ws.routing_buf.free(); + if (ws.h_expert_indices) { cudaFreeHost(ws.h_expert_indices); ws.h_expert_indices = nullptr; } + if (ws.h_expert_weights) { cudaFreeHost(ws.h_expert_weights); ws.h_expert_weights = nullptr; } + ws.hidden_dim = ws.n_experts = ws.top_k = ws.expert_d_ff = ws.shared_d_ff = 0; } // --------------------------------------------------------------------------- @@ -117,7 +187,7 @@ bool mtp_draft_step(int prev_token_id, const void* d_h_prev, const MtpHead& mtp, const Tensor& main_tok_emb, const Tensor& main_lm_head, - const MtpDraftWorkspace& ws, + MtpDraftWorkspace& ws, int hidden_dim, int vocab_size, int* out_token_id, cudaStream_t stream) { @@ -183,11 +253,182 @@ bool mtp_draft_step(int prev_token_id, const void* d_h_prev, imp::gemm(fc_in_view, mtp.fc, fc_out_view, 1.0f, 0.0f, stream); } - // Step 5 (PHASE 2.2 PLACEHOLDER): transformer block skipped. - // Real impl: input_layernorm → self_attn (q/k/v_proj + GQA + o_proj) + - // residual → post_attention_layernorm → MoE (router + 256 experts + - // shared expert) + residual. - // For Phase 2.1: passthrough — copy d_fc_out into d_fc_out (no-op). + // Step 5: transformer block. + // + // 5.A — Attention (PHASE 2.2.Attn deferred): the MTP layer's q_proj + // outputs 8192 ≠ standard GQA shape for o_proj's 4096 input on + // Qwen3.6, suggesting a non-trivial Q-projection split that needs + // upstream-reference investigation. For now, attention is a pass- + // through (input_layernorm → no-op → o_proj-skipped → residual=identity). + // + // 5.B — MoE (Phase 2.2.MoE, this commit): full 256-expert top-8 MoE + // forward + shared-expert with sigmoid gating. Uses the existing + // imp::moe_gate_topk_fused / swiglu / shared_expert_gate_scale + // primitives. + if (ws.n_experts > 0 && ws.top_k > 0 && ws.expert_d_ff > 0 && + mtp.router.data != nullptr) { + const int hd = hidden_dim; + const int d_ff_e = ws.expert_d_ff; + const int d_ff_s = ws.shared_d_ff; + const int top_k = ws.top_k; + const int ne = ws.n_experts; + + // 5.B.1 — post_attention_layernorm(fc_out) → d_post_norm + { + int64_t hd1[1] = {hd}; + Tensor fc_out_view (ws.d_fc_out, QType::F16, 1, hd1, true); + Tensor pn_view (ws.d_post_norm,QType::F16, 1, hd1, true); + imp::rmsnorm(fc_out_view, mtp.post_attention_layernorm, pn_view, 1e-6f, stream); + } + + // 5.B.2 — Router + top-k. moe_gate_topk_fused: router @ post_norm, + // softmax, top-k. Writes into ws.routing_buf. + MoeRoutingResult routing{}; + imp::moe_gate_topk_fused(mtp.router.data, ws.d_post_norm, ne, hd, top_k, + ws.routing_buf, routing, stream, + /*use_sigmoid=*/false, /*normalize_weights=*/true); + + // 5.B.3 — D2H copy of expert indices + weights so the host loop can + // dispatch per-expert GEMVs. This is non-graph-safe but + // drafts run outside graph capture for now. + cudaMemcpyAsync(ws.h_expert_indices, ws.routing_buf.expert_indices, + top_k * sizeof(int), cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(ws.h_expert_weights, ws.routing_buf.expert_weights, + top_k * sizeof(float), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + // 5.B.4 — For each chosen expert: GEMV gate_up_packed[e] @ post_norm, + // swiglu, GEMV down_packed[e] @ act, store into d_expert_outputs[k]. + // + // Layout of packed tensors: + // experts_gate_up_packed shape: [ne, 2*d_ff_e, hd] FP16 + // experts_down_packed shape: [ne, hd, d_ff_e] FP16 + const size_t gu_per_expert_bytes = static_cast(2) * d_ff_e * hd * sizeof(__half); + const size_t dn_per_expert_bytes = static_cast(hd) * d_ff_e * sizeof(__half); + + int64_t gu_shape[2] = {2 * d_ff_e, hd}; + int64_t dn_shape[2] = {hd, d_ff_e}; + + for (int k = 0; k < top_k; ++k) { + int e_idx = ws.h_expert_indices[k]; + if (e_idx < 0 || e_idx >= ne) { + IMP_LOG_WARN("mtp MoE: invalid expert index %d (top_k=%d)", e_idx, k); + continue; + } + + // Build view tensors into the packed buffers for this expert. + char* gu_base = static_cast(mtp.experts_gate_up_packed.data) + + static_cast(e_idx) * gu_per_expert_bytes; + char* dn_base = static_cast(mtp.experts_down_packed.data) + + static_cast(e_idx) * dn_per_expert_bytes; + Tensor gu_view(gu_base, QType::F16, 2, gu_shape, true); + Tensor dn_view(dn_base, QType::F16, 2, dn_shape, true); + + // gate_up = gu_view @ post_norm + { + int64_t in_shape[2] = {1, hd}; + int64_t out_shape[2] = {1, 2 * d_ff_e}; + Tensor in_view (ws.d_post_norm, QType::F16, 2, in_shape, true); + Tensor out_view(ws.d_expert_gate_up, QType::F16, 2, out_shape, true); + imp::gemm(in_view, gu_view, out_view, 1.0f, 0.0f, stream); + } + // swiglu: gate = first half, up = second half → act = silu(gate)*up + { + int64_t half_shape[2] = {1, d_ff_e}; + Tensor gate_view(ws.d_expert_gate_up, + QType::F16, 2, half_shape, true); + Tensor up_view( static_cast(ws.d_expert_gate_up) + d_ff_e * sizeof(__half), + QType::F16, 2, half_shape, true); + Tensor act_view(ws.d_expert_act, QType::F16, 2, half_shape, true); + imp::swiglu(gate_view, up_view, act_view, stream); + } + // down = dn_view @ act → write directly into d_expert_outputs[k * hd] + { + int64_t in_shape[2] = {1, d_ff_e}; + int64_t out_shape[2] = {1, hd}; + Tensor in_view (ws.d_expert_act, QType::F16, 2, in_shape, true); + __half* out_base = static_cast<__half*>(ws.d_expert_outputs) + k * hd; + Tensor out_view(out_base, QType::F16, 2, out_shape, true); + imp::gemm(in_view, dn_view, out_view, 1.0f, 0.0f, stream); + } + } + + // 5.B.5 — Weighted sum + residual: moe_out = fc_out + Σ_k w[k]*expert_outputs[k] + imp::moe_weighted_sum_residual( + /*expert_outputs=*/ws.d_expert_outputs, + /*expert_weights=*/ws.routing_buf.expert_weights, + /*residual=*/ ws.d_fc_out, + /*output=*/ ws.d_moe_out, + /*d_model=*/ hd, + /*top_k=*/ top_k, + stream); + + // 5.B.6 — Shared expert (if present): silu(gate_proj·x) * (up_proj·x), + // scale by sigmoid(shared_expert_gate_inp · x), then add to + // moe_out (which already includes the attention residual). + if (d_ff_s > 0 && mtp.shared_expert_gate_proj.data && mtp.shared_expert_up_proj.data && + mtp.shared_expert_down_proj.data && mtp.shared_expert_gate.data) { + // shared_gate = shared_expert_gate_proj @ post_norm → [d_ff_s] + { + int64_t in_shape[2] = {1, hd}; + int64_t out_shape[2] = {1, d_ff_s}; + Tensor in_view (ws.d_post_norm, QType::F16, 2, in_shape, true); + Tensor out_view(ws.d_shared_gate, QType::F16, 2, out_shape, true); + imp::gemm(in_view, mtp.shared_expert_gate_proj, out_view, 1.0f, 0.0f, stream); + } + // shared_up = shared_expert_up_proj @ post_norm → [d_ff_s] + { + int64_t in_shape[2] = {1, hd}; + int64_t out_shape[2] = {1, d_ff_s}; + Tensor in_view (ws.d_post_norm, QType::F16, 2, in_shape, true); + Tensor out_view(ws.d_shared_up, QType::F16, 2, out_shape, true); + imp::gemm(in_view, mtp.shared_expert_up_proj, out_view, 1.0f, 0.0f, stream); + } + // shared_act = silu(shared_gate) * shared_up + { + int64_t s_shape[2] = {1, d_ff_s}; + Tensor gate_view(ws.d_shared_gate, QType::F16, 2, s_shape, true); + Tensor up_view (ws.d_shared_up, QType::F16, 2, s_shape, true); + Tensor act_view (ws.d_shared_act, QType::F16, 2, s_shape, true); + imp::swiglu(gate_view, up_view, act_view, stream); + } + // shared_out = shared_expert_down_proj @ shared_act → [hd] + { + int64_t in_shape[2] = {1, d_ff_s}; + int64_t out_shape[2] = {1, hd}; + Tensor in_view (ws.d_shared_act, QType::F16, 2, in_shape, true); + Tensor out_view(ws.d_shared_out, QType::F16, 2, out_shape, true); + imp::gemm(in_view, mtp.shared_expert_down_proj, out_view, 1.0f, 0.0f, stream); + } + // Apply sigmoid(shared_expert_gate · post_norm) scalar to shared_out + // in-place via the existing fused kernel. + imp::shared_expert_gate_scale( + /*x=*/ ws.d_post_norm, + /*W=*/ mtp.shared_expert_gate.data, + /*y_inout=*/ ws.d_shared_out, + /*n=*/ 1, + /*d_model=*/ hd, + /*d=*/ hd, + stream); + + // moe_out += shared_out → write back into d_fc_out for downstream + { + int block = 256; + int grid = (hd + block - 1) / block; + mtp_add_shared_kernel<<>>( + static_cast<__half*>(ws.d_moe_out), + static_cast(ws.d_shared_out), + hd); + } + } + // Copy d_moe_out → d_fc_out (overwrite) so downstream RMSNorm reads the + // post-transformer hidden state. moe_weighted_sum_residual already + // added fc_out as residual into d_moe_out; the shared-expert addition + // above (if present) updates d_moe_out in place. + cudaMemcpyAsync(ws.d_fc_out, ws.d_moe_out, hd * sizeof(__half), + cudaMemcpyDeviceToDevice, stream); + } + // else: legacy reduced forward (Phase 2.1 behavior) — d_fc_out unchanged. // Step 6: h_final = RMSNorm(fc_out, final_norm) { diff --git a/src/runtime/mtp_forward.h b/src/runtime/mtp_forward.h index e349cf8..f31599b 100644 --- a/src/runtime/mtp_forward.h +++ b/src/runtime/mtp_forward.h @@ -1,20 +1,26 @@ #pragma once // ============================================================================= -// mtp_forward.h — Multi-Token-Predictor draft step (Phase 2 scaffolding) +// mtp_forward.h — Multi-Token-Predictor draft step // ============================================================================= // -// One draft-token forward pass through the MTP head. Phase 2 status: -// - Phase 2.1: scaffolding + reduced forward (emb → pre_fc_norm → fc → -// final_norm → lm_head → argmax). Skips the MTP transformer -// block (attention + 256-expert MoE) — that's the bulk of the -// work and is genuinely multi-week to implement from scratch. -// The reduced path produces draft tokens but acceptance rate -// will be far below trained-MTP optimum. -// - Phase 2.2: full transformer block (attention + MoE). Future session. +// One draft-token forward pass through the MTP head. +// +// Phase status: +// - 2.1 (shipped PR #172): reduced forward (emb → pre_fc_norm → fc → +// final_norm → lm_head → argmax). Skips the transformer block. +// Acceptance rate will be far below trained-MTP optimum. +// - 2.2.MoE (this file): MoE block plumbed via existing imp::gemm / +// swiglu / moe_gate_topk_fused / shared_expert_gate_scale +// primitives. Attention block still a passthrough (architectural +// shape ambiguity — q_proj outputs 8192 but o_proj inputs 4096 +// on Qwen3.6 MTP, doesn't match standard GQA conventions; needs +// upstream-reference investigation before correct implementation). +// - 2.2.Attn (future): full attention block. // // Design: docs/superpowers/specs/2026-05-14-mtp-wiring-design.md // ============================================================================= +#include "compute/moe_routing.h" // MoeRoutingBuffers #include "core/tensor.h" #include "model/mtp_head.h" #include @@ -26,6 +32,7 @@ class Model; // Workspace tensors needed for one draft step. Caller pre-allocates these so // the draft step is graph-safe (no cudaMalloc inside captured graph). struct MtpDraftWorkspace { + // ---- Phase 2.1 reduced-forward scratch ---- // [hidden_dim] FP16 — normalized embedding input void* d_emb_norm = nullptr; // [hidden_dim] FP16 — normalized hidden-state input @@ -38,6 +45,44 @@ struct MtpDraftWorkspace { void* d_h_final = nullptr; // [vocab_size] FP16 — draft logits void* d_logits = nullptr; + + // ---- Phase 2.2 MoE scratch ---- + // [hidden_dim] FP16 — post_attention_layernorm(fc_out) + void* d_post_norm = nullptr; + // [n_experts] FP16 — router logits (currently unused — moe_gate_topk_fused + // produces indices+weights directly into the routing buffers). + void* d_router_logits = nullptr; + // [2*expert_d_ff] FP16 — single-expert gate_up output (gate at 0..d_ff, + // up at d_ff..2*d_ff) + void* d_expert_gate_up = nullptr; + // [expert_d_ff] FP16 — silu(gate)*up per chosen expert + void* d_expert_act = nullptr; + // [top_k * hidden_dim] FP16 — per-chosen-expert down outputs, contiguous + // along the top_k axis; consumed by moe_weighted_sum_residual. + void* d_expert_outputs = nullptr; + // [hidden_dim] FP16 — accumulator (moe weighted-sum + residual via + // moe_weighted_sum_residual). + void* d_moe_out = nullptr; + // Shared expert scratch + void* d_shared_gate = nullptr; // [shared_d_ff] FP16 + void* d_shared_up = nullptr; // [shared_d_ff] FP16 + void* d_shared_act = nullptr; // [shared_d_ff] FP16 (silu(gate)*up) + void* d_shared_out = nullptr; // [hidden_dim] FP16 (shared_down_proj @ act) + + // Routing buffer pool (n_experts, top_k both known at enable time) + MoeRoutingBuffers routing_buf; + // Per-step host-side copies of indices/weights for the M=1 host-side + // per-expert GEMV loop. Allocated as cudaHostAlloc'd for pinned D2H. + int* h_expert_indices = nullptr; // [top_k] + float* h_expert_weights = nullptr; // [top_k] + + // Hyperparameters captured at workspace-allocate time so the draft step + // doesn't need to re-derive them from the model. + int hidden_dim = 0; + int n_experts = 0; + int top_k = 0; + int expert_d_ff = 0; + int shared_d_ff = 0; }; // One MTP draft step. Returns the draft token id via host out_token_id. @@ -61,15 +106,20 @@ bool mtp_draft_step(int prev_token_id, const void* d_h_prev, const MtpHead& mtp, const Tensor& main_tok_emb, const Tensor& main_lm_head, - const MtpDraftWorkspace& ws, + MtpDraftWorkspace& ws, int hidden_dim, int vocab_size, int* out_token_id, cudaStream_t stream); // Allocate the workspace from the VRAM allocator. Caller is responsible for // keeping ws alive (typically owned by the Engine for the lifetime of a session). -// Phase 4 wires this into the engine; for now, callers manage manually. -bool mtp_workspace_allocate(MtpDraftWorkspace& ws, int hidden_dim, int vocab_size); +// The MoE-related buffers (post_norm, expert outputs, shared expert scratch, +// routing pool) are sized from `n_experts`, `top_k`, `expert_d_ff`, +// `shared_d_ff`. Pass 0 for any of those to disable the MoE block at runtime +// (back-compat — Phase 2.1 callers can keep using the 2-arg form below). +bool mtp_workspace_allocate(MtpDraftWorkspace& ws, int hidden_dim, int vocab_size, + int n_experts = 0, int top_k = 0, + int expert_d_ff = 0, int shared_d_ff = 0); void mtp_workspace_free(MtpDraftWorkspace& ws); } // namespace imp From afaf701cb51d2383f1dd3f0de44bcee0790316e6 Mon Sep 17 00:00:00 2001 From: Raphael Friedmann Date: Thu, 14 May 2026 16:18:27 +0200 Subject: [PATCH 3/3] test(mtp): integration test for Phase 2.2 MoE block MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MtpForwardTest.DraftStepProducesValidToken: - Loads Qwen3.6-NVFP4 + MTP sidecar end-to-end - Allocates MTP workspace with full MoE config (256 experts / top-8 / expert_d_ff=512 / shared_d_ff=512) - Calls mtp_draft_step with a random FP16 hidden state + arbitrary token id - Asserts out_token_id ∈ [0, vocab_size) PASSES on RTX 5090 (14.4s including 1.57 GiB MTP upload), exercising: - router GEMV + top-8 selection - per-expert gate_up + swiglu + down (8 experts dispatched) - moe_weighted_sum_residual - shared expert gate_proj/up_proj/down_proj - sigmoid scalar gate This is the first test that actually invokes the MoE block; existing E2E paths don't auto-call mtp_draft_one (Phase 3.5 deferred). Co-Authored-By: Claude Opus 4.7 (1M context) --- CMakeLists.txt | 1 + tests/test_mtp_forward.cpp | 120 +++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 tests/test_mtp_forward.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0dc148b..1f53087 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -488,6 +488,7 @@ if(IMP_BUILD_TESTS) tests/test_weight_registry_preservation.cpp tests/test_e2e_llm_compressor.cpp tests/test_chunked_prefill.cu + tests/test_mtp_forward.cpp ) set(IMP_TEST_BINARIES diff --git a/tests/test_mtp_forward.cpp b/tests/test_mtp_forward.cpp new file mode 100644 index 0000000..bc7f554 --- /dev/null +++ b/tests/test_mtp_forward.cpp @@ -0,0 +1,120 @@ +// ============================================================================= +// test_mtp_forward.cpp — Phase 2.2 MoE-block integration test +// ============================================================================= +// +// Exercises mtp_draft_step() end-to-end against the real Qwen3.6-NVFP4 model: +// 1. Load main model + MTP head (BF16 → FP16, 1.57 GiB) +// 2. Allocate MTP workspace with full MoE config (256/top-8/512/512) +// 3. Run mtp_draft_step with a synthetic d_h_prev + token id +// 4. Assert it does NOT crash and out_token_id is in [0, vocab_size) +// +// GTEST_SKIPs if the model is absent so CI on bare hosts still passes. +// +// Spec: docs/superpowers/specs/2026-05-14-mtp-wiring-design.md +// ============================================================================= + +#include "model/model.h" +#include "model/safetensors_loader.h" +#include "runtime/engine.h" +#include "runtime/mtp_forward.h" + +#include +#include + +#include +#include +#include + +namespace fs = std::filesystem; + +namespace { + +constexpr const char kQwen36ModelDir[] = + "/home/kekz/models/Qwen3.6-35B-A3B-NVFP4"; + +bool model_available() { + return fs::exists(std::string(kQwen36ModelDir) + "/model_mtp.safetensors"); +} + +} // namespace + +TEST(MtpForwardTest, DraftStepProducesValidToken) { + if (!model_available()) { + GTEST_SKIP() << "Qwen3.6-NVFP4 with MTP not present at " << kQwen36ModelDir; + } + + // Load model + upload weights with MTP enabled. + auto model = imp::load_safetensors(kQwen36ModelDir); + ASSERT_NE(model, nullptr); + ASSERT_TRUE(model->mtp_.has_value()); + + // upload_weights_gpu automatically uploads the MTP sidecar when + // model->mtp_->loaded is set (the safetensors loader sets it after + // parsing the sidecar tensors). See weight_upload.cu:2027 for the gate. + ASSERT_TRUE(model->upload_weights_gpu(imp::QType::F16, nullptr, 1ULL << 30)); + ASSERT_TRUE(model->mtp_->loaded); + + // Build a synthetic d_h_prev (random FP16). The MTP forward should still + // produce a valid token id even with non-realistic hidden state — we just + // want to confirm the MoE block runs without crashing or producing OOB + // tokens. + const int hidden_dim = model->config_.d_model; + const int vocab_size = model->config_.vocab_size; + const int n_experts = model->config_.n_experts; + const int top_k = model->config_.n_experts_active; + const int expert_d_ff = model->config_.expert_d_ff; + const int shared_d_ff = model->config_.expert_shared_d_ff; + + ASSERT_EQ(n_experts, 256); + ASSERT_EQ(top_k, 8); + ASSERT_EQ(expert_d_ff, 512); + ASSERT_EQ(shared_d_ff, 512); + + imp::MtpDraftWorkspace ws{}; + ASSERT_TRUE(imp::mtp_workspace_allocate(ws, hidden_dim, vocab_size, + n_experts, top_k, expert_d_ff, shared_d_ff)); + + // Build a host-side random FP16 hidden state, upload. + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.05f, 0.05f); + std::vector h_state(hidden_dim); + auto float_to_fp16 = [](float v) -> uint16_t { + // Quick-and-dirty: use cuda's __float2half via a device pass. For + // simplicity, encode FP16 by hand for values near zero. We only need + // values in a sane range, exact bits don't matter. + uint32_t bits; + std::memcpy(&bits, &v, 4); + uint16_t s = (bits >> 31) & 1; + int e = ((bits >> 23) & 0xFF) - 127; + uint32_t m = bits & 0x7FFFFF; + if (e > 15) return (s << 15) | 0x7C00; + if (e < -14) return (s << 15); + return (s << 15) | ((e + 15) << 10) | (m >> 13); + }; + for (int i = 0; i < hidden_dim; ++i) { + h_state[i] = float_to_fp16(dist(rng)); + } + void* d_h_prev = nullptr; + ASSERT_EQ(cudaMalloc(&d_h_prev, hidden_dim * sizeof(uint16_t)), cudaSuccess); + ASSERT_EQ(cudaMemcpy(d_h_prev, h_state.data(), hidden_dim * sizeof(uint16_t), + cudaMemcpyHostToDevice), cudaSuccess); + + int out_token_id = -1; + bool ok = imp::mtp_draft_step( + /*prev_token_id=*/123, // arbitrary valid token + d_h_prev, + *model->mtp_, + model->tok_emb_, + model->out_proj_, + ws, + hidden_dim, vocab_size, + &out_token_id, + /*stream=*/nullptr); + + EXPECT_TRUE(ok) << "mtp_draft_step returned false"; + EXPECT_GE(out_token_id, 0); + EXPECT_LT(out_token_id, vocab_size); + + cudaFree(d_h_prev); + imp::mtp_workspace_free(ws); +}