[WIP] perf(eagle3): skip dead-position compute in draft catch-up step#217
[WIP] perf(eagle3): skip dead-position compute in draft catch-up step#217rjzhb wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cacabd6a6f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| draft_reduce_to_last = is_decode_like and _DRAFT_REDUCE_FIRST_STEP_ENABLED | ||
| draft_reduce_indices = ( | ||
| self.last_index_offsets_buf[:bs] + unpadded_input_lengths | ||
| if draft_reduce_to_last |
There was a problem hiding this comment.
Gate draft-reduce mode to compatible model/backend paths
draft_reduce_to_last is enabled for every DRAFT_EXTEND first step here, but only llama_eagle3 actually performs the mid-layer gather and it only does so when KV prewrite is available; other draft models (or LLaMA draft runs on non-prewrite backends) still produce [bs*spec_num_tokens, H] hidden states. Because this flag is also consumed by LogitsMetadata.from_forward_context to mark outputs as pre_pruned, LogitsProcessor skips its normal last-token slicing and downstream code receives the wrong batch shape (e.g. next_tokens[:, 1] assignment in the drafter), causing incorrect behavior or runtime shape errors.
Useful? React with 👍 / 👎.
cacabd6 to
ac042cd
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ac042cd43b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Signed-off-by: rjzhb <rjzhb222@163.com>
ac042cd to
d410fa6
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d410fa6417
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| draft_reduce_to_last = ( | ||
| forward_mode.is_decode() and _DRAFT_REDUCE_FIRST_STEP_ENABLED | ||
| ) |
There was a problem hiding this comment.
Restrict draft_reduce_to_last to reducible backends
draft_reduce_to_last is now enabled for every decode first step, but the actual mid-layer pruning only happens in the LLaMA EAGLE3 attention path when KV prewrite is active (llama_eagle3.py only reduces inside the fused_kv_arg is not None branch). In this revision that means non-prewrite draft backends (and non-LLaMA draft models sharing Eagle) can still emit [bs*spec_num_tokens, H] while LogitsMetadata.pre_pruned is set, so LogitsProcessor skips last-token slicing and returns mismatched batch shapes to drafter code that expects [bs, ...].
Useful? React with 👍 / 👎.
Summary
EAGLE3 draft head's catch-up first step takes a padded
bs*spec_num_tokensinput, but only one logit per request is sampled — the otherbs*(spec_num_tokens-1)rows exist purely to writeK/Vinto the KV cache for the next spec round. Withlookahead=3(4 tokens per request) that's 75% of the input rows whose attn / O-proj / MLP / norm outputs are computed and then immediately discarded by the LogitsProcessor's last-position slice.This PR moves that slice inside the layer, right after KV write — instead of doing dead work on
[bs*spec_num_tokens, H]and pruning at the end, we prune to[bs, H]between KV write and attention, then route attn through the decode kernel (q_len_per_req=1× full cache) and run o_proj / MLP / norms on thebslive rows only.What's saved per layer (live rows / total rows =
1/spec_num_tokens):bsqueries instead ofbs*spec_num_tokensbsrowsbsrowsScope
models/llama_eagle3.py)End-to-end sim
MiniMax-M2.5 +
thoughtworks/MiniMax-M2.5-Eagle3head, B200 TP=2, reasoning-style workload (8K prompt / 3K gen, QPS=0.3, 300s sustain):gen_tps(Loaded)inflight_meanmean_accept_lenDirectional improvement ~+6% to +11% gen_tps vs baseline across both opt runs.
Test plan