Skip to content

Support DP sampling for spec decode#232

Open
yubofredwang wants to merge 9 commits into
mainfrom
ywang/dp-sampling
Open

Support DP sampling for spec decode#232
yubofredwang wants to merge 9 commits into
mainfrom
ywang/dp-sampling

Conversation

@yubofredwang
Copy link
Copy Markdown
Contributor

@yubofredwang yubofredwang commented May 24, 2026

Summary

Enable data parallele sampling. This has a few benefits:

  1. No need to use deterministic kernel during larger batch. Potentially better perf.
  2. Reduce the all gather communication overhead by calculating needed part and use all to all single operation.
  3. Avoid duplicate work from ranks in TP mode. No need to sync verify results between ranks, but rather, gather the results.

Still working on:

  1. Testing if draft sampling should also be DP sampling.
  2. Verify overall perf improvement on Kimi K2.5
  3. support logprobs

Test Plan

Initial stage of testing on minimax model, spec 3,1,4 case. TP=4

requests  DP tok/s  legacy tok/s  delta    wins   tokens_ok
16        4248.6    4134.7        +2.82%   2/3    True
24        5784.5    5717.3        +1.18%   2/3    True
32        7219.3    7076.7        +2.02%   3/3    True
40        8741.7    8464.4        +3.28%   3/3    True

@yubofredwang yubofredwang requested a review from a team as a code owner May 24, 2026 01:42
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

write_output_logprobs(
logits_output, logits_output.next_token_logits, predict
)

P1 Badge Keep logits and token shapes aligned for output logprobs

In DP verify mode, predict is expanded to full-batch shape after gather_verify_outputs, but logits_output.next_token_logits is still this rank's local shard. If enable_output_logprobs is on, write_output_logprobs is called with mismatched row counts, which can fail or compute incorrect logprobs. The logprob write path needs full-batch logits (or a local-only token slice) before this call.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

max_pad_bs=self._dp_max_pad_bs,
num_tokens_per_req=config.max_draft_tokens_per_req,
vocab_size=v_aligned,
logits_dtype=torch.bfloat16,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Match DP comm logits dtype to runtime logits

DpSamplingComm is always initialized with logits_dtype=torch.bfloat16, but the one-sided swap path asserts that incoming logits match the state dtype. On systems where one-sided backend is selected (auto/onesided), fp16 or fp32 models will hit an assertion during DP swap in decode verify. This should use the actual logits dtype for the model path instead of a hard-coded bf16.

Useful? React with 👍 / 👎.

@yubofredwang yubofredwang requested review from syuoni and yweng0828 May 24, 2026 06:55
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5b8440a7b7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

accept_length_local=accept_length,
pad_bs=pad_bs,
)
predict = predict_full.view(-1)[: full_bs * n]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep verify logits aligned with gathered DP predictions

When dp_sampling is enabled, this line replaces predict with a full-batch tensor (full_bs * n), but logits_output.next_token_logits still contains only the local shard. The later write_output_logprobs(...) call gathers with tokens=predict, so enabling output logprobs in DP mode causes a row-count mismatch (or wrong row/token alignment) and breaks speculative decode logprob output. This path needs to either gather/reorder logits to full-batch order or skip DP logprob writing until supported.

Useful? React with 👍 / 👎.

Comment on lines +175 to +176
if processor.skip_all_gather:
return
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Do not enable drafter DP route without a DP comm object

Returning early on processor.skip_all_gather leaves self.dp_sampling_enabled=True but never initializes self._dp_comm. Later, _dp_argmax_gather() unconditionally asserts self._dp_comm is not None whenever dp_sampling is active, so TP+DP configurations that set skip_all_gather crash during draft sampling instead of falling back. The DP route should be disabled in this branch or a gather-capable comm path should still be constructed.

Useful? React with 👍 / 👎.

@yubofredwang yubofredwang force-pushed the ywang/dp-sampling branch 2 times, most recently from b940507 to 16db036 Compare May 24, 2026 08:29
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 16db036e43

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +293 to +296
infra_supports_dp = (
self.drafter is not None
and isinstance(self.sampling_backend, FlashInferSamplingBackend)
and processor.tp_size > 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Disable DP verify route for unsupported sampling backends

infra_supports_dp currently uses isinstance(..., FlashInferSamplingBackend), which also matches FlashInferFullSamplingBackend. In TP speculative decode with --dp-sampling (or TOKENSPEED_DP_SAMPLING=auto/on), this enables DP mode and makes LogitsProcessor emit request-sharded logits, but FlashInferFullSamplingBackend.verify() still follows the full-batch path and does not implement DP gather/reassembly. That mismatch can produce shape/assert failures at runtime when decode verify runs. Gate this to backends that actually implement DP verify (or add DP handling to flashinfer_full).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9611a40417

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +374 to +375
sampling_info.vocab_mask is None
), "dp_sampling + grammar bitmask is not supported"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fallback to non-DP verify when grammar masking is active

verify() hard-asserts that sampling_info.vocab_mask is None in the DP path, but grammar setup binds a vocab-mask buffer whenever grammar runtime is enabled (even when rows are all-ones/no-op). That means enabling both --dp-sampling and a grammar backend can crash speculative decode at runtime instead of gracefully using the legacy verify path. This is a production-breaking feature interaction for structured-output users and should be routed to non-DP verify rather than asserted.

Useful? React with 👍 / 👎.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d9c9d7ba91

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

predict[max_pad_bs, N], accept_index[max_pad_bs, N], and
accept_length[max_pad_bs].
"""
assert type(group) == dist.ProcessGroup, f"Expected ProcessGroup, got {type(group)}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use isinstance for process-group validation

The one-sided DP path can crash during initialization because this check requires group to be exactly dist.ProcessGroup, but process groups are created via dist.new_group(...) and can be backend-specific subclasses (for example NCCL process-group objects). In environments where one-sided mode is selected (auto/onesided), that strict type equality trips before any fallback and aborts startup; this should accept subclasses (e.g., isinstance) instead of exact type equality.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant