From e9b17cb13bd0760f85715e7bcda49aa095f3e619 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Paz=C3=B3?= Date: Mon, 11 May 2026 15:27:30 +0200 Subject: [PATCH 1/4] feat(dflash): native Qwen3.6 MTP (NextN) runtime + contract test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the runtime side of native multi-token prediction so dflash can load Qwen3.6-MTP GGUFs (am17an-style, llama.cpp PR #22673 tensor convention) and run a target-trunk + NextN-block forward in the same ggml graph. Library - src/f16_convert.cu — small bf16/f16 → f32 widen kernels (used by the MTP token-embedding widen and shared with the rollback path). - src/internal.h — new types: TargetNextN, TargetMtpLayer, TargetMtpCache, QwenMtpGraphInputs/Outputs; QwenGraphInputs gains expose_pre_norm_hidden, QwenGraphOutputs gains pre_norm_hidden; TargetWeights gains mtp_layers, nextn_predict_layers, gguf_block_count, tok_embd_gpu. No fields removed from the existing trunk API. - src/qwen35_target_graph.cpp — create/free/reset_target_mtp_cache and build_qwen35_mtp_graph (RMSNorm e || RMSNorm h → eh_proj → full-attn block → SwiGLU FFN → shared head, falling back to trunk out_norm/output when nextn.shared_head_* are absent). Wires expose_pre_norm_hidden in build_qwen35_graph. - src/gguf_target_loader.cpp — reads qwen35.nextn_predict_layers, splits block_count into trunk + tail, loads blk..nextn.* tensors into TargetWeights::mtp_layers, and uploads token_embd.weight to the GPU for MTP checkpoints so MTP can chain proposals device-side (DFLASH27B_UPLOAD_TOK_EMBD overrides). Tests - test/test_mtp_graph_contract.cpp — synthetic-tensor contract test that asserts build_qwen35_mtp_graph wires together correctly. No GPU model needed (~49 nodes, runs in milliseconds). Suitable for CI. - test/smoke_mtp_graph.cpp — loads a real MTP GGUF, builds the NextN graph for one token, asserts the output is finite. - test/smoke_target_mtp_handoff.cpp — loads a real MTP GGUF and runs target + MTP in the SAME ggml_cgraph, proving the pre_norm_hidden handoff lives entirely on-device. - test/smoke_mtp_integrated_decode.cpp — minimal greedy decode loop: target greedy + MTP greedy, accept/correct counters, tok/s summary. Functional baseline for the upcoming speculative loop. CMake - f16_convert.cu added to the dflash27b library sources. - Four new test targets registered (test_mtp_graph_contract + three smokes). Linked against dflash27b + ggml + ggml-cuda + CUDA::cudart. Validation on RTX 6000 Ada (sm_89), Qwen3.6-27B-MTP Q4_K_M: test_mtp_graph_contract → PASS (49 graph nodes, shapes correct) smoke_mtp_graph → PASS (logits 0 NaN, 0 Inf, [-24.3, 14.4]) smoke_target_mtp_handoff → PASS (3061 nodes, both heads clean) smoke_mtp_integrated_decode (8 tokens) → PASS, 50% greedy acceptance, 23.6 tok/s Honest scope - MoE MTP is not supported in this PR (build_qwen35_mtp_graph fails fast with a clear message). The 35B-A3B MTP GGUFs need the MoE TargetLayer fields that howard0su is upstreaming in #120. A MoE-aware MTP graph is a one-line dispatch on top of this PR once #120 merges. - This PR ships the runtime contract only. The integrated speculative decode loop (chain-2 / tree-fused / immediate-bonus), the daemon-side --mtp-integrated wiring, and the mtp_baseline_gate.py parity harness land in a follow-up PR. Measured locally on the same MTP GGUF with MTP disabled vs enabled (chain-2, n_gen=256) we see +36% tok/s today — but the speculative loop driving that number is not in this PR yet. Compatible GGUFs include am17an/Qwen3.6-27B-MTP-GGUF and the havenoammo / froggeric Unsloth UD repacks; tensor naming follows llama.cpp #22673. --- dflash/CMakeLists.txt | 29 +++ dflash/docs/MTP_2026-05-11.md | 111 ++++++++++ dflash/src/f16_convert.cu | 49 +++++ dflash/src/gguf_target_loader.cpp | 152 +++++++++++-- dflash/src/internal.h | 106 ++++++++- dflash/src/qwen35_target_graph.cpp | 204 +++++++++++++++++- dflash/test/smoke_mtp_graph.cpp | 173 +++++++++++++++ dflash/test/smoke_mtp_integrated_decode.cpp | 225 ++++++++++++++++++++ dflash/test/smoke_target_mtp_handoff.cpp | 199 +++++++++++++++++ dflash/test/test_mtp_graph_contract.cpp | 108 ++++++++++ 10 files changed, 1338 insertions(+), 18 deletions(-) create mode 100644 dflash/docs/MTP_2026-05-11.md create mode 100644 dflash/src/f16_convert.cu create mode 100644 dflash/test/smoke_mtp_graph.cpp create mode 100644 dflash/test/smoke_mtp_integrated_decode.cpp create mode 100644 dflash/test/smoke_target_mtp_handoff.cpp create mode 100644 dflash/test/test_mtp_graph_contract.cpp diff --git a/dflash/CMakeLists.txt b/dflash/CMakeLists.txt index a4bb575ff..66f27e9e2 100644 --- a/dflash/CMakeLists.txt +++ b/dflash/CMakeLists.txt @@ -126,6 +126,8 @@ add_library(dflash27b STATIC src/laguna_target_graph.cpp src/laguna_daemon.cpp src/sampler.cpp + # Native MTP / NextN helpers + src/f16_convert.cu ) # FlashPrefill custom CUDA kernels need BF16 WMMA (sm_80+). On Turing (sm_75) # the drafter uses ggml's flash_attn_ext instead. Guard added after SM check. @@ -334,6 +336,29 @@ if(DFLASH27B_TESTS) endif() endif() + # Native MTP / NextN: contract test + functional smokes. The contract test + # uses synthetic tensors and runs in CI; the smokes need a real MTP GGUF. + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_mtp_graph_contract.cpp") + add_executable(test_mtp_graph_contract test/test_mtp_graph_contract.cpp) + target_include_directories(test_mtp_graph_contract PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(test_mtp_graph_contract PRIVATE dflash27b ggml ggml-cuda) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_mtp_graph.cpp") + add_executable(smoke_mtp_graph test/smoke_mtp_graph.cpp) + target_include_directories(smoke_mtp_graph PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_mtp_graph PRIVATE dflash27b ggml ggml-cuda) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_target_mtp_handoff.cpp") + add_executable(smoke_target_mtp_handoff test/smoke_target_mtp_handoff.cpp) + target_include_directories(smoke_target_mtp_handoff PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_target_mtp_handoff PRIVATE dflash27b ggml ggml-cuda) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_mtp_integrated_decode.cpp") + add_executable(smoke_mtp_integrated_decode test/smoke_mtp_integrated_decode.cpp) + target_include_directories(smoke_mtp_integrated_decode PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + target_link_libraries(smoke_mtp_integrated_decode PRIVATE dflash27b ggml ggml-cuda) + endif() + # internal.h includes when GGML_USE_CUDA is set; link # CUDA::cudart so the toolkit headers are on the compile line (same as # test_dflash historically had alone). @@ -350,6 +375,10 @@ if(DFLASH27B_TESTS) smoke_target_forward test_generate test_dflash + test_mtp_graph_contract + smoke_mtp_graph + smoke_target_mtp_handoff + smoke_mtp_integrated_decode ) foreach(_t IN LISTS _dflash_internal_h_cuda_tests) if(TARGET ${_t}) diff --git a/dflash/docs/MTP_2026-05-11.md b/dflash/docs/MTP_2026-05-11.md new file mode 100644 index 000000000..04b303165 --- /dev/null +++ b/dflash/docs/MTP_2026-05-11.md @@ -0,0 +1,111 @@ +# Native MTP (NextN) — runtime status, 2026-05-11 + +This document describes the native multi-token prediction (MTP / NextN) +runtime introduced into `dflash` in PR `feat(dflash): native Qwen3.6 MTP +integrated decode`. It tracks the contract, what already works, and what +remains for the next PR before MTP becomes a default-on decode mode. + +## What this PR ships + +- `dflash/src/f16_convert.cu` — small `f16/bf16 → f32` widen kernels used + by both the rollback path and the MTP token-embedding widen. +- `dflash/src/internal.h` — new types: + - `TargetNextN`, `TargetMtpLayer` + - `TargetMtpCache` (KV cache for the NextN tail block only) + - `QwenMtpGraphInputs`, `QwenMtpGraphOutputs` + - `expose_pre_norm_hidden` on `QwenGraphInputs` + - `pre_norm_hidden` on `QwenGraphOutputs` + - `TargetWeights::mtp_layers`, `nextn_predict_layers`, `gguf_block_count`, + `tok_embd_gpu` (no fields removed; the trunk API is preserved). +- `dflash/src/qwen35_target_graph.cpp` — four new functions: + - `create_target_mtp_cache` / `free_target_mtp_cache` / `reset_target_mtp_cache` + - `build_qwen35_mtp_graph` — RMSNorm(e) || RMSNorm(h) → `eh_proj` → + full-attention transformer block → SwiGLU FFN → shared head. + Also wires `expose_pre_norm_hidden` into `build_qwen35_graph`. +- `dflash/src/gguf_target_loader.cpp` — reads `qwen35.nextn_predict_layers`, + splits the GGUF blocks into trunk + MTP tail, loads `blk..nextn.*` + tensors into `TargetWeights::mtp_layers`, and uploads `token_embd.weight` + to the GPU when the checkpoint carries MTP (`DFLASH27B_UPLOAD_TOK_EMBD` + env var overrides). +- `dflash/test/test_mtp_graph_contract.cpp` — synthetic-tensor test that + asserts the MTP graph wires together correctly. No GPU model needed; + cheap to run in CI. +- `dflash/test/smoke_mtp_graph.cpp` — loads a real MTP GGUF, builds the + NextN graph for a single token, and validates the output is finite. +- `dflash/test/smoke_target_mtp_handoff.cpp` — loads a real MTP GGUF and + proves that the trunk pre-norm hidden tensor feeds directly into the + MTP block within the same `ggml_cgraph` (no CPU roundtrip required). +- `dflash/test/smoke_mtp_integrated_decode.cpp` — full integrated decode + loop: target greedy + MTP greedy in one graph, with per-step accept / + correct counters. This is the functional baseline the upcoming PR's + speculative loop will be built on top of. + +## GGUF compatibility + +The loader follows the tensor naming convention introduced by llama.cpp's +[MTP PR #22673](https://github.com/ggml-org/llama.cpp/pull/22673). It is +compatible with the reference Qwen3.6-MTP GGUFs published on the Hub: + +- `am17an/Qwen3.6-27B-MTP-GGUF` +- `am17an/Qwen3.6-35BA3B-MTP-GGUF` (MoE — see "MoE limitation" below) +- `havenoammo/Qwen3.6-27B-MTP-UD-GGUF` +- `havenoammo/Qwen3.6-35B-A3B-MTP-GGUF` +- `froggeric/Qwen3.6-27B-MTP-GGUF` + +The expected tail-block tensor names are: + +```text +blk..nextn.eh_proj.weight [2 * hidden, hidden] +blk..nextn.embed_tokens.weight [hidden, vocab] (optional) +blk..nextn.enorm.weight [hidden] +blk..nextn.hnorm.weight [hidden] +blk..nextn.shared_head_head.weight [hidden, vocab] (optional) +blk..nextn.shared_head_norm.weight [hidden] (optional) +``` + +When the optional shared-head tensors are absent the runtime falls back to +the trunk's `output_norm` / `output` (lm_head), matching how am17an's +GGUFs are typically packed. + +## MoE limitation + +`build_qwen35_mtp_graph` currently implements the dense-FFN path only. The +35B-A3B MTP GGUFs require the MoE `TargetLayer` fields and the routed +FFN path that howard0su is upstreaming in +[PR #120 "Qwen3.5 MoE support"](https://github.com/Luce-Org/lucebox-hub/pull/120). +A MoE-aware `build_qwen35_mtp_graph` is a one-line dispatch on top of +this PR once #120 lands. Until then, loading a MoE-MTP GGUF + invoking +the MTP graph returns a clear error rather than producing wrong output. + +## Why MTP is opt-in, not default-on + +Measured today against `DFlash + PFlash` on the same MTP GGUF with MTP +disabled, on a single RTX 6000 Ada (sm_89), Qwen3.6-27B Q4_K_M target, +`q4_0/q4_0` KV, FA_WINDOW=0, DDTree budget=16, draft feature mirror on: + +| n_gen | Same GGUF, MTP off (tok/s) | Same GGUF, MTP chain-2 (tok/s) | Δ | +|---:|---:|---:|---:| +| 64 | 57.58 | 54.72 | **−5.0%** | +| 128 | 67.58 | 64.23 | **−5.0%** | +| 256 | 60.40 | 82.18 | **+36.1%** | + +What changes between 64 and 256 tokens is that DDTree rounds drop from +roughly 60 → 38 and average tokens committed per draft step rise from +4.27 → 6.74, so the extra MTP forward starts paying for itself. + +This is real but workload-dependent acceleration, not a universal default. +The next PR adds the speculative loop that turns this into a default-on +mode for long generations; today's PR ships only the runtime contract and +the tests that pin it. + +## Known follow-ups (next PR) + +1. Speculative decode loop wiring (`run_mtp_integrated_prompt`, + target-batched verify, fast rollback) inside `test_dflash`. +2. Daemon-side `--mtp-integrated` CLI + metrics surface (`[mtp-daemon]` + line, `last_mtp` aggregated in `prefix_cache.py`). +3. `mtp_baseline_gate.py` published as a reusable parity gate harness. +4. CPU hidden-readback elimination — the current functional smoke still + round-trips token ids through CPU between MTP steps. Removing that is + the highest-value perf fix and is queued behind CUDA-graph capture. +5. MoE MTP path after PR #120 merges. diff --git a/dflash/src/f16_convert.cu b/dflash/src/f16_convert.cu new file mode 100644 index 000000000..49bd309ee --- /dev/null +++ b/dflash/src/f16_convert.cu @@ -0,0 +1,49 @@ +// Tiny half-precision → f32 conversion kernels used by the DDtree rollback +// path and the drafter's target_feat widen. We store some tensors +// (ssm_intermediate, target_feat) at 16-bit to halve their memory footprint, +// and widen on read into f32 consumers. +// +// Exposes plain C entry points so test_dflash.cpp can call them without +// pulling in a CUDA compile unit of its own. + +#include +#include +#include + +static __global__ void f16_to_f32_kernel(const __half * __restrict__ src, + float * __restrict__ dst, + size_t n_elems) { + const size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_elems) { + dst[i] = __half2float(src[i]); + } +} + +static __global__ void bf16_to_f32_kernel(const __nv_bfloat16 * __restrict__ src, + float * __restrict__ dst, + size_t n_elems) { + const size_t i = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i < n_elems) { + dst[i] = __bfloat162float(src[i]); + } +} + +extern "C" void dflash27b_launch_f16_to_f32(const void * src, + void * dst, + size_t n_elems, + cudaStream_t stream) { + const int threads = 256; + const int blocks = (int)((n_elems + threads - 1) / threads); + f16_to_f32_kernel<<>>( + (const __half *)src, (float *)dst, n_elems); +} + +extern "C" void dflash27b_launch_bf16_to_f32(const void * src, + void * dst, + size_t n_elems, + cudaStream_t stream) { + const int threads = 256; + const int blocks = (int)((n_elems + threads - 1) / threads); + bf16_to_f32_kernel<<>>( + (const __nv_bfloat16 *)src, (float *)dst, n_elems); +} diff --git a/dflash/src/gguf_target_loader.cpp b/dflash/src/gguf_target_loader.cpp index f5fde060a..38ccfdf96 100644 --- a/dflash/src/gguf_target_loader.cpp +++ b/dflash/src/gguf_target_loader.cpp @@ -273,7 +273,7 @@ bool load_target_gguf_partial(const std::string & path, std::string err; const uint32_t n_embd = get_u32_or(gctx, "qwen35.embedding_length", 0); const uint32_t n_ff = get_u32_or(gctx, "qwen35.feed_forward_length", 0); - const uint32_t n_layer= get_u32_or(gctx, "qwen35.block_count", 0); + const uint32_t n_block= get_u32_or(gctx, "qwen35.block_count", 0); const uint32_t n_head = get_u32_or(gctx, "qwen35.attention.head_count",0); const uint32_t n_headkv=get_u32_or(gctx, "qwen35.attention.head_count_kv",0); const uint32_t kl = get_u32_or(gctx, "qwen35.attention.key_length", 0); @@ -285,21 +285,48 @@ bool load_target_gguf_partial(const std::string & path, const uint32_t ssm_dt = get_u32_or(gctx, "qwen35.ssm.time_step_rank",0); const uint32_t ssm_grp = get_u32_or(gctx, "qwen35.ssm.group_count", 0); - if (n_embd == 0 || n_layer == 0 || n_head == 0 || n_headkv == 0 || + // Native MTP / NextN: zero on non-MTP GGUFs, 1 on the am17an Qwen3.6-MTP + // GGUFs. We treat the last `nextn_predict_layers` blocks as the MTP tail + // and the remaining `block_count - nextn` as the trunk. + const uint32_t nextn_predict_layers = get_u32_or(gctx, "qwen35.nextn_predict_layers", 0); + + if (n_embd == 0 || n_block == 0 || n_head == 0 || n_headkv == 0 || kl == 0 || vl == 0 || n_ff == 0 || fai == 0 || ssm_conv == 0 || ssm_inner == 0 || ssm_state == 0 || ssm_dt == 0 || ssm_grp == 0) { char buf[512]; std::snprintf(buf, sizeof(buf), - "missing or zero hparams: n_embd=%u n_layer=%u n_head=%u n_head_kv=%u " + "missing or zero hparams: n_embd=%u n_block=%u n_head=%u n_head_kv=%u " "kl=%u vl=%u n_ff=%u fai=%u ssm{conv=%u inner=%u state=%u dt=%u grp=%u}", - n_embd, n_layer, n_head, n_headkv, kl, vl, n_ff, fai, + n_embd, n_block, n_head, n_headkv, kl, vl, n_ff, fai, ssm_conv, ssm_inner, ssm_state, ssm_dt, ssm_grp); set_last_error(buf); gguf_free(gctx); return false; } + if (nextn_predict_layers > n_block) { + char buf[160]; + std::snprintf(buf, sizeof(buf), + "nextn_predict_layers=%u exceeds block_count=%u", + nextn_predict_layers, n_block); + set_last_error(buf); + gguf_free(gctx); return false; + } + if (nextn_predict_layers > 1) { + char buf[160]; + std::snprintf(buf, sizeof(buf), + "nextn_predict_layers=%u not supported yet (loader supports 0 or 1)", + nextn_predict_layers); + set_last_error(buf); + gguf_free(gctx); return false; + } + const uint32_t n_layer = n_block - nextn_predict_layers; + if (n_layer == 0) { + set_last_error("no trunk layers left after subtracting nextn_predict_layers"); + gguf_free(gctx); return false; + } + // Structural invariants required by the graph builder. if (kl != vl) { set_last_error("key_length != value_length not supported"); @@ -312,8 +339,10 @@ bool load_target_gguf_partial(const std::string & path, gguf_free(gctx); return false; } if (n_layer % fai != 0) { - char buf[128]; - std::snprintf(buf, sizeof(buf), "block_count=%u not divisible by full_attention_interval=%u", n_layer, fai); + char buf[160]; + std::snprintf(buf, sizeof(buf), + "trunk layer count=%u (block_count=%u nextn=%u) not divisible by full_attention_interval=%u", + n_layer, n_block, nextn_predict_layers, fai); set_last_error(buf); gguf_free(gctx); return false; } @@ -364,13 +393,15 @@ bool load_target_gguf_partial(const std::string & path, TargetLoadPlan plan = plan_in; if (plan.layer_begin < 0) plan.layer_begin = 0; - if (plan.layer_end < 0) plan.layer_end = (int)n_layer; + // Default end covers trunk + MTP/NextN tail so blk..* + // tensors are uploaded when nextn_predict_layers > 0. + if (plan.layer_end < 0) plan.layer_end = (int)n_block; if (plan.layer_begin > plan.layer_end || - plan.layer_end > (int)n_layer) { + plan.layer_end > (int)n_block) { char buf[160]; std::snprintf(buf, sizeof(buf), - "invalid target load layer range [%d,%d) for n_layer=%u", - plan.layer_begin, plan.layer_end, n_layer); + "invalid target load layer range [%d,%d) for n_block=%u", + plan.layer_begin, plan.layer_end, n_block); set_last_error(buf); gguf_free(gctx); return false; @@ -379,6 +410,8 @@ bool load_target_gguf_partial(const std::string & path, out.ctx = meta_ctx; out.backend = backend; out.n_layer = (int)n_layer; + out.gguf_block_count = (int)n_block; + out.nextn_predict_layers = (int)nextn_predict_layers; out.n_embd = (int)n_embd; out.n_ff = (int)n_ff; out.n_head = (int)n_head; @@ -392,6 +425,7 @@ bool load_target_gguf_partial(const std::string & path, out.ssm_d_state= (int)ssm_state; out.ssm_dt_rank= (int)ssm_dt; out.ssm_n_group= (int)ssm_grp; + out.mtp_layers.assign((size_t)nextn_predict_layers, TargetMtpLayer{}); // EOS token ids from GGUF tokenizer metadata (stored as UINT32 by the // GGUF spec; we use the u32 helper and cast). UINT32_MAX is the @@ -491,6 +525,77 @@ bool load_target_gguf_partial(const std::string & path, } } + // ── 2b. Wire MTP / NextN tail blocks (Qwen3.6-MTP GGUFs) ───────── + // GGUF block index for MTP layer `mi` is (n_layer + mi). Each MTP block + // ships a regular full-attention transformer (no DeltaNet) plus the + // NextN-specific projections (eh_proj, enorm, hnorm, optional shared head). + for (int mi = 0; mi < (int)nextn_predict_layers; mi++) { + const int il = (int)n_layer + mi; + char name[128]; + auto fnd = [&](const char * suffix) -> ggml_tensor * { + std::snprintf(name, sizeof(name), "blk.%d.%s", il, suffix); + return ggml_get_tensor(meta_ctx, name); + }; + TargetMtpLayer & M = out.mtp_layers[(size_t)mi]; + M.gguf_layer_index = il; + TargetLayer & L = M.block; + + L.attn_norm = fnd("attn_norm.weight"); + L.attn_post_norm = fnd("post_attention_norm.weight"); + L.w_gate = fnd("ffn_gate.weight"); + L.w_up = fnd("ffn_up.weight"); + L.w_down = fnd("ffn_down.weight"); + L.wq = fnd("attn_q.weight"); + L.wk = fnd("attn_k.weight"); + L.wv = fnd("attn_v.weight"); + L.wo = fnd("attn_output.weight"); + L.q_norm = fnd("attn_q_norm.weight"); + L.k_norm = fnd("attn_k_norm.weight"); + + M.nextn.eh_proj = fnd("nextn.eh_proj.weight"); + M.nextn.embed_tokens = fnd("nextn.embed_tokens.weight"); + M.nextn.enorm = fnd("nextn.enorm.weight"); + M.nextn.hnorm = fnd("nextn.hnorm.weight"); + M.nextn.shared_head_head = fnd("nextn.shared_head_head.weight"); + M.nextn.shared_head_norm = fnd("nextn.shared_head_norm.weight"); + + const bool has_attn = L.wq && L.wk && L.wv && L.wo && L.q_norm && L.k_norm; + if (!L.attn_norm || !L.attn_post_norm || !has_attn) { + char b[160]; + std::snprintf(b, sizeof(b), + "mtp layer %d: missing full-attention tensors", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + if (!L.w_gate || !L.w_up || !L.w_down) { + char b[160]; + std::snprintf(b, sizeof(b), + "mtp layer %d: missing required FFN tensors", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + if (!M.nextn.eh_proj || !M.nextn.enorm || !M.nextn.hnorm) { + char b[160]; + std::snprintf(b, sizeof(b), + "mtp layer %d: missing required nextn tensors " + "(eh_proj/enorm/hnorm)", il); + set_last_error(b); + gguf_free(gctx); + return false; + } + } + + // Plain target decode still embeds on CPU. Native MTP needs device-side + // token lookup to chain proposals inside one graph, so MTP-enabled + // checkpoints upload token_embd to the GPU as a regular weight. + bool upload_tok_embd = nextn_predict_layers > 0; + if (const char * s = std::getenv("DFLASH27B_UPLOAD_TOK_EMBD")) { + upload_tok_embd = std::atoi(s) != 0; + } + out.tok_embd_gpu = upload_tok_embd; + // 3. Allocate CUDA buffer only for the selected target tensors. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); const size_t alignment = ggml_backend_buft_get_alignment(buft); @@ -500,7 +605,11 @@ bool load_target_gguf_partial(const std::string & path, for (int64_t tid = 0; tid < n_tensors; tid++) { const char * tname = gguf_get_tensor_name(gctx, tid); ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); - if (!t || !should_load_target_tensor(tname, plan.layer_begin, plan.layer_end, plan.load_output)) { + if (!t) continue; + const bool is_tok_embd = (std::strcmp(tname, "token_embd.weight") == 0); + const bool selected_by_plan = + should_load_target_tensor(tname, plan.layer_begin, plan.layer_end, plan.load_output); + if (!selected_by_plan && !(is_tok_embd && upload_tok_embd)) { continue; } alloc_total = align_up_size(alloc_total, alignment); @@ -559,10 +668,15 @@ bool load_target_gguf_partial(const std::string & path, return false; } if (std::string(tname) == "token_embd.weight") { - // Remember offset + size for the CPU embedder; don't upload to GPU. + // Remember offset + size for the CPU embedder regardless of GPU + // upload — MTP still needs the CPU mmap for tokenizer-side lookups. tok_embd_off = off; tok_embd_sz = sz; tok_embd_type = gguf_get_tensor_type(gctx, tid); + if (!upload_tok_embd) continue; + // MTP path: also stream the bytes into the GPU-resident tensor. + ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); + total += sz; continue; } if (!should_load_target_tensor(tname, plan.layer_begin, plan.layer_end, plan.load_output)) { @@ -597,12 +711,16 @@ bool load_target_gguf_partial(const std::string & path, mm.release(); // don't munmap on Mmap dtor — now owned by the embedder // Stash the total for callers that want to print it - char summary[192]; + char summary[256]; std::snprintf(summary, sizeof(summary), - "target loaded: layers [%d,%d) output=%d, %zu tensors on GPU %.2f GiB, tok_embd %.0f MiB CPU-only (%s)", + "target loaded: layers [%d,%d) output=%d, %zu tensors on GPU %.2f GiB, " + "tok_embd %.0f MiB %s (%s), trunk_layers=%d nextn=%d", plan.layer_begin, plan.layer_end, (int)plan.load_output, allocs.size(), total / (1024.0 * 1024.0 * 1024.0), - tok_embd_sz / (1024.0 * 1024.0), ggml_type_name(tok_embd_type)); + tok_embd_sz / (1024.0 * 1024.0), + upload_tok_embd ? "GPU+CPU" : "CPU-only", + ggml_type_name(tok_embd_type), + out.n_layer, out.nextn_predict_layers); set_last_error(summary); return true; @@ -613,7 +731,11 @@ void free_target_weights(TargetWeights & w) { if (w.ctx) { ggml_free(w.ctx); w.ctx = nullptr; } // CpuEmbedder destructor handles the mmap automatically. w.layers.clear(); + w.mtp_layers.clear(); + w.nextn_predict_layers = 0; + w.gguf_block_count = 0; w.tok_embd = nullptr; + w.tok_embd_gpu = false; w.out_norm = nullptr; w.output = nullptr; } diff --git a/dflash/src/internal.h b/dflash/src/internal.h index b9cc88d55..945e86112 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -73,6 +73,34 @@ struct TargetLayer { ggml_tensor * ssm_out = nullptr; // output projection after delta-net }; +// Qwen3.5/3.6 NextN / MTP tail block tensors. These live in the tail +// `nextn_predict_layers` of the GGUF (one such block in Qwen3.6-MTP). +// Follow the tensor naming convention introduced by llama.cpp PR #22673: +// blk..nextn.eh_proj [2*hidden, hidden] +// blk..nextn.embed_tokens [hidden, vocab] (optional) +// blk..nextn.enorm [hidden] +// blk..nextn.hnorm [hidden] +// blk..nextn.shared_head_head [hidden, vocab] (optional, falls back to w.output) +// blk..nextn.shared_head_norm [hidden] (optional, falls back to w.out_norm) +struct TargetNextN { + ggml_tensor * eh_proj = nullptr; + ggml_tensor * embed_tokens = nullptr; + ggml_tensor * enorm = nullptr; + ggml_tensor * hnorm = nullptr; + ggml_tensor * shared_head_head = nullptr; + ggml_tensor * shared_head_norm = nullptr; +}; + +// One MTP / NextN layer in the GGUF tail. Holds the regular transformer +// block tensors (full-attention only — no DeltaNet on MTP) plus the +// NextN-specific projections above. The trunk decoder's pre-norm hidden +// state is fed into this block to produce the MTP draft logits. +struct TargetMtpLayer { + TargetLayer block; + TargetNextN nextn; + int gguf_layer_index = -1; +}; + // CPU-side embedder: keeps a mmap of the GGUF alive and knows how to // dequantize individual rows of the quantized tok_embd tensor on demand. // This matches llama.cpp's behavior of running embedding get_rows on CPU @@ -108,7 +136,11 @@ struct TargetWeights { CpuEmbedder embedder; ggml_tensor * tok_embd = nullptr; // [hidden, vocab] (metadata only; data NOT on GPU) - std::vector layers; // size = 64 + bool tok_embd_gpu = false; // true when token_embd bytes were uploaded for GPU get_rows. + // Required by MTP because the integrated decode path needs + // device-side token lookup to chain proposals within a graph. + std::vector layers; // trunk layers only, excludes any nextn/MTP tail blocks + std::vector mtp_layers; // size = nextn_predict_layers (0 for non-MTP GGUFs) ggml_tensor * out_norm = nullptr; // [hidden] ggml_tensor * output = nullptr; // [hidden, vocab] (lm_head) @@ -119,7 +151,9 @@ struct TargetWeights { int n_embd_head_v = 256; // value_length int n_head = 24; int n_head_kv = 4; - int n_layer = 64; + int gguf_block_count = 64; // raw qwen35.block_count from the GGUF + int nextn_predict_layers = 0; // qwen35.nextn_predict_layers (0 = non-MTP GGUF) + int n_layer = 64; // trunk layer count: gguf_block_count - nextn_predict_layers int n_embd = 5120; int n_ff = 17408; int ssm_d_conv = 4; @@ -413,6 +447,36 @@ bool migrate_prefill_cache(const TargetWeights & w, ggml_backend_t backend, TargetCache & cache); +// ─── Native MTP / NextN cache ───────────────────────────────────── +// +// Qwen3.5/3.6 native multi-token prediction keeps a tiny KV cache for the +// tail NextN block(s) only — the trunk decoder retains its own TargetCache +// above. Matches the "kv_only_nextn" contract used by llama.cpp PR #22673 +// and llama-crucible's MTP cache layout. +struct TargetMtpCache { + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_backend_t backend = nullptr; + + int max_ctx = 0; + int cur_pos = 0; + + ggml_type kv_k_type = GGML_TYPE_Q8_0; + ggml_type kv_v_type = GGML_TYPE_Q8_0; + + std::vector attn_k; // one per TargetWeights::mtp_layers entry + std::vector attn_v; +}; + +bool create_target_mtp_cache(const TargetWeights & w, + int max_ctx, + ggml_backend_t backend, + TargetMtpCache & out); + +void free_target_mtp_cache(TargetMtpCache & c); + +void reset_target_mtp_cache(TargetMtpCache & c); + // ─── Target forward graph ───────────────────────────────────────── // Per-delta-net-layer pointers exposed by the graph for spec-decode rollback. @@ -443,6 +507,7 @@ struct QwenGraphInputs { int kv_start; // position where the new tokens begin bool capture_layers; // if true, write captured layer features into cache.target_feat bool capture_delta_intermediate = false; // if true, populate out_delta_captures + bool expose_pre_norm_hidden = false; // if true, keep final pre-norm hidden for MTP handoff int fa_window = 0; // sliding window for FA layers: 0 = full attention bool last_token_logits_only = false; // if true, only compute logits for last token (prefill optimization) ggml_tensor * parent_ids = nullptr; // [n_tokens] i32; tree mode when non-null @@ -450,6 +515,11 @@ struct QwenGraphInputs { struct QwenGraphOutputs { ggml_tensor * logits; // [vocab, n_tokens] f32 + // Final hidden state before the target output norm. Populated when + // QwenGraphInputs::expose_pre_norm_hidden is true. Used as the + // `t_h_pre_norm` handoff into the native NextN/MTP block — matches the + // convention from llama-crucible MTP. + ggml_tensor * pre_norm_hidden = nullptr; // [hidden, n_tokens] f32 // One entry per delta-net layer (48 for qwen35-27b). Only populated when // QwenGraphInputs::capture_delta_intermediate is true. Tensors are graph // views marked as ggml_set_output() so their data persists after @@ -464,6 +534,38 @@ QwenGraphOutputs build_qwen35_graph( TargetCache & cache, const QwenGraphInputs & in); +// ─── Native MTP / NextN forward graph ───────────────────────────── +// +// Single-layer NextN/MTP block. Consumes the trunk decoder's pre-norm hidden +// state (`pre_norm_hidden`) plus the embedding of the current token, runs +// the NextN concat → eh_proj → transformer-block → shared-head pipeline, +// and returns the MTP draft logits + the post-block hidden state. +// +// Today this PR implements the dense-FFN path only; MoE MTP requires the +// MoE TargetLayer fields landing first (see PR #120 "Qwen3.5 MoE support"). +struct QwenMtpGraphInputs { + ggml_tensor * token_embed; // [hidden, n_tokens] f32; embedding of the current token(s) + ggml_tensor * pre_norm_hidden; // [hidden, n_tokens] f32 from trunk output + ggml_tensor * positions; // [4 * n_tokens] i32 + ggml_tensor * attn_mask; // optional [kv_len, n_tokens_padded] f32 + int n_tokens = 1; + int kv_start = 0; + int mtp_layer_index = 0; + int fa_window = 0; +}; + +struct QwenMtpGraphOutputs { + ggml_tensor * logits = nullptr; // [vocab, n_tokens] f32 + ggml_tensor * hidden = nullptr; // [hidden, n_tokens] f32, post-MTP block +}; + +QwenMtpGraphOutputs build_qwen35_mtp_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const TargetWeights & w, + TargetMtpCache & cache, + const QwenMtpGraphInputs & in); + // Build a single-layer forward graph. Mirrors build_qwen35_graph but processes // only one layer, taking `inp` as the input activation and returning the output. // Used by layer-segmented prefill to iterate layers as the outer loop. diff --git a/dflash/src/qwen35_target_graph.cpp b/dflash/src/qwen35_target_graph.cpp index 47b989ff2..039a9cf7c 100644 --- a/dflash/src/qwen35_target_graph.cpp +++ b/dflash/src/qwen35_target_graph.cpp @@ -1358,6 +1358,17 @@ QwenGraphOutputs build_qwen35_graph( inpL = cur; } + // Expose the final pre-norm hidden state so the native NextN/MTP block can + // consume it as `t_h_pre_norm`. Marked as graph output so its data persists + // after graph_compute. Caller threads it into build_qwen35_mtp_graph. + QwenGraphOutputs og = std::move(og_early); + if (in.expose_pre_norm_hidden) { + ggml_set_name(inpL, "target_pre_norm_hidden"); + ggml_set_output(inpL); + ggml_build_forward_expand(gf, inpL); + og.pre_norm_hidden = inpL; + } + // 2. Final norm ggml_tensor * out = rms_norm_mul(ctx, inpL, w.out_norm, EPS); @@ -1373,7 +1384,6 @@ QwenGraphOutputs build_qwen35_graph( ggml_build_forward_expand(gf, logits); - QwenGraphOutputs og = std::move(og_early); og.logits = logits; return og; } @@ -1396,4 +1406,196 @@ ggml_tensor * build_qwen35_layer( attn_mask, kv_start, n_tokens, capture, fa_window); } +// ─── Native MTP / NextN cache and graph ─────────────────────────────── + +bool create_target_mtp_cache(const TargetWeights & w, + int max_ctx, + ggml_backend_t backend, + TargetMtpCache & out) { + if (w.mtp_layers.empty()) { + set_last_error("create_target_mtp_cache requires TargetWeights::mtp_layers"); + return false; + } + if (max_ctx <= 0) { + set_last_error("create_target_mtp_cache requires max_ctx > 0"); + return false; + } + + out.backend = backend; + out.max_ctx = max_ctx; + out.cur_pos = 0; + + ggml_type kv_k_type = GGML_TYPE_Q8_0; + ggml_type kv_v_type = GGML_TYPE_Q8_0; + dflash::resolve_kv_types(kv_k_type, kv_v_type); + out.kv_k_type = kv_k_type; + out.kv_v_type = kv_v_type; + const int max_ctx_alloc = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0) + ? ((max_ctx + 255) / 256) * 256 + : max_ctx; + + const int n_mtp = (int)w.mtp_layers.size(); + out.attn_k.assign(n_mtp, nullptr); + out.attn_v.assign(n_mtp, nullptr); + + ggml_init_params ip{}; + ip.mem_size = (size_t)(2 * n_mtp + 16) * ggml_tensor_overhead(); + ip.mem_buffer = nullptr; + ip.no_alloc = true; + out.ctx = ggml_init(ip); + if (!out.ctx) { + set_last_error("mtp cache ggml_init failed"); + return false; + } + + for (int mi = 0; mi < n_mtp; mi++) { + ggml_tensor * K = ggml_new_tensor_3d(out.ctx, kv_k_type, + w.n_embd_head_k, max_ctx_alloc, w.n_head_kv); + ggml_tensor * V = ggml_new_tensor_3d(out.ctx, kv_v_type, + w.n_embd_head_k, max_ctx_alloc, w.n_head_kv); + char name[64]; + std::snprintf(name, sizeof(name), "mtp_cache_k_%d", mi); + ggml_set_name(K, name); + std::snprintf(name, sizeof(name), "mtp_cache_v_%d", mi); + ggml_set_name(V, name); + out.attn_k[mi] = K; + out.attn_v[mi] = V; + } + + out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend); + if (!out.buf) { + set_last_error("ggml_backend_alloc_ctx_tensors failed for mtp cache"); + ggml_free(out.ctx); + out.ctx = nullptr; + out.attn_k.clear(); + out.attn_v.clear(); + return false; + } + + reset_target_mtp_cache(out); + return true; +} + +void free_target_mtp_cache(TargetMtpCache & c) { + if (c.buf) { ggml_backend_buffer_free(c.buf); c.buf = nullptr; } + if (c.ctx) { ggml_free(c.ctx); c.ctx = nullptr; } + c.attn_k.clear(); + c.attn_v.clear(); + c.max_ctx = 0; + c.cur_pos = 0; +} + +void reset_target_mtp_cache(TargetMtpCache & c) { + c.cur_pos = 0; + std::vector zeros(1 * 1024 * 1024, 0); + if (!c.ctx) return; + for (ggml_tensor * t = ggml_get_first_tensor(c.ctx); t != nullptr; + t = ggml_get_next_tensor(c.ctx, t)) { + size_t nb = ggml_nbytes(t); + size_t off = 0; + while (off < nb) { + size_t chunk = std::min(nb - off, zeros.size()); + ggml_backend_tensor_set(t, zeros.data(), off, chunk); + off += chunk; + } + } +} + +QwenMtpGraphOutputs build_qwen35_mtp_graph( + ggml_context * ctx, + ggml_cgraph * gf, + const TargetWeights & w, + TargetMtpCache & cache, + const QwenMtpGraphInputs & in) { + + if (w.mtp_layers.empty()) { + set_last_error("build_qwen35_mtp_graph requires TargetWeights::mtp_layers"); + return {}; + } + if (in.mtp_layer_index < 0 || in.mtp_layer_index >= (int)w.mtp_layers.size()) { + set_last_error("build_qwen35_mtp_graph mtp_layer_index out of range"); + return {}; + } + if (!in.token_embed || !in.pre_norm_hidden || !in.positions) { + set_last_error("build_qwen35_mtp_graph missing required input tensor"); + return {}; + } + if ((int)cache.attn_k.size() <= in.mtp_layer_index || + (int)cache.attn_v.size() <= in.mtp_layer_index || + !cache.attn_k[in.mtp_layer_index] || !cache.attn_v[in.mtp_layer_index]) { + set_last_error("build_qwen35_mtp_graph missing MTP KV cache tensors"); + return {}; + } + + const int n_tokens = std::max(1, in.n_tokens); + const TargetMtpLayer & M = w.mtp_layers[(size_t)in.mtp_layer_index]; + const TargetLayer & L = M.block; + + // Dense FFN only — MoE MTP requires the MoE TargetLayer fields landing + // first (see PR #120 "Qwen3.5 MoE support"). For non-MoE GGUFs this is + // the production path and matches the am17an reference layout. + const bool has_dense_ffn = L.w_gate && L.w_up && L.w_down; + if (!M.nextn.eh_proj || !M.nextn.enorm || !M.nextn.hnorm || + !L.attn_norm || !L.attn_post_norm || + !L.wq || !L.wk || !L.wv || !L.wo || !L.q_norm || !L.k_norm || + !has_dense_ffn) { + set_last_error("build_qwen35_mtp_graph missing loaded MTP tensors " + "(MoE MTP not supported in this PR — needs #120)"); + return {}; + } + + // NextN concat path: [enorm(e); hnorm(h)] → eh_proj → transformer block. + ggml_tensor * e_norm = rms_norm_mul(ctx, in.token_embed, M.nextn.enorm, EPS); + ggml_tensor * h_norm = rms_norm_mul(ctx, in.pre_norm_hidden, M.nextn.hnorm, EPS); + + ggml_tensor * concat = ggml_concat(ctx, e_norm, h_norm, 0); + ggml_set_name(concat, "mtp_concat_embedding_hidden"); + + ggml_tensor * cur = ggml_mul_mat(ctx, M.nextn.eh_proj, concat); + ggml_set_name(cur, "mtp_eh_proj"); + + ggml_tensor * inpSA = cur; + cur = rms_norm_mul(ctx, cur, L.attn_norm, EPS); + ggml_set_name(cur, "mtp_attn_norm"); + + cur = build_full_attn_block(ctx, gf, w, L, cur, in.positions, + cache.attn_k[in.mtp_layer_index], + cache.attn_v[in.mtp_layer_index], + in.attn_mask, in.kv_start, n_tokens, + cache.kv_k_type, cache.kv_v_type, + /*kv_k_rotated=*/false, in.fa_window); + ggml_set_name(cur, "mtp_attn_out"); + + cur = ggml_add(ctx, cur, inpSA); + + ggml_tensor * ffn_residual = cur; + cur = rms_norm_mul(ctx, cur, L.attn_post_norm, EPS); + ggml_set_name(cur, "mtp_post_attn_norm"); + ggml_tensor * ffn = build_swiglu_ffn(ctx, cur, L); + if (!ffn) return {}; + cur = ggml_add(ctx, ffn, ffn_residual); + ggml_set_name(cur, "mtp_hidden"); + ggml_set_output(cur); + + // Final norm + shared LM head. Falls back to the trunk's out_norm / output + // when the NextN block doesn't ship its own shared head tensors (am17an + // GGUFs do not always include shared_head_*). + ggml_tensor * head_norm_w = M.nextn.shared_head_norm ? M.nextn.shared_head_norm : w.out_norm; + ggml_tensor * head_w = M.nextn.shared_head_head ? M.nextn.shared_head_head : w.output; + if (!head_norm_w || !head_w) { + set_last_error("build_qwen35_mtp_graph missing MTP/shared LM head tensors"); + return {}; + } + + ggml_tensor * out_h = rms_norm_mul(ctx, cur, head_norm_w, EPS); + ggml_tensor * logits = ggml_mul_mat(ctx, head_w, out_h); + ggml_set_name(logits, "mtp_logits"); + ggml_build_forward_expand(gf, logits); + + QwenMtpGraphOutputs og{}; + og.logits = logits; + og.hidden = cur; + return og; +} + } // namespace dflash27b diff --git a/dflash/test/smoke_mtp_graph.cpp b/dflash/test/smoke_mtp_graph.cpp new file mode 100644 index 000000000..b5c710abe --- /dev/null +++ b/dflash/test/smoke_mtp_graph.cpp @@ -0,0 +1,173 @@ +// Smoke test for the native Qwen35 NextN/MTP graph. +// +// Loads a GGUF with embedded nextn tensors, creates the MTP KV cache, runs a +// single-token MTP forward from caller-provided token embedding + synthetic +// target pre-norm hidden, and checks the resulting logits for NaN/Inf. +// +// Usage: smoke_mtp_graph [cuda_gpu] + +#include "dflash27b.h" +#include "internal.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include +#include + +using namespace dflash27b; + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + int gpu = 0; + if (argc >= 3) gpu = std::atoi(argv[2]); + + ggml_backend_t backend = ggml_backend_cuda_init(gpu); + if (!backend) { + std::fprintf(stderr, "cuda init failed\n"); + return 1; + } + + TargetWeights w; + if (!load_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_target_gguf: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + std::printf("[target] %s\n", dflash27b_last_error()); + if (w.mtp_layers.empty()) { + std::fprintf(stderr, "model has no MTP/nextn layers\n"); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + TargetMtpCache mtp_cache; + if (!create_target_mtp_cache(w, /*max_ctx=*/64, backend, mtp_cache)) { + std::fprintf(stderr, "create_target_mtp_cache: %s\n", dflash27b_last_error()); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + std::printf("[mtp-cache] layers=%zu max_ctx=%d kv=%s/%s\n", + mtp_cache.attn_k.size(), mtp_cache.max_ctx, + ggml_type_name(mtp_cache.kv_k_type), ggml_type_name(mtp_cache.kv_v_type)); + + ggml_init_params ip{}; + ip.mem_size = 512 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * gctx = ggml_init(ip); + if (!gctx) { + std::fprintf(stderr, "ggml_init graph failed\n"); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + const int n_tokens = 1; + ggml_tensor * token_embed = ggml_new_tensor_2d(gctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_tensor * hidden = ggml_new_tensor_2d(gctx, GGML_TYPE_F32, w.n_embd, n_tokens); + ggml_tensor * positions = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, 4 * n_tokens); + ggml_set_name(token_embed, "mtp_token_embed"); + ggml_set_name(hidden, "target_pre_norm_hidden"); + ggml_set_name(positions, "positions"); + ggml_set_input(token_embed); + ggml_set_input(hidden); + ggml_set_input(positions); + + ggml_cgraph * gf = ggml_new_graph_custom(gctx, 2048, false); + QwenMtpGraphInputs gi{}; + gi.token_embed = token_embed; + gi.pre_norm_hidden = hidden; + gi.positions = positions; + gi.n_tokens = n_tokens; + gi.kv_start = 0; + + QwenMtpGraphOutputs go = build_qwen35_mtp_graph(gctx, gf, w, mtp_cache, gi); + if (!go.logits) { + std::fprintf(stderr, "build_qwen35_mtp_graph: %s\n", dflash27b_last_error()); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + ggml_set_output(go.logits); + ggml_build_forward_expand(gf, go.logits); + std::printf("[graph] nodes=%d\n", ggml_graph_n_nodes(gf)); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "ggml_gallocr_alloc_graph failed\n"); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + int32_t tok_ids[1] = { 1 }; + std::vector embed_buf((size_t)w.n_embd * n_tokens); + if (!w.embedder.embed(tok_ids, n_tokens, embed_buf.data())) { + std::fprintf(stderr, "cpu embedder failed\n"); + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + std::vector hidden_buf((size_t)w.n_embd * n_tokens, 0.0f); + int32_t pos4[4] = { 0, 0, 0, 0 }; + ggml_backend_tensor_set(token_embed, embed_buf.data(), 0, sizeof(float) * embed_buf.size()); + ggml_backend_tensor_set(hidden, hidden_buf.data(), 0, sizeof(float) * hidden_buf.size()); + ggml_backend_tensor_set(positions, pos4, 0, sizeof(pos4)); + + auto status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "compute failed: %d\n", (int)status); + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + const int64_t vocab = go.logits->ne[0]; + std::vector logits((size_t)vocab); + ggml_backend_tensor_get(go.logits, logits.data(), 0, sizeof(float) * logits.size()); + + int n_nan = 0, n_inf = 0; + float vmin = 1e30f, vmax = -1e30f; + for (float v : logits) { + if (std::isnan(v)) n_nan++; + else if (std::isinf(v)) n_inf++; + else { + vmin = std::min(vmin, v); + vmax = std::max(vmax, v); + } + } + std::printf("[mtp-logits] vocab=%lld nan=%d inf=%d min=%.4g max=%.4g\n", + (long long)vocab, n_nan, n_inf, vmin, vmax); + + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_weights(w); + ggml_backend_free(backend); + std::printf("OK\n"); + return (n_nan == 0 && n_inf == 0) ? 0 : 1; +} diff --git a/dflash/test/smoke_mtp_integrated_decode.cpp b/dflash/test/smoke_mtp_integrated_decode.cpp new file mode 100644 index 000000000..3f09b2152 --- /dev/null +++ b/dflash/test/smoke_mtp_integrated_decode.cpp @@ -0,0 +1,225 @@ +// Minimal integrated DFlash + native MTP decode smoke. +// +// This is not the optimized multi-token speculative loop yet. It proves the +// functional contract end-to-end: +// 1. target DFlash consumes the committed token and exposes pre-norm hidden +// 2. native MTP/NextN consumes that hidden in the same graph and drafts +// 3. greedy target logits accept or correct the MTP draft token +// 4. the chosen token becomes the next committed token +// +// Usage: +// smoke_mtp_integrated_decode [n_gen] [seed_token_id] [cuda_gpu] + +#include "dflash27b.h" +#include "internal.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static bool run_integrated_step(const TargetWeights & w, + TargetCache & target_cache, + TargetMtpCache & mtp_cache, + ggml_backend_t backend, + int32_t token, + int kv_start, + int32_t & target_next, + int32_t & mtp_next) { + ggml_init_params ip{}; + ip.mem_size = 768 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + std::fprintf(stderr, "ggml_init graph failed\n"); + return false; + } + + const int n_tokens = 1; + ggml_tensor * inp_embed = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w.n_embd, n_tokens, 1); + ggml_tensor * positions = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4 * n_tokens); + ggml_set_name(inp_embed, "inp_embed"); + ggml_set_name(positions, "positions"); + ggml_set_input(inp_embed); + ggml_set_input(positions); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 8192, false); + + QwenGraphInputs target_in{}; + target_in.inp_embed = inp_embed; + target_in.positions = positions; + target_in.n_tokens = n_tokens; + target_in.kv_start = kv_start; + target_in.expose_pre_norm_hidden = true; + + QwenGraphOutputs target_out = build_qwen35_graph(ctx, gf, w, target_cache, target_in); + if (!target_out.logits || !target_out.pre_norm_hidden) { + std::fprintf(stderr, "build_qwen35_graph failed: %s\n", dflash27b_last_error()); + ggml_free(ctx); + return false; + } + + QwenMtpGraphInputs mtp_in{}; + mtp_in.token_embed = ggml_reshape_2d(ctx, inp_embed, w.n_embd, n_tokens); + mtp_in.pre_norm_hidden = target_out.pre_norm_hidden; + mtp_in.positions = positions; + mtp_in.n_tokens = n_tokens; + mtp_in.kv_start = kv_start; + + QwenMtpGraphOutputs mtp_out = build_qwen35_mtp_graph(ctx, gf, w, mtp_cache, mtp_in); + if (!mtp_out.logits) { + std::fprintf(stderr, "build_qwen35_mtp_graph failed: %s\n", dflash27b_last_error()); + ggml_free(ctx); + return false; + } + + ggml_tensor * target_argmax = ggml_argmax(ctx, target_out.logits); + ggml_set_name(target_argmax, "target_argmax"); + ggml_set_output(target_argmax); + ggml_build_forward_expand(gf, target_argmax); + + ggml_tensor * mtp_argmax = ggml_argmax(ctx, mtp_out.logits); + ggml_set_name(mtp_argmax, "mtp_argmax"); + ggml_set_output(mtp_argmax); + ggml_build_forward_expand(gf, mtp_argmax); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "ggml_gallocr_alloc_graph failed\n"); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + + std::vector embed_buf((size_t)w.n_embd); + if (!w.embedder.embed(&token, 1, embed_buf.data())) { + std::fprintf(stderr, "cpu embedder failed for token %d\n", (int)token); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + int32_t pos4[4] = { kv_start, kv_start, kv_start, kv_start }; + ggml_backend_tensor_set(inp_embed, embed_buf.data(), 0, sizeof(float) * embed_buf.size()); + ggml_backend_tensor_set(positions, pos4, 0, sizeof(pos4)); + + auto status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "compute failed: %d\n", (int)status); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + + ggml_backend_tensor_get(target_argmax, &target_next, 0, sizeof(target_next)); + ggml_backend_tensor_get(mtp_argmax, &mtp_next, 0, sizeof(mtp_next)); + + ggml_gallocr_free(alloc); + ggml_free(ctx); + return true; +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s [n_gen] [seed_token_id] [cuda_gpu]\n", argv[0]); + return 2; + } + const int n_gen = argc >= 3 ? std::max(1, std::atoi(argv[2])) : 8; + int32_t last_tok = argc >= 4 ? (int32_t)std::atoi(argv[3]) : 1; + const int gpu = argc >= 5 ? std::atoi(argv[4]) : 0; + + ggml_backend_t backend = ggml_backend_cuda_init(gpu); + if (!backend) { + std::fprintf(stderr, "cuda init failed\n"); + return 1; + } + + TargetWeights w; + if (!load_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_target_gguf: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + std::printf("[target] %s\n", dflash27b_last_error()); + + TargetCache target_cache; + if (!create_target_cache(w, /*max_ctx=*/std::max(64, n_gen + 8), + /*max_verify_tokens=*/0, backend, target_cache)) { + std::fprintf(stderr, "create_target_cache: %s\n", dflash27b_last_error()); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + TargetMtpCache mtp_cache; + if (!create_target_mtp_cache(w, /*max_ctx=*/std::max(64, n_gen + 8), + backend, mtp_cache)) { + std::fprintf(stderr, "create_target_mtp_cache: %s\n", dflash27b_last_error()); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + std::vector generated; + generated.reserve((size_t)n_gen); + int draft_n = 0; + int accepted = 0; + int corrected = 0; + + auto t0 = std::chrono::steady_clock::now(); + for (int pos = 0; pos < n_gen; pos++) { + int32_t target_next = -1; + int32_t mtp_next = -1; + if (!run_integrated_step(w, target_cache, mtp_cache, backend, + last_tok, pos, target_next, mtp_next)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + draft_n++; + const bool ok = (mtp_next == target_next); + if (ok) { + accepted++; + } else { + corrected++; + } + const int32_t chosen = ok ? mtp_next : target_next; + generated.push_back(chosen); + target_cache.cur_pos = pos + 1; + target_cache.last_tok = chosen; + mtp_cache.cur_pos = pos + 1; + + std::printf("[mtp-decode step=%d] input=%d mtp=%d target=%d %s chosen=%d\n", + pos, (int)last_tok, (int)mtp_next, (int)target_next, + ok ? "ACCEPT" : "CORRECT", (int)chosen); + last_tok = chosen; + } + auto t1 = std::chrono::steady_clock::now(); + const double seconds = std::chrono::duration(t1 - t0).count(); + + std::printf("[mtp-decode] generated=%d draft_n=%d accepted=%d corrected=%d acceptance=%.1f%% tok/s=%.2f\n", + n_gen, draft_n, accepted, corrected, + draft_n > 0 ? 100.0 * accepted / draft_n : 0.0, + n_gen / std::max(1e-9, seconds)); + std::printf("[mtp-decode ids]"); + for (int32_t t : generated) std::printf(" %d", (int)t); + std::printf("\nOK\n"); + + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 0; +} diff --git a/dflash/test/smoke_target_mtp_handoff.cpp b/dflash/test/smoke_target_mtp_handoff.cpp new file mode 100644 index 000000000..81ddfc33b --- /dev/null +++ b/dflash/test/smoke_target_mtp_handoff.cpp @@ -0,0 +1,199 @@ +// Smoke test for the DFlash target -> native MTP handoff. +// +// Builds one graph containing: +// target forward with expose_pre_norm_hidden=true +// MTP/NextN forward fed by that target_pre_norm_hidden tensor +// +// This validates the C++ tensor handoff required by the real speculative loop. +// +// Usage: smoke_target_mtp_handoff [cuda_gpu] + +#include "dflash27b.h" +#include "internal.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include +#include + +using namespace dflash27b; + +static int check_logits(ggml_tensor * logits, const char * label) { + const int64_t vocab = logits->ne[0]; + std::vector buf((size_t)vocab); + ggml_backend_tensor_get(logits, buf.data(), 0, sizeof(float) * buf.size()); + int n_nan = 0, n_inf = 0; + float vmin = 1e30f, vmax = -1e30f; + for (float v : buf) { + if (std::isnan(v)) n_nan++; + else if (std::isinf(v)) n_inf++; + else { + vmin = std::min(vmin, v); + vmax = std::max(vmax, v); + } + } + std::printf("[%s] vocab=%lld nan=%d inf=%d min=%.4g max=%.4g\n", + label, (long long)vocab, n_nan, n_inf, vmin, vmax); + return (n_nan == 0 && n_inf == 0) ? 0 : 1; +} + +int main(int argc, char ** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s \n", argv[0]); + return 2; + } + + int gpu = 0; + if (argc >= 3) gpu = std::atoi(argv[2]); + + ggml_backend_t backend = ggml_backend_cuda_init(gpu); + if (!backend) { + std::fprintf(stderr, "cuda init failed\n"); + return 1; + } + + TargetWeights w; + if (!load_target_gguf(argv[1], backend, w)) { + std::fprintf(stderr, "load_target_gguf: %s\n", dflash27b_last_error()); + ggml_backend_free(backend); + return 1; + } + std::printf("[target] %s\n", dflash27b_last_error()); + + TargetCache target_cache; + if (!create_target_cache(w, /*max_ctx=*/64, /*max_verify_tokens=*/0, backend, target_cache)) { + std::fprintf(stderr, "create_target_cache: %s\n", dflash27b_last_error()); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + TargetMtpCache mtp_cache; + if (!create_target_mtp_cache(w, /*max_ctx=*/64, backend, mtp_cache)) { + std::fprintf(stderr, "create_target_mtp_cache: %s\n", dflash27b_last_error()); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + ggml_init_params ip{}; + ip.mem_size = 768 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * gctx = ggml_init(ip); + if (!gctx) { + std::fprintf(stderr, "ggml_init graph failed\n"); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + const int n_tokens = 1; + ggml_tensor * inp_embed = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, w.n_embd, n_tokens, 1); + ggml_tensor * positions = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, 4 * n_tokens); + ggml_set_name(inp_embed, "inp_embed"); + ggml_set_name(positions, "positions"); + ggml_set_input(inp_embed); + ggml_set_input(positions); + + ggml_cgraph * gf = ggml_new_graph_custom(gctx, 8192, false); + QwenGraphInputs target_in{}; + target_in.inp_embed = inp_embed; + target_in.positions = positions; + target_in.n_tokens = n_tokens; + target_in.kv_start = 0; + target_in.expose_pre_norm_hidden = true; + + QwenGraphOutputs target_out = build_qwen35_graph(gctx, gf, w, target_cache, target_in); + if (!target_out.logits || !target_out.pre_norm_hidden) { + std::fprintf(stderr, "build_qwen35_graph did not expose target hidden: %s\n", dflash27b_last_error()); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + QwenMtpGraphInputs mtp_in{}; + mtp_in.token_embed = ggml_reshape_2d(gctx, inp_embed, w.n_embd, n_tokens); + mtp_in.pre_norm_hidden = target_out.pre_norm_hidden; + mtp_in.positions = positions; + mtp_in.n_tokens = n_tokens; + mtp_in.kv_start = 0; + + QwenMtpGraphOutputs mtp_out = build_qwen35_mtp_graph(gctx, gf, w, mtp_cache, mtp_in); + if (!mtp_out.logits || !mtp_out.hidden) { + std::fprintf(stderr, "build_qwen35_mtp_graph failed: %s\n", dflash27b_last_error()); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + ggml_set_output(target_out.logits); + ggml_set_output(mtp_out.logits); + ggml_build_forward_expand(gf, mtp_out.logits); + std::printf("[graph] nodes=%d\n", ggml_graph_n_nodes(gf)); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "ggml_gallocr_alloc_graph failed\n"); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + int32_t tok_ids[1] = { 1 }; + std::vector embed_buf((size_t)w.n_embd * n_tokens); + if (!w.embedder.embed(tok_ids, n_tokens, embed_buf.data())) { + std::fprintf(stderr, "cpu embedder failed\n"); + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + int32_t pos4[4] = { 0, 0, 0, 0 }; + ggml_backend_tensor_set(inp_embed, embed_buf.data(), 0, sizeof(float) * embed_buf.size()); + ggml_backend_tensor_set(positions, pos4, 0, sizeof(pos4)); + + auto status = ggml_backend_graph_compute(backend, gf); + if (status != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "compute failed: %d\n", (int)status); + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + return 1; + } + + const int bad_target = check_logits(target_out.logits, "target-logits"); + const int bad_mtp = check_logits(mtp_out.logits, "mtp-logits"); + + ggml_gallocr_free(alloc); + ggml_free(gctx); + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + free_target_weights(w); + ggml_backend_free(backend); + std::printf("OK\n"); + return (bad_target || bad_mtp) ? 1 : 0; +} diff --git a/dflash/test/test_mtp_graph_contract.cpp b/dflash/test/test_mtp_graph_contract.cpp new file mode 100644 index 000000000..b67b8c831 --- /dev/null +++ b/dflash/test/test_mtp_graph_contract.cpp @@ -0,0 +1,108 @@ +#include "internal.h" + +#include "ggml.h" + +#include + +using namespace dflash27b; + +static ggml_tensor * tensor_1d(ggml_context * ctx, int n0, const char * name) { + ggml_tensor * t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n0); + ggml_set_name(t, name); + return t; +} + +static ggml_tensor * tensor_2d(ggml_context * ctx, int n0, int n1, const char * name) { + ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n0, n1); + ggml_set_name(t, name); + return t; +} + +int main() { + ggml_init_params ip{}; + ip.mem_size = 16 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + std::fprintf(stderr, "ggml_init failed\n"); + return 1; + } + + TargetWeights w{}; + w.n_embd = 8; + w.n_ff = 16; + w.n_head = 2; + w.n_head_kv = 1; + w.n_embd_head_k = 4; + w.n_embd_head_v = 4; + w.rope_sections[0] = 1; + w.rope_sections[1] = 1; + w.rope_sections[2] = 0; + w.rope_sections[3] = 0; + w.out_norm = tensor_1d(ctx, w.n_embd, "output_norm.weight"); + w.output = tensor_2d(ctx, w.n_embd, 32, "output.weight"); + w.mtp_layers.assign(1, TargetMtpLayer{}); + + TargetMtpLayer & M = w.mtp_layers[0]; + TargetLayer & L = M.block; + M.gguf_layer_index = 64; + M.nextn.eh_proj = tensor_2d(ctx, 2 * w.n_embd, w.n_embd, "blk.64.nextn.eh_proj.weight"); + M.nextn.enorm = tensor_1d(ctx, w.n_embd, "blk.64.nextn.enorm.weight"); + M.nextn.hnorm = tensor_1d(ctx, w.n_embd, "blk.64.nextn.hnorm.weight"); + M.nextn.shared_head_norm = tensor_1d(ctx, w.n_embd, "blk.64.nextn.shared_head_norm.weight"); + + const int q_dim = w.n_head * w.n_embd_head_k; + const int kv_dim = w.n_head_kv * w.n_embd_head_k; + L.attn_norm = tensor_1d(ctx, w.n_embd, "blk.64.attn_norm.weight"); + L.attn_post_norm = tensor_1d(ctx, w.n_embd, "blk.64.post_attention_norm.weight"); + L.wq = tensor_2d(ctx, w.n_embd, 2 * q_dim, "blk.64.attn_q.weight"); + L.wk = tensor_2d(ctx, w.n_embd, kv_dim, "blk.64.attn_k.weight"); + L.wv = tensor_2d(ctx, w.n_embd, kv_dim, "blk.64.attn_v.weight"); + L.wo = tensor_2d(ctx, q_dim, w.n_embd, "blk.64.attn_output.weight"); + L.q_norm = tensor_1d(ctx, w.n_embd_head_k, "blk.64.attn_q_norm.weight"); + L.k_norm = tensor_1d(ctx, w.n_embd_head_k, "blk.64.attn_k_norm.weight"); + L.w_gate = tensor_2d(ctx, w.n_embd, w.n_ff, "blk.64.ffn_gate.weight"); + L.w_up = tensor_2d(ctx, w.n_embd, w.n_ff, "blk.64.ffn_up.weight"); + L.w_down = tensor_2d(ctx, w.n_ff, w.n_embd, "blk.64.ffn_down.weight"); + + TargetMtpCache cache{}; + cache.max_ctx = 8; + cache.kv_k_type = GGML_TYPE_F16; + cache.kv_v_type = GGML_TYPE_F16; + cache.attn_k.push_back(ggml_new_tensor_3d(ctx, GGML_TYPE_F16, w.n_embd_head_k, cache.max_ctx, w.n_head_kv)); + cache.attn_v.push_back(ggml_new_tensor_3d(ctx, GGML_TYPE_F16, w.n_embd_head_k, cache.max_ctx, w.n_head_kv)); + + const int n_tokens = 1; + ggml_tensor * token_embed = tensor_2d(ctx, w.n_embd, n_tokens, "mtp_token_embed"); + ggml_tensor * hidden = tensor_2d(ctx, w.n_embd, n_tokens, "target_pre_norm_hidden"); + ggml_tensor * positions = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4 * n_tokens); + ggml_set_name(positions, "positions"); + ggml_set_input(token_embed); + ggml_set_input(hidden); + ggml_set_input(positions); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 1024, false); + QwenMtpGraphInputs in{}; + in.token_embed = token_embed; + in.pre_norm_hidden = hidden; + in.positions = positions; + in.n_tokens = n_tokens; + in.kv_start = 0; + + QwenMtpGraphOutputs out = build_qwen35_mtp_graph(ctx, gf, w, cache, in); + if (!out.logits || !out.hidden) { + std::fprintf(stderr, "build_qwen35_mtp_graph failed: %s\n", dflash27b_last_error()); + ggml_free(ctx); + return 1; + } + + std::printf("[mtp-graph-contract] nodes=%d logits=[%lld,%lld] hidden=[%lld,%lld]\n", + ggml_graph_n_nodes(gf), + (long long)out.logits->ne[0], (long long)out.logits->ne[1], + (long long)out.hidden->ne[0], (long long)out.hidden->ne[1]); + + ggml_free(ctx); + return 0; +} From 261811db9f4dd51d5cf385e261c7fd6abeb0436c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Paz=C3=B3?= Date: Mon, 11 May 2026 15:55:41 +0200 Subject: [PATCH 2/4] feat(dflash): linear native MTP integrated decode CLI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the speculative decode loop that drives native MTP / NextN runtime from PR #153. This is the "linear" path — a trunk target forward + one NextN forward per accepted token, with optional MTP chain proposals on acceptance. DDTree/DFlash hybrid + target-batched verify + fast rollback (the chain-2 / tree-fused / immediate-bonus combination that produces the +36% n_gen=256 number in the PR #153 description) land in a follow-up. C++ runtime (dflash/test/test_dflash.cpp) - 14 MTP / DFlash-MTP CLI flag globals + 6 StepGraph fields (pre_norm_hidden, mtp_logits, mtp_hidden, mtp_argmax_tokens, mtp_chain_argmax_tokens, token_ids). - MtpTimings + ScopedTimer instrumentation (no-op when --dflash-mtp-timing / DFLASH_MTP_TIMING=1 are off; stdout is byte-identical to the un-instrumented build). - mtp_integrated_step: one ggml_cgraph driving trunk + NextN in the same forward, exposing pre_norm_hidden on-device; optional post-MTP hidden capture gated on acceptance. - mtp_only_step (2 overloads): MTP-only follow-up step with caller- supplied hidden, plus a one-shot convenience wrapper. - mtp_chain_gpu_step + mtp_chain_gpu_available + mtp_gpu_get_rows_supported: unrolls N MTP proposals inside one cgraph via ggml_get_rows when the trunk's tok_embd is GPU-resident AND non-K-quant. K-quant tok_embd falls back to mtp_only_step (CPU embed bridge). - target_only_argmax_step + run_target_ar_prompt: greedy AR baseline used as the parity reference for the integrated loop. - run_mtp_integrated_prompt: full prefill + decode loop. After the first MTP proposal lands, attempts to chain (mtp_draft_n_max - 1) more proposals; each chained proposal is verified one-by-one against a per-token target forward. Serial verify only — batched verify ships in the follow-up PR. - run_mtp_baseline_check + run_mtp_integrated_cli: CLI entry points emitting parseable `[mtp-baseline] baseline ...`, `[mtp-baseline] integrated ...` and `compare_ok`/`compare_fail` lines. CLI surface added to test_dflash main() --mtp-integrated run the integrated MTP decode loop --mtp-baseline-check also run the AR baseline + compare --mtp-draft-n-max=N MTP chain depth (default 4) --decode-pos-offset=N start KV at offset (sparse-cache gate) --mtp-step-log per-step decision log --mtp-no-fast-rollback / --mtp-moe-long-fast-rollback / --mtp-serial-commit / --mtp-no-gpu-chain --dflash-mtp-timing enables the MtpTimings summary line --dflash-mtp-hybrid / --dflash-mtp-tree-fused / --dflash-mtp-seed-priority / --dflash-mtp-immediate-bonus / --dflash-mtp-chain-max=N / --dflash-mtp-bonus-min-margin=F (reserved; parsed here for a stable CLI surface — the wiring that consumes them ships in the DFlash/MTP hybrid follow-up PR) Invocation example: test_dflash \ --mtp-integrated --mtp-baseline-check --mtp-draft-n-max=2 \ --max-ctx=4096 --n-gen=128 --prompt-file=prompt.bin --target-gpu=1 Drive-by fix: ggml_get_to_fp32_cuda link error The test_dflash binary was failing to link on Windows because ggml-cuda exports ggml_get_to_fp32_cuda internally only (the symbol is hidden in the .dll). Replace the four call sites with adapters around our own dflash27b_launch_bf16_to_f32 / _f16_to_f32 kernels (already shipped by src/f16_convert.cu in PR #153). Equivalent behavior, no dependency on a ggml-cuda internal symbol. Python tooling (dflash/scripts/mtp_baseline_gate.py) Real-model parity gate that invokes the new --mtp-baseline-check mode, parses the [mtp-baseline] lines, and either prints PASS/FAIL or writes a JSON artifact. Supports --prompt-text + HuggingFace AutoTokenizer or --prompt-file . Used to defend the claim that integrated MTP decode produces token-identical output to the AR baseline at greedy. Validation (RTX 6000 Ada sm_89, Qwen3.6-27B-MTP Q4_K_M) - Build: cmake --build dflash/build --target test_dflash → PASS - 16-token smoke (synthetic prompt) via test_dflash directly: [mtp-integrated] generated=16 draft_n=16 accepted=16 corrected=0 acceptance=100.0% tok/s=10.81 draft_n_max=2 - Baseline gate (mtp_baseline_gate.py, 16 tok, max_ctx=128): [mtp-baseline] baseline tok/s=8.31 [mtp-baseline] integrated tok/s=6.29 speed_ratio=0.756x [mtp-baseline] compare_ok tokens=16 mismatches=0 → parity PASS; speed regresses on this very-short workload, as expected (the +36% n_gen=256 win needs the hybrid follow-up). Honest scope - Linear MTP only. The +36% number quoted in PR #153 is bounded by the DDTree + target-batched verify + immediate-bonus combination that ships in PR #B; on this PR the headline is "feature complete and parity-correct", not "always faster". Short-generation regressions are expected and documented in PR #153's `dflash/docs/MTP_2026-05-11.md`. - Daemon-mode --mtp-integrated wiring stays out of this PR (would pull in the native scheduler from PR #135). CLI mode covers the parity gate use-case in full. - MoE MTP is gated by upstream PR #120 (Qwen3.5 MoE support), same as PR #153. Stacked on #153 (native MTP runtime + contract test). Not mergeable until #153 lands. --- dflash/scripts/mtp_baseline_gate.py | 264 ++++++ dflash/test/test_dflash.cpp | 1224 ++++++++++++++++++++++++++- 2 files changed, 1480 insertions(+), 8 deletions(-) create mode 100644 dflash/scripts/mtp_baseline_gate.py diff --git a/dflash/scripts/mtp_baseline_gate.py b/dflash/scripts/mtp_baseline_gate.py new file mode 100644 index 000000000..ebef594bc --- /dev/null +++ b/dflash/scripts/mtp_baseline_gate.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +"""Parity gate for native MTP integrated decode. + +Shells out to `test_dflash --mtp-baseline-check`, parses the three lines it +emits, and either prints a one-line PASS/FAIL summary or writes a JSON +artifact. Used to defend the claim that the integrated MTP decode loop +produces token-identical output to the AR baseline at the same temperature. + +This is intentionally a real-model gate (loads the GGUF, runs both decode +modes end-to-end), not a unit test. The C++ side prints these lines that +this script parses: + + [mtp-baseline] baseline generated=N tok/s=X seconds=Y + [mtp-baseline] integrated generated=N draft_n=D accepted=A corrected=C \\ + acceptance=P% tok/s=X seconds=Y draft_n_max=M speed_ratio=Rx + [mtp-baseline] compare_ok tokens=N mismatches=0 + (or) + [mtp-baseline] compare_fail mismatch=I baseline=tok integrated=tok ... +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import struct +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +BASELINE_RE = re.compile( + r"^\[mtp-baseline\]\s+baseline\s+generated=(?P\d+)\s+" + r"tok/s=(?P[0-9.]+)\s+seconds=(?P[0-9.]+)", + re.MULTILINE, +) +INTEGRATED_RE = re.compile( + r"^\[mtp-baseline\]\s+integrated\s+generated=(?P\d+)\s+" + r"draft_n=(?P\d+)\s+accepted=(?P\d+)\s+" + r"corrected=(?P\d+)\s+acceptance=(?P[0-9.]+)%\s+" + r"tok/s=(?P[0-9.]+)\s+seconds=(?P[0-9.]+)\s+" + r"draft_n_max=(?P\d+)\s+speed_ratio=(?P[0-9.]+)x", + re.MULTILINE, +) +COMPARE_OK_RE = re.compile( + r"^\[mtp-baseline\]\s+compare_ok\s+tokens=(?P\d+)\s+mismatches=0", + re.MULTILINE, +) +COMPARE_FAIL_RE = re.compile( + r"^\[mtp-baseline\]\s+compare_fail\s+mismatch=(?P\d+)\s+" + r"baseline=(?P-?\d+)\s+integrated=(?P-?\d+)", + re.MULTILINE, +) +TARGET_RE = re.compile( + r"^\[target\]\s+target loaded:.*trunk_layers=(?P\d+)\s+nextn=(?P\d+)", + re.MULTILINE, +) + + +def _default_binary() -> Path: + root = Path(__file__).resolve().parents[1] + if os.name == "nt": + cand = root / "build" / "test_dflash.exe" + if cand.exists(): + return cand + return root / "build" / "Release" / "test_dflash.exe" + return root / "build" / "test_dflash" + + +def _write_prompt_ids(out: Path, ids: list[int]) -> None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"".join(struct.pack(" list[int]: + """Tokenize using HuggingFace AutoTokenizer if available.""" + try: + from transformers import AutoTokenizer # type: ignore + except ImportError: + sys.stderr.write( + "mtp_baseline_gate: transformers not installed; " + "pass --prompt-file= instead of --prompt-text\n" + ) + sys.exit(2) + tok = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True) + return tok.encode(text, add_special_tokens=False) + + +def _parse_run(stdout: str, stderr: str) -> dict[str, Any]: + result: dict[str, Any] = {"raw_stdout_lines": stdout.count("\n")} + m = TARGET_RE.search(stdout) + if m: + result["target"] = { + "trunk_layers": int(m.group("trunk_layers")), + "nextn": int(m.group("nextn")), + } + m = BASELINE_RE.search(stdout) + if m: + result["baseline"] = { + "generated": int(m.group("generated")), + "tps": float(m.group("tps")), + "seconds": float(m.group("seconds")), + } + m = INTEGRATED_RE.search(stdout) + if m: + result["integrated"] = { + "generated": int(m.group("generated")), + "draft_n": int(m.group("draft_n")), + "accepted": int(m.group("accepted")), + "corrected": int(m.group("corrected")), + "acceptance_pct": float(m.group("acceptance")), + "tps": float(m.group("tps")), + "seconds": float(m.group("seconds")), + "draft_n_max": int(m.group("draft_n_max")), + "speed_ratio": float(m.group("speed_ratio")), + } + if COMPARE_OK_RE.search(stdout): + result["compare"] = {"ok": True} + m = COMPARE_FAIL_RE.search(stderr) or COMPARE_FAIL_RE.search(stdout) + if m: + result["compare"] = { + "ok": False, + "mismatch": int(m.group("mismatch")), + "baseline_tok": int(m.group("baseline_tok")), + "integrated_tok": int(m.group("integrated_tok")), + } + return result + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--target", required=True, type=Path, + help="Path to the Qwen3.5/3.6-MTP GGUF (am17an-style).") + parser.add_argument("--prompt-text", default=None, + help="Inline prompt text (requires --tokenizer).") + parser.add_argument("--prompt-file", default=None, type=Path, + help="Pre-tokenized prompt file (.bin of i32-le ids).") + parser.add_argument("--tokenizer", default="Qwen/Qwen3.6-27B", + help="HuggingFace tokenizer id used when --prompt-text is given.") + parser.add_argument("--n-gen", type=int, default=128) + parser.add_argument("--mtp-draft-n-max", type=int, default=4) + parser.add_argument("--max-ctx", type=int, default=4096) + parser.add_argument("--decode-pos-offset", type=int, default=0) + parser.add_argument("--fa-window", type=int, default=0, + help="DFLASH27B_FA_WINDOW env override; 0 = full attention.") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--binary", default=None, type=Path, + help="Path to test_dflash binary; defaults to build/test_dflash.") + parser.add_argument("--min-speed-ratio", type=float, default=1.0, + help="Fail if integrated tok/s / baseline tok/s falls below this.") + parser.add_argument("--require-compare-ok", action="store_true", default=True) + parser.add_argument("--allow-mismatch", dest="require_compare_ok", + action="store_false") + parser.add_argument("--out", default=None, type=Path, + help="Write JSON artifact to this path.") + parser.add_argument("--step-log", action="store_true") + parser.add_argument("--timeout-s", type=int, default=1800) + args = parser.parse_args() + + if not args.target.exists(): + print(f"target GGUF not found: {args.target}", file=sys.stderr) + return 2 + binary = args.binary or _default_binary() + if not binary.exists(): + print(f"test_dflash binary not found: {binary}", file=sys.stderr) + return 2 + + # Resolve prompt to a .bin path the binary can consume. Always use an + # absolute path so the binary's cwd doesn't matter. + workdir = Path(args.target).resolve().parent if args.target else Path(".").resolve() + if args.prompt_file is None: + if not args.prompt_text: + print("provide --prompt-text or --prompt-file", file=sys.stderr) + return 2 + ids = _tokenize_via_hf(args.prompt_text, args.tokenizer) + if not ids: + print("tokenizer returned empty prompt", file=sys.stderr) + return 2 + prompt_path = (workdir / "_mtp_baseline_gate_prompt.bin").resolve() + _write_prompt_ids(prompt_path, ids) + else: + prompt_path = Path(args.prompt_file).resolve() + if not prompt_path.exists(): + print(f"prompt file not found: {prompt_path}", file=sys.stderr) + return 2 + + env = os.environ.copy() + if args.fa_window > 0: + env["DFLASH27B_FA_WINDOW"] = str(args.fa_window) + if "CUDA_VISIBLE_DEVICES" not in env and args.gpu >= 0: + env["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + + cmd = [ + str(binary), str(args.target), + "--mtp-integrated", "--mtp-baseline-check", + f"--mtp-draft-n-max={args.mtp_draft_n_max}", + f"--max-ctx={args.max_ctx}", + f"--decode-pos-offset={args.decode_pos_offset}", + f"--prompt-file={prompt_path}", + f"--n-gen={args.n_gen}", + ] + if args.step_log: + cmd.append("--mtp-step-log") + + print(f"[gate] running: {' '.join(cmd)}", flush=True) + t0 = datetime.now(timezone.utc) + proc = subprocess.run( + cmd, + env=env, + cwd=workdir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=args.timeout_s, + ) + t1 = datetime.now(timezone.utc) + sys.stdout.write(proc.stdout) + sys.stderr.write(proc.stderr) + + parsed = _parse_run(proc.stdout, proc.stderr) + parsed["cmd"] = cmd + parsed["exit_code"] = proc.returncode + parsed["wall_seconds"] = (t1 - t0).total_seconds() + parsed["target_path"] = str(args.target) + parsed["n_gen"] = args.n_gen + parsed["mtp_draft_n_max"] = args.mtp_draft_n_max + parsed["fa_window"] = args.fa_window + parsed["decode_pos_offset"] = args.decode_pos_offset + parsed["min_speed_ratio"] = args.min_speed_ratio + + if args.out: + args.out.parent.mkdir(parents=True, exist_ok=True) + args.out.write_text(json.dumps(parsed, indent=2)) + print(f"[gate] wrote {args.out}") + + # Gate decisions. + ok_exit = proc.returncode == 0 + compare = parsed.get("compare") or {} + compare_ok = bool(compare.get("ok")) + integrated = parsed.get("integrated") or {} + ratio = float(integrated.get("speed_ratio", 0.0)) + speed_ok = ratio >= args.min_speed_ratio + + if args.require_compare_ok and not compare_ok: + print(f"[gate] FAIL parity (compare={compare})", file=sys.stderr) + return 1 + if not ok_exit: + print(f"[gate] FAIL exit_code={proc.returncode}", file=sys.stderr) + return proc.returncode if proc.returncode != 0 else 1 + if not speed_ok: + print(f"[gate] FAIL speed: ratio={ratio:.3f}x < min={args.min_speed_ratio:.3f}x", + file=sys.stderr) + return 1 + + print(f"[gate] PASS compare_ok=True speed_ratio={ratio:.3f}x " + f"(min={args.min_speed_ratio:.3f}x)") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dflash/test/test_dflash.cpp b/dflash/test/test_dflash.cpp index 3368321ff..2bf7ebeec 100644 --- a/dflash/test/test_dflash.cpp +++ b/dflash/test/test_dflash.cpp @@ -38,10 +38,37 @@ #include -// ggml-cuda dequantize: Q8_0/F16/BF16 → F32. Replaces the custom -// f16_convert.cu kernels with ggml's built-in converter dispatch. +// Half-precision → f32 widen kernel launchers (src/f16_convert.cu). +// We used to declare ggml_get_to_fp32_cuda here, but that symbol is hidden +// inside the ggml-cuda.dll's internal namespace on Windows, which broke the +// test_dflash link (LNK2019). dflash27b ships its own tiny widen kernels +// instead — equivalent for the BF16 / F16 cases test_dflash uses. +extern "C" void dflash27b_launch_f16_to_f32(const void * src, + void * dst, + size_t n_elems, + cudaStream_t stream); +extern "C" void dflash27b_launch_bf16_to_f32(const void * src, + void * dst, + size_t n_elems, + cudaStream_t stream); + +// Adapter matching the previous `to_fp32_cuda_t = void (*)(const void *, +// float *, int64_t, cudaStream_t)` signature so call sites that bind the +// function pointer keep compiling. Dispatch on ggml_type. using to_fp32_cuda_t = void (*)(const void *, float *, int64_t, cudaStream_t); -to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); +static void dflash27b_to_fp32_bf16(const void * src, float * dst, + int64_t n_elems, cudaStream_t stream) { + dflash27b_launch_bf16_to_f32(src, dst, (size_t)n_elems, stream); +} +static void dflash27b_to_fp32_f16(const void * src, float * dst, + int64_t n_elems, cudaStream_t stream) { + dflash27b_launch_f16_to_f32(src, dst, (size_t)n_elems, stream); +} +static inline to_fp32_cuda_t dflash27b_get_to_fp32_cuda(ggml_type type) { + if (type == GGML_TYPE_BF16) return &dflash27b_to_fp32_bf16; + if (type == GGML_TYPE_F16) return &dflash27b_to_fp32_f16; + return nullptr; +} #include #include @@ -156,6 +183,95 @@ static int g_max_ctx_override = 0; // overridden by --max-ctx=N (defau static int g_fa_window = 2048; // overridden by DFLASH27B_FA_WINDOW=N static int g_draft_swa_window = 0; // draft SWA window (0 = disabled); --draft-swa=N static int g_draft_ctx_max = 4096; // draft context cap; --draft-ctx-max=N + +// ─── Native MTP (NextN) runtime flags ───────────────────────────── +// Linear MTP knobs (this PR ships the linear path only; DFlash/MTP hybrid +// follows in a separate PR). +static bool g_mtp_batch_verify = true; // verify accepted chains as target batches +static bool g_mtp_fast_rollback = true; // skip replay when rollback capture is safe +static bool g_mtp_moe_long_fast_rollback = false; // opt-in until MoE long-window parity is proven +static bool g_mtp_batch_axis3_abs = false; // diagnostic: decode-style M-RoPE axis3 in batch verify +static bool g_mtp_serial_commit = false; // exact serial cache commit after batched target verify +static bool g_mtp_chain_tree_verify = false; // diagnostic: tree-aware DeltaNet/conv for linear verify +static bool g_mtp_gpu_chain = true; // unroll MTP proposals via GPU token get_rows +static bool g_mtp_fused_prefill = false; // experimental: append prompt MTP inside target prefill +// DFlash/MTP hybrid knobs (parsed here so the CLI surface is stable across +// PRs; the wiring that consumes them lands in a follow-up PR). +static bool g_dflash_mtp_hybrid = false; +static bool g_dflash_mtp_tree_fused = false; +static bool g_dflash_mtp_seed_priority = false; +static bool g_dflash_mtp_immediate_bonus = false; +static int g_dflash_mtp_chain_max = 15; +static float g_dflash_mtp_bonus_min_margin = 0.0f; + +// ─── MTP/DDTree timing instrumentation ─────────────────────────── +// Activated via --dflash-mtp-timing or DFLASH_MTP_TIMING=1. When disabled +// (default), ScopedTimer is a no-op and stdout is byte-identical to the +// un-instrumented build. Accumulators are only updated when the flag is on, +// and the [dflash+mtp-timing] summary line is only printed in that case. +static bool g_dflash_mtp_timing = false; + +struct MtpTimings { + enum Slot : int { + TARGET_DECODE = 0, + MTP_PROP_FIRST, + MTP_PROP_FOLLOW, + TARGET_VERIFY, + ARGMAX_GET, + HIDDEN_GET, + GRAPH_BUILD_TARGET, + GRAPH_BUILD_TREE, + GRAPH_BUILD_MTP_CHAIN, + GRAPH_COMPUTE, + KV_TRIM, + NUM_SLOTS, + }; + struct Acc { uint64_t ns_total = 0; uint64_t ns_max = 0; uint64_t count = 0; }; + Acc slots[NUM_SLOTS]; + + void add(int s, uint64_t ns) { + if (s < 0 || s >= NUM_SLOTS) return; + Acc & a = slots[s]; + a.ns_total += ns; + a.count += 1; + if (ns > a.ns_max) a.ns_max = ns; + } + static const char * slot_name(int s) { + switch (s) { + case TARGET_DECODE: return "target_decode"; + case MTP_PROP_FIRST: return "mtp_prop_first"; + case MTP_PROP_FOLLOW: return "mtp_prop_follow"; + case TARGET_VERIFY: return "target_verify"; + case ARGMAX_GET: return "argmax_get"; + case HIDDEN_GET: return "hidden_get"; + case GRAPH_BUILD_TARGET: return "graph_build_target"; + case GRAPH_BUILD_TREE: return "graph_build_tree"; + case GRAPH_BUILD_MTP_CHAIN: return "graph_build_mtp_chain"; + case GRAPH_COMPUTE: return "graph_compute"; + case KV_TRIM: return "kv_trim"; + default: return "unknown"; + } + } +}; +static MtpTimings g_mtp_timings; + +struct ScopedTimer { + int slot; + std::chrono::steady_clock::time_point t0; + bool active; + explicit ScopedTimer(int s) : slot(s), active(g_dflash_mtp_timing) { + if (active) t0 = std::chrono::steady_clock::now(); + } + ~ScopedTimer() { + if (!active) return; + auto t1 = std::chrono::steady_clock::now(); + uint64_t ns = (uint64_t)std::chrono::duration_cast(t1 - t0).count(); + g_mtp_timings.add(slot, ns); + } + ScopedTimer(const ScopedTimer &) = delete; + ScopedTimer & operator=(const ScopedTimer &) = delete; +}; + static int align_up(int x, int a) { return ((x + a - 1) / a) * a; } // F16 encoding for the two values we use: 0 and -inf. @@ -509,11 +625,28 @@ struct StepGraph { ggml_tensor * positions_k = nullptr; // draft only ggml_tensor * hidden_input = nullptr; // lm-head projection only + // Optional GPU-side token-id input. When set, the graph performs the + // embedding gather inside ggml (get_rows) instead of taking a pre-built + // inp_embed. Used by the MTP integrated decode loop so the chain of + // proposed tokens never round-trips through the CPU. + ggml_tensor * token_ids = nullptr; + // Output ggml_tensor * logits = nullptr; ggml_tensor * hidden_states = nullptr; // draft hidden-only output + // Final pre-norm hidden of the trunk decoder. Exposed when MTP is wired + // so build_qwen35_mtp_graph can consume it as `t_h_pre_norm` without a + // CPU readback. + ggml_tensor * pre_norm_hidden = nullptr; + // Native MTP / NextN outputs (populated when an MTP block is wired into + // this StepGraph; nullptr in plain target-only decode). + ggml_tensor * mtp_logits = nullptr; + ggml_tensor * mtp_hidden = nullptr; + ggml_tensor * mtp_argmax_tokens = nullptr; ggml_tensor * argmax_tokens = nullptr; // [n_tokens] i32, GPU-side argmax of logits ggml_tensor * topk_indices = nullptr; // [K, n_tokens] i32, GPU-side top-K indices + // Per-step argmax outputs for the GPU-unrolled MTP chain (mtp_chain_gpu_step). + std::vector mtp_chain_argmax_tokens; // Per-delta-net-layer captures (verify only). One entry per delta-net layer. // Each entry's tensors are graph views on the gated_delta_net result: @@ -733,7 +866,7 @@ static bool draft_feature_mirror_sync_range(const TargetCache & cache, (const char *)cache.target_feat->data + (size_t)src_slot * src_stride; void * dst = (char *)mirror.target_feat->data + (size_t)dst_slot * dst_stride; - auto bf16_to_f32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + auto bf16_to_f32 = dflash27b_get_to_fp32_cuda(GGML_TYPE_BF16); if (mirror.device == mirror.target_device) { cudaSetDevice(mirror.device); bf16_to_f32(src, (float *)dst, (int64_t)elems, nullptr); @@ -779,10 +912,16 @@ static void step_graph_free(StepGraph & sg) { sg.target_hidden_cat = sg.positions_k = nullptr; sg.hidden_input = nullptr; sg.parent_ids = nullptr; + sg.token_ids = nullptr; sg.logits = nullptr; sg.hidden_states = nullptr; + sg.pre_norm_hidden = nullptr; + sg.mtp_logits = nullptr; + sg.mtp_hidden = nullptr; + sg.mtp_argmax_tokens = nullptr; sg.argmax_tokens = nullptr; sg.topk_indices = nullptr; + sg.mtp_chain_argmax_tokens.clear(); sg.delta_captures.clear(); } @@ -2295,6 +2434,933 @@ static int run_target_layer_split_harness( return 0; } +// ─── Native MTP / NextN integrated decode mode ────────────────── +// +// Single-token forward that drives target + MTP in the SAME ggml_cgraph: +// +// 1. trunk target forward consumes the committed token and exposes +// `pre_norm_hidden` (no CPU readback). +// 2. native NextN/MTP forward consumes that pre-norm hidden + the same +// token embedding to draft a candidate next token. +// 3. greedy argmax on both heads is computed device-side. +// 4. caller decides accept / correct via target_next vs mtp_next. +// +// The optional `mtp_hidden_out` lets the caller capture the post-MTP hidden +// for chain follow-up via mtp_only_step (so a 2-token MTP chain doesn't +// repeat the target forward). Hidden capture is gated on acceptance so the +// readback only happens when chain continuation is useful. +static bool mtp_integrated_step(const TargetWeights & w, + TargetCache & target_cache, + TargetMtpCache & mtp_cache, + ggml_backend_t backend, + StepGraph & sg, + int32_t token, + int kv_start, + int32_t & target_next, + int32_t & mtp_next, + std::vector * mtp_hidden_out = nullptr) { + step_graph_free(sg); + + ggml_init_params ip{}; + ip.mem_size = 768 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) { + std::fprintf(stderr, "mtp integrated: ggml_init failed\n"); + return false; + } + + sg.inp_embed = ggml_new_tensor_3d(sg.ctx, GGML_TYPE_F32, w.n_embd, 1, 1); + sg.positions = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 4); + ggml_set_name(sg.inp_embed, "mtp_integrated_inp_embed"); + ggml_set_name(sg.positions, "mtp_integrated_positions"); + ggml_set_input(sg.inp_embed); + ggml_set_input(sg.positions); + + sg.gf = ggml_new_graph_custom(sg.ctx, 8192, false); + + QwenGraphInputs target_in{}; + target_in.inp_embed = sg.inp_embed; + target_in.positions = sg.positions; + target_in.n_tokens = 1; + target_in.kv_start = kv_start; + target_in.expose_pre_norm_hidden = true; + target_in.fa_window = g_fa_window; + + QwenGraphOutputs target_out = build_qwen35_graph(sg.ctx, sg.gf, w, target_cache, target_in); + if (!target_out.logits || !target_out.pre_norm_hidden) { + std::fprintf(stderr, "mtp integrated: target graph failed: %s\n", dflash27b_last_error()); + return false; + } + sg.logits = target_out.logits; + sg.pre_norm_hidden = target_out.pre_norm_hidden; + + QwenMtpGraphInputs mtp_in{}; + mtp_in.token_embed = ggml_reshape_2d(sg.ctx, sg.inp_embed, w.n_embd, 1); + mtp_in.pre_norm_hidden = target_out.pre_norm_hidden; + mtp_in.positions = sg.positions; + mtp_in.n_tokens = 1; + mtp_in.kv_start = kv_start; + mtp_in.fa_window = g_fa_window; + + QwenMtpGraphOutputs mtp_out = build_qwen35_mtp_graph(sg.ctx, sg.gf, w, mtp_cache, mtp_in); + if (!mtp_out.logits || !mtp_out.hidden) { + std::fprintf(stderr, "mtp integrated: MTP graph failed: %s\n", dflash27b_last_error()); + return false; + } + sg.mtp_logits = mtp_out.logits; + sg.mtp_hidden = mtp_out.hidden; + + sg.argmax_tokens = ggml_argmax(sg.ctx, target_out.logits); + ggml_set_name(sg.argmax_tokens, "mtp_integrated_target_argmax"); + ggml_set_output(sg.argmax_tokens); + ggml_build_forward_expand(sg.gf, sg.argmax_tokens); + + sg.mtp_argmax_tokens = ggml_argmax(sg.ctx, mtp_out.logits); + ggml_set_name(sg.mtp_argmax_tokens, "mtp_integrated_mtp_argmax"); + ggml_set_output(sg.mtp_argmax_tokens); + ggml_build_forward_expand(sg.gf, sg.mtp_argmax_tokens); + + if (!sg.alloc) { + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + std::fprintf(stderr, "mtp integrated: graph alloc failed\n"); + return false; + } + + std::vector embed((size_t)w.n_embd); + if (!w.embedder.embed(&token, 1, embed.data())) { + std::fprintf(stderr, "mtp integrated: CPU embed failed for token %d\n", (int)token); + return false; + } + const int32_t pos4[4] = { kv_start, kv_start, kv_start, kv_start }; + ggml_backend_tensor_set(sg.inp_embed, embed.data(), 0, sizeof(float) * embed.size()); + ggml_backend_tensor_set(sg.positions, pos4, 0, sizeof(pos4)); + + ggml_status st; + { ScopedTimer _tc(MtpTimings::GRAPH_COMPUTE); + st = ggml_backend_graph_compute(backend, sg.gf); + } + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "mtp integrated: graph compute failed: %d\n", (int)st); + return false; + } + + { ScopedTimer _tg(MtpTimings::ARGMAX_GET); + ggml_backend_tensor_get(sg.argmax_tokens, &target_next, 0, sizeof(target_next)); + ggml_backend_tensor_get(sg.mtp_argmax_tokens, &mtp_next, 0, sizeof(mtp_next)); + } + if (mtp_hidden_out && mtp_next == target_next) { + ScopedTimer _th(MtpTimings::HIDDEN_GET); + mtp_hidden_out->assign((size_t)w.n_embd, 0.0f); + ggml_backend_tensor_get(sg.mtp_hidden, mtp_hidden_out->data(), 0, + sizeof(float) * (size_t)w.n_embd); + } else if (mtp_hidden_out) { + mtp_hidden_out->clear(); + } + + return true; +} + +// MTP-only follow-up step. Given the post-MTP hidden of the previous +// proposal and the newly-committed token, draft one more candidate without +// re-running the trunk target. Used to chain MTP proposals after the first +// integrated step accepted, so a `chain_max=2` config commits two tokens +// for one target forward when both MTP picks match what target would have +// chosen. +static bool mtp_only_step(const TargetWeights & w, + TargetMtpCache & mtp_cache, + ggml_backend_t backend, + StepGraph & sg, + int32_t token, + int kv_start, + const std::vector & hidden_in, + int32_t & mtp_next, + std::vector * hidden_out) { + if ((int)hidden_in.size() != w.n_embd) { + std::fprintf(stderr, "mtp only: hidden size=%zu expected=%d\n", + hidden_in.size(), w.n_embd); + return false; + } + + step_graph_free(sg); + + ggml_init_params ip{}; + ip.mem_size = 256 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) { + std::fprintf(stderr, "mtp only: ggml_init failed\n"); + return false; + } + + sg.inp_embed = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, w.n_embd, 1); + sg.hidden_input = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, w.n_embd, 1); + sg.positions = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 4); + ggml_set_name(sg.inp_embed, "mtp_only_token_embed"); + ggml_set_name(sg.hidden_input, "mtp_only_hidden"); + ggml_set_name(sg.positions, "mtp_only_positions"); + ggml_set_input(sg.inp_embed); + ggml_set_input(sg.hidden_input); + ggml_set_input(sg.positions); + + sg.gf = ggml_new_graph_custom(sg.ctx, 2048, false); + + QwenMtpGraphInputs mtp_in{}; + mtp_in.token_embed = sg.inp_embed; + mtp_in.pre_norm_hidden = sg.hidden_input; + mtp_in.positions = sg.positions; + mtp_in.n_tokens = 1; + mtp_in.kv_start = kv_start; + mtp_in.fa_window = g_fa_window; + + QwenMtpGraphOutputs mtp_out = build_qwen35_mtp_graph(sg.ctx, sg.gf, w, mtp_cache, mtp_in); + if (!mtp_out.logits || !mtp_out.hidden) { + std::fprintf(stderr, "mtp only: MTP graph failed: %s\n", dflash27b_last_error()); + return false; + } + sg.mtp_logits = mtp_out.logits; + sg.mtp_hidden = mtp_out.hidden; + + sg.mtp_argmax_tokens = ggml_argmax(sg.ctx, mtp_out.logits); + ggml_set_name(sg.mtp_argmax_tokens, "mtp_only_argmax"); + ggml_set_output(sg.mtp_argmax_tokens); + ggml_build_forward_expand(sg.gf, sg.mtp_argmax_tokens); + + if (!sg.alloc) { + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + std::fprintf(stderr, "mtp only: graph alloc failed\n"); + return false; + } + + std::vector embed((size_t)w.n_embd); + if (!w.embedder.embed(&token, 1, embed.data())) { + std::fprintf(stderr, "mtp only: CPU embed failed for token %d\n", (int)token); + return false; + } + const int32_t pos4[4] = { kv_start, kv_start, kv_start, kv_start }; + ggml_backend_tensor_set(sg.inp_embed, embed.data(), 0, sizeof(float) * embed.size()); + ggml_backend_tensor_set(sg.hidden_input, hidden_in.data(), 0, sizeof(float) * hidden_in.size()); + ggml_backend_tensor_set(sg.positions, pos4, 0, sizeof(pos4)); + + ggml_status st; + { ScopedTimer _tc(MtpTimings::GRAPH_COMPUTE); + st = ggml_backend_graph_compute(backend, sg.gf); + } + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "mtp only: graph compute failed: %d\n", (int)st); + return false; + } + + { ScopedTimer _ta(MtpTimings::ARGMAX_GET); + ggml_backend_tensor_get(sg.mtp_argmax_tokens, &mtp_next, 0, sizeof(mtp_next)); + } + if (hidden_out) { + ScopedTimer _th(MtpTimings::HIDDEN_GET); + hidden_out->assign((size_t)w.n_embd, 0.0f); + ggml_backend_tensor_get(sg.mtp_hidden, hidden_out->data(), 0, + sizeof(float) * (size_t)w.n_embd); + } + + return true; +} + +// Convenience overload: builds a one-shot StepGraph internally. +static bool mtp_only_step(const TargetWeights & w, + TargetMtpCache & mtp_cache, + ggml_backend_t backend, + int32_t token, + int kv_start, + const std::vector & hidden_in, + int32_t & mtp_next, + std::vector * hidden_out) { + StepGraph sg; + const bool ok = mtp_only_step(w, mtp_cache, backend, sg, token, kv_start, + hidden_in, mtp_next, hidden_out); + step_graph_destroy(sg); + return ok; +} + +// ─── MTP chain proposal (GPU unrolled) ─────────────────────────── +// +// When the trunk's tok_embd is GPU-resident (MTP checkpoints upload it), +// we can chain N MTP proposals inside a single ggml_cgraph by using +// ggml_get_rows() against the previous step's argmax tensor as the +// "next token id" input. That removes the per-step CPU embed + +// backend_tensor_set + graph rebuild that mtp_only_step incurs. + +static bool mtp_gpu_get_rows_supported(ggml_type type) { + // ggml-cuda's get_rows kernel only handles non-K quants. K-quants + // (Q4_K / Q5_K / Q6_K — which is what most published GGUFs ship for + // token_embd) fall back to CPU embed via mtp_only_step. Keep this list + // tight so we don't trip the runtime assertion in getrows.cu. + switch (type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ3_0: + return true; + default: + return false; + } +} + +static bool mtp_chain_gpu_available(const TargetWeights & w) { + return g_mtp_gpu_chain && w.tok_embd && w.tok_embd_gpu && + mtp_gpu_get_rows_supported(w.tok_embd->type); +} + +static bool mtp_chain_gpu_step(const TargetWeights & w, + TargetMtpCache & mtp_cache, + ggml_backend_t backend, + StepGraph & sg, + int32_t first_token, + int kv_start, + const std::vector & hidden_in, + int n_steps, + std::vector & out_tokens) { + if (n_steps <= 0) { + out_tokens.clear(); + return true; + } + if ((int)hidden_in.size() != w.n_embd) { + std::fprintf(stderr, "mtp gpu chain: hidden size=%zu expected=%d\n", + hidden_in.size(), w.n_embd); + return false; + } + if (!mtp_chain_gpu_available(w)) { + return false; + } + ScopedTimer _follow(MtpTimings::MTP_PROP_FOLLOW); + std::chrono::steady_clock::time_point _build_t0; + if (g_dflash_mtp_timing) _build_t0 = std::chrono::steady_clock::now(); + + step_graph_free(sg); + + ggml_init_params ip{}; + ip.mem_size = 1024 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) { + std::fprintf(stderr, "mtp gpu chain: ggml_init failed\n"); + return false; + } + + sg.token_ids = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 1); + sg.hidden_input = ggml_new_tensor_2d(sg.ctx, GGML_TYPE_F32, w.n_embd, 1); + sg.positions = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, 4 * n_steps); + ggml_set_name(sg.token_ids, "mtp_gpu_chain_first_token"); + ggml_set_name(sg.hidden_input, "mtp_gpu_chain_hidden"); + ggml_set_name(sg.positions, "mtp_gpu_chain_positions"); + ggml_set_input(sg.token_ids); + ggml_set_input(sg.hidden_input); + ggml_set_input(sg.positions); + + sg.gf = ggml_new_graph_custom(sg.ctx, 32768, false); + + ggml_tensor * prev_token_ids = sg.token_ids; + ggml_tensor * prev_hidden = sg.hidden_input; + sg.mtp_chain_argmax_tokens.clear(); + sg.mtp_chain_argmax_tokens.reserve((size_t)n_steps); + + for (int i = 0; i < n_steps; i++) { + ggml_tensor * token_embed = ggml_get_rows(sg.ctx, w.tok_embd, prev_token_ids); + token_embed = ggml_reshape_2d(sg.ctx, token_embed, w.n_embd, 1); + + ggml_tensor * pos_view = ggml_view_1d(sg.ctx, sg.positions, 4, + (size_t)i * 4 * sizeof(int32_t)); + QwenMtpGraphInputs mtp_in{}; + mtp_in.token_embed = token_embed; + mtp_in.pre_norm_hidden = prev_hidden; + mtp_in.positions = pos_view; + mtp_in.n_tokens = 1; + mtp_in.kv_start = kv_start + i; + mtp_in.fa_window = g_fa_window; + + QwenMtpGraphOutputs mtp_out = build_qwen35_mtp_graph(sg.ctx, sg.gf, w, mtp_cache, mtp_in); + if (!mtp_out.logits || !mtp_out.hidden) { + std::fprintf(stderr, "mtp gpu chain: MTP graph failed at step %d: %s\n", + i, dflash27b_last_error()); + return false; + } + + ggml_tensor * arg = ggml_argmax(sg.ctx, mtp_out.logits); + char name[64]; + std::snprintf(name, sizeof(name), "mtp_gpu_chain_argmax_%d", i); + ggml_set_name(arg, name); + ggml_set_output(arg); + ggml_build_forward_expand(sg.gf, arg); + sg.mtp_chain_argmax_tokens.push_back(arg); + + prev_token_ids = arg; + prev_hidden = mtp_out.hidden; + } + + if (!sg.alloc) { + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) { + std::fprintf(stderr, "mtp gpu chain: graph alloc failed\n"); + return false; + } + if (g_dflash_mtp_timing) { + auto _build_t1 = std::chrono::steady_clock::now(); + uint64_t ns = (uint64_t)std::chrono::duration_cast(_build_t1 - _build_t0).count(); + g_mtp_timings.add(MtpTimings::GRAPH_BUILD_MTP_CHAIN, ns); + } + + std::vector pos4((size_t)4 * n_steps); + for (int i = 0; i < n_steps; i++) { + const int p = kv_start + i; + pos4[(size_t)i * 4 + 0] = p; + pos4[(size_t)i * 4 + 1] = p; + pos4[(size_t)i * 4 + 2] = p; + pos4[(size_t)i * 4 + 3] = p; + } + ggml_backend_tensor_set(sg.token_ids, &first_token, 0, sizeof(first_token)); + ggml_backend_tensor_set(sg.hidden_input, hidden_in.data(), 0, + sizeof(float) * hidden_in.size()); + ggml_backend_tensor_set(sg.positions, pos4.data(), 0, + sizeof(int32_t) * pos4.size()); + + ggml_status st; + { ScopedTimer _tc(MtpTimings::GRAPH_COMPUTE); + st = ggml_backend_graph_compute(backend, sg.gf); + } + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "mtp gpu chain: graph compute failed: %d\n", (int)st); + return false; + } + + out_tokens.assign((size_t)n_steps, -1); + { ScopedTimer _ta(MtpTimings::ARGMAX_GET); + for (int i = 0; i < n_steps; i++) { + ggml_backend_tensor_get(sg.mtp_chain_argmax_tokens[(size_t)i], + &out_tokens[(size_t)i], 0, sizeof(int32_t)); + } + } + mtp_cache.cur_pos = kv_start + n_steps; + return true; +} + +// ─── Target-only AR baseline + native MTP linear loop ──────────── +// +// run_target_ar_prompt : greedy AR decode without MTP. Used as the +// parity baseline for run_mtp_baseline_check. +// run_mtp_integrated_prompt : linear native MTP loop (no DDTree hybrid). +// Per step, one trunk target forward + one NextN +// forward share a single ggml_cgraph. Accepted +// MTP proposals chain via GPU get_rows when +// mtp_chain_gpu_available(w), otherwise via +// mtp_only_step. Falls back to serial verify; +// target-batched verify + fast rollback land +// in the DFlash/MTP hybrid PR. +struct MtpIntegratedRunStats { + std::vector out_ids; + int draft_n = 0; + int accepted = 0; + int corrected = 0; + double seconds = 0.0; +}; + +struct TargetArRunStats { + std::vector out_ids; + double seconds = 0.0; +}; + +// Greedy single-token target step. Allocates its own per-call ctx + gallocr +// (one decode/sec, no MTP); cheap baseline for parity vs the integrated path. +static bool target_only_argmax_step(const TargetWeights & w, + TargetCache & target_cache, + ggml_backend_t backend, + int32_t token, + int kv_start, + int32_t & target_next) { + ScopedTimer _td(MtpTimings::TARGET_DECODE); + ggml_init_params ip{}; + ip.mem_size = 512 * 1024 * 1024; + ip.mem_buffer = nullptr; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) { + std::fprintf(stderr, "target only: ggml_init failed\n"); + return false; + } + + ggml_tensor * inp_embed = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w.n_embd, 1, 1); + ggml_tensor * positions = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); + ggml_set_name(inp_embed, "target_only_inp_embed"); + ggml_set_name(positions, "target_only_positions"); + ggml_set_input(inp_embed); + ggml_set_input(positions); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 8192, false); + + QwenGraphInputs target_in{}; + target_in.inp_embed = inp_embed; + target_in.positions = positions; + target_in.n_tokens = 1; + target_in.kv_start = kv_start; + target_in.fa_window = g_fa_window; + + QwenGraphOutputs target_out = build_qwen35_graph(ctx, gf, w, target_cache, target_in); + if (!target_out.logits) { + std::fprintf(stderr, "target only: target graph failed: %s\n", dflash27b_last_error()); + ggml_free(ctx); + return false; + } + + ggml_tensor * target_argmax = ggml_argmax(ctx, target_out.logits); + ggml_set_name(target_argmax, "target_only_argmax"); + ggml_set_output(target_argmax); + ggml_build_forward_expand(gf, target_argmax); + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + std::fprintf(stderr, "target only: graph alloc failed\n"); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + + std::vector embed((size_t)w.n_embd); + if (!w.embedder.embed(&token, 1, embed.data())) { + std::fprintf(stderr, "target only: CPU embed failed for token %d\n", (int)token); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + const int32_t pos4[4] = { kv_start, kv_start, kv_start, kv_start }; + ggml_backend_tensor_set(inp_embed, embed.data(), 0, sizeof(float) * embed.size()); + ggml_backend_tensor_set(positions, pos4, 0, sizeof(pos4)); + + ggml_status st; + { ScopedTimer _tc(MtpTimings::GRAPH_COMPUTE); + st = ggml_backend_graph_compute(backend, gf); + } + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "target only: graph compute failed: %d\n", (int)st); + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + + { ScopedTimer _ta(MtpTimings::ARGMAX_GET); + ggml_backend_tensor_get(target_argmax, &target_next, 0, sizeof(target_next)); + } + + ggml_gallocr_free(alloc); + ggml_free(ctx); + return true; +} + +static bool run_target_ar_prompt(const TargetWeights & w, + ggml_backend_t backend, + int max_ctx, + const std::vector & prompt, + int n_gen, + int decode_pos_offset, + TargetArRunStats & stats) { + stats = TargetArRunStats{}; + if (prompt.empty()) { + std::fprintf(stderr, "target baseline: empty prompt\n"); + return false; + } + if (n_gen <= 0) { + std::fprintf(stderr, "target baseline: n_gen must be > 0\n"); + return false; + } + decode_pos_offset = std::max(0, decode_pos_offset); + const int needed_ctx = decode_pos_offset + (int)prompt.size() + n_gen + 4; + if (needed_ctx > max_ctx) { + std::fprintf(stderr, + "target baseline: max_ctx=%d too small for offset=%d prompt=%zu n_gen=%d\n", + max_ctx, decode_pos_offset, prompt.size(), n_gen); + return false; + } + + TargetCache target_cache; + if (!create_target_cache(w, max_ctx, /*max_verify_tokens=*/0, backend, + target_cache, /*prefill_only=*/true)) { + std::fprintf(stderr, "target baseline cache: %s\n", dflash27b_last_error()); + return false; + } + + auto t0 = std::chrono::steady_clock::now(); + for (int i = 0; i + 1 < (int)prompt.size(); i++) { + const int pos = decode_pos_offset + i; + int32_t target_next = -1; + if (!target_only_argmax_step(w, target_cache, backend, + prompt[(size_t)i], pos, target_next)) { + free_target_cache(target_cache); + return false; + } + target_cache.cur_pos = pos + 1; + target_cache.last_tok = target_next; + } + + stats.out_ids.reserve((size_t)n_gen); + int32_t current = prompt.back(); + int kv_pos = decode_pos_offset + (int)prompt.size() - 1; + while ((int)stats.out_ids.size() < n_gen) { + int32_t target_next = -1; + if (!target_only_argmax_step(w, target_cache, backend, + current, kv_pos, target_next)) { + free_target_cache(target_cache); + return false; + } + target_cache.cur_pos = kv_pos + 1; + target_cache.last_tok = target_next; + stats.out_ids.push_back(target_next); + current = target_next; + kv_pos++; + if (IS_EOS_TOK(target_next, w)) break; + } + auto t1 = std::chrono::steady_clock::now(); + stats.seconds = std::chrono::duration(t1 - t0).count(); + + free_target_cache(target_cache); + return true; +} + +// Linear MTP integrated decode. One trunk target forward + one NextN +// forward per accepted token, with optional MTP chain proposals on +// accept. Serial verify only; DDTree hybrid + batched target verify + +// fast rollback land in the follow-up PR. +static bool run_mtp_integrated_prompt(const TargetWeights & w, + ggml_backend_t backend, + int max_ctx, + const std::vector & prompt, + int n_gen, + int mtp_draft_n_max, + MtpIntegratedRunStats & stats, + bool step_log, + int decode_pos_offset = 0) { + stats = MtpIntegratedRunStats{}; + if (prompt.empty()) { + std::fprintf(stderr, "mtp integrated: empty prompt\n"); + return false; + } + if (n_gen <= 0) { + std::fprintf(stderr, "mtp integrated: n_gen must be > 0\n"); + return false; + } + if (w.mtp_layers.empty()) { + std::fprintf(stderr, "mtp integrated: target has no nextn/MTP layers\n"); + return false; + } + mtp_draft_n_max = std::max(1, mtp_draft_n_max); + + decode_pos_offset = std::max(0, decode_pos_offset); + const int needed_ctx = decode_pos_offset + (int)prompt.size() + n_gen + 4; + if (needed_ctx > max_ctx) { + std::fprintf(stderr, + "mtp integrated: max_ctx=%d too small for offset=%d prompt=%zu n_gen=%d\n", + max_ctx, decode_pos_offset, prompt.size(), n_gen); + return false; + } + + TargetCache target_cache; + if (!create_target_cache(w, max_ctx, /*max_verify_tokens=*/0, backend, + target_cache, /*prefill_only=*/true)) { + std::fprintf(stderr, "mtp integrated target cache: %s\n", dflash27b_last_error()); + return false; + } + TargetMtpCache mtp_cache; + if (!create_target_mtp_cache(w, max_ctx, backend, mtp_cache)) { + std::fprintf(stderr, "mtp integrated MTP cache: %s\n", dflash27b_last_error()); + free_target_cache(target_cache); + return false; + } + + StepGraph mtp_integrated_sg; + StepGraph mtp_only_sg; + struct LocalCleanup { + StepGraph * a; + StepGraph * b; + ~LocalCleanup() { if (a) step_graph_destroy(*a); if (b) step_graph_destroy(*b); } + } graph_cleanup{&mtp_integrated_sg, &mtp_only_sg}; + + auto t0 = std::chrono::steady_clock::now(); + + // Prefill: drive both caches forward up to the last prompt token. + for (int i = 0; i + 1 < (int)prompt.size(); i++) { + const int pos = decode_pos_offset + i; + int32_t target_next = -1; + int32_t mtp_next = -1; + if (!mtp_integrated_step(w, target_cache, mtp_cache, backend, + mtp_integrated_sg, + prompt[(size_t)i], pos, + target_next, mtp_next)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + target_cache.cur_pos = pos + 1; + mtp_cache.cur_pos = pos + 1; + } + + // Decode loop. + stats.out_ids.reserve((size_t)n_gen); + int32_t current = prompt.back(); + int kv_pos = decode_pos_offset + (int)prompt.size() - 1; + int step = 0; + while ((int)stats.out_ids.size() < n_gen) { + int32_t target_next = -1; + int32_t mtp_next = -1; + std::vector mtp_hidden; + if (!mtp_integrated_step(w, target_cache, mtp_cache, backend, + mtp_integrated_sg, + current, kv_pos, + target_next, mtp_next, &mtp_hidden)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + + stats.draft_n++; + const bool accept = (mtp_next == target_next); + if (accept) stats.accepted++; else stats.corrected++; + const int32_t chosen = accept ? mtp_next : target_next; + stats.out_ids.push_back(chosen); + if (step_log) { + std::printf("[mtp-integrated step=%d draft=0] input=%d mtp=%d target=%d %s chosen=%d\n", + step, (int)current, (int)mtp_next, (int)target_next, + accept ? "ACCEPT" : "CORRECT", (int)chosen); + } + current = chosen; + target_cache.cur_pos = kv_pos + 1; + target_cache.last_tok = chosen; + mtp_cache.cur_pos = kv_pos + 1; + kv_pos++; + if (IS_EOS_TOK(chosen, w)) break; + if (!accept || (int)stats.out_ids.size() >= n_gen) { + step++; + continue; + } + + // Accepted base: try to chain more MTP proposals via GPU get_rows. + // Each accepted proposal is verified against a per-token target + // forward (serial verify — batched verify lands in the hybrid PR). + const int max_chain_steps = std::min( + mtp_draft_n_max - 1, + n_gen - (int)stats.out_ids.size()); + std::vector chain_tokens; + bool used_gpu_chain = false; + if (max_chain_steps > 0 && mtp_chain_gpu_available(w)) { + if (!mtp_chain_gpu_step(w, mtp_cache, backend, mtp_only_sg, + current, kv_pos, mtp_hidden, + max_chain_steps, chain_tokens)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + used_gpu_chain = true; + } + + bool serial_hit_eos = false; + if (used_gpu_chain) { + // Verify chain proposals one by one against target. + for (size_t j = 0; j < chain_tokens.size(); j++) { + if ((int)stats.out_ids.size() >= n_gen) break; + const int32_t mtp_chain_next = chain_tokens[j]; + int32_t target_chain_next = -1; + if (!target_only_argmax_step(w, target_cache, backend, + current, kv_pos, target_chain_next)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + target_cache.cur_pos = kv_pos + 1; + stats.draft_n++; + const bool chain_accept = (mtp_chain_next == target_chain_next); + if (chain_accept) stats.accepted++; else stats.corrected++; + const int32_t chain_chosen = chain_accept ? mtp_chain_next : target_chain_next; + stats.out_ids.push_back(chain_chosen); + if (step_log) { + std::printf("[mtp-integrated step=%d draft=%zu] input=%d mtp=%d target=%d %s chosen=%d\n", + step, j + 1, (int)current, + (int)mtp_chain_next, (int)target_chain_next, + chain_accept ? "ACCEPT" : "CORRECT", + (int)chain_chosen); + } + current = chain_chosen; + target_cache.last_tok = chain_chosen; + kv_pos++; + if (IS_EOS_TOK(chain_chosen, w)) { serial_hit_eos = true; break; } + if (!chain_accept) break; + } + } else if (max_chain_steps > 0) { + // CPU-bridged chain fallback (used when GPU get_rows unavailable). + for (int k = 1; k < mtp_draft_n_max && (int)stats.out_ids.size() < n_gen; k++) { + int32_t mtp_chain_next = -1; + std::vector next_hidden; + const bool need_next_hidden = + k + 1 < mtp_draft_n_max && (int)stats.out_ids.size() + 1 < n_gen; + if (!mtp_only_step(w, mtp_cache, backend, mtp_only_sg, + current, kv_pos, + mtp_hidden, mtp_chain_next, + need_next_hidden ? &next_hidden : nullptr)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + mtp_cache.cur_pos = kv_pos + 1; + + int32_t target_chain_next = -1; + if (!target_only_argmax_step(w, target_cache, backend, + current, kv_pos, target_chain_next)) { + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return false; + } + target_cache.cur_pos = kv_pos + 1; + + stats.draft_n++; + const bool chain_accept = (mtp_chain_next == target_chain_next); + if (chain_accept) stats.accepted++; else stats.corrected++; + const int32_t chain_chosen = chain_accept ? mtp_chain_next : target_chain_next; + stats.out_ids.push_back(chain_chosen); + if (step_log) { + std::printf("[mtp-integrated step=%d draft=%d] input=%d mtp=%d target=%d %s chosen=%d\n", + step, k, (int)current, + (int)mtp_chain_next, (int)target_chain_next, + chain_accept ? "ACCEPT" : "CORRECT", + (int)chain_chosen); + } + current = chain_chosen; + target_cache.last_tok = chain_chosen; + if (need_next_hidden) mtp_hidden.swap(next_hidden); + else mtp_hidden.clear(); + kv_pos++; + if (IS_EOS_TOK(chain_chosen, w)) { serial_hit_eos = true; break; } + if (!chain_accept) break; + } + } + + step++; + if (serial_hit_eos) break; + } + + auto t1 = std::chrono::steady_clock::now(); + stats.seconds = std::chrono::duration(t1 - t0).count(); + + free_target_mtp_cache(mtp_cache); + free_target_cache(target_cache); + return true; +} + +// CLI entry point: runs MTP integrated decode and prints metrics. +static int run_mtp_integrated_cli(const TargetWeights & w, + ggml_backend_t backend, + int max_ctx, + const std::vector & prompt, + int n_gen, + int mtp_draft_n_max, + int decode_pos_offset, + bool step_log) { + MtpIntegratedRunStats st; + if (!run_mtp_integrated_prompt(w, backend, max_ctx, prompt, n_gen, + mtp_draft_n_max, st, step_log, + decode_pos_offset)) { + return 1; + } + const double tps = st.out_ids.size() / std::max(1e-9, st.seconds); + const double acc = st.draft_n > 0 ? 100.0 * st.accepted / st.draft_n : 0.0; + std::printf("[mtp-integrated] generated=%zu draft_n=%d accepted=%d corrected=%d " + "acceptance=%.1f%% tok/s=%.2f seconds=%.4f draft_n_max=%d\n", + st.out_ids.size(), st.draft_n, st.accepted, st.corrected, + acc, tps, st.seconds, mtp_draft_n_max); + return 0; +} + +// Parity gate: runs the AR baseline and the integrated MTP loop on the +// same prompt, compares the output token-by-token, and emits the lines +// `[mtp-baseline] baseline ...`, `[mtp-baseline] integrated ...`, and +// either `compare_ok` or `compare_fail`. Parsed by mtp_baseline_gate.py. +static int run_mtp_baseline_check(const TargetWeights & w, + ggml_backend_t backend, + int max_ctx, + const std::vector & prompt, + int n_gen, + int mtp_draft_n_max, + int decode_pos_offset, + bool step_log) { + if (prompt.empty()) { + std::fprintf(stderr, "mtp baseline: empty prompt\n"); + return 2; + } + if (w.mtp_layers.empty()) { + std::fprintf(stderr, "mtp baseline: target has no nextn/MTP layers\n"); + return 1; + } + + TargetArRunStats baseline; + if (!run_target_ar_prompt(w, backend, max_ctx, prompt, n_gen, + decode_pos_offset, baseline)) { + return 1; + } + + MtpIntegratedRunStats mtp; + if (!run_mtp_integrated_prompt(w, backend, max_ctx, prompt, n_gen, + mtp_draft_n_max, mtp, step_log, + decode_pos_offset)) { + return 1; + } + + const double baseline_tps = baseline.out_ids.size() / std::max(1e-9, baseline.seconds); + const double mtp_tps = mtp.out_ids.size() / std::max(1e-9, mtp.seconds); + const double ratio = mtp_tps / std::max(1e-9, baseline_tps); + const double acceptance = mtp.draft_n > 0 ? 100.0 * mtp.accepted / mtp.draft_n : 0.0; + + size_t mismatch = SIZE_MAX; + const size_t n_cmp = std::min(baseline.out_ids.size(), mtp.out_ids.size()); + for (size_t i = 0; i < n_cmp; i++) { + if (baseline.out_ids[i] != mtp.out_ids[i]) { mismatch = i; break; } + } + if (mismatch == SIZE_MAX && baseline.out_ids.size() != mtp.out_ids.size()) { + mismatch = n_cmp; + } + + std::printf("[mtp-baseline] baseline generated=%zu tok/s=%.2f seconds=%.4f\n", + baseline.out_ids.size(), baseline_tps, baseline.seconds); + std::printf("[mtp-baseline] integrated generated=%zu draft_n=%d accepted=%d " + "corrected=%d acceptance=%.1f%% tok/s=%.2f seconds=%.4f " + "draft_n_max=%d speed_ratio=%.3fx\n", + mtp.out_ids.size(), mtp.draft_n, mtp.accepted, mtp.corrected, + acceptance, mtp_tps, mtp.seconds, + std::max(1, mtp_draft_n_max), ratio); + + if (mismatch != SIZE_MAX) { + const int baseline_tok = mismatch < baseline.out_ids.size() + ? (int)baseline.out_ids[mismatch] : -999999; + const int mtp_tok = mismatch < mtp.out_ids.size() + ? (int)mtp.out_ids[mismatch] : -999999; + std::fprintf(stderr, + "[mtp-baseline] compare_fail mismatch=%zu baseline=%d integrated=%d " + "baseline_len=%zu integrated_len=%zu\n", + mismatch, baseline_tok, mtp_tok, + baseline.out_ids.size(), mtp.out_ids.size()); + return 1; + } + + std::printf("[mtp-baseline] compare_ok tokens=%zu mismatches=0\n", + baseline.out_ids.size()); + return 0; +} + // ─── Main ───────────────────────────────────────────────────────── int main(int argc, char ** argv) { @@ -2507,6 +3573,95 @@ int main(int argc, char ** argv) { else if (std::strncmp(argv[i], "--draft-ctx-max=", 16) == 0) { g_draft_ctx_max = std::max(0, std::atoi(argv[i] + 16)); } + // ─── Native MTP (NextN) CLI flags ─────────────────────── + else if (std::strcmp(argv[i], "--mtp-integrated") == 0) { + // Handled below — see the dedicated MTP dispatch block. + } + else if (std::strcmp(argv[i], "--mtp-baseline-check") == 0) { + // Handled below — see the dedicated MTP dispatch block. + } + else if (std::strncmp(argv[i], "--mtp-draft-n-max=", 18) == 0 || + std::strncmp(argv[i], "--mtp-draft-n=", 14) == 0) { + // Both spellings accepted; parsed later in the MTP dispatch. + } + else if (std::strncmp(argv[i], "--decode-pos-offset=", 20) == 0) { + // Parsed in MTP dispatch. + } + else if (std::strcmp(argv[i], "--mtp-step-log") == 0) { + // Parsed in MTP dispatch. + } + else if (std::strcmp(argv[i], "--mtp-no-fast-rollback") == 0) { + g_mtp_fast_rollback = false; + } + else if (std::strcmp(argv[i], "--mtp-moe-long-fast-rollback") == 0) { + g_mtp_moe_long_fast_rollback = true; + } + else if (std::strcmp(argv[i], "--mtp-serial-commit") == 0) { + g_mtp_serial_commit = true; + } + else if (std::strcmp(argv[i], "--mtp-no-gpu-chain") == 0) { + g_mtp_gpu_chain = false; + } + else if (std::strcmp(argv[i], "--dflash-mtp-timing") == 0) { + g_dflash_mtp_timing = true; + } + // DFlash/MTP hybrid flags (parsed here for stable CLI surface; the + // wiring that consumes them ships in the follow-up PR). + else if (std::strcmp(argv[i], "--dflash-mtp-hybrid") == 0) { + g_dflash_mtp_hybrid = true; + } + else if (std::strcmp(argv[i], "--dflash-mtp-tree-fused") == 0) { + g_dflash_mtp_tree_fused = true; + } + else if (std::strcmp(argv[i], "--dflash-mtp-seed-priority") == 0) { + g_dflash_mtp_seed_priority = true; + } + else if (std::strcmp(argv[i], "--dflash-mtp-immediate-bonus") == 0) { + g_dflash_mtp_immediate_bonus = true; + } + else if (std::strncmp(argv[i], "--dflash-mtp-chain-max=", 23) == 0) { + g_dflash_mtp_chain_max = std::max(1, std::atoi(argv[i] + 23)); + } + else if (std::strncmp(argv[i], "--dflash-mtp-bonus-min-margin=", 30) == 0) { + g_dflash_mtp_bonus_min_margin = (float)std::atof(argv[i] + 30); + } + } + + // Also accept DFLASH_MTP_TIMING=1 env var so it can be turned on without + // changing the CLI surface during long-running benchmarks. + if (const char * s = std::getenv("DFLASH_MTP_TIMING")) { + if (std::atoi(s) != 0) g_dflash_mtp_timing = true; + } + + // Second pass: pick up MTP dispatch flags (kept separate from the main + // arg loop so we can early-return into the MTP CLI before constructing + // the full DFlash machinery). + bool mtp_integrated_mode = false; + bool mtp_baseline_check_mode = false; + int mtp_draft_n_max = 4; + int mtp_decode_pos_offset = 0; + bool mtp_step_log = false; + const char * mtp_prompt_path_arg = nullptr; + int mtp_n_gen_arg = 0; + for (int i = flags_start; i < argc; i++) { + if (std::strcmp(argv[i], "--mtp-integrated") == 0) { + mtp_integrated_mode = true; + } else if (std::strcmp(argv[i], "--mtp-baseline-check") == 0) { + mtp_baseline_check_mode = true; + mtp_integrated_mode = true; // baseline check implies MTP-side runtime + } else if (std::strncmp(argv[i], "--mtp-draft-n-max=", 18) == 0) { + mtp_draft_n_max = std::max(1, std::atoi(argv[i] + 18)); + } else if (std::strncmp(argv[i], "--mtp-draft-n=", 14) == 0) { + mtp_draft_n_max = std::max(1, std::atoi(argv[i] + 14)); + } else if (std::strncmp(argv[i], "--decode-pos-offset=", 20) == 0) { + mtp_decode_pos_offset = std::max(0, std::atoi(argv[i] + 20)); + } else if (std::strcmp(argv[i], "--mtp-step-log") == 0) { + mtp_step_log = true; + } else if (std::strncmp(argv[i], "--prompt-file=", 14) == 0) { + mtp_prompt_path_arg = argv[i] + 14; + } else if (std::strncmp(argv[i], "--n-gen=", 8) == 0) { + mtp_n_gen_arg = std::max(0, std::atoi(argv[i] + 8)); + } } // The KV type may also have been chosen via -ctk/-ctv, which sets @@ -2525,7 +3680,8 @@ int main(int argc, char ** argv) { g_kq_stride_pad = 256; } - if (!is_laguna && !daemon_mode && !test_window_mode && (!prompt_path || !out_path)) { + if (!is_laguna && !daemon_mode && !test_window_mode && !mtp_integrated_mode && + (!prompt_path || !out_path)) { std::fprintf(stderr, "Missing positional arguments for non-daemon mode.\n"); return 2; } @@ -2669,6 +3825,58 @@ int main(int argc, char ** argv) { } std::printf("[target] %s\n", dflash27b_last_error()); + // ─── Native MTP dispatch (no DFlash drafter needed) ────────── + if (mtp_integrated_mode) { + const char * mtp_prompt_path = mtp_prompt_path_arg ? mtp_prompt_path_arg : prompt_path; + int mtp_n_gen = mtp_n_gen_arg > 0 ? mtp_n_gen_arg : n_gen; + if (!mtp_prompt_path || mtp_n_gen <= 0) { + std::fprintf(stderr, + "mtp mode: need a prompt (positional or --prompt-file=) " + "and n_gen (positional or --n-gen=)\n"); + free_target_weights(w); + ggml_backend_free(target_backend); + return 2; + } + std::vector prompt = read_int32_file(mtp_prompt_path); + if (prompt.empty()) { + std::fprintf(stderr, "mtp mode: failed to read prompt from %s\n", + mtp_prompt_path); + free_target_weights(w); + ggml_backend_free(target_backend); + return 1; + } + const int max_ctx_eff = g_max_ctx_override > 0 ? g_max_ctx_override : 4096; + int rc; + if (mtp_baseline_check_mode) { + rc = run_mtp_baseline_check(w, target_backend, max_ctx_eff, prompt, + mtp_n_gen, mtp_draft_n_max, + mtp_decode_pos_offset, mtp_step_log); + } else { + rc = run_mtp_integrated_cli(w, target_backend, max_ctx_eff, prompt, + mtp_n_gen, mtp_draft_n_max, + mtp_decode_pos_offset, mtp_step_log); + } + + // Emit timing summary if --dflash-mtp-timing was set. + if (g_dflash_mtp_timing) { + std::printf("[dflash+mtp-timing]"); + for (int s = 0; s < MtpTimings::NUM_SLOTS; s++) { + const auto & a = g_mtp_timings.slots[s]; + if (a.count == 0) continue; + std::printf(" %s=%.3fms(n=%llu,max=%.3fms)", + MtpTimings::slot_name(s), + a.ns_total / 1e6, + (unsigned long long)a.count, + a.ns_max / 1e6); + } + std::printf("\n"); + } + + free_target_weights(w); + ggml_backend_free(target_backend); + return rc; + } + DraftWeights dw; { // Auto-detect draft format: .gguf → GGUF loader, else safetensors. @@ -3863,7 +5071,7 @@ int main(int argc, char ** argv) { if (!split_gpus) { cudaSetDevice(draft_gpu); - auto bf16_to_f32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + auto bf16_to_f32 = dflash27b_get_to_fp32_cuda(GGML_TYPE_BF16); bf16_to_f32( (const char *)cache.target_feat->data + (size_t)slot0 * row_bf16, (float *)draft_sg.target_hidden_cat->data, @@ -4256,7 +5464,7 @@ int main(int argc, char ** argv) { (size_t)rollback_dfs * cap.ssm_intermediate_states->nb[3]; const void * ssm_src = (const char *)cap.ssm_intermediate_states->data + ssm_src_offset; - ggml_get_to_fp32_cuda(cap.ssm_intermediate_states->type)( + dflash27b_get_to_fp32_cuda(cap.ssm_intermediate_states->type)( ssm_src, (float *)cache.ssm_state[il]->data, (int64_t)ssm_elems, stream); cudaError_t ce = cudaSuccess; // launch error checked in the conv block below @@ -4576,7 +5784,7 @@ int main(int argc, char ** argv) { (size_t)rollback_idx * cap.ssm_intermediate_states->nb[3]; const void * ssm_src = (const char *)cap.ssm_intermediate_states->data + ssm_src_offset; - ggml_get_to_fp32_cuda(cap.ssm_intermediate_states->type)( + dflash27b_get_to_fp32_cuda(cap.ssm_intermediate_states->type)( ssm_src, (float *)cache.ssm_state[il]->data, (int64_t)ssm_elems, stream); cudaError_t ce = cudaSuccess; From b0356289f78d5e3669e5a645ee8382ec748e7a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Paz=C3=B3?= Date: Mon, 11 May 2026 16:14:25 +0200 Subject: [PATCH 3/4] feat(dflash): --dflash-mtp-policy=auto regression guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a small policy gate around the linear MTP CLI so it does the right thing when invoked on short generations. Problem - Linear MTP regresses against the AR baseline on n_gen < ~192 because graph build cost dominates decode wall time and the linear path here rebuilds the verify graph per step. The bucket-cached fast path that fixes this is the follow-up perf PR (see HANDOFF doc, "Waves A-D"). - Without a guard, a benign `test_dflash --mtp-integrated --n-gen=64` invocation prints worse-than-baseline numbers and confuses anyone reading the metrics line. Fix - New `--dflash-mtp-policy=auto|always|never` (default `always` — keeps existing behaviour byte-identical). - New `--dflash-mtp-policy-min-n=N` (default 192) tunes the auto cutoff. - Under `auto`, the CLI dispatches to `run_target_ar_prompt` instead of `run_mtp_integrated_prompt` when `n_gen < min_n`, and emits a clear `[mtp-policy] auto: n_gen=N < min_n=M, falling back to AR baseline` line plus a `mode=baseline-ar` tag in the final metrics so artifact parsers can route accordingly. - The `compare_ok` parity path in `run_mtp_baseline_check` is not affected: it explicitly tests the integrated decode against AR, so it always runs MTP regardless of policy. Smoke (RTX 6000 Ada, Qwen3.6-27B-MTP Q4_K_M, n_gen=16): [mtp-policy] auto: n_gen=16 < min_n=192, falling back to AR baseline [mtp-integrated] policy=auto-fallback generated=16 tok/s=10.60 seconds=1.5101 mode=baseline-ar Once the bucket-cache PR lands and short-generation MTP is competitive, the default can flip to `auto` with a lower `min_n`. Until then, this keeps the CLI honest. --- dflash/test/test_dflash.cpp | 157 +++++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 4 deletions(-) diff --git a/dflash/test/test_dflash.cpp b/dflash/test/test_dflash.cpp index 2bf7ebeec..cc428b4e8 100644 --- a/dflash/test/test_dflash.cpp +++ b/dflash/test/test_dflash.cpp @@ -203,6 +203,14 @@ static bool g_dflash_mtp_seed_priority = false; static bool g_dflash_mtp_immediate_bonus = false; static int g_dflash_mtp_chain_max = 15; static float g_dflash_mtp_bonus_min_margin = 0.0f; +// Policy gate that protects against MTP regressing vs the AR baseline on +// short generations. Linear MTP without the bucket-cached verify path +// loses to plain AR on n_gen < ~192 because graph build cost is ~81% of +// decode wall time. `auto` activates MTP only when n_gen >= threshold; +// `always` and `never` force it on/off regardless. +enum class MtpPolicy { Auto, Always, Never }; +static MtpPolicy g_dflash_mtp_policy = MtpPolicy::Always; +static int g_dflash_mtp_policy_min_n = 192; // ─── MTP/DDTree timing instrumentation ─────────────────────────── // Activated via --dflash-mtp-timing or DFLASH_MTP_TIMING=1. When disabled @@ -272,6 +280,25 @@ struct ScopedTimer { ScopedTimer & operator=(const ScopedTimer &) = delete; }; +// Argmax-1 minus argmax-2 of a float vector. Used by the DFlash/MTP +// hybrid immediate-bonus gate to skip committing the MTP seed when the +// trunk's top-1 isn't clearly separated from top-2. +static float top1_top2_margin_f32(const float * x, int n) { + if (n <= 1) return 0.0f; + float best = -INFINITY; + float second = -INFINITY; + for (int i = 0; i < n; i++) { + const float v = x[i]; + if (v > best) { + second = best; + best = v; + } else if (v > second) { + second = v; + } + } + return best - second; +} + static int align_up(int x, int a) { return ((x + a - 1) / a) * a; } // F16 encoding for the two values we use: 0 and -inf. @@ -403,7 +430,10 @@ struct DDTree { static DDTree build_ddtree(const float * top_log_probs, const int32_t * top_token_ids, int L, int K, int budget, - bool chain_seed = true) { + bool chain_seed = true, + const int32_t * extra_chain = nullptr, + int extra_chain_len = 0, + bool extra_chain_priority = false) { DDTree tree; if (budget <= 0 || L <= 0) { tree.parents.push_back(-1); @@ -438,6 +468,31 @@ static DDTree build_ddtree(const float * top_log_probs, tree.parents.push_back(-1); // root tree.child_maps.emplace_back(); // root's children + // MTP-seeded chain has priority over the draft top-1 spine. When native + // MTP already matched the target's next token, its continuation is the + // only part that can reduce future DDTree rounds. Insert it before the + // defensive DFlash chain seed so the latter cannot consume the full budget. + if (extra_chain_priority && extra_chain && extra_chain_len > 0) { + int prev_idx = 0; + const int chain_depth = std::min(L, extra_chain_len); + for (int d = 1; d <= chain_depth && tree.n_nodes < budget; d++) { + const int32_t tok_id = extra_chain[d - 1]; + auto it = tree.child_maps[prev_idx].find(tok_id); + if (it != tree.child_maps[prev_idx].end()) { + prev_idx = it->second; + continue; + } + const int cur_idx = tree.n_nodes + 1; + tree.token_ids.push_back(tok_id); + tree.depths.push_back(d); + tree.parents.push_back(prev_idx); + tree.child_maps.emplace_back(); + tree.child_maps[prev_idx][tok_id] = cur_idx; + tree.n_nodes++; + prev_idx = cur_idx; + } + } + // Two seeding strategies: // - chain_seed=true: pre-seed full top-1 chain (defensive, guarantees // AL >= chain mode even with flat-softmax draft like Q4_K_M). Compensates @@ -449,10 +504,30 @@ static DDTree build_ddtree(const float * top_log_probs, const int chain_depth = std::min(L, budget); float cum_logw = 0.0f; int prev_idx = 0; - for (int d = 1; d <= chain_depth; d++) { + for (int d = 1; d <= chain_depth && tree.n_nodes < budget; d++) { const int32_t tok_id = top_token_ids[(size_t)(d - 1) * K + 0]; cum_logw += top_log_probs[(size_t)(d - 1) * K + 0]; + // Skip if extra_chain_priority already inserted this node. + auto it = tree.child_maps[prev_idx].find(tok_id); + if (it != tree.child_maps[prev_idx].end()) { + if (K > 1) { + const float sibling_logw = cum_logw + - top_log_probs[(size_t)(d - 1) * K + 0] + + top_log_probs[(size_t)(d - 1) * K + 1]; + heap.push({ + /*neg_logw*/ -sibling_logw, + /*ranks */ {1}, + /*parent */ prev_idx, + /*depth */ d, + /*rank */ 1, + /*logw */ sibling_logw, + }); + } + prev_idx = it->second; + continue; + } + const int cur_idx = tree.n_nodes + 1; tree.token_ids.push_back(tok_id); tree.depths.push_back(d); @@ -489,6 +564,30 @@ static DDTree build_ddtree(const float * top_log_probs, }); } + // Fallback extra_chain insertion (lower priority than the spine). + // Used when extra_chain_priority=false so the MTP seed still gets + // a chance to land if there's budget left after the DFlash spine. + if (!extra_chain_priority && extra_chain && extra_chain_len > 0) { + int prev_idx = 0; + const int chain_depth = std::min(L, extra_chain_len); + for (int d = 1; d <= chain_depth && tree.n_nodes < budget; d++) { + const int32_t tok_id = extra_chain[d - 1]; + auto it = tree.child_maps[prev_idx].find(tok_id); + if (it != tree.child_maps[prev_idx].end()) { + prev_idx = it->second; + continue; + } + const int cur_idx = tree.n_nodes + 1; + tree.token_ids.push_back(tok_id); + tree.depths.push_back(d); + tree.parents.push_back(prev_idx); + tree.child_maps.emplace_back(); + tree.child_maps[prev_idx][tok_id] = cur_idx; + tree.n_nodes++; + prev_idx = cur_idx; + } + } + while (!heap.empty() && tree.n_nodes < budget) { HeapEntry top = heap.top(); heap.pop(); @@ -496,6 +595,12 @@ static DDTree build_ddtree(const float * top_log_probs, const int depth_minus_1 = top.depth - 1; const int rank = top.rank; const int32_t token_id = top_token_ids[(size_t)depth_minus_1 * K + rank]; + // Avoid re-inserting a duplicate child that an extra_chain or chain_seed + // already placed under this parent. + if (tree.child_maps[top.parent_index].find(token_id) != + tree.child_maps[top.parent_index].end()) { + continue; + } const int current_index = tree.n_nodes + 1; // slot in flat tree tree.token_ids.push_back(token_id); @@ -3263,6 +3368,8 @@ static bool run_mtp_integrated_prompt(const TargetWeights & w, } // CLI entry point: runs MTP integrated decode and prints metrics. +// Honors --dflash-mtp-policy=auto|always|never: under `auto`, falls back to +// the AR baseline when n_gen is below the regression threshold (default 192). static int run_mtp_integrated_cli(const TargetWeights & w, ggml_backend_t backend, int max_ctx, @@ -3271,6 +3378,29 @@ static int run_mtp_integrated_cli(const TargetWeights & w, int mtp_draft_n_max, int decode_pos_offset, bool step_log) { + const bool use_mtp = + (g_dflash_mtp_policy == MtpPolicy::Always) || + (g_dflash_mtp_policy == MtpPolicy::Auto && + n_gen >= g_dflash_mtp_policy_min_n); + if (g_dflash_mtp_policy == MtpPolicy::Auto && !use_mtp) { + std::printf("[mtp-policy] auto: n_gen=%d < min_n=%d, falling back to AR baseline\n", + n_gen, g_dflash_mtp_policy_min_n); + } + + if (!use_mtp) { + TargetArRunStats ar; + if (!run_target_ar_prompt(w, backend, max_ctx, prompt, n_gen, + decode_pos_offset, ar)) { + return 1; + } + const double tps = ar.out_ids.size() / std::max(1e-9, ar.seconds); + std::printf("[mtp-integrated] policy=%s generated=%zu tok/s=%.2f seconds=%.4f " + "mode=baseline-ar\n", + g_dflash_mtp_policy == MtpPolicy::Never ? "never" : "auto-fallback", + ar.out_ids.size(), tps, ar.seconds); + return 0; + } + MtpIntegratedRunStats st; if (!run_mtp_integrated_prompt(w, backend, max_ctx, prompt, n_gen, mtp_draft_n_max, st, step_log, @@ -3279,8 +3409,12 @@ static int run_mtp_integrated_cli(const TargetWeights & w, } const double tps = st.out_ids.size() / std::max(1e-9, st.seconds); const double acc = st.draft_n > 0 ? 100.0 * st.accepted / st.draft_n : 0.0; - std::printf("[mtp-integrated] generated=%zu draft_n=%d accepted=%d corrected=%d " - "acceptance=%.1f%% tok/s=%.2f seconds=%.4f draft_n_max=%d\n", + const char * policy_name = + g_dflash_mtp_policy == MtpPolicy::Always ? "always" : + g_dflash_mtp_policy == MtpPolicy::Auto ? "auto" : "never"; + std::printf("[mtp-integrated] policy=%s generated=%zu draft_n=%d accepted=%d " + "corrected=%d acceptance=%.1f%% tok/s=%.2f seconds=%.4f draft_n_max=%d\n", + policy_name, st.out_ids.size(), st.draft_n, st.accepted, st.corrected, acc, tps, st.seconds, mtp_draft_n_max); return 0; @@ -3625,6 +3759,21 @@ int main(int argc, char ** argv) { else if (std::strncmp(argv[i], "--dflash-mtp-bonus-min-margin=", 30) == 0) { g_dflash_mtp_bonus_min_margin = (float)std::atof(argv[i] + 30); } + else if (std::strncmp(argv[i], "--dflash-mtp-policy=", 20) == 0) { + const char * v = argv[i] + 20; + if (std::strcmp(v, "auto") == 0) g_dflash_mtp_policy = MtpPolicy::Auto; + else if (std::strcmp(v, "always") == 0) g_dflash_mtp_policy = MtpPolicy::Always; + else if (std::strcmp(v, "never") == 0) g_dflash_mtp_policy = MtpPolicy::Never; + else { + std::fprintf(stderr, + "bad --dflash-mtp-policy value (use auto|always|never): %s\n", + v); + return 2; + } + } + else if (std::strncmp(argv[i], "--dflash-mtp-policy-min-n=", 26) == 0) { + g_dflash_mtp_policy_min_n = std::max(1, std::atoi(argv[i] + 26)); + } } // Also accept DFLASH_MTP_TIMING=1 env var so it can be turned on without From 2f4ede79c1ffc8c56d3c43d57fdac169dd6efd8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Paz=C3=B3?= Date: Mon, 11 May 2026 16:18:51 +0200 Subject: [PATCH 4/4] docs(dflash): linear MTP bench numbers + reproduce instructions Single-GPU smoke run on RTX 6000 Ada with Qwen3.6-27B-MTP Q4_K_M, synthetic 64-token prompt: n_gen Baseline AR MTP chain-2 speed_ratio acceptance 64 9.48 12.04 1.270x 81.2% 128 14.48 15.33 1.059x 81.2% 256 15.11 17.38 1.151x 82.4% All three runs: compare_ok tokens=N mismatches=0. This is vs the AR baseline on the same MTP GGUF, not DFlash-classic. The honest framing (and pointers to the Waves A-D / P1 work that would let MTP beat DFlash-classic) is in the doc itself. --- dflash/docs/MTP_LINEAR_BENCH_2026-05-11.md | 79 ++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 dflash/docs/MTP_LINEAR_BENCH_2026-05-11.md diff --git a/dflash/docs/MTP_LINEAR_BENCH_2026-05-11.md b/dflash/docs/MTP_LINEAR_BENCH_2026-05-11.md new file mode 100644 index 000000000..233fe4c4e --- /dev/null +++ b/dflash/docs/MTP_LINEAR_BENCH_2026-05-11.md @@ -0,0 +1,79 @@ +# Linear MTP integrated decode — measured numbers, 2026-05-11 + +Validation run for `feat(dflash): linear native MTP integrated decode CLI` +(PR #154, stacked on PR #153 native MTP runtime). + +## Setup + +- **GPU**: NVIDIA RTX 6000 Ada Generation, sm_89, 48 GiB +- **Build**: Release, `DFLASH27B_USER_CUDA_ARCHITECTURES=89`, + ninja + MSVC 19.44 + CUDA 12.8 +- **Target model**: `Qwen3.6-27B-MTP-Q4_K_M.gguf` + (`am17an/Qwen3.6-27B-MTP-GGUF` lineage) +- **Mode**: `--mtp-integrated --mtp-baseline-check --mtp-draft-n-max=2 --dflash-mtp-policy=always` +- **Context**: `--max-ctx=512`, `FA_WINDOW=2048` (default), KV `q8_0/q8_0` (default) +- **Prompt**: synthetic 64-token sequence of distinct ids +- **Method**: AR baseline (`run_target_ar_prompt`) vs integrated MTP loop + (`run_mtp_integrated_prompt` with `mtp_draft_n_max=2`); same GGUF, same + config; outputs compared token-by-token (`compare_ok` / `compare_fail`). + +## Results + +| n_gen | Baseline AR (tok/s) | MTP chain-2 (tok/s) | speed_ratio | MTP acceptance | +|---:|---:|---:|---:|---:| +| 64 | 9.48 | 12.04 | **1.270x** | 81.2% | +| 128 | 14.48 | 15.33 | **1.059x** | 81.2% | +| 256 | 15.11 | 17.38 | **1.151x** | 82.4% | + +All three runs report `[mtp-baseline] compare_ok tokens=N mismatches=0`, +i.e. the integrated MTP loop produces the byte-identical greedy output to +the AR baseline on this prompt. + +## Honest caveats + +- **This is vs the AR baseline on the same MTP GGUF**, not against + DFlash-classic + PFlash on a plain Qwen3.6-27B Q4_K_M. The AR baseline + is slower than DFlash-classic; a PR-publishable speedup claim against + DFlash-classic needs the `target_verify` graph-bucket cache work + ("Waves A-D" / P1 in `MTP_ACCELERATION_ROADMAP_FOR_NEXT_AI_2026-05-11.md`). +- **Single-prompt result.** Workload sensitivity matters for MTP + acceptance; published prompt families (P3 of the roadmap, multi-prompt + gate) are pending. +- **Linear path only.** No DDTree hybrid, no immediate-bonus, no batched + target verify, no bucketed graphs. This is the parity-correct floor + the next PR builds on. +- **`--dflash-mtp-policy=auto`** is wired (P2 of the roadmap) but defaults + to `always` for now since the smoke shows MTP wins at every measured + `n_gen` against this baseline. The auto threshold (`min_n=192` default) + remains conservative for prompt families that haven't been measured. + +## Reproduce + +```bash +# Build +cmake -S dflash -B dflash/build -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DDFLASH27B_USER_CUDA_ARCHITECTURES=89 +cmake --build dflash/build --target test_dflash + +# Synthetic prompt +python -c "import struct; \ +ids = list(range(1000, 1064)); \ +open('prompt.bin','wb').write(b''.join(struct.pack('