Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions server/src/common/dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ struct DFlashTarget {
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) = 0;

// Read the full [n_tokens x vocab] f32 logits produced by the most
// recent verify_batch call. Used by sampled-verify (spec decode with
// temperature). Returns false when the implementation does not keep
// verify logits around.
virtual bool read_verify_logits(int n_tokens, std::vector<float> & out) {
(void)n_tokens; (void)out;
return false;
}

// ── KV state management ─────────────────────────────────────────

// Snapshot KV cache state before speculative verify, so it can be
Expand Down
149 changes: 141 additions & 8 deletions server/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1348,16 +1348,38 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
// out-of-bounds tensor read. cache_.last_tok is always correct.
int32_t last_tok = cache_.last_tok;

// Sampled-verify: spec decode with an active sampler. Each chain
// position is verified against a token drawn from the target's own
// sampler chain instead of its argmax, so every committed token is an
// exact target sample — the output distribution is identical to AR
// sampling. Acceptance drops vs greedy but stays far above the AR
// floor. Opt in with DFLASH_SAMPLED_VERIFY=1; without it, sampling
// requests fall back to AR decode (zero behavior change by default).
static const bool kSampledVerify = []() {
const char * e = std::getenv("DFLASH_SAMPLED_VERIFY");
return e != nullptr && std::string(e) == "1";
}();
// Sampled-verify additionally requires full attention in the verify
// path. With a finite --fa-window the verify batch applies one
// window-start to the whole batch (unlike the AR step graph, which is
// hardcoded to full attention): the argmax stays robust, so greedy
// verification is unaffected, but the logit TAIL drifts at long
// context and top-k sampling draws degenerate tokens from it
// (reproduced at 24K: 0/12 tool calls with fa-window 2048, 4/4 with 0).
const bool sampled_verify = kSampledVerify &&
sampler_.needs_logit_processing() &&
cfg_.fa_window == 0;

// Check if we can use speculative decode:
// - draft model loaded and not parked
// - feature mirror initialized
// - greedy decoding (no logit processing) — spec decode uses argmax verification
// - greedy decoding, or sampled-verify enabled
const bool can_spec = cfg_.draft_path
&& !draft_parked_
&& (cfg_.remote_draft.enabled()
? remote_draft_.active()
: feature_mirror_.target_feat != nullptr)
&& !sampler_.needs_logit_processing();
&& (!sampler_.needs_logit_processing() || sampled_verify);

if (!can_spec) {
// AR fallback consumes the final prefill position itself, then advances
Expand All @@ -1371,6 +1393,31 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
}

out_spec_ran = true;

// Sampled-verify: cache_.last_tok is do_prefill's argmax, and the spec
// loop commits it verbatim as the first generated token. The first token
// is the highest-entropy decision of the whole generation (e.g. "answer
// with text" vs "open a tool call"), so it must be sampled like every
// other committed token — mirror do_ar_decode's first-token sampling.
if (sampled_verify && out_tokens.empty() && prefill_last_logits_valid_) {
std::vector<float> first_logits(w_.n_vocab);
ggml_backend_tensor_get(sg_.logits, first_logits.data(),
prefill_last_logits_offset_,
sizeof(float) * (size_t)w_.n_vocab);
if (std::getenv("DFLASH_SV_DEBUG")) {
int am = 0; float best = first_logits[0];
for (int v = 1; v < w_.n_vocab; v++)
if (first_logits[v] > best) { best = first_logits[v]; am = v; }
std::fprintf(stderr,
"[sv-debug] first-token: logits_argmax=%d cache_last_tok=%d "
"(match=%d) top_logit=%.3f\n",
am, cache_.last_tok, am == cache_.last_tok, best);
}
last_tok = sample_logits(first_logits.data(), w_.n_vocab, sampler_,
out_tokens, sampler_rng_);
cache_.last_tok = last_tok;
}

const int _min_floor = dflash_min_tokens_floor();

// ── DFlash spec-decode: draft → verify → accept → replay ──────────
Expand All @@ -1385,6 +1432,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
std::vector<int32_t> noise_ids(q_len);
std::vector<int32_t> draft_tok(q_len);
std::vector<int32_t> target_tok(q_len);
std::vector<float> verify_logits; // sampled-verify: [q_len x vocab]
std::vector<int32_t> verify_history; // sampled-verify: penalty history
std::vector<int32_t> pos_q(q_len);
std::vector<int32_t> pos_k;
std::vector<float> local_hidden;
Expand Down Expand Up @@ -1567,18 +1616,86 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
return false;
}

