Skip to content
Open
217 changes: 216 additions & 1 deletion python/tokenspeed/runtime/layers/attention/backends/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class TRTLLMMHAMetadata:
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
page_table: torch.Tensor = None
# MIXED-only: row partition counts. Set by _init_mixed_metadata on the
# forward_mixed_metadata slot; default 0 on EXTEND/DECODE slots.
num_prefill_reqs: int = 0
num_decode_reqs: int = 0
num_prefill_tokens: int = 0


class TRTLLMMHAAttnBackend(AttentionBackend):
Expand Down Expand Up @@ -142,8 +147,15 @@ def __init__(self, config: MHAConfig):

# Separate slots for prefill-kernel vs decode-kernel forward paths.
# forward_extend reads prefill; forward_decode reads decode.
# In MIXED steps, _init_mixed_metadata populates all three slots:
# - forward_mixed_metadata: sentinel + row partition counts
# - forward_prefill_metadata: row view of prefill rows
# - forward_decode_metadata: row view of decode rows
# forward_extend then dispatches to _forward_mixed_kernel which
# runs context kernel on prefill rows + decode kernel on decode rows.
self.forward_prefill_metadata: TRTLLMMHAMetadata | None = None
self.forward_decode_metadata: TRTLLMMHAMetadata | None = None
self.forward_mixed_metadata: TRTLLMMHAMetadata | None = None

