Skip to content

perf(sampling): opt-in fast verify path for topk=1 chain spec#133

Draft
cicirori wants to merge 2 commits into
lightseekorg:mainfrom
cicirori:yh/spec-fast-chain-sampling
Draft

perf(sampling): opt-in fast verify path for topk=1 chain spec#133
cicirori wants to merge 2 commits into
lightseekorg:mainfrom
cicirori:yh/spec-fast-chain-sampling

Conversation

@cicirori
Copy link
Copy Markdown
Collaborator

@cicirori cicirori commented May 13, 2026

Summary

Opt-in optimization for chain-spec (eagle_topk = 1) only. In this configuration the draft is always argmax-greedy, so the verify pass can collapse the sequential renorm chain into a single fused-sample kernel followed by a constant-time token-equality verify.

Today the spec verify path goes through

softmax → top_k_renorm_prob → top_p_renorm_prob → chain_speculative_sampling_target_only

which scans full vocab twice at the renorm steps. For vocab≈200K (M2.5/MiniMax) this dominates verify sampling cost.

This PR adds an opt-in branch in verify() under TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1 that, in the topk=1 chain-spec case:

  1. Samples one target token per chain position via top_k_top_p_sampling_from_logits(filter_apply_order="top_k_first") — the same flashinfer API and filter mode that stock TRTLLM 1.3.0rc14 uses for verify
  2. Verifies by token-equality against the (already argmax-greedy) draft candidates using the existing verify_chain_greedy kernel (same kernel the greedy branch already uses)
  3. Bonus token at first mismatch = the sampled target there

Scope and applicability

Only applies when all of the following hold:

  • speculative_eagle_topk == 1 (chain spec) — tree spec topk>1 keeps the generic rejection kernel
  • non-greedy sampling — greedy path is already on the fast verify_chain_greedy kernel
  • no grammar mask — per-position bitmasks need the renorm chain (auto-fallback)
  • request param hint matches the narrow scope this backend supports (no min_p, no penalties, no logit_bias)

For any other configuration, verify() falls through to the existing sequential path with no behavior change.

Measured impact (M2.5plus-fp4 NVFP4, TP=2 B200, vocab=200054, c=32 / 200 reqs)

Sampling kernel level (per verify, inside captured CUDA graph)

Captured via nsys profile --cuda-graph-trace=node at typical full batch (gridX=128):

Path breakdown total µs/verify
Sequential (baseline) OnlineSoftmax 89 + RadixTopKRenormProb 165 + AirTopPRenormRadix×3 410 + AirTopPRenormApply 191 + AirTopPInit 2 + ChainSpecSampling 84 ~941
Fast (this PR) RadixTopKMaskLogits + TopPSamplingFromProb + VerifyChainGreedy (fused dispatch) ~55–90
Speedup ~10–17×

The fast-path range covers small-batch (~55 µs, well-sampled at gridX=4) up to extrapolated full-batch (~90 µs). nsys event buffers overflow aggressively at full batch with --cuda-graph-trace=node, so the precise full-batch number for top_k_first is bracketed by the joint-mode capture at the same gridX (~85 µs sample + ~2 µs verify) and the +6–19% micro-bench delta between the two filter modes.

End-to-end throughput (3 reps each, no nsys overhead)

Config avg tok/s stddev Δ
Sequential (default) 1667 31 (1.9%)
Fast (this PR) 1707 32 (1.9%) +2.4% (within 1σ)

The ≥10× kernel speedup does not fully translate to E2E because sampling now occupies <1% of decode iter wall in this backend — MoE + attention + allreduce dominate. The PR is most useful when:

  • sampling shows up as a hotspot in profiling (e.g. larger batch or larger vocab models)
  • you want to eliminate the renorm chain to simplify reasoning about decode hot path
  • a future configuration shifts the bottleneck

Numerical semantics