// 5. Acceptance: longest matching prefix between draft and target argmax
// 5. Acceptance. Greedy: longest matching prefix between draft and
// target argmax. Sampled-verify: walk the chain drawing each next
// token from the target's sampler chain; accept while the draft
// guessed the drawn token, and the first mismatch becomes the bonus
// token (it is already a valid target sample at that position).
int accept_n = 1;
for (int i = 0; i < q_len - 1; i++) {
if (draft_tok[i + 1] == target_tok[i]) accept_n++;
else break;
int bonus_tok = -1;
if (sampled_verify) {
if (!target->read_verify_logits(q_len, verify_logits)) {
std::fprintf(stderr, "spec-decode: verify logits read failed\n");
target->restore_kv();
step_graph_destroy(draft_sg);
return false;
}
const int vocab_v = (int)(verify_logits.size() / (size_t)q_len);
static const bool kSvDebug = []() {
const char * e = std::getenv("DFLASH_SV_DEBUG");
return e != nullptr && std::string(e) == "1";
}();
if (kSvDebug) {
// Row-alignment check: CPU argmax over each bulk-read row must
// equal the GPU argmax (target_tok). Divergence = misaligned
// or stale bulk read.
for (int i = 0; i < q_len; i++) {
const float * row = verify_logits.data() + (size_t)i * vocab_v;
int am = 0; float best = row[0];
for (int v = 1; v < vocab_v; v++)
if (row[v] > best) { best = row[v]; am = v; }
if (am != target_tok[i]) {
std::fprintf(stderr,
"[sv-debug] ROW MISMATCH i=%d cpu_argmax=%d (%.3f) "
"gpu_argmax=%d (%.3f) vocab_v=%d\n",
i, am, best, target_tok[i],
target_tok[i] < vocab_v ? row[target_tok[i]] : -999.0f,
vocab_v);
break;
}
}
}
// Penalty history must match AR exactly: when AR samples the
// token after X, X is already in out_tokens. The seed
// draft_tok[0] is committed by this step's replay but not yet
// in out_tokens, so add it before the walk — without it the
// repetition penalty never sees the seed and the sampled
// distribution drifts from AR whenever penalties are active.
verify_history = out_tokens;
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
verify_history.push_back(draft_tok[0]);
bool mismatched = false;
for (int i = 0; i < q_len - 1; i++) {
const int s = sample_logits(
verify_logits.data() + (size_t)i * vocab_v, vocab_v,
sampler_, verify_history, sampler_rng_);
if (kSvDebug && n_draft_steps < 3 && i < 4) {
std::fprintf(stderr,
"[sv-debug] step=%d pos=%d seed/draft0=%d draft=%d "
"sampled=%d\n",
n_draft_steps, i, draft_tok[0], draft_tok[i + 1], s);
}
if (draft_tok[i + 1] == s) {
accept_n++;
verify_history.push_back(s);
} else {
bonus_tok = s;
mismatched = true;
break;
}
}
(void)mismatched;
} else {
for (int i = 0; i < q_len - 1; i++) {
if (draft_tok[i + 1] == target_tok[i]) accept_n++;
else break;
}
bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1;
}
// Track hint acceptance telemetry.
if (hint_fill > 0) {
n_hint_proposed += hint_fill;
n_hint_accepted += std::min(hint_fill, accept_n - 1);
}
int bonus_tok = (accept_n < q_len) ? target_tok[accept_n - 1] : -1;
int commit_n = accept_n + (bonus_tok >= 0 ? 1 : 0);
if (commit_n > need_commit_budget) {
commit_n = need_commit_budget;
Expand Down Expand Up @@ -1699,7 +1816,23 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
// next draft step. The floor_to_ar path never reaches the next
// iteration — it sets cache_.last_tok directly below and returns —
// so last_tok is intentionally left untouched when flooring.
last_tok = replay_last_tok;
//
// Sampled-verify: the seed is committed as-is by the next step
// (draft_tok[0]), so it must itself be a sample from the target
// distribution. replay_last_tok is the argmax — seeding with it
// injects one greedy token per step, which biases the output and
// locks long generations into repetition loops.
if (sampled_verify && !replay_tok.empty() &&
target->read_verify_logits((int)replay_tok.size(), verify_logits)) {
const int vocab_v =
(int)(verify_logits.size() / replay_tok.size());
last_tok = sample_logits(
verify_logits.data() +
(replay_tok.size() - 1) * (size_t)vocab_v,
vocab_v, sampler_, out_tokens, sampler_rng_);
} else {
last_tok = replay_last_tok;
}
committed += emitted;
}
cache_.cur_pos = committed;
Expand Down
10 changes: 10 additions & 0 deletions server/src/qwen35/qwen35_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ bool Qwen35DFlashTarget::verify_batch(
return true;
}

bool Qwen35DFlashTarget::read_verify_logits(int n_tokens, std::vector<float> & out) {
if (!sg_.logits || n_tokens <= 0) return false;
const int64_t vocab = sg_.logits->ne[0];
if (n_tokens > (int)sg_.logits->ne[1]) return false;
out.resize((size_t)n_tokens * (size_t)vocab);
ggml_backend_tensor_get(sg_.logits, out.data(), 0,
sizeof(float) * out.size());
return true;
}

bool Qwen35DFlashTarget::snapshot_kv() {
snapshot_ssm_state(cache_);
return true;
Expand Down
2 changes: 2 additions & 0 deletions server/src/qwen35/qwen35_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Qwen35DFlashTarget : public DFlashTarget {
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) override;

bool read_verify_logits(int n_tokens, std::vector<float> & out) override;

bool snapshot_kv() override;
bool restore_kv() override;

Expand Down
Loading