# CUDA graph state — per-slot dicts.
self.cuda_graph_prefill_metadata: dict[int, TRTLLMMHAMetadata] = {}
Expand Down Expand Up @@ -300,6 +312,21 @@ def forward_extend(
save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor:
# MIXED step: split rows into prefill / decode subsets and run the
# dedicated context kernel + decode kernel back to back (case1). The
# forward_mixed_metadata sentinel is set by _init_mixed_metadata.
if self.forward_mixed_metadata is not None:
return self._forward_mixed_kernel(
q,
k,
v,
layer,
out_cache_loc,
token_to_kv_pool,
save_kv_cache=save_kv_cache,
**kwargs,
)

q = self._save_kv_and_prepare_q(
q, k, v, layer, out_cache_loc, token_to_kv_pool, save_kv_cache
)
Expand Down Expand Up @@ -330,6 +357,71 @@ def forward_extend(
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def _forward_mixed_kernel(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: PagedAttention,
out_cache_loc: torch.Tensor,
token_to_kv_pool,
save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor:
"""MIXED forward (case1): split rows into prefill / decode subsets and
run the TRT-LLM context kernel on the prefill subset, decode kernel on
the decode subset, then concat outputs along the token dim.

_init_mixed_metadata has populated forward_prefill_metadata and
forward_decode_metadata as row views of the full mixed batch, so the
recursive forward_extend / forward_decode calls just see what looks
like a normal EXTEND / DECODE step.
"""
mixed_meta = self.forward_mixed_metadata
n_pf = mixed_meta.num_prefill_tokens

q_ext, q_dec = q[:n_pf], q[n_pf:]
k_ext, k_dec = (None, None) if k is None else (k[:n_pf], k[n_pf:])
v_ext, v_dec = (None, None) if v is None else (v[:n_pf], v[n_pf:])
loc_ext = out_cache_loc[:n_pf]
loc_dec = out_cache_loc[n_pf:]

# Clear the MIXED sentinel so the recursive forward_extend /
# forward_decode take the regular paths and read the row-view metadata.
# Restored in finally so subsequent layers in the same step still see it.
self.forward_mixed_metadata = None
try:
# When all prefill rows are fully prefix-cached, num_prefill_tokens
# is 0; q_ext is empty so there is no context-kernel work to do.
if n_pf > 0:
o_ext = self.forward_extend(
q_ext,
k_ext,
v_ext,
layer,
loc_ext,
token_to_kv_pool,
bs=mixed_meta.num_prefill_reqs,
save_kv_cache=save_kv_cache,
**kwargs,
)
else:
o_ext = q_ext # shape (0, hidden); pass through to the concat
o_dec = self.forward_decode(
q_dec,
k_dec,
v_dec,
layer,
loc_dec,
token_to_kv_pool,
bs=mixed_meta.num_decode_reqs,
save_kv_cache=save_kv_cache,
**kwargs,
)
finally:
self.forward_mixed_metadata = mixed_meta
return torch.cat([o_ext, o_dec], dim=0)

# ------------------------------------------------------------------
# Metadata initialisation
# ------------------------------------------------------------------
Expand All @@ -350,7 +442,26 @@ def init_forward_metadata(
use_cuda_graph: bool = False,
**kwargs,
):
if forward_mode.is_extend_or_mixed():
if forward_mode.is_mixed():
# num_extends / extend_seq_lens come from cuda_graph_wrapper kwargs.
num_extends = kwargs.get("num_extends")
extend_seq_lens = kwargs.get("extend_seq_lens")
assert num_extends is not None and extend_seq_lens is not None, (
"MIXED forward_mode requires num_extends and extend_seq_lens "
"in kwargs (passed by cuda_graph_wrapper)"
)
self._init_mixed_metadata(
bs=bs,
num_extends=num_extends,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_page=req_to_page,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
)
return

if forward_mode.is_extend():
self._init_extend_metadata(
bs,
req_pool_indices,
Expand Down Expand Up @@ -384,6 +495,10 @@ def _init_decode_metadata(
assert (
seq_lens.dtype == torch.int32
), f"seq_lens must be int32, got {seq_lens.dtype}"
# Clear the MIXED sentinel — a prior MIXED step would otherwise leave
# it set, causing forward_extend to wrongly dispatch to the split
# kernel on the next pure-decode step.
self.forward_mixed_metadata = None
device = seq_lens.device
# Alias seq_lens (no copy, no mutation). cu_seqlens_k omitted:
# the decode kernel doesn't read it.
Expand Down Expand Up @@ -411,6 +526,8 @@ def _init_multi_token_metadata(
assert (
seq_lens.dtype == torch.int32
), f"seq_lens must be int32, got {seq_lens.dtype}"
# Clear the MIXED sentinel for the same reason as _init_decode_metadata.
self.forward_mixed_metadata = None
device = seq_lens.device
self.forward_prefill_metadata = TRTLLMMHAMetadata(
cache_seqlens_int32=seq_lens[:bs],
Expand Down Expand Up @@ -443,6 +560,10 @@ def _init_extend_metadata(
assert (
seq_lens.dtype == torch.int32
), f"seq_lens must be int32, got {seq_lens.dtype}"
# Clear the MIXED sentinel — a prior MIXED step would otherwise leave
# it set, causing forward_extend to wrongly dispatch to the split
# kernel on the next pure-prefill step.
self.forward_mixed_metadata = None
assert (
extend_seq_lens_cpu is not None
), "trtllm extend requires extend_seq_lens_cpu (pinned-CPU mirror) to avoid GPU sync"
Expand Down Expand Up @@ -487,6 +608,100 @@ def _init_extend_metadata(
page_table=page_table,
)

def _init_mixed_metadata(
self,
bs: int,
num_extends: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
req_to_page: torch.Tensor,
extend_seq_lens: torch.Tensor | None,
extend_seq_lens_cpu: torch.Tensor | None,
):
"""Populate three slots for MIXED batches (case1, split-kernel path).

The batch has prefill rows in [0, num_extends) and decode rows in
[num_extends, bs). This populates:
- forward_mixed_metadata: sentinel + row partition counts so
forward_extend can dispatch to _forward_mixed_kernel.
- forward_prefill_metadata: row view (head) for the prefill subset,
looks identical to what _init_extend_metadata would build.
- forward_decode_metadata: row view (tail) for the decode subset,
looks identical to what _init_decode_metadata would build.
"""
assert (
seq_lens.dtype == torch.int32
), f"seq_lens must be int32, got {seq_lens.dtype}"
assert (
0 < num_extends < bs
), f"MIXED requires 0 < num_extends < bs, got num_extends={num_extends}, bs={bs}"
assert extend_seq_lens_cpu is not None and extend_seq_lens is not None, (
"MIXED requires extend_seq_lens and extend_seq_lens_cpu (pinned mirror) "
"to compute query_lens without GPU sync"
)

device = seq_lens.device

# Per-request query lengths for the mixed batch: extend lengths for
# prefill rows, 1 for decode rows. Built on device to avoid host->device
# copy in the hot path; extend_seq_lens already lives on device.
query_lens = torch.ones(bs, dtype=torch.int32, device=device)
query_lens[:num_extends] = extend_seq_lens[:num_extends].to(torch.int32)

cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(query_lens, dim=0, dtype=torch.int32), (1, 0)
)
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens[:bs], dim=0, dtype=torch.int32), (1, 0)
)

# max_seq_len_q for the prefill subset. If all prefill rows are fully
# prefix-cached (extend length 0), clamp to 1 so the context kernel
# sees a non-zero max_q_len. Decode kernel uses its own max_seq_len_q=1.
max_extend = int(extend_seq_lens_cpu[:num_extends].max().item())
max_seq_len_q_prefill = max(max_extend, 1)

# Pre-compute total prefill tokens (cumulative sum at the partition
# boundary) for slicing q in _forward_mixed_kernel without sync.
num_prefill_tokens = int(extend_seq_lens_cpu[:num_extends].sum().item())

page_table = self._build_page_table(
req_pool_indices, seq_lens, bs, req_to_page, self.page_table_buf
)
cache_seqlens_int32 = seq_lens[:bs]

# Sentinel + partition counts for split-kernel dispatch.
self.forward_mixed_metadata = TRTLLMMHAMetadata(
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=max_seq_len_q_prefill,
max_seq_len_k=self.max_context_len,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table=page_table,
num_prefill_reqs=num_extends,
num_decode_reqs=bs - num_extends,
num_prefill_tokens=num_prefill_tokens,
)

# Row view (head): prefill subset, looks like a regular EXTEND.
self.forward_prefill_metadata = TRTLLMMHAMetadata(
cache_seqlens_int32=cache_seqlens_int32[:num_extends],
max_seq_len_q=max_seq_len_q_prefill,
max_seq_len_k=self.max_context_len,
cu_seqlens_q=cu_seqlens_q[: num_extends + 1],
cu_seqlens_k=cu_seqlens_k[: num_extends + 1],
page_table=page_table[:num_extends],
)

# Row view (tail): decode subset, looks like a regular DECODE.
# cu_seqlens_q omitted: decode kernel doesn't read it.
self.forward_decode_metadata = TRTLLMMHAMetadata(
cache_seqlens_int32=cache_seqlens_int32[num_extends:],
max_seq_len_q=1,
max_seq_len_k=self.max_context_len,
page_table=page_table[num_extends:],
)

# ------------------------------------------------------------------
# CUDA graph support
# ------------------------------------------------------------------
Expand Down