From 9d6300213b7ddc69b79919add63226a623f2d947 Mon Sep 17 00:00:00 2001 From: Rhonstin Date: Fri, 12 Jun 2026 14:04:45 +0300 Subject: [PATCH] =?UTF-8?q?qwen35:=20sampled-verify=20=E2=80=94=20speculat?= =?UTF-8?q?ive=20decoding=20with=20an=20active=20sampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The chain verify path commits the target's argmax at every position, so spec decode only matches the target distribution at temperature 0. Any sampling request had to fall back to AR decode (16 tok/s on a 27B Q4_K_M / RTX 3090), while greedy spec ran at 60-70 tok/s. This adds a sampled-verify mode (DFLASH_SAMPLED_VERIFY=1): at each chain position the verifier draws a token from the target's full sampler chain (temperature / top-k / top-p / penalties) over that position's verify logits, and accepts the draft token iff the sampled token matches it. Every committed token is therefore an exact sample from the target distribution — the output distribution is identical to AR sampling, sample-and-match style. On mismatch the sampled token itself becomes the bonus token, so a rejected position still commits one valid sample. Three places must sample rather than take argmax, and getting any of them wrong corrupts generation in distinctive ways: - First token after prefill: it used to be the prefill argmax. With sampling enabled this injects one greedy token per request; at temperature 1 every generation opened identically. Now sampled from the prefill logits. - Per-step seed: the next draft step is seeded with the replay's last token, which the next verify commits as-is. Seeding with the replay argmax injects one greedy token per step (~1/16 of output), which biases the distribution and locks long generations into repetition loops. Now sampled from the replay's last verify logits. - The acceptance walk itself, over per-position verify logits exposed via a new DFlashTarget::read_verify_logits() hook (default-off; the qwen35 target reads them from the step graph's logits tensor). Sampled-verify requires full-attention verify (fa_window == 0). With a finite fa-window the windowed verify pass starves the logit tail at long context: argmax stays intact, but sampling from the poisoned tail degrades quality — on a 24K-token agent prompt, tool-call success went 0/12 with a finite window vs 12/12 with full attention. The mode is gated off unless fa_window == 0; prefill and AR decode are unaffected (they always run full attention). Measured on RTX 3090 (Qwen3.6-27B Q4_K_M + DFlash draft, temp 0.7-1.0): 16 -> 60-70 tok/s on short prompts, 16 -> ~31 tok/s at 24K context. Greedy path is bit-exact unchanged. Token histograms over 150 sampled generations match AR within noise. DFLASH_SV_DEBUG=1 traces the acceptance walk per position. Implements the "active sampler during verification" future addition from the README. --- server/src/common/dflash_target.h | 9 ++ server/src/qwen35/qwen35_backend.cpp | 149 +++++++++++++++++++-- server/src/qwen35/qwen35_dflash_target.cpp | 10 ++ server/src/qwen35/qwen35_dflash_target.h | 2 + 4 files changed, 162 insertions(+), 8 deletions(-) diff --git a/server/src/common/dflash_target.h b/server/src/common/dflash_target.h index 56fd4bece..2b7b148c4 100644 --- a/server/src/common/dflash_target.h +++ b/server/src/common/dflash_target.h @@ -33,6 +33,15 @@ struct DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) = 0; + // Read the full [n_tokens x vocab] f32 logits produced by the most + // recent verify_batch call. Used by sampled-verify (spec decode with + // temperature). Returns false when the implementation does not keep + // verify logits around. + virtual bool read_verify_logits(int n_tokens, std::vector & out) { + (void)n_tokens; (void)out; + return false; + } + // ── KV state management ───────────────────────────────────────── // Snapshot KV cache state before speculative verify, so it can be diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index c22b37ed5..4c656e987 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -1348,16 +1348,38 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, // out-of-bounds tensor read. cache_.last_tok is always correct. int32_t last_tok = cache_.last_tok; + // Sampled-verify: spec decode with an active sampler. Each chain + // position is verified against a token drawn from the target's own + // sampler chain instead of its argmax, so every committed token is an + // exact target sample — the output distribution is identical to AR + // sampling. Acceptance drops vs greedy but stays far above the AR + // floor. Opt in with DFLASH_SAMPLED_VERIFY=1; without it, sampling + // requests fall back to AR decode (zero behavior change by default). + static const bool kSampledVerify = []() { + const char * e = std::getenv("DFLASH_SAMPLED_VERIFY"); + return e != nullptr && std::string(e) == "1"; + }(); + // Sampled-verify additionally requires full attention in the verify + // path. With a finite --fa-window the verify batch applies one + // window-start to the whole batch (unlike the AR step graph, which is + // hardcoded to full attention): the argmax stays robust, so greedy + // verification is unaffected, but the logit TAIL drifts at long + // context and top-k sampling draws degenerate tokens from it + // (reproduced at 24K: 0/12 tool calls with fa-window 2048, 4/4 with 0). + const bool sampled_verify = kSampledVerify && + sampler_.needs_logit_processing() && + cfg_.fa_window == 0; + // Check if we can use speculative decode: // - draft model loaded and not parked // - feature mirror initialized - // - greedy decoding (no logit processing) — spec decode uses argmax verification + // - greedy decoding, or sampled-verify enabled const bool can_spec = cfg_.draft_path && !draft_parked_ && (cfg_.remote_draft.enabled() ? remote_draft_.active() : feature_mirror_.target_feat != nullptr) - && !sampler_.needs_logit_processing(); + && (!sampler_.needs_logit_processing() || sampled_verify); if (!can_spec) { // AR fallback consumes the final prefill position itself, then advances @@ -1371,6 +1393,31 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, } out_spec_ran = true; + + // Sampled-verify: cache_.last_tok is do_prefill's argmax, and the spec + // loop commits it verbatim as the first generated token. The first token + // is the highest-entropy decision of the whole generation (e.g. "answer + // with text" vs "open a tool call"), so it must be sampled like every + // other committed token — mirror do_ar_decode's first-token sampling. + if (sampled_verify && out_tokens.empty() && prefill_last_logits_valid_) { + std::vector first_logits(w_.n_vocab); + ggml_backend_tensor_get(sg_.logits, first_logits.data(), + prefill_last_logits_offset_, + sizeof(float) * (size_t)w_.n_vocab); + if (std::getenv("DFLASH_SV_DEBUG")) { + int am = 0; float best = first_logits[0]; + for (int v = 1; v < w_.n_vocab; v++) + if (first_logits[v] > best) { best = first_logits[v]; am = v; } + std::fprintf(stderr, + "[sv-debug] first-token: logits_argmax=%d cache_last_tok=%d " + "(match=%d) top_logit=%.3f\n", + am, cache_.last_tok, am == cache_.last_tok, best); + } + last_tok = sample_logits(first_logits.data(), w_.n_vocab, sampler_, + out_tokens, sampler_rng_); + cache_.last_tok = last_tok; + } + const int _min_floor = dflash_min_tokens_floor(); // ── DFlash spec-decode: draft → verify → accept → replay ────────── @@ -1385,6 +1432,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, std::vector noise_ids(q_len); std::vector draft_tok(q_len); std::vector target_tok(q_len); + std::vector verify_logits; // sampled-verify: [q_len x vocab] + std::vector verify_history; // sampled-verify: penalty history std::vector pos_q(q_len); std::vector pos_k; std::vector local_hidden; @@ -1567,18 +1616,86 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, return false; } - // 5. Acceptance: longest matching prefix between draft and target argmax + // 5. Acceptance. Greedy: longest matching prefix between draft and + // target argmax. Sampled-verify: walk the chain drawing each next + // token from the target's sampler chain; accept while the draft + // guessed the drawn token, and the first mismatch becomes the bonus + // token (it is already a valid target sample at that position). int accept_n = 1; - for (int i = 0; i < q_len - 1; i++) { - if (draft_tok[i + 1] == target_tok[i]) accept_n++; - else break; + int bonus_tok = -1; + if (sampled_verify) { + if (!target->read_verify_logits(q_len, verify_logits)) { + std::fprintf(stderr, "spec-decode: verify logits read failed\n"); + target->restore_kv(); + step_graph_destroy(draft_sg); + return false; + } + const int vocab_v = (int)(verify_logits.size() / (size_t)q_len); + static const bool kSvDebug = []() { + const char * e = std::getenv("DFLASH_SV_DEBUG"); + return e != nullptr && std::string(e) == "1"; + }(); + if (kSvDebug) { + // Row-alignment check: CPU argmax over each bulk-read row must + // equal the GPU argmax (target_tok). Divergence = misaligned + // or stale bulk read. + for (int i = 0; i < q_len; i++) { + const float * row = verify_logits.data() + (size_t)i * vocab_v; + int am = 0; float best = row[0]; + for (int v = 1; v < vocab_v; v++) + if (row[v] > best) { best = row[v]; am = v; } + if (am != target_tok[i]) { + std::fprintf(stderr, + "[sv-debug] ROW MISMATCH i=%d cpu_argmax=%d (%.3f) " + "gpu_argmax=%d (%.3f) vocab_v=%d\n", + i, am, best, target_tok[i], + target_tok[i] < vocab_v ? row[target_tok[i]] : -999.0f, + vocab_v); + break; + } + } + } + // Penalty history must match AR exactly: when AR samples the + // token after X, X is already in out_tokens. The seed + // draft_tok[0] is committed by this step's replay but not yet + // in out_tokens, so add it before the walk — without it the + // repetition penalty never sees the seed and the sampled + // distribution drifts from AR whenever penalties are active. + verify_history = out_tokens; + verify_history.push_back(draft_tok[0]); + bool mismatched = false; + for (int i = 0; i < q_len - 1; i++) { + const int s = sample_logits( + verify_logits.data() + (size_t)i * vocab_v, vocab_v, + sampler_, verify_history, sampler_rng_); + if (kSvDebug && n_draft_steps < 3 && i < 4) { + std::fprintf(stderr, + "[sv-debug] step=%d pos=%d seed/draft0=%d draft=%d " + "sampled=%d\n", + n_draft_steps, i, draft_tok[0], draft_tok[i + 1], s); + } + if (draft_tok[i + 1] == s) { + accept_n++; + verify_history.push_back(s); + } else { + bonus_tok = s; + mismatched = true; + break; + } + } + (void)mismatched; + } else { + for (int i = 0; i < q_len - 1; i++) { + if (draft_tok[i + 1] == target_tok[i]) accept_n++; + else break; + } + bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; } // Track hint acceptance telemetry. if (hint_fill > 0) { n_hint_proposed += hint_fill; n_hint_accepted += std::min(hint_fill, accept_n - 1); } - int bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; int commit_n = accept_n + (bonus_tok >= 0 ? 1 : 0); if (commit_n > need_commit_budget) { commit_n = need_commit_budget; @@ -1699,7 +1816,23 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, // next draft step. The floor_to_ar path never reaches the next // iteration — it sets cache_.last_tok directly below and returns — // so last_tok is intentionally left untouched when flooring. - last_tok = replay_last_tok; + // + // Sampled-verify: the seed is committed as-is by the next step + // (draft_tok[0]), so it must itself be a sample from the target + // distribution. replay_last_tok is the argmax — seeding with it + // injects one greedy token per step, which biases the output and + // locks long generations into repetition loops. + if (sampled_verify && !replay_tok.empty() && + target->read_verify_logits((int)replay_tok.size(), verify_logits)) { + const int vocab_v = + (int)(verify_logits.size() / replay_tok.size()); + last_tok = sample_logits( + verify_logits.data() + + (replay_tok.size() - 1) * (size_t)vocab_v, + vocab_v, sampler_, out_tokens, sampler_rng_); + } else { + last_tok = replay_last_tok; + } committed += emitted; } cache_.cur_pos = committed; diff --git a/server/src/qwen35/qwen35_dflash_target.cpp b/server/src/qwen35/qwen35_dflash_target.cpp index 65713d1bb..1657dba54 100644 --- a/server/src/qwen35/qwen35_dflash_target.cpp +++ b/server/src/qwen35/qwen35_dflash_target.cpp @@ -99,6 +99,16 @@ bool Qwen35DFlashTarget::verify_batch( return true; } +bool Qwen35DFlashTarget::read_verify_logits(int n_tokens, std::vector & out) { + if (!sg_.logits || n_tokens <= 0) return false; + const int64_t vocab = sg_.logits->ne[0]; + if (n_tokens > (int)sg_.logits->ne[1]) return false; + out.resize((size_t)n_tokens * (size_t)vocab); + ggml_backend_tensor_get(sg_.logits, out.data(), 0, + sizeof(float) * out.size()); + return true; +} + bool Qwen35DFlashTarget::snapshot_kv() { snapshot_ssm_state(cache_); return true; diff --git a/server/src/qwen35/qwen35_dflash_target.h b/server/src/qwen35/qwen35_dflash_target.h index 6a72e48b5..5db682ee8 100644 --- a/server/src/qwen35/qwen35_dflash_target.h +++ b/server/src/qwen35/qwen35_dflash_target.h @@ -37,6 +37,8 @@ class Qwen35DFlashTarget : public DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) override; + bool read_verify_logits(int n_tokens, std::vector & out) override; + bool snapshot_kv() override; bool restore_kv() override;