Support DP sampling for spec decode#232
Conversation
There was a problem hiding this comment.
💡 Codex Review
tokenspeed/python/tokenspeed/runtime/sampling/backends/flashinfer.py
Lines 522 to 524 in b940507
In DP verify mode, predict is expanded to full-batch shape after gather_verify_outputs, but logits_output.next_token_logits is still this rank's local shard. If enable_output_logprobs is on, write_output_logprobs is called with mismatched row counts, which can fail or compute incorrect logprobs. The logprob write path needs full-batch logits (or a local-only token slice) before this call.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| max_pad_bs=self._dp_max_pad_bs, | ||
| num_tokens_per_req=config.max_draft_tokens_per_req, | ||
| vocab_size=v_aligned, | ||
| logits_dtype=torch.bfloat16, |
There was a problem hiding this comment.
Match DP comm logits dtype to runtime logits
DpSamplingComm is always initialized with logits_dtype=torch.bfloat16, but the one-sided swap path asserts that incoming logits match the state dtype. On systems where one-sided backend is selected (auto/onesided), fp16 or fp32 models will hit an assertion during DP swap in decode verify. This should use the actual logits dtype for the model path instead of a hard-coded bf16.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 5b8440a7b7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| accept_length_local=accept_length, | ||
| pad_bs=pad_bs, | ||
| ) | ||
| predict = predict_full.view(-1)[: full_bs * n] |
There was a problem hiding this comment.
Keep verify logits aligned with gathered DP predictions
When dp_sampling is enabled, this line replaces predict with a full-batch tensor (full_bs * n), but logits_output.next_token_logits still contains only the local shard. The later write_output_logprobs(...) call gathers with tokens=predict, so enabling output logprobs in DP mode causes a row-count mismatch (or wrong row/token alignment) and breaks speculative decode logprob output. This path needs to either gather/reorder logits to full-batch order or skip DP logprob writing until supported.
Useful? React with 👍 / 👎.
| if processor.skip_all_gather: | ||
| return |
There was a problem hiding this comment.
Do not enable drafter DP route without a DP comm object
Returning early on processor.skip_all_gather leaves self.dp_sampling_enabled=True but never initializes self._dp_comm. Later, _dp_argmax_gather() unconditionally asserts self._dp_comm is not None whenever dp_sampling is active, so TP+DP configurations that set skip_all_gather crash during draft sampling instead of falling back. The DP route should be disabled in this branch or a gather-capable comm path should still be constructed.
Useful? React with 👍 / 👎.
b940507 to
16db036
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 16db036e43
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| infra_supports_dp = ( | ||
| self.drafter is not None | ||
| and isinstance(self.sampling_backend, FlashInferSamplingBackend) | ||
| and processor.tp_size > 1 |
There was a problem hiding this comment.
Disable DP verify route for unsupported sampling backends
infra_supports_dp currently uses isinstance(..., FlashInferSamplingBackend), which also matches FlashInferFullSamplingBackend. In TP speculative decode with --dp-sampling (or TOKENSPEED_DP_SAMPLING=auto/on), this enables DP mode and makes LogitsProcessor emit request-sharded logits, but FlashInferFullSamplingBackend.verify() still follows the full-batch path and does not implement DP gather/reassembly. That mismatch can produce shape/assert failures at runtime when decode verify runs. Gate this to backends that actually implement DP verify (or add DP handling to flashinfer_full).
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9611a40417
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| sampling_info.vocab_mask is None | ||
| ), "dp_sampling + grammar bitmask is not supported" |
There was a problem hiding this comment.
Fallback to non-DP verify when grammar masking is active
verify() hard-asserts that sampling_info.vocab_mask is None in the DP path, but grammar setup binds a vocab-mask buffer whenever grammar runtime is enabled (even when rows are all-ones/no-op). That means enabling both --dp-sampling and a grammar backend can crash speculative decode at runtime instead of gracefully using the legacy verify path. This is a production-breaking feature interaction for structured-output users and should be routed to non-DP verify rather than asserted.
Useful? React with 👍 / 👎.
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
9611a40 to
d9c9d7b
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d9c9d7ba91
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| predict[max_pad_bs, N], accept_index[max_pad_bs, N], and | ||
| accept_length[max_pad_bs]. | ||
| """ | ||
| assert type(group) == dist.ProcessGroup, f"Expected ProcessGroup, got {type(group)}" |
There was a problem hiding this comment.
Use isinstance for process-group validation
The one-sided DP path can crash during initialization because this check requires group to be exactly dist.ProcessGroup, but process groups are created via dist.new_group(...) and can be backend-specific subclasses (for example NCCL process-group objects). In environments where one-sided mode is selected (auto/onesided), that strict type equality trips before any fallback and aborts startup; this should accept subclasses (e.g., isinstance) instead of exact type equality.
Useful? React with 👍 / 👎.
Summary
Enable data parallele sampling. This has a few benefits:
Still working on:
Test Plan
Initial stage of testing on minimax model, spec 3,1,4 case. TP=4