diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index e4fe3f922..f44c8d7f5 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -60,3 +60,14 @@ class ForwardContext: # --- logits processor --- keep_full_logits: bool = False last_index_offsets: torch.Tensor | None = None + + # --- EAGLE draft first-step reduce --- + # When True, the draft head's catch-up forward writes KV for all + # bs*spec_num_tokens input positions but skips attn/MLP/post-norms on + # every position except the last accepted one per request. Set by + # Eagle._run_first_step when the optimization is enabled and the attn + # backend supports the reduced-Q decode path. draft_reduce_indices is + # last_index_offsets + unpadded_input_lengths, i.e. the [bs] indices + # of those live positions inside the flattened [bs*spec_num_tokens] layout. + draft_reduce_to_last: bool = False + draft_reduce_indices: torch.Tensor | None = None diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index a1989ea0a..b7e21fbd5 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -20,6 +20,7 @@ from __future__ import annotations +import os from dataclasses import dataclass from typing import TYPE_CHECKING @@ -41,6 +42,13 @@ logger = get_colorful_logger(__name__) +# Env-var gate for the EAGLE draft first-step reduce optimization. Temporary — +# set to 0 to disable the optimization and run the baseline catch-up path for +# A/B perf comparison. Will be removed once stabilized. +_DRAFT_REDUCE_FIRST_STEP_ENABLED = ( + os.environ.get("TOKENSPEED_EAGLE_DRAFT_REDUCE", "1") == "1" +) + if TYPE_CHECKING: from tokenspeed.runtime.execution.input_buffer import InputBuffers from tokenspeed.runtime.execution.model_runner import ModelRunner @@ -188,6 +196,21 @@ def _run_first_step( padded_static_len = self.spec_num_tokens last_index_offsets = self.last_index_offsets_buf[:bs] + # Catch-up step with multi-token DRAFT_EXTEND input: every position + # except the live one per request is purely a KV-cache write. The + # midlayer slices Q after KV write and runs attn/MLP/post-norms on + # just the [bs] live positions. draft_reduce_indices = [bs] int64, + # = last_index_offsets + accept_lengths, points at each request's + # live row inside the [bs*spec_num_tokens] flattened layout. + draft_reduce_to_last = ( + forward_mode.is_decode() and _DRAFT_REDUCE_FIRST_STEP_ENABLED + ) + draft_reduce_indices = ( + self.last_index_offsets_buf[:bs] + unpadded_input_lengths + if draft_reduce_to_last + else None + ) + # make a ctx every time model runner forward first_step_ctx = ForwardContext( attn_backend=self.attn_backend, @@ -204,6 +227,8 @@ def _run_first_step( global_num_tokens=draft_input.global_num_tokens, global_bs=draft_input.global_bs, all_decode_or_idle=draft_input.all_decode_or_idle, + draft_reduce_to_last=draft_reduce_to_last, + draft_reduce_indices=draft_reduce_indices, ) return self.draft_model_runner.forward( diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index 039e49e30..0011dab6b 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -108,6 +108,12 @@ class LogitsMetadata: padded_static_len: int = -1 last_index_offsets: torch.Tensor | None = None + # Hidden states arrive pre-pruned to [bs, H]. Set by the EAGLE draft + # first-step reduce path, where the model already gathered the live + # position per request inside the midlayer. Skips the pad-aware slicing + # below. + pre_pruned: bool = False + @classmethod def from_forward_context( cls, @@ -120,6 +126,7 @@ def from_forward_context( extend_seq_lens=input_lengths, padded_static_len=ctx.padded_static_len, last_index_offsets=ctx.last_index_offsets, + pre_pruned=ctx.draft_reduce_to_last, ) @@ -212,7 +219,14 @@ def forward( aux_hidden_states: torch.Tensor | None = None, ) -> LogitsProcessorOutput: # Get the last hidden states and last logits for the next token prediction - if not logits_metadata.extend_return_logprob: + if logits_metadata.pre_pruned: + # EAGLE draft first-step reduce: hidden_states is already [bs, H]. + pruned_states = hidden_states + if aux_hidden_states is not None: + aux_pruned_states = [hidden for hidden in aux_hidden_states] + sample_indices = None + input_logprob_indices = None + elif not logits_metadata.extend_return_logprob: if logits_metadata.forward_mode.is_extend_or_mixed(): # Prefill: last token of each request via cumulative seq lens. last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 8e1e9ab13..12ef33585 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -35,6 +35,7 @@ from tokenspeed.runtime.configs.utils import get_rope_theta from tokenspeed.runtime.distributed.mapping import Mapping from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode from tokenspeed.runtime.layers.activation import SiluAndMul from tokenspeed.runtime.layers.common import concat from tokenspeed.runtime.layers.layernorm import RMSNorm @@ -185,14 +186,34 @@ def forward( output_q_rope=q_rope, enable_pdl=pdl_enabled(), ) - attn_output = self.attn( - q_rope, - None, - None, - save_kv_cache=False, - ctx=ctx, - out_cache_loc=out_cache_loc, - ) + if ctx.draft_reduce_to_last: + # KV cache for all bs*spec_num_tokens positions already written + # by fused_set_kv_buffer_arg above. Reduce Q to one query per + # request (live position) and route to decode kernel: bs queries + # against the full cache (which now includes the just-written + # spec_num_tokens new tokens per request). DRAFT_EXTEND init + # populates forward_decode_metadata precisely for this case. + q_rope = q_rope.index_select(0, ctx.draft_reduce_indices) + attn_output = ctx.attn_backend.forward( + q_rope, + None, + None, + self.attn, + out_cache_loc, + ctx.token_to_kv_pool, + ForwardMode.DECODE, + ctx.bs, + save_kv_cache=False, + ) + else: + attn_output = self.attn( + q_rope, + None, + None, + save_kv_cache=False, + ctx=ctx, + out_cache_loc=out_cache_loc, + ) else: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, ctx=ctx, out_cache_loc=out_cache_loc) @@ -361,6 +382,11 @@ def forward_low_latency( ctx=ctx, out_cache_loc=out_cache_loc, ) + if ctx.draft_reduce_to_last: + # self_attn returned [bs, H]; gather residual to match before the + # fused allreduce+norm so shapes line up. Everything downstream + # (post-norm, MLP, final norm) now runs on [bs, H]. + residual = residual.index_select(0, ctx.draft_reduce_indices) # Fused post-attn allreduce + norm (uses attn tp group) block_scale = None @@ -426,6 +452,11 @@ def forward( ctx=ctx, out_cache_loc=out_cache_loc, ) + if ctx.draft_reduce_to_last: + # self_attn returned [bs, H]; gather residual to match before the + # post-attn norm+comm so shapes line up. Everything downstream + # (post-norm, MLP, final norm) now runs on [bs, H]. + residual = residual.index_select(0, ctx.draft_reduce_indices) hidden_states, residual = self.comm_manager.post_attn_comm( hidden_states, residual, ctx )