Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tokenspeed/runtime/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,14 @@ class ForwardContext:
# --- logits processor ---
keep_full_logits: bool = False
last_index_offsets: torch.Tensor | None = None

# --- EAGLE draft first-step reduce ---
# When True, the draft head's catch-up forward writes KV for all
# bs*spec_num_tokens input positions but skips attn/MLP/post-norms on
# every position except the last accepted one per request. Set by
# Eagle._run_first_step when the optimization is enabled and the attn
# backend supports the reduced-Q decode path. draft_reduce_indices is
# last_index_offsets + unpadded_input_lengths, i.e. the [bs] indices
# of those live positions inside the flattened [bs*spec_num_tokens] layout.
draft_reduce_to_last: bool = False
draft_reduce_indices: torch.Tensor | None = None
25 changes: 25 additions & 0 deletions python/tokenspeed/runtime/execution/drafter/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from __future__ import annotations

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

Expand All @@ -41,6 +42,13 @@

logger = get_colorful_logger(__name__)

# Env-var gate for the EAGLE draft first-step reduce optimization. Temporary —
# set to 0 to disable the optimization and run the baseline catch-up path for
# A/B perf comparison. Will be removed once stabilized.
_DRAFT_REDUCE_FIRST_STEP_ENABLED = (
os.environ.get("TOKENSPEED_EAGLE_DRAFT_REDUCE", "1") == "1"
)

