diff --git a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py index 523c12a6a..dfef89743 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/trtllm.py +++ b/python/tokenspeed/runtime/layers/attention/backends/trtllm.py @@ -90,6 +90,11 @@ class TRTLLMMHAMetadata: cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None page_table: torch.Tensor = None + # MIXED-only: row partition counts. Set by _init_mixed_metadata on the + # forward_mixed_metadata slot; default 0 on EXTEND/DECODE slots. + num_prefill_reqs: int = 0 + num_decode_reqs: int = 0 + num_prefill_tokens: int = 0 class TRTLLMMHAAttnBackend(AttentionBackend): @@ -142,8 +147,15 @@ def __init__(self, config: MHAConfig): # Separate slots for prefill-kernel vs decode-kernel forward paths. # forward_extend reads prefill; forward_decode reads decode. + # In MIXED steps, _init_mixed_metadata populates all three slots: + # - forward_mixed_metadata: sentinel + row partition counts + # - forward_prefill_metadata: row view of prefill rows + # - forward_decode_metadata: row view of decode rows + # forward_extend then dispatches to _forward_mixed_kernel which + # runs context kernel on prefill rows + decode kernel on decode rows. self.forward_prefill_metadata: TRTLLMMHAMetadata | None = None self.forward_decode_metadata: TRTLLMMHAMetadata | None = None + self.forward_mixed_metadata: TRTLLMMHAMetadata | None = None # CUDA graph state — per-slot dicts. self.cuda_graph_prefill_metadata: dict[int, TRTLLMMHAMetadata] = {} @@ -300,6 +312,21 @@ def forward_extend( save_kv_cache: bool = True, **kwargs, ) -> torch.Tensor: + # MIXED step: split rows into prefill / decode subsets and run the + # dedicated context kernel + decode kernel back to back (case1). The + # forward_mixed_metadata sentinel is set by _init_mixed_metadata. + if self.forward_mixed_metadata is not None: + return self._forward_mixed_kernel( + q, + k, + v, + layer, + out_cache_loc, + token_to_kv_pool, + save_kv_cache=save_kv_cache, + **kwargs, + ) + q = self._save_kv_and_prepare_q( q, k, v, layer, out_cache_loc, token_to_kv_pool, save_kv_cache ) @@ -330,6 +357,71 @@ def forward_extend( ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) + def _forward_mixed_kernel( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: PagedAttention, + out_cache_loc: torch.Tensor, + token_to_kv_pool, + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """MIXED forward (case1): split rows into prefill / decode subsets and + run the TRT-LLM context kernel on the prefill subset, decode kernel on + the decode subset, then concat outputs along the token dim. + + _init_mixed_metadata has populated forward_prefill_metadata and + forward_decode_metadata as row views of the full mixed batch, so the + recursive forward_extend / forward_decode calls just see what looks + like a normal EXTEND / DECODE step. + """ + mixed_meta = self.forward_mixed_metadata + n_pf = mixed_meta.num_prefill_tokens + + q_ext, q_dec = q[:n_pf], q[n_pf:] + k_ext, k_dec = (None, None) if k is None else (k[:n_pf], k[n_pf:]) + v_ext, v_dec = (None, None) if v is None else (v[:n_pf], v[n_pf:]) + loc_ext = out_cache_loc[:n_pf] + loc_dec = out_cache_loc[n_pf:] + + # Clear the MIXED sentinel so the recursive forward_extend / + # forward_decode take the regular paths and read the row-view metadata. + # Restored in finally so subsequent layers in the same step still see it. + self.forward_mixed_metadata = None + try: + # When all prefill rows are fully prefix-cached, num_prefill_tokens + # is 0; q_ext is empty so there is no context-kernel work to do. + if n_pf > 0: + o_ext = self.forward_extend( + q_ext, + k_ext, + v_ext, + layer, + loc_ext, + token_to_kv_pool, + bs=mixed_meta.num_prefill_reqs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + o_ext = q_ext # shape (0, hidden); pass through to the concat + o_dec = self.forward_decode( + q_dec, + k_dec, + v_dec, + layer, + loc_dec, + token_to_kv_pool, + bs=mixed_meta.num_decode_reqs, + save_kv_cache=save_kv_cache, + **kwargs, + ) + finally: + self.forward_mixed_metadata = mixed_meta + return torch.cat([o_ext, o_dec], dim=0) + # ------------------------------------------------------------------ # Metadata initialisation # ------------------------------------------------------------------ @@ -350,7 +442,26 @@ def init_forward_metadata( use_cuda_graph: bool = False, **kwargs, ): - if forward_mode.is_extend_or_mixed(): + if forward_mode.is_mixed(): + # num_extends / extend_seq_lens come from cuda_graph_wrapper kwargs. + num_extends = kwargs.get("num_extends") + extend_seq_lens = kwargs.get("extend_seq_lens") + assert num_extends is not None and extend_seq_lens is not None, ( + "MIXED forward_mode requires num_extends and extend_seq_lens " + "in kwargs (passed by cuda_graph_wrapper)" + ) + self._init_mixed_metadata( + bs=bs, + num_extends=num_extends, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_page=req_to_page, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + ) + return + + if forward_mode.is_extend(): self._init_extend_metadata( bs, req_pool_indices, @@ -384,6 +495,10 @@ def _init_decode_metadata( assert ( seq_lens.dtype == torch.int32 ), f"seq_lens must be int32, got {seq_lens.dtype}" + # Clear the MIXED sentinel — a prior MIXED step would otherwise leave + # it set, causing forward_extend to wrongly dispatch to the split + # kernel on the next pure-decode step. + self.forward_mixed_metadata = None device = seq_lens.device # Alias seq_lens (no copy, no mutation). cu_seqlens_k omitted: # the decode kernel doesn't read it. @@ -411,6 +526,8 @@ def _init_multi_token_metadata( assert ( seq_lens.dtype == torch.int32 ), f"seq_lens must be int32, got {seq_lens.dtype}" + # Clear the MIXED sentinel for the same reason as _init_decode_metadata. + self.forward_mixed_metadata = None device = seq_lens.device self.forward_prefill_metadata = TRTLLMMHAMetadata( cache_seqlens_int32=seq_lens[:bs], @@ -443,6 +560,10 @@ def _init_extend_metadata( assert ( seq_lens.dtype == torch.int32 ), f"seq_lens must be int32, got {seq_lens.dtype}" + # Clear the MIXED sentinel — a prior MIXED step would otherwise leave + # it set, causing forward_extend to wrongly dispatch to the split + # kernel on the next pure-prefill step. + self.forward_mixed_metadata = None assert ( extend_seq_lens_cpu is not None ), "trtllm extend requires extend_seq_lens_cpu (pinned-CPU mirror) to avoid GPU sync" @@ -487,6 +608,100 @@ def _init_extend_metadata( page_table=page_table, ) + def _init_mixed_metadata( + self, + bs: int, + num_extends: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + req_to_page: torch.Tensor, + extend_seq_lens: torch.Tensor | None, + extend_seq_lens_cpu: torch.Tensor | None, + ): + """Populate three slots for MIXED batches (case1, split-kernel path). + + The batch has prefill rows in [0, num_extends) and decode rows in + [num_extends, bs). This populates: + - forward_mixed_metadata: sentinel + row partition counts so + forward_extend can dispatch to _forward_mixed_kernel. + - forward_prefill_metadata: row view (head) for the prefill subset, + looks identical to what _init_extend_metadata would build. + - forward_decode_metadata: row view (tail) for the decode subset, + looks identical to what _init_decode_metadata would build. + """ + assert ( + seq_lens.dtype == torch.int32 + ), f"seq_lens must be int32, got {seq_lens.dtype}" + assert ( + 0 < num_extends < bs + ), f"MIXED requires 0 < num_extends < bs, got num_extends={num_extends}, bs={bs}" + assert extend_seq_lens_cpu is not None and extend_seq_lens is not None, ( + "MIXED requires extend_seq_lens and extend_seq_lens_cpu (pinned mirror) " + "to compute query_lens without GPU sync" + ) + + device = seq_lens.device + + # Per-request query lengths for the mixed batch: extend lengths for + # prefill rows, 1 for decode rows. Built on device to avoid host->device + # copy in the hot path; extend_seq_lens already lives on device. + query_lens = torch.ones(bs, dtype=torch.int32, device=device) + query_lens[:num_extends] = extend_seq_lens[:num_extends].to(torch.int32) + + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(query_lens, dim=0, dtype=torch.int32), (1, 0) + ) + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seq_lens[:bs], dim=0, dtype=torch.int32), (1, 0) + ) + + # max_seq_len_q for the prefill subset. If all prefill rows are fully + # prefix-cached (extend length 0), clamp to 1 so the context kernel + # sees a non-zero max_q_len. Decode kernel uses its own max_seq_len_q=1. + max_extend = int(extend_seq_lens_cpu[:num_extends].max().item()) + max_seq_len_q_prefill = max(max_extend, 1) + + # Pre-compute total prefill tokens (cumulative sum at the partition + # boundary) for slicing q in _forward_mixed_kernel without sync. + num_prefill_tokens = int(extend_seq_lens_cpu[:num_extends].sum().item()) + + page_table = self._build_page_table( + req_pool_indices, seq_lens, bs, req_to_page, self.page_table_buf + ) + cache_seqlens_int32 = seq_lens[:bs] + + # Sentinel + partition counts for split-kernel dispatch. + self.forward_mixed_metadata = TRTLLMMHAMetadata( + cache_seqlens_int32=cache_seqlens_int32, + max_seq_len_q=max_seq_len_q_prefill, + max_seq_len_k=self.max_context_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + page_table=page_table, + num_prefill_reqs=num_extends, + num_decode_reqs=bs - num_extends, + num_prefill_tokens=num_prefill_tokens, + ) + + # Row view (head): prefill subset, looks like a regular EXTEND. + self.forward_prefill_metadata = TRTLLMMHAMetadata( + cache_seqlens_int32=cache_seqlens_int32[:num_extends], + max_seq_len_q=max_seq_len_q_prefill, + max_seq_len_k=self.max_context_len, + cu_seqlens_q=cu_seqlens_q[: num_extends + 1], + cu_seqlens_k=cu_seqlens_k[: num_extends + 1], + page_table=page_table[:num_extends], + ) + + # Row view (tail): decode subset, looks like a regular DECODE. + # cu_seqlens_q omitted: decode kernel doesn't read it. + self.forward_decode_metadata = TRTLLMMHAMetadata( + cache_seqlens_int32=cache_seqlens_int32[num_extends:], + max_seq_len_q=1, + max_seq_len_k=self.max_context_len, + page_table=page_table[num_extends:], + ) + # ------------------------------------------------------------------ # CUDA graph support # ------------------------------------------------------------------