diff --git a/server/src/common/dflash_target.h b/server/src/common/dflash_target.h index 56fd4bec..2b7b148c 100644 --- a/server/src/common/dflash_target.h +++ b/server/src/common/dflash_target.h @@ -33,6 +33,15 @@ struct DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) = 0; + // Read the full [n_tokens x vocab] f32 logits produced by the most + // recent verify_batch call. Used by sampled-verify (spec decode with + // temperature). Returns false when the implementation does not keep + // verify logits around. + virtual bool read_verify_logits(int n_tokens, std::vector & out) { + (void)n_tokens; (void)out; + return false; + } + // ── KV state management ───────────────────────────────────────── // Snapshot KV cache state before speculative verify, so it can be diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index c22b37ed..4c656e98 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -1348,16 +1348,38 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, // out-of-bounds tensor read. cache_.last_tok is always correct. int32_t last_tok = cache_.last_tok; + // Sampled-verify: spec decode with an active sampler. Each chain + // position is verified against a token drawn from the target's own + // sampler chain instead of its argmax, so every committed token is an + // exact target sample — the output distribution is identical to AR + // sampling. Acceptance drops vs greedy but stays far above the AR + // floor. Opt in with DFLASH_SAMPLED_VERIFY=1; without it, sampling + // requests fall back to AR decode (zero behavior change by default). + static const bool kSampledVerify = []() { + const char * e = std::getenv("DFLASH_SAMPLED_VERIFY"); + return e != nullptr && std::string(e) == "1"; + }(); + // Sampled-verify additionally requires full attention in the verify + // path. With a finite --fa-window the verify batch applies one + // window-start to the whole batch (unlike the AR step graph, which is + // hardcoded to full attention): the argmax stays robust, so greedy + // verification is unaffected, but the logit TAIL drifts at long + // context and top-k sampling draws degenerate tokens from it + // (reproduced at 24K: 0/12 tool calls with fa-window 2048, 4/4 with 0). + const bool sampled_verify = kSampledVerify && + sampler_.needs_logit_processing() && + cfg_.fa_window == 0; + // Check if we can use speculative decode: // - draft model loaded and not parked // - feature mirror initialized - // - greedy decoding (no logit processing) — spec decode uses argmax verification + // - greedy decoding, or sampled-verify enabled const bool can_spec = cfg_.draft_path && !draft_parked_ && (cfg_.remote_draft.enabled() ? remote_draft_.active() : feature_mirror_.target_feat != nullptr) - && !sampler_.needs_logit_processing(); + && (!sampler_.needs_logit_processing() || sampled_verify); if (!can_spec) { // AR fallback consumes the final prefill position itself, then advances @@ -1371,6 +1393,31 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, } out_spec_ran = true; + + // Sampled-verify: cache_.last_tok is do_prefill's argmax, and the spec + // loop commits it verbatim as the first generated token. The first token + // is the highest-entropy decision of the whole generation (e.g. "answer + // with text" vs "open a tool call"), so it must be sampled like every + // other committed token — mirror do_ar_decode's first-token sampling. + if (sampled_verify && out_tokens.empty() && prefill_last_logits_valid_) { + std::vector first_logits(w_.n_vocab); + ggml_backend_tensor_get(sg_.logits, first_logits.data(), + prefill_last_logits_offset_, + sizeof(float) * (size_t)w_.n_vocab); + if (std::getenv("DFLASH_SV_DEBUG")) { + int am = 0; float best = first_logits[0]; + for (int v = 1; v < w_.n_vocab; v++) + if (first_logits[v] > best) { best = first_logits[v]; am = v; } + std::fprintf(stderr, + "[sv-debug] first-token: logits_argmax=%d cache_last_tok=%d " + "(match=%d) top_logit=%.3f\n", + am, cache_.last_tok, am == cache_.last_tok, best); + } + last_tok = sample_logits(first_logits.data(), w_.n_vocab, sampler_, + out_tokens, sampler_rng_); + cache_.last_tok = last_tok; + } + const int _min_floor = dflash_min_tokens_floor(); // ── DFlash spec-decode: draft → verify → accept → replay ────────── @@ -1385,6 +1432,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, std::vector noise_ids(q_len); std::vector draft_tok(q_len); std::vector target_tok(q_len); + std::vector verify_logits; // sampled-verify: [q_len x vocab] + std::vector verify_history; // sampled-verify: penalty history std::vector pos_q(q_len); std::vector pos_k; std::vector local_hidden; @@ -1567,18 +1616,86 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, return false; } - // 5. Acceptance: longest matching prefix between draft and target argmax + // 5. Acceptance. Greedy: longest matching prefix between draft and + // target argmax. Sampled-verify: walk the chain drawing each next + // token from the target's sampler chain; accept while the draft + // guessed the drawn token, and the first mismatch becomes the bonus + // token (it is already a valid target sample at that position). int accept_n = 1; - for (int i = 0; i < q_len - 1; i++) { - if (draft_tok[i + 1] == target_tok[i]) accept_n++; - else break; + int bonus_tok = -1; + if (sampled_verify) { + if (!target->read_verify_logits(q_len, verify_logits)) { + std::fprintf(stderr, "spec-decode: verify logits read failed\n"); + target->restore_kv(); + step_graph_destroy(draft_sg); + return false; + } + const int vocab_v = (int)(verify_logits.size() / (size_t)q_len); + static const bool kSvDebug = []() { + const char * e = std::getenv("DFLASH_SV_DEBUG"); + return e != nullptr && std::string(e) == "1"; + }(); + if (kSvDebug) { + // Row-alignment check: CPU argmax over each bulk-read row must + // equal the GPU argmax (target_tok). Divergence = misaligned + // or stale bulk read. + for (int i = 0; i < q_len; i++) { + const float * row = verify_logits.data() + (size_t)i * vocab_v; + int am = 0; float best = row[0]; + for (int v = 1; v < vocab_v; v++) + if (row[v] > best) { best = row[v]; am = v; } + if (am != target_tok[i]) { + std::fprintf(stderr, + "[sv-debug] ROW MISMATCH i=%d cpu_argmax=%d (%.3f) " + "gpu_argmax=%d (%.3f) vocab_v=%d\n", + i, am, best, target_tok[i], + target_tok[i] < vocab_v ? row[target_tok[i]] : -999.0f, + vocab_v); + break; + } + } + } + // Penalty history must match AR exactly: when AR samples the + // token after X, X is already in out_tokens. The seed + // draft_tok[0] is committed by this step's replay but not yet + // in out_tokens, so add it before the walk — without it the + // repetition penalty never sees the seed and the sampled + // distribution drifts from AR whenever penalties are active. + verify_history = out_tokens; + verify_history.push_back(draft_tok[0]); + bool mismatched = false; + for (int i = 0; i < q_len - 1; i++) { + const int s = sample_logits( + verify_logits.data() + (size_t)i * vocab_v, vocab_v, + sampler_, verify_history, sampler_rng_); + if (kSvDebug && n_draft_steps < 3 && i < 4) { + std::fprintf(stderr, + "[sv-debug] step=%d pos=%d seed/draft0=%d draft=%d " + "sampled=%d\n", + n_draft_steps, i, draft_tok[0], draft_tok[i + 1], s); + } + if (draft_tok[i + 1] == s) { + accept_n++; + verify_history.push_back(s); + } else { + bonus_tok = s; + mismatched = true; + break; + } + } + (void)mismatched; + } else { + for (int i = 0; i < q_len - 1; i++) { + if (draft_tok[i + 1] == target_tok[i]) accept_n++; + else break; + } + bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; } // Track hint acceptance telemetry. if (hint_fill > 0) { n_hint_proposed += hint_fill; n_hint_accepted += std::min(hint_fill, accept_n - 1); } - int bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1; int commit_n = accept_n + (bonus_tok >= 0 ? 1 : 0); if (commit_n > need_commit_budget) { commit_n = need_commit_budget; @@ -1699,7 +1816,23 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, // next draft step. The floor_to_ar path never reaches the next // iteration — it sets cache_.last_tok directly below and returns — // so last_tok is intentionally left untouched when flooring. - last_tok = replay_last_tok; + // + // Sampled-verify: the seed is committed as-is by the next step + // (draft_tok[0]), so it must itself be a sample from the target + // distribution. replay_last_tok is the argmax — seeding with it + // injects one greedy token per step, which biases the output and + // locks long generations into repetition loops. + if (sampled_verify && !replay_tok.empty() && + target->read_verify_logits((int)replay_tok.size(), verify_logits)) { + const int vocab_v = + (int)(verify_logits.size() / replay_tok.size()); + last_tok = sample_logits( + verify_logits.data() + + (replay_tok.size() - 1) * (size_t)vocab_v, + vocab_v, sampler_, out_tokens, sampler_rng_); + } else { + last_tok = replay_last_tok; + } committed += emitted; } cache_.cur_pos = committed; diff --git a/server/src/qwen35/qwen35_dflash_target.cpp b/server/src/qwen35/qwen35_dflash_target.cpp index 65713d1b..1657dba5 100644 --- a/server/src/qwen35/qwen35_dflash_target.cpp +++ b/server/src/qwen35/qwen35_dflash_target.cpp @@ -99,6 +99,16 @@ bool Qwen35DFlashTarget::verify_batch( return true; } +bool Qwen35DFlashTarget::read_verify_logits(int n_tokens, std::vector & out) { + if (!sg_.logits || n_tokens <= 0) return false; + const int64_t vocab = sg_.logits->ne[0]; + if (n_tokens > (int)sg_.logits->ne[1]) return false; + out.resize((size_t)n_tokens * (size_t)vocab); + ggml_backend_tensor_get(sg_.logits, out.data(), 0, + sizeof(float) * out.size()); + return true; +} + bool Qwen35DFlashTarget::snapshot_kv() { snapshot_ssm_state(cache_); return true; diff --git a/server/src/qwen35/qwen35_dflash_target.h b/server/src/qwen35/qwen35_dflash_target.h index 6a72e48b..5db682ee 100644 --- a/server/src/qwen35/qwen35_dflash_target.h +++ b/server/src/qwen35/qwen35_dflash_target.h @@ -37,6 +37,8 @@ class Qwen35DFlashTarget : public DFlashTarget { int & last_tok, std::vector * all_argmax = nullptr) override; + bool read_verify_logits(int n_tokens, std::vector & out) override; + bool snapshot_kv() override; bool restore_kv() override;