filter_apply_order="top_k_first" is bit-equivalent in filter semantics to the sequential top_k_renorm_prob → top_p_renorm_prob chain it replaces — both apply top-p on the top-k-renormalized distribution. Same prompt + sampling params produce statistically equivalent samples whether spec is on or off (empirical TV vs sequential analytic reference: 0.0040 at vocab=1024 / top_k=50 / top_p=0.95, identical noise floor to the sequential path's own 0.0035).

An earlier revision of this PR used filter_apply_order="joint" (top-k AND top-p applied simultaneously on the original distribution), which introduced a ~5% TV semantic shift relative to the sequential path. Profiling showed joint was at most ~19% faster than top_k_first at full-batch kernel level but the difference disappeared in E2E noise, so top_k_first was selected for the cleaner semantics and TRTLLM alignment.

Prior art

The chain-spec + token-equality verify shape, the flashinfer API choice, and the filter_apply_order="top_k_first" selection all mirror stock TensorRT-LLM 1.3.0rc14's verify sampling (tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py).

CUDA graph compatibility

The fast path uses kernels (top_k_top_p_sampling_from_logits, verify_chain_greedy) that are already in-graph in this backend's sample() and greedy verify() paths. Confirmed via nsys --cuda-graph-trace=node: fast-path verify kernels in the benchmark trace are reported as in_graph (alongside eager calls from prefill sample()).

Persistent buffers (_predict_buf, _accept_index_buf, _accept_length_buf) are reused as in the sequential path — no per-step allocation in the hot path.

Tests

tokenspeed-kernel/test/ops/test_fast_chain_sampling.py — Monte Carlo on 262,144 trials, two checks:

  1. Fast path's empirical output matches the sequential-filter analytic reference within noise (TV < 0.02, measured 0.0040)
  2. Sequential path's empirical output also matches the sequential-filter analytic reference (TV < 0.02, measured 0.0035) — validates the test infrastructure

Both pass.

Opt-in / default

Default off. Read once at import time. Spawn-launched TP workers inherit os.environ via Python's cached dict (independent of setproctitle's clobber on /proc/<pid>/environ), so setting TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1 before launching propagates to all workers.

Test plan

  • Statistical correctness test (2 Monte Carlo checks, both passing at TV ≈ 0.004)
  • E2E benchmark, 3 reps × baseline+fast at c=32
  • nsys profile --cuda-graph-trace=node confirms in-graph kernel composition and per-verify kernel speedup
  • Pre-commit clean (isort, black, etc.)

🤖 Generated with Claude Code

@cicirori cicirori requested a review from a team as a code owner May 13, 2026 21:17
@cicirori cicirori marked this pull request as draft May 13, 2026 21:18
cicirori added 2 commits May 13, 2026 14:18
Targeted optimization for the *chain-spec* (eagle_topk = 1) configuration:
draft is already argmax-greedy, so the verify pass can mirror plain
decode's sampling instead of running the generic tree-rejection chain.

Plain decode in this backend already uses flashinfer's joint fused kernel
(top_k_top_p_sampling_from_logits with filter_apply_order="joint"), but
spec verify falls through the sequential

    softmax -> top_k_renorm_prob -> top_p_renorm_prob ->
    chain_speculative_sampling_target_only

chain which scans the full vocab twice at the renorm steps. On large-vocab
models (vocab ≈ 200K) this dominates the verify cost.

Add an opt-in path under TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1 that, in
the topk=1 chain-spec case, samples one target token per chain position
via the joint kernel and verifies by token-equality against the
argmax-greedy draft candidates using the existing verify_chain_greedy
kernel from the greedy branch. The bonus token at the first mismatched
position is the sampled target there. Same kernel pair the greedy path
already runs, so CUDA graph compatibility is unchanged.

Scope and applicability
-----------------------
This is *only* applicable when:
  * chain spec: tree-spec topk > 1 keeps the generic rejection kernel
  * topk = 1: the per-position single-candidate assumption is what makes
    token-equality verify equivalent to draft-greedy rejection sampling
  * non-greedy sampling: greedy path is already fast (argmax + verify)
  * no grammar mask: per-position vocab masks need the renorm chain;
    falls back to the original sequential path automatically

Measurements (M2.5plus-fp4 NVFP4, TP=2 B200, vocab=200054, c=32 / 200 reqs)
---------------------------------------------------------------------------
Sampling kernel level, per verify, inside the captured CUDA graph
(measured with nsys --cuda-graph-trace=node):

  Sequential path (baseline):
    OnlineSoftmaxMap+Reduce   89 us
    RadixTopKRenormProb      165 us
    AirTopPRenormRadix+Apply 463 us
    ChainSpecSampling         84 us
    ChainSpec init/etc       146 us
    --------------------------------
    total per verify        ~947 us

  Fast path (this PR):
    gather_and_expand          2 us
    TopKTopPSamplingFromProb  76 us
    VerifyChainGreedy          2 us
    --------------------------------
    total per verify         ~80 us

  ~11.7x sampling-kernel speedup per verify call.

