perf(sampling): opt-in fast verify path for topk=1 chain spec#133
Draft
cicirori wants to merge 2 commits into
Draft
perf(sampling): opt-in fast verify path for topk=1 chain spec#133cicirori wants to merge 2 commits into
cicirori wants to merge 2 commits into
Conversation
Targeted optimization for the *chain-spec* (eagle_topk = 1) configuration:
draft is already argmax-greedy, so the verify pass can mirror plain
decode's sampling instead of running the generic tree-rejection chain.
Plain decode in this backend already uses flashinfer's joint fused kernel
(top_k_top_p_sampling_from_logits with filter_apply_order="joint"), but
spec verify falls through the sequential
softmax -> top_k_renorm_prob -> top_p_renorm_prob ->
chain_speculative_sampling_target_only
chain which scans the full vocab twice at the renorm steps. On large-vocab
models (vocab ≈ 200K) this dominates the verify cost.
Add an opt-in path under TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1 that, in
the topk=1 chain-spec case, samples one target token per chain position
via the joint kernel and verifies by token-equality against the
argmax-greedy draft candidates using the existing verify_chain_greedy
kernel from the greedy branch. The bonus token at the first mismatched
position is the sampled target there. Same kernel pair the greedy path
already runs, so CUDA graph compatibility is unchanged.
Scope and applicability
-----------------------
This is *only* applicable when:
* chain spec: tree-spec topk > 1 keeps the generic rejection kernel
* topk = 1: the per-position single-candidate assumption is what makes
token-equality verify equivalent to draft-greedy rejection sampling
* non-greedy sampling: greedy path is already fast (argmax + verify)
* no grammar mask: per-position vocab masks need the renorm chain;
falls back to the original sequential path automatically
Measurements (M2.5plus-fp4 NVFP4, TP=2 B200, vocab=200054, c=32 / 200 reqs)
---------------------------------------------------------------------------
Sampling kernel level, per verify, inside the captured CUDA graph
(measured with nsys --cuda-graph-trace=node):
Sequential path (baseline):
OnlineSoftmaxMap+Reduce 89 us
RadixTopKRenormProb 165 us
AirTopPRenormRadix+Apply 463 us
ChainSpecSampling 84 us
ChainSpec init/etc 146 us
--------------------------------
total per verify ~947 us
Fast path (this PR):
gather_and_expand 2 us
TopKTopPSamplingFromProb 76 us
VerifyChainGreedy 2 us
--------------------------------
total per verify ~80 us
~11.7x sampling-kernel speedup per verify call.
End-to-end throughput (3 reps each, no nsys):
Sequential: avg 1667 tok/s (stddev 31, 1.9%)
Fast: avg 1684 tok/s (stddev 21, 1.3%)
Delta: +1.0% (within 1 stddev)
Note: kernel-level speedup is large but E2E gain is modest because
sampling now occupies less than 1% of decode iter wall in this backend
(MoE + attention + allreduce dominate). The fast path is most useful as
a free-correctness alternative when the sequential renorm chain shows
up as a hotspot in profiling, or for future configs where sampling
becomes a larger fraction of the iter.
Numerical semantics
-------------------
Both this path and the sequential chain_speculative_sampling_target_only
path are unbiased rejection samplers on a chain-spec draft with topk=1
(draft is a greedy point mass at argmax(target_probs)). They differ in
how top-k and top-p are applied:
* sequential: top_k_renorm_prob then top_p_renorm_prob, so the top-p
cutoff is computed on the top-k-renormalized distribution.
* joint (this PR): flashinfer applies top-k AND top-p simultaneously
on the *original* (pre-renorm) distribution.
For top_k=50/top_p=0.95 the two filtered distributions differ by ~5%
total variation (empirically 0.046 on a vocab=1024 synthetic target).
This brings verify in line with plain decode's filter semantics: same
prompt + sampling params produce statistically equivalent samples
whether spec is on or off.
Inspired by stock TensorRT-LLM 1.3.0rc14's
`sampling_batch_spec_dec_one_model` (`allow_advanced_sampling=true` +
`eagle3_one_model=true`), which uses the same chain-spec + token-eq
verify shape. This PR additionally upgrades to flashinfer's joint mode
(stock trtllm still uses the default top_k_first which dispatches to
two kernels).
Opt-in via TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1, default off. Read once
at import time. Falls back to the original sequential path for vocab_mask
requests (grammar bitmasks need the per-position renorm chain) and for
all-greedy batches (already on a faster path).
Tests
-----
test/ops/test_fast_chain_sampling.py verifies via Monte Carlo on 262,144
trials per check:
1. fast path's empirical output matches the joint-filter analytic
reference (TV < 0.02 in noise floor).
2. sequential path's empirical output matches the sequential-filter
analytic reference (TV < 0.02).
3. joint vs sequential filter gap is bounded by 0.10 TV and strictly
non-zero, documenting the deliberate semantic shift.
All three pass.
Signed-off-by: cicirori <yliu@together.ai>
…tics Switch the opt-in fast verify path from filter_apply_order="joint" to "top_k_first" so it preserves the original sequential renorm chain's distribution semantics (top-k then top-p on the top-k-renormalized distribution) rather than introducing a ~5% total-variation shift. This also matches stock TRTLLM 1.3.0rc14's verify sampling (tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py uses the same flashinfer API with the same filter_apply_order). Performance impact: - nsys in-graph (CUDA-graph node trace): joint and top_k_first land in the same band at small/medium batches. micro-bench at full batch (bs=64, vocab=200K, top_k=50, top_p=0.95) shows joint ~19% faster at pure kernel level, but at E2E the sampling kernel is <1% of decode wall, so both filter modes land at 1.7K tok/s ± noise (≥1σ overlap across three independent 200-prompt c=32 runs). Numerical: top_k_first is bit-equivalent in filter to the sequential chain it replaces; the previous joint variant introduced a ~5% TV shift at top_k=50/top_p=0.95 that this avoids. Tests updated: the joint-vs-sequential gap test is dropped (no longer applicable since both paths use the same filter ordering), and the fast path test now compares against the sequential analytic reference. Signed-off-by: cicirori <yliu@together.ai>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Opt-in optimization for chain-spec (
eagle_topk = 1) only. In this configuration the draft is always argmax-greedy, so the verify pass can collapse the sequential renorm chain into a single fused-sample kernel followed by a constant-time token-equality verify.Today the spec verify path goes through
which scans full vocab twice at the renorm steps. For vocab≈200K (M2.5/MiniMax) this dominates verify sampling cost.
This PR adds an opt-in branch in
verify()underTOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1that, in the topk=1 chain-spec case:top_k_top_p_sampling_from_logits(filter_apply_order="top_k_first")— the same flashinfer API and filter mode that stock TRTLLM 1.3.0rc14 uses for verifyverify_chain_greedykernel (same kernel the greedy branch already uses)Scope and applicability
Only applies when all of the following hold:
speculative_eagle_topk == 1(chain spec) — tree spec topk>1 keeps the generic rejection kernelverify_chain_greedykernelFor any other configuration,
verify()falls through to the existing sequential path with no behavior change.Measured impact (M2.5plus-fp4 NVFP4, TP=2 B200, vocab=200054, c=32 / 200 reqs)
Sampling kernel level (per verify, inside captured CUDA graph)
Captured via
nsys profile --cuda-graph-trace=nodeat typical full batch (gridX=128):The fast-path range covers small-batch (~55 µs, well-sampled at gridX=4) up to extrapolated full-batch (~90 µs). nsys event buffers overflow aggressively at full batch with
--cuda-graph-trace=node, so the precise full-batch number for top_k_first is bracketed by the joint-mode capture at the same gridX (~85 µs sample + ~2 µs verify) and the +6–19% micro-bench delta between the two filter modes.End-to-end throughput (3 reps each, no nsys overhead)
The ≥10× kernel speedup does not fully translate to E2E because sampling now occupies <1% of decode iter wall in this backend — MoE + attention + allreduce dominate. The PR is most useful when:
Numerical semantics
filter_apply_order="top_k_first"is bit-equivalent in filter semantics to the sequentialtop_k_renorm_prob → top_p_renorm_probchain it replaces — both apply top-p on the top-k-renormalized distribution. Same prompt + sampling params produce statistically equivalent samples whether spec is on or off (empirical TV vs sequential analytic reference: 0.0040 at vocab=1024 / top_k=50 / top_p=0.95, identical noise floor to the sequential path's own 0.0035).An earlier revision of this PR used
filter_apply_order="joint"(top-k AND top-p applied simultaneously on the original distribution), which introduced a ~5% TV semantic shift relative to the sequential path. Profiling showedjointwas at most ~19% faster thantop_k_firstat full-batch kernel level but the difference disappeared in E2E noise, sotop_k_firstwas selected for the cleaner semantics and TRTLLM alignment.Prior art
The chain-spec + token-equality verify shape, the flashinfer API choice, and the
filter_apply_order="top_k_first"selection all mirror stock TensorRT-LLM 1.3.0rc14's verify sampling (tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py).CUDA graph compatibility
The fast path uses kernels (
top_k_top_p_sampling_from_logits,verify_chain_greedy) that are already in-graph in this backend'ssample()and greedyverify()paths. Confirmed viansys --cuda-graph-trace=node: fast-path verify kernels in the benchmark trace are reported asin_graph(alongside eager calls from prefill sample()).Persistent buffers (
_predict_buf,_accept_index_buf,_accept_length_buf) are reused as in the sequential path — no per-step allocation in the hot path.Tests
tokenspeed-kernel/test/ops/test_fast_chain_sampling.py— Monte Carlo on 262,144 trials, two checks:TV < 0.02, measured 0.0040)TV < 0.02, measured 0.0035) — validates the test infrastructureBoth pass.
Opt-in / default
Default off. Read once at import time. Spawn-launched TP workers inherit
os.environvia Python's cached dict (independent ofsetproctitle's clobber on/proc/<pid>/environ), so settingTOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1before launching propagates to all workers.Test plan
nsys profile --cuda-graph-trace=nodeconfirms in-graph kernel composition and per-verify kernel speedup🤖 Generated with Claude Code