if TYPE_CHECKING:
from tokenspeed.runtime.execution.input_buffer import InputBuffers
from tokenspeed.runtime.execution.model_runner import ModelRunner
Expand Down Expand Up @@ -188,6 +196,21 @@ def _run_first_step(
padded_static_len = self.spec_num_tokens
last_index_offsets = self.last_index_offsets_buf[:bs]

# Catch-up step with multi-token DRAFT_EXTEND input: every position
# except the live one per request is purely a KV-cache write. The
# midlayer slices Q after KV write and runs attn/MLP/post-norms on
# just the [bs] live positions. draft_reduce_indices = [bs] int64,
# = last_index_offsets + accept_lengths, points at each request's
# live row inside the [bs*spec_num_tokens] flattened layout.
draft_reduce_to_last = (
forward_mode.is_decode() and _DRAFT_REDUCE_FIRST_STEP_ENABLED
)
Comment on lines +205 to +207
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restrict draft_reduce_to_last to reducible backends

draft_reduce_to_last is now enabled for every decode first step, but the actual mid-layer pruning only happens in the LLaMA EAGLE3 attention path when KV prewrite is active (llama_eagle3.py only reduces inside the fused_kv_arg is not None branch). In this revision that means non-prewrite draft backends (and non-LLaMA draft models sharing Eagle) can still emit [bs*spec_num_tokens, H] while LogitsMetadata.pre_pruned is set, so LogitsProcessor skips last-token slicing and returns mismatched batch shapes to drafter code that expects [bs, ...].

Useful? React with 👍 / 👎.

draft_reduce_indices = (
self.last_index_offsets_buf[:bs] + unpadded_input_lengths
if draft_reduce_to_last
else None
)

# make a ctx every time model runner forward
first_step_ctx = ForwardContext(
attn_backend=self.attn_backend,
Expand All @@ -204,6 +227,8 @@ def _run_first_step(
global_num_tokens=draft_input.global_num_tokens,
global_bs=draft_input.global_bs,
all_decode_or_idle=draft_input.all_decode_or_idle,
draft_reduce_to_last=draft_reduce_to_last,
draft_reduce_indices=draft_reduce_indices,
)

return self.draft_model_runner.forward(
Expand Down
16 changes: 15 additions & 1 deletion python/tokenspeed/runtime/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class LogitsMetadata:
padded_static_len: int = -1
last_index_offsets: torch.Tensor | None = None

# Hidden states arrive pre-pruned to [bs, H]. Set by the EAGLE draft
# first-step reduce path, where the model already gathered the live
# position per request inside the midlayer. Skips the pad-aware slicing
# below.
pre_pruned: bool = False

@classmethod
def from_forward_context(
cls,
Expand All @@ -120,6 +126,7 @@ def from_forward_context(
extend_seq_lens=input_lengths,
padded_static_len=ctx.padded_static_len,
last_index_offsets=ctx.last_index_offsets,
pre_pruned=ctx.draft_reduce_to_last,
)


Expand Down Expand Up @@ -212,7 +219,14 @@ def forward(
aux_hidden_states: torch.Tensor | None = None,
) -> LogitsProcessorOutput:
# Get the last hidden states and last logits for the next token prediction
if not logits_metadata.extend_return_logprob:
if logits_metadata.pre_pruned:
# EAGLE draft first-step reduce: hidden_states is already [bs, H].
pruned_states = hidden_states
if aux_hidden_states is not None:
aux_pruned_states = [hidden for hidden in aux_hidden_states]
sample_indices = None
input_logprob_indices = None
elif not logits_metadata.extend_return_logprob:
if logits_metadata.forward_mode.is_extend_or_mixed():
# Prefill: last token of each request via cumulative seq lens.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
Expand Down
47 changes: 39 additions & 8 deletions python/tokenspeed/runtime/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tokenspeed.runtime.configs.utils import get_rope_theta
from tokenspeed.runtime.distributed.mapping import Mapping
from tokenspeed.runtime.execution.context import ForwardContext
from tokenspeed.runtime.execution.forward_batch_info import ForwardMode
from tokenspeed.runtime.layers.activation import SiluAndMul
from tokenspeed.runtime.layers.common import concat
from tokenspeed.runtime.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -185,14 +186,34 @@ def forward(
output_q_rope=q_rope,
enable_pdl=pdl_enabled(),
)
attn_output = self.attn(
q_rope,
None,
None,
save_kv_cache=False,
ctx=ctx,
out_cache_loc=out_cache_loc,
)
if ctx.draft_reduce_to_last:
# KV cache for all bs*spec_num_tokens positions already written
# by fused_set_kv_buffer_arg above. Reduce Q to one query per
# request (live position) and route to decode kernel: bs queries
# against the full cache (which now includes the just-written
# spec_num_tokens new tokens per request). DRAFT_EXTEND init
# populates forward_decode_metadata precisely for this case.
q_rope = q_rope.index_select(0, ctx.draft_reduce_indices)
attn_output = ctx.attn_backend.forward(
q_rope,
None,
None,
self.attn,
out_cache_loc,
ctx.token_to_kv_pool,
ForwardMode.DECODE,
ctx.bs,
save_kv_cache=False,
Comment thread
rjzhb marked this conversation as resolved.
)
else:
attn_output = self.attn(
q_rope,
None,
None,
save_kv_cache=False,
ctx=ctx,
out_cache_loc=out_cache_loc,
)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc)
Expand Down Expand Up @@ -361,6 +382,11 @@ def forward_low_latency(
ctx=ctx,
out_cache_loc=out_cache_loc,
)
if ctx.draft_reduce_to_last:
# self_attn returned [bs, H]; gather residual to match before the
# fused allreduce+norm so shapes line up. Everything downstream
# (post-norm, MLP, final norm) now runs on [bs, H].
residual = residual.index_select(0, ctx.draft_reduce_indices)

# Fused post-attn allreduce + norm (uses attn tp group)
block_scale = None
Expand Down Expand Up @@ -426,6 +452,11 @@ def forward(
ctx=ctx,
out_cache_loc=out_cache_loc,
)
if ctx.draft_reduce_to_last:
# self_attn returned [bs, H]; gather residual to match before the
# post-attn norm+comm so shapes line up. Everything downstream
# (post-norm, MLP, final norm) now runs on [bs, H].
residual = residual.index_select(0, ctx.draft_reduce_indices)
hidden_states, residual = self.comm_manager.post_attn_comm(
hidden_states, residual, ctx
)
Expand Down