End-to-end throughput (3 reps each, no nsys):

  Sequential: avg 1667 tok/s (stddev 31, 1.9%)
  Fast:       avg 1684 tok/s (stddev 21, 1.3%)
  Delta:      +1.0% (within 1 stddev)

Note: kernel-level speedup is large but E2E gain is modest because
sampling now occupies less than 1% of decode iter wall in this backend
(MoE + attention + allreduce dominate). The fast path is most useful as
a free-correctness alternative when the sequential renorm chain shows
up as a hotspot in profiling, or for future configs where sampling
becomes a larger fraction of the iter.

Numerical semantics
-------------------
Both this path and the sequential chain_speculative_sampling_target_only
path are unbiased rejection samplers on a chain-spec draft with topk=1
(draft is a greedy point mass at argmax(target_probs)). They differ in
how top-k and top-p are applied:

  * sequential: top_k_renorm_prob then top_p_renorm_prob, so the top-p
    cutoff is computed on the top-k-renormalized distribution.
  * joint (this PR): flashinfer applies top-k AND top-p simultaneously
    on the *original* (pre-renorm) distribution.

For top_k=50/top_p=0.95 the two filtered distributions differ by ~5%
total variation (empirically 0.046 on a vocab=1024 synthetic target).
This brings verify in line with plain decode's filter semantics: same
prompt + sampling params produce statistically equivalent samples
whether spec is on or off.

Inspired by stock TensorRT-LLM 1.3.0rc14's
`sampling_batch_spec_dec_one_model` (`allow_advanced_sampling=true` +
`eagle3_one_model=true`), which uses the same chain-spec + token-eq
verify shape. This PR additionally upgrades to flashinfer's joint mode
(stock trtllm still uses the default top_k_first which dispatches to
two kernels).

Opt-in via TOKENSPEED_SPEC_FAST_CHAIN_SAMPLING=1, default off. Read once
at import time. Falls back to the original sequential path for vocab_mask
requests (grammar bitmasks need the per-position renorm chain) and for
all-greedy batches (already on a faster path).

Tests
-----
test/ops/test_fast_chain_sampling.py verifies via Monte Carlo on 262,144
trials per check:
  1. fast path's empirical output matches the joint-filter analytic
     reference (TV < 0.02 in noise floor).
  2. sequential path's empirical output matches the sequential-filter
     analytic reference (TV < 0.02).
  3. joint vs sequential filter gap is bounded by 0.10 TV and strictly
     non-zero, documenting the deliberate semantic shift.

All three pass.

Signed-off-by: cicirori <yliu@together.ai>
…tics

Switch the opt-in fast verify path from filter_apply_order="joint" to
"top_k_first" so it preserves the original sequential renorm chain's
distribution semantics (top-k then top-p on the top-k-renormalized
distribution) rather than introducing a ~5% total-variation shift.

This also matches stock TRTLLM 1.3.0rc14's verify sampling
(tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py uses the
same flashinfer API with the same filter_apply_order).

Performance impact:
- nsys in-graph (CUDA-graph node trace): joint and top_k_first land in
  the same band at small/medium batches. micro-bench at full batch
  (bs=64, vocab=200K, top_k=50, top_p=0.95) shows joint ~19% faster at
  pure kernel level, but at E2E the sampling kernel is <1% of decode
  wall, so both filter modes land at 1.7K tok/s ± noise (≥1σ overlap
  across three independent 200-prompt c=32 runs).

Numerical: top_k_first is bit-equivalent in filter to the sequential
chain it replaces; the previous joint variant introduced a ~5% TV shift
at top_k=50/top_p=0.95 that this avoids.

Tests updated: the joint-vs-sequential gap test is dropped (no longer
applicable since both paths use the same filter ordering), and the fast
path test now compares against the sequential analytic reference.

Signed-off-by: cicirori <yliu@together.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant