Skip to content

[WIP] perf(eagle3): skip dead-position compute in draft catch-up step#217

Open
rjzhb wants to merge 1 commit into
lightseekorg:mainfrom
rjzhb:feat/eagle-draft-single-token
Open

[WIP] perf(eagle3): skip dead-position compute in draft catch-up step#217
rjzhb wants to merge 1 commit into
lightseekorg:mainfrom
rjzhb:feat/eagle-draft-single-token

Conversation

@rjzhb
Copy link
Copy Markdown

@rjzhb rjzhb commented May 22, 2026

Summary

EAGLE3 draft head's catch-up first step takes a padded bs*spec_num_tokens input, but only one logit per request is sampled — the other bs*(spec_num_tokens-1) rows exist purely to write K/V into the KV cache for the next spec round. With lookahead=3 (4 tokens per request) that's 75% of the input rows whose attn / O-proj / MLP / norm outputs are computed and then immediately discarded by the LogitsProcessor's last-position slice.

This PR moves that slice inside the layer, right after KV write — instead of doing dead work on [bs*spec_num_tokens, H] and pruning at the end, we prune to [bs, H] between KV write and attention, then route attn through the decode kernel (q_len_per_req=1 × full cache) and run o_proj / MLP / norms on the bs live rows only.

What's saved per layer (live rows / total rows = 1/spec_num_tokens):

  • QKV proj — still full (K/V must be written for the next round)
  • Attn (Q·K^T + softmax·V)bs queries instead of bs*spec_num_tokens
  • O-projbs rows
  • post-attn norm / residual / MLP / post-mlp norm / final normbs rows

Scope

  • EAGLE3 LLaMa-style head (models/llama_eagle3.py)

End-to-end sim

MiniMax-M2.5 + thoughtworks/MiniMax-M2.5-Eagle3 head, B200 TP=2, reasoning-style workload (8K prompt / 3K gen, QPS=0.3, 300s sustain):

opt #1 opt #2 baseline
gen_tps (Loaded) 32.0 34.9 30.2
inflight_mean 26.7 22.7 25.1
mean_accept_len 2.02 1.92 1.94

Metric notes (measured over the steady-state window after warm-up)

  • gen_tps: system-wide decode tokens per second. The throughput metric this optimization targets.
  • inflight_mean: average number of concurrent requests being processed — i.e. the effective batch size. The optimization's gain scales with this number.
  • mean_accept_len: average accepted tokens per spec round. Close values across runs confirm spec decode behavior is unchanged, so the comparison is apples-to-apples.

Directional improvement ~+6% to +11% gen_tps vs baseline across both opt runs.

Test plan

  • End-to-end sim A/B on MiniMax-M2.5 + EAGLE3, env-var gated

@rjzhb rjzhb requested a review from a team as a code owner May 22, 2026 20:58
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cacabd6a6f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +212 to +215
draft_reduce_to_last = is_decode_like and _DRAFT_REDUCE_FIRST_STEP_ENABLED
draft_reduce_indices = (
self.last_index_offsets_buf[:bs] + unpadded_input_lengths
if draft_reduce_to_last
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate draft-reduce mode to compatible model/backend paths

draft_reduce_to_last is enabled for every DRAFT_EXTEND first step here, but only llama_eagle3 actually performs the mid-layer gather and it only does so when KV prewrite is available; other draft models (or LLaMA draft runs on non-prewrite backends) still produce [bs*spec_num_tokens, H] hidden states. Because this flag is also consumed by LogitsMetadata.from_forward_context to mark outputs as pre_pruned, LogitsProcessor skips its normal last-token slicing and downstream code receives the wrong batch shape (e.g. next_tokens[:, 1] assignment in the drafter), causing incorrect behavior or runtime shape errors.

Useful? React with 👍 / 👎.

@rjzhb rjzhb force-pushed the feat/eagle-draft-single-token branch from cacabd6 to ac042cd Compare May 22, 2026 21:05
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ac042cd43b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/models/llama_eagle3.py
@rjzhb rjzhb force-pushed the feat/eagle-draft-single-token branch from ac042cd to d410fa6 Compare May 22, 2026 21:21
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d410fa6417

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +205 to +207
draft_reduce_to_last = (
forward_mode.is_decode() and _DRAFT_REDUCE_FIRST_STEP_ENABLED
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restrict draft_reduce_to_last to reducible backends

draft_reduce_to_last is now enabled for every decode first step, but the actual mid-layer pruning only happens in the LLaMA EAGLE3 attention path when KV prewrite is active (llama_eagle3.py only reduces inside the fused_kv_arg is not None branch). In this revision that means non-prewrite draft backends (and non-LLaMA draft models sharing Eagle) can still emit [bs*spec_num_tokens, H] while LogitsMetadata.pre_pruned is set, so LogitsProcessor skips last-token slicing and returns mismatched batch shapes to drafter code that expects [bs, ...].

Useful? React with 👍 / 👎.

@LorrinWWW LorrinWWW self-assigned this May 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants