From 21dd822d09754e92f81aae03ad5087d76c9f22f1 Mon Sep 17 00:00:00 2001 From: jiyingd <87510204+dongjiyingdjy@users.noreply.github.com> Date: Wed, 13 May 2026 12:17:13 +0000 Subject: [PATCH 1/2] feat(deepseek-v4): support mixed prefill decode Signed-off-by: jiyingd <87510204+dongjiyingdjy@users.noreply.github.com> --- .../tokenspeed/runtime/engine/event_loop.py | 20 +- .../engine/generation_output_processor.py | 8 +- .../runtime/engine/scheduler_utils.py | 4 +- .../runtime/execution/cuda_graph_wrapper.py | 9 +- .../runtime/execution/forward_batch_info.py | 16 + .../runtime/execution/input_buffer.py | 59 +- .../runtime/execution/model_executor.py | 12 +- .../layers/attention/backends/deepseek_v4.py | 457 ++- .../layers/attention/deepseek_v4_ops.py | 49 +- .../layers/attention/kv_cache/deepseek_v4.py | 35 + .../tokenspeed/runtime/models/deepseek_v4.py | 2773 +++++++++++++++-- python/tokenspeed/runtime/utils/custom_ops.py | 61 + .../tokenspeed/runtime/utils/server_args.py | 8 + test/runtime/kernels/test_trtllm_wrapper.py | 103 + test/runtime/test_cli_config_compat.py | 6 + .../runtime/test_deepseek_v4_attention_ops.py | 165 + test/runtime/test_deepseek_v4_config.py | 1295 +++++++- .../test_generation_output_processor.py | 100 + .../ops/attention/flash_mla/__init__.py | 4 +- .../ops/attention/triton/deepseek_v4.py | 115 + .../tokenspeed_kernel/ops/moe/triton.py | 169 + .../cuda/csrc/deepseek_v4_attention.cu | 172 +- .../csrc/deepseek_v4_attention_binding.cu | 9 + .../thirdparty/cuda/deepseek_v4_attention.py | 41 + .../thirdparty/trtllm/__init__.py | 4 +- .../bindings/python_module.cpp | 1 + .../csrc/scheduler/operations/forward.cpp | 9 +- tokenspeed-scheduler/csrc/scheduler/types.h | 1 + .../python/tests/test_fsm_and_scheduling.py | 26 +- 29 files changed, 5298 insertions(+), 433 deletions(-) create mode 100644 python/tokenspeed/runtime/utils/custom_ops.py create mode 100644 test/runtime/kernels/test_trtllm_wrapper.py create mode 100644 test/runtime/test_generation_output_processor.py create mode 100644 tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index cccfe1e84..5530774db 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -272,6 +272,9 @@ def __init__( f"(ratio={server_args.mamba_full_memory_ratio})." ) + enable_mixed_prefill_decode = ( + server_args.enable_mixed_chunk and server_args.speculative_algorithm is None + ) scheduler_cfg = make_config( num_device_pages=self.max_total_num_tokens // server_args.block_size, max_scheduled_tokens=server_args.chunked_prefill_size, @@ -293,6 +296,7 @@ def __init__( mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, mamba_pool_total_chunks=mamba_pool_total_chunks, paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool), + enable_mixed_prefill_decode=enable_mixed_prefill_decode, ) logger.info( "Scheduler config: page_size=%s num_device_pages=%s " @@ -785,8 +789,10 @@ def _commit_forward_results( on_first_token=None, ): self.request_handler.forward_ct += 1 - forward_mode = ( - ForwardMode.EXTEND if forward_op.num_extends() > 0 else ForwardMode.DECODE + forward_mode = ForwardMode.from_num_extends( + forward_op.num_extends(), + len(forward_op.request_ids), + has_drafter=self.server_args.speculative_algorithm is not None, ) self.request_handler._profile_batch_predicate(forward_mode) @@ -859,12 +865,12 @@ def _dp_sync_and_check(self, forward_op) -> DpForwardMetadata: batch_size = len(forward_op.request_ids) if forward_op is not None else 0 if forward_op is None: forward_mode = ForwardMode.IDLE - elif forward_op.num_extends() > 0: - forward_mode = ForwardMode.EXTEND - elif self.server_args.speculative_algorithm is not None: - forward_mode = ForwardMode.TARGET_VERIFY else: - forward_mode = ForwardMode.DECODE + forward_mode = ForwardMode.from_num_extends( + forward_op.num_extends(), + batch_size, + has_drafter=self.server_args.speculative_algorithm is not None, + ) self._dp_local_info[0, 0] = num_tokens self._dp_local_info[0, 1] = batch_size diff --git a/python/tokenspeed/runtime/engine/generation_output_processor.py b/python/tokenspeed/runtime/engine/generation_output_processor.py index 1234db283..2bcb1891c 100644 --- a/python/tokenspeed/runtime/engine/generation_output_processor.py +++ b/python/tokenspeed/runtime/engine/generation_output_processor.py @@ -483,7 +483,8 @@ def post_process_forward_op( forward_op.input_lengths, forward_op.extend_prefix_lens, ) - is_decode_op = forward_op.num_extends() <= 0 + num_extends = forward_op.num_extends() + is_decode_op = num_extends <= 0 request_changes = [] stream_out_rids = [] @@ -504,6 +505,7 @@ def post_process_forward_op( if output_logprobs_list is not None else None ) + is_decode_slot = i >= num_extends if self.spec_num_tokens is not None and is_decode_op: pt += self.spec_num_tokens else: @@ -524,7 +526,7 @@ def post_process_forward_op( if on_first_token is not None and model_output_ids: on_first_token(forward_op.request_pool_indices[i], model_output_ids[0]) - if is_decode_op and self.spec_algorithm is not None: + if is_decode_slot and self.spec_algorithm is not None: request_state.spec_verify_ct += 1 # With the capturable grammar pipeline the matcher is @@ -597,7 +599,7 @@ def post_process_forward_op( else: stream_out_rids.append(rid) stream_out_states.append(request_state) - if is_decode_op: + if is_decode_slot: request_changes.append( make_update_reserve_tokens_event(rid, output_length) ) diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 820c0a17a..ccf3110a7 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -20,7 +20,6 @@ """Helper functions for constructing scheduler specs and events.""" -import os from collections.abc import Sequence from typing import Any, Mapping @@ -39,7 +38,6 @@ "WriteBackDoneEvent": Cache.WriteBackDoneEvent, "PrefetchDoneEvent": Cache.PrefetchDoneEvent, } -_TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"} def make_spec(rid: str, tokens: list[int]) -> RequestSpec: @@ -66,6 +64,7 @@ def make_config( mamba_cache_chunk_size: int = 64, mamba_pool_total_chunks: int = 0, paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None, + enable_mixed_prefill_decode: bool = False, ) -> SchedulerConfig: cfg = SchedulerConfig() cfg.num_device_pages = num_device_pages @@ -92,6 +91,7 @@ def make_config( cfg.enable_mamba = enable_mamba cfg.mamba_cache_chunk_size = mamba_cache_chunk_size cfg.mamba_pool_total_chunks = mamba_pool_total_chunks + cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode if paged_cache_groups: cfg.paged_cache_groups = list(paged_cache_groups) return cfg diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index e3dcdc7b9..4402c79eb 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -546,6 +546,7 @@ def _pad_offsets_to_padded_bs( def _init_replay_metadata( self, padded_bs: int, + actual_bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, req_to_page: torch.Tensor, @@ -562,7 +563,7 @@ def _init_replay_metadata( "uses_paged_cache_groups", False, ): - actual_bs = next( + table_bs = next( ( int(table.shape[0]) for table in paged_cache_block_tables.values() @@ -572,7 +573,7 @@ def _init_replay_metadata( ) paged_cache_block_tables = self._pad_block_tables_to_padded_bs( paged_cache_block_tables, - actual_bs=actual_bs, + actual_bs=table_bs, padded_bs=padded_bs, ) kwargs["paged_cache_block_tables"] = paged_cache_block_tables @@ -585,6 +586,8 @@ def _init_replay_metadata( kwargs["paged_cache_block_table_base_offsets"] = ( paged_cache_block_table_base_offsets ) + if getattr(self.attn_backend, "uses_padded_decode_token_mask", False): + kwargs["actual_bs"] = actual_bs self.attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, req_pool_indices, @@ -785,6 +788,7 @@ def __call__( ) self._init_replay_metadata( padded_bs, + bs, req_pool_indices, seq_lens, req_to_page=req_to_page, @@ -831,6 +835,7 @@ def __call__( extend_prefix_lens_cpu=extend_prefix_lens_cpu, extend_seq_lens=extend_seq_lens, extend_seq_lens_cpu=extend_seq_lens_cpu, + num_extends=ctx.num_extends, positions=positions, out_cache_loc=out_cache_loc, global_num_tokens=ctx.global_num_tokens, diff --git a/python/tokenspeed/runtime/execution/forward_batch_info.py b/python/tokenspeed/runtime/execution/forward_batch_info.py index 07bc6ab1c..609003b0c 100755 --- a/python/tokenspeed/runtime/execution/forward_batch_info.py +++ b/python/tokenspeed/runtime/execution/forward_batch_info.py @@ -55,6 +55,9 @@ def is_extend(self): def is_decode(self): return self == ForwardMode.DECODE + def is_mixed(self): + return self == ForwardMode.MIXED + def is_idle(self): return self == ForwardMode.IDLE @@ -67,6 +70,19 @@ def is_draft_extend(self): def is_decode_or_idle(self): return self == ForwardMode.DECODE or self == ForwardMode.IDLE + @staticmethod + def from_num_extends( + num_extends: int, + batch_size: int, + *, + has_drafter: bool = False, + ) -> "ForwardMode": + if batch_size <= 0: + return ForwardMode.IDLE + if num_extends > 0: + return ForwardMode.MIXED if num_extends < batch_size else ForwardMode.EXTEND + return ForwardMode.TARGET_VERIFY if has_drafter else ForwardMode.DECODE + class CaptureHiddenMode(IntEnum): NULL = auto() diff --git a/python/tokenspeed/runtime/execution/input_buffer.py b/python/tokenspeed/runtime/execution/input_buffer.py index 37e04d952..88feca724 100644 --- a/python/tokenspeed/runtime/execution/input_buffer.py +++ b/python/tokenspeed/runtime/execution/input_buffer.py @@ -175,15 +175,19 @@ def fill_input_buffers( page_size=self.page_size, ) - valid_cache_lengths = runtime_states.valid_cache_lengths[ + cached_prefix_lens = runtime_states.valid_cache_lengths[ self.req_pool_indices_buf[:batch_size] ] - # Compute positions - prefix_lens = ( - self.extend_prefix_lens_buf[:num_extends] - if num_extends > 0 - else valid_cache_lengths - ) + # Compute positions. In mixed batches, prefill rows use their extend + # prefix lengths while decode rows use the current valid cache lengths. + prefill_prefix_lens = self.extend_prefix_lens_buf[:num_extends] + if num_extends == 0: + prefix_lens = cached_prefix_lens + elif num_extends == batch_size: + prefix_lens = prefill_prefix_lens + else: + prefix_lens = cached_prefix_lens.clone() + prefix_lens[:num_extends].copy_(prefill_prefix_lens) positions, _ = compute_position_triton( extend_prefix_lens=prefix_lens, extend_seq_lens=input_lengths_device, @@ -193,20 +197,55 @@ def fill_input_buffers( # Determine input_ids and forward_mode if num_extends > 0: + prefill_token_count = sum(forward_op.input_lengths[:num_extends]) input_ids_cpu = torch.tensor( forward_op.input_ids, device="cpu", pin_memory=True ) - self.input_ids_buf[:total_tokens].copy_( + self.input_ids_buf[:prefill_token_count].copy_( input_ids_cpu, non_blocking=True, ) shifted_ids_cpu = torch.tensor( forward_op.shifted_input_ids, device="cpu", pin_memory=True ) - self.shifted_prefill_ids_buf[:total_tokens].copy_( + self.shifted_prefill_ids_buf[:prefill_token_count].copy_( shifted_ids_cpu, non_blocking=True, ) + if num_extends < batch_size: + decode_req_pool_indices = req_pool_indices_device[ + num_extends:batch_size + ] + if forward_op.decode_input_ids is not None: + decode_count = batch_size - num_extends + if len(forward_op.decode_input_ids) != decode_count: + raise RuntimeError( + "mixed forward decode_input_ids length mismatch: " + f"got {len(forward_op.decode_input_ids)}, " + f"expected {decode_count}" + ) + decode_input_ids_tensor = torch.tensor( + forward_op.decode_input_ids, + dtype=torch.int32, + device="cpu", + pin_memory=True, + ).to(req_pool_indices_device.device, non_blocking=True) + mask = (decode_input_ids_tensor != -1).unsqueeze(1) + slot = runtime_states.future_input_map[decode_req_pool_indices, :1] + runtime_states.future_input_map[decode_req_pool_indices, :1] = ( + torch.where(mask, decode_input_ids_tensor.unsqueeze(1), slot) + ) + decode_ids = runtime_states.future_input_map[ + decode_req_pool_indices, :1 + ].flatten() + self.input_ids_buf[prefill_token_count:total_tokens].copy_( + decode_ids, + non_blocking=True, + ) + self.shifted_prefill_ids_buf[prefill_token_count:total_tokens].copy_( + decode_ids, + non_blocking=True, + ) else: # If the scheduler provides explicit decode input ids (!= -1), write # them into future_input_map before reading, so that they take effect @@ -230,7 +269,7 @@ def fill_input_buffers( non_blocking=True, ) - self.seq_lens_buf[:batch_size].copy_(input_lengths_device + valid_cache_lengths) + self.seq_lens_buf[:batch_size].copy_(input_lengths_device + cached_prefix_lens) # Reset positions beyond total_tokens to the dummy KV slot so that any # CUDA graph replay with a larger (padded) batch size writes padding diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index b224124fc..0eda2f0f6 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -829,14 +829,12 @@ def execute_forward_op( total_tokens=total_tokens, ) - if num_extends > 0: - forward_mode = ForwardMode.EXTEND - elif self.drafter is not None: - forward_mode = ForwardMode.TARGET_VERIFY - else: - forward_mode = ForwardMode.DECODE - bs = len(forward_op.request_ids) + forward_mode = ForwardMode.from_num_extends( + num_extends, + bs, + has_drafter=self.drafter is not None, + ) if self.runtime_states.mamba_pool is not None and ( num_extends > 0 or has_retract diff --git a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py index 885ed39e5..f53edb274 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py @@ -14,6 +14,9 @@ from __future__ import annotations import torch +from tokenspeed_kernel.ops.attention.triton.deepseek_v4 import ( + deepseek_v4_indexer_decode_metadata_compute, +) from tokenspeed.runtime.configs.model_config import AttentionArch from tokenspeed.runtime.execution.forward_batch_info import ForwardMode @@ -56,10 +59,167 @@ def _cu_seqlens(lengths: torch.Tensor) -> torch.Tensor: ) +def _decode_positions_from_metadata( + metadata: DeepseekV4ForwardMetadata, + num_tokens: int, +) -> torch.Tensor: + token_to_req = metadata.token_to_req_indices[:num_tokens].to(torch.int64) + query_starts = metadata.query_start_loc[token_to_req].to(torch.int64) + query_lens = metadata.query_lens[token_to_req].to(torch.int64) + seq_lens = metadata.seq_lens[token_to_req].to(torch.int64) + token_offsets = torch.arange( + num_tokens, + dtype=torch.int64, + device=metadata.seq_lens.device, + ) + return seq_lens - query_lens + token_offsets - query_starts + + +def _refresh_decode_indexer_plan_cache( + metadata: DeepseekV4ForwardMetadata, + *, + max_context_len: int, +) -> None: + """Pre-build decode-indexer plan tensors before per-layer parallel work. + + This keeps per-layer indexer calls read-only with respect to cached plan + buffers while compressor work may run on an auxiliary stream. + """ + cache = metadata.decode_indexer_plan_cache + if not cache: + return + refreshed_keys = metadata.decode_indexer_plan_refreshed_keys + refreshed_keys.clear() + for ( + compress_ratio, + cache_block_size, + num_tokens, + ), plan in list(cache.items()): + if num_tokens <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + continue + positions = _decode_positions_from_metadata(metadata, num_tokens) + token_to_req_indices = metadata.token_to_req_indices[:num_tokens] + block_table = metadata.compressed_block_table( + compress_ratio, + cache_block_size, + ) + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + if rows <= 0 or cols <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + continue + max_blocks = int(plan.block_table.shape[1]) + if max_context_len > 0: + derived_max_len = max( + 1, + (max_context_len + compress_ratio - 1) // compress_ratio, + ) + else: + derived_max_len = max( + 1, + (block_table.shape[1] * cache_block_size + compress_ratio - 1) + // compress_ratio, + ) + if plan.max_context_len != derived_max_len: + plan.max_context_len = derived_max_len + deepseek_v4_indexer_decode_metadata_compute( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + max_blocks=max_blocks, + out_context_lens=plan.context_lens, + out_block_tables=plan.block_table, + ) + if metadata.is_valid_token is not None: + valid = metadata.is_valid_token[:num_tokens].to( + device=plan.context_lens.device, + dtype=torch.bool, + ) + with torch.inference_mode(): + plan.context_lens.masked_fill_(~valid.view(num_tokens, 1), 0) + plan.block_table.masked_fill_( + ~valid.to(device=plan.block_table.device).view(num_tokens, 1), + 0, + ) + refreshed_keys.add((compress_ratio, cache_block_size, num_tokens)) + + +def _refresh_decode_indexer_schedule_metadata( + metadata: DeepseekV4ForwardMetadata, +) -> None: + if not metadata.decode_indexer_schedule_metadata: + return + try: + from tokenspeed_kernel.thirdparty import deep_gemm + except Exception: + return + get_metadata = getattr(deep_gemm, "get_paged_mqa_logits_metadata", None) + if get_metadata is None: + return + for ( + compress_ratio, + cache_block_size, + num_tokens, + ), schedule_metadata in list(metadata.decode_indexer_schedule_metadata.items()): + if num_tokens <= 0: + continue + key = (compress_ratio, cache_block_size, num_tokens) + decode_plan = metadata.decode_indexer_plan_cache.get(key) + context_lens = getattr(decode_plan, "context_lens", None) + if ( + context_lens is not None + and context_lens.shape == (num_tokens, 1) + and context_lens.dtype == torch.int32 + ): + context_lens = context_lens.contiguous() + else: + positions = _decode_positions_from_metadata(metadata, num_tokens) + compressed_lens = torch.div( + positions.to(torch.int32) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + if metadata.is_valid_token is not None: + valid = metadata.is_valid_token[:num_tokens].to( + device=compressed_lens.device, + dtype=torch.bool, + ) + compressed_lens = torch.where( + valid, + compressed_lens, + torch.zeros_like(compressed_lens), + ) + context_lens = compressed_lens.view(num_tokens, 1).contiguous() + refreshed = get_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if ( + schedule_metadata.shape == refreshed.shape + and schedule_metadata.device == refreshed.device + and schedule_metadata.dtype == refreshed.dtype + ): + with torch.inference_mode(): + schedule_metadata.copy_(refreshed) + else: + metadata.decode_indexer_schedule_metadata[key] = refreshed + + class DeepseekV4AttentionBackend(AttentionBackend): """Metadata owner for the model-local DeepSeek V4 attention path.""" uses_paged_cache_groups = True + uses_padded_decode_token_mask = True def __init__(self, config) -> None: super().__init__(config) @@ -127,12 +287,37 @@ def _query_lens( bs: int, seq_lens: torch.Tensor, forward_mode: ForwardMode | None, + num_extends: int, extend_seq_lens_cpu: torch.Tensor | None, extend_prefix_lens_cpu: torch.Tensor | None, extend_prefix_lens: torch.Tensor | None, ) -> torch.Tensor: if forward_mode is not None and forward_mode.is_decode_or_idle(): return torch.ones(bs, dtype=torch.int32, device=seq_lens.device) + if forward_mode is not None and forward_mode.is_mixed(): + lens = torch.ones(bs, dtype=torch.int32, device=seq_lens.device) + num_prefill_reqs = max(0, min(int(num_extends), bs)) + if num_prefill_reqs == 0: + return lens + if extend_seq_lens_cpu is not None and extend_seq_lens_cpu.numel() > 0: + lens[:num_prefill_reqs] = extend_seq_lens_cpu[:num_prefill_reqs].to( + seq_lens.device, dtype=torch.int32 + ) + elif extend_prefix_lens_cpu is not None: + prefix = extend_prefix_lens_cpu[:num_prefill_reqs].to( + seq_lens.device, dtype=torch.int32 + ) + lens[:num_prefill_reqs] = ( + seq_lens[:num_prefill_reqs].to(torch.int32) - prefix + ).clamp_min(0) + elif extend_prefix_lens is not None: + prefix = extend_prefix_lens[:num_prefill_reqs].to(torch.int32) + lens[:num_prefill_reqs] = ( + seq_lens[:num_prefill_reqs].to(torch.int32) - prefix + ).clamp_min(0) + else: + lens[:num_prefill_reqs] = seq_lens[:num_prefill_reqs].to(torch.int32) + return lens if extend_seq_lens_cpu is not None: return extend_seq_lens_cpu[:bs].to(seq_lens.device, dtype=torch.int32) if extend_prefix_lens_cpu is not None: @@ -143,6 +328,33 @@ def _query_lens( return (seq_lens[:bs].to(torch.int32) - prefix).clamp_min(0) return seq_lens[:bs].to(torch.int32) + def _query_lens_cpu( + self, + bs: int, + forward_mode: Optional[ForwardMode], + num_extends: int, + extend_seq_lens_cpu: Optional[torch.Tensor], + extend_prefix_lens_cpu: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + if forward_mode is not None and forward_mode.is_decode_or_idle(): + return torch.ones(bs, dtype=torch.int32) + if forward_mode is not None and forward_mode.is_mixed(): + lens = torch.ones(bs, dtype=torch.int32) + num_prefill_reqs = max(0, min(int(num_extends), bs)) + if num_prefill_reqs == 0: + return lens + if extend_seq_lens_cpu is None: + return None + lens[:num_prefill_reqs] = extend_seq_lens_cpu[:num_prefill_reqs].to( + dtype=torch.int32, device="cpu" + ) + return lens + if extend_seq_lens_cpu is not None: + return extend_seq_lens_cpu[:bs].to(dtype=torch.int32, device="cpu") + if extend_prefix_lens_cpu is not None: + return None + return None + def init_forward_metadata( self, bs: int, @@ -160,6 +372,8 @@ def init_forward_metadata( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) + num_extends_arg = kwargs.pop("num_extends", None) + num_extends = bs if num_extends_arg is None else int(num_extends_arg) del num_tokens, kwargs device = seq_lens.device req_pool_indices = req_pool_indices[:bs] @@ -168,10 +382,51 @@ def init_forward_metadata( bs, seq_lens, forward_mode, + num_extends, extend_seq_lens_cpu, extend_prefix_lens_cpu, extend_prefix_lens, ) + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_reqs = max(0, min(num_extends, bs)) + elif forward_mode is not None and forward_mode.is_extend(): + num_prefill_reqs = bs + else: + num_prefill_reqs = 0 + query_lens_cpu = self._query_lens_cpu( + bs, + forward_mode, + num_extends, + extend_seq_lens_cpu, + extend_prefix_lens_cpu, + ) + seq_lens_cpu = None + if extend_prefix_lens_cpu is not None and query_lens_cpu is not None: + seq_lens_cpu = seq_lens[:bs].to(dtype=torch.int32, device="cpu") + prefix_count = min( + int(extend_prefix_lens_cpu.numel()), + ( + num_prefill_reqs + if forward_mode is not None and forward_mode.is_mixed() + else bs + ), + ) + if prefix_count: + seq_lens_cpu[:prefix_count] = ( + extend_prefix_lens_cpu[:prefix_count].to( + dtype=torch.int32, + device="cpu", + ) + + query_lens_cpu[:prefix_count] + ) + elif extend_seq_lens_cpu is not None and forward_mode is not None: + if forward_mode.is_extend(): + seq_lens_cpu = extend_seq_lens_cpu[:bs].to( + dtype=torch.int32, + device="cpu", + ) + elif forward_mode.is_mixed(): + seq_lens_cpu = seq_lens[:bs].to(dtype=torch.int32, device="cpu") max_seq_len = int(seq_lens.max().item()) if bs else 0 max_pages = (max_seq_len + self.page_size - 1) // self.page_size if req_to_page is None: @@ -210,6 +465,9 @@ def init_forward_metadata( ) req_ids = torch.arange(bs, device=device, dtype=torch.int32) token_to_req = torch.repeat_interleave(req_ids, query_lens.clamp_min(0)) + num_prefill_tokens = ( + int(query_lens[:num_prefill_reqs].sum().item()) if num_prefill_reqs else 0 + ) self.forward_metadata = DeepseekV4ForwardMetadata( page_size=self.page_size, req_pool_indices=req_pool_indices, @@ -218,6 +476,10 @@ def init_forward_metadata( query_lens=query_lens, query_start_loc=_cu_seqlens(query_lens), token_to_req_indices=token_to_req, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + num_prefill_reqs=num_prefill_reqs, + num_prefill_tokens=num_prefill_tokens, forward_mode=forward_mode, paged_cache_block_tables=paged_cache_block_tables, paged_cache_block_table_base_offsets=base_offsets_on_device, @@ -271,6 +533,7 @@ def _update_decode_swa_metadata( block_table_base_offsets=metadata.swa_base_logical_page, window_size=window_size, block_size=block_size, + is_valid_token=metadata.is_valid_token, out_indices=metadata.decode_swa_indices, out_lens=metadata.decode_swa_lens, ) @@ -304,7 +567,7 @@ def _get_decode_swa_metadata( block_size=block_size, ) - def _decode_compressed_indices_and_lens( + def _decode_compressed_attention_indices_and_lens( self, positions: torch.Tensor, *, @@ -320,6 +583,12 @@ def _decode_compressed_indices_and_lens( num_tokens = positions.numel() req_idx = metadata.token_to_req_indices[:num_tokens].to(torch.int64) block_table = metadata.compressed_block_table(compress_ratio, block_size) + is_valid_token = ( + metadata.is_valid_token[:num_tokens] + if metadata.is_valid_token is not None + else None + ) + capturing = positions.is_cuda and torch.cuda.is_current_stream_capturing() if compress_ratio == 4: if topk_indices is None: raise RuntimeError("DeepSeek V4 CSA decode requires top-k indices") @@ -328,19 +597,41 @@ def _decode_compressed_indices_and_lens( token_to_req_indices=metadata.token_to_req_indices[:num_tokens], block_table=block_table, block_size=block_size, + is_valid_token=is_valid_token, ) return indices_2d.unsqueeze(1), lens - else: - width = self._dense_compressed_indices_width(compress_ratio) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp(0, width) - offsets = torch.arange(width, dtype=torch.int64, device=positions.device) - local = offsets[None, :].expand(num_tokens, -1) - valid = offsets[None, :] < compressed_lens[:, None] - lens = compressed_lens.to(torch.int32) + + cache_key = ( + int(compress_ratio), + int(block_size), + int(num_tokens), + int(positions.data_ptr()) if positions.numel() else 0, + ) + dense_indices_cache = metadata.decode_dense_compressed_indices_cache + capture_safe_keys = metadata.decode_dense_compressed_indices_capture_safe_keys + cached = dense_indices_cache.get(cache_key) + capture_cached = cache_key in capture_safe_keys + if cached is not None and (not capturing or capture_cached): + return cached + + width = self._dense_compressed_indices_width(compress_ratio) + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp(0, width) + offsets = torch.arange(width, dtype=torch.int64, device=positions.device) + local = offsets[None, :].expand(num_tokens, -1) + valid = offsets[None, :] < compressed_lens[:, None] + if is_valid_token is not None: + valid = valid & is_valid_token.to(torch.bool)[:, None] + lens = compressed_lens.to(torch.int32) + if is_valid_token is not None: + lens = torch.where( + is_valid_token.to(torch.bool), + lens, + torch.zeros_like(lens), + ) safe_local = torch.where(valid, local, torch.zeros_like(local)) pages = torch.div(safe_local, block_size, rounding_mode="floor") @@ -353,6 +644,9 @@ def _decode_compressed_indices_and_lens( torch.full_like(slots, -1), ) indices = indices_2d.to(torch.int32).unsqueeze(1) + dense_indices_cache[cache_key] = (indices, lens) + if capturing: + capture_safe_keys.add(cache_key) return indices, lens def _dense_compressed_indices_width(self, compress_ratio: int) -> int: @@ -477,7 +771,7 @@ def forward_deepseek_v4_decode( block_size=token_to_kv_pool.swa_block_size, ) compressed_block_size = token_to_kv_pool.get_compressed_block_size(layer_id) - extra_indices, extra_lens = self._decode_compressed_indices_and_lens( + extra_indices, extra_lens = self._decode_compressed_attention_indices_and_lens( positions, compress_ratio=compress_ratio, block_size=compressed_block_size, @@ -518,6 +812,103 @@ def forward_deepseek_v4_decode( out = out.squeeze(1) return out[:, :num_local_heads] + def forward_deepseek_v4_mixed( + self, + *, + q: torch.Tensor, + positions: torch.Tensor, + token_to_kv_pool, + layer_id: int, + kind: str, + compress_ratio: int, + num_local_heads: int, + padded_heads: int, + head_dim: int, + window_size: int, + softmax_scale: float, + attn_sink: torch.Tensor, + topk_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + metadata = self.forward_metadata + if metadata is None: + raise RuntimeError("DeepSeek V4 mixed attention requires forward metadata") + if metadata.forward_mode is None or not metadata.forward_mode.is_mixed(): + raise RuntimeError( + "forward_deepseek_v4_mixed only supports ForwardMode.MIXED" + ) + + num_prefill_reqs = metadata.num_prefill_reqs + num_prefill_tokens = metadata.num_prefill_tokens + num_decode_reqs = metadata.decode_req_count() + num_decode_tokens = metadata.decode_token_count() + out = q.new_empty((q.shape[0], num_local_heads, head_dim)) + saved_metadata = self.forward_metadata + try: + if num_prefill_tokens > 0: + self.forward_metadata = self._metadata_slice( + metadata, + req_start=0, + req_end=num_prefill_reqs, + token_start=0, + token_end=num_prefill_tokens, + forward_mode=ForwardMode.EXTEND, + ) + prefill_out = self.forward_deepseek_v4_prefill( + q=q[:num_prefill_tokens], + positions=positions[:num_prefill_tokens], + token_to_kv_pool=token_to_kv_pool, + layer_id=layer_id, + kind=kind, + compress_ratio=compress_ratio, + num_local_heads=num_local_heads, + padded_heads=padded_heads, + head_dim=head_dim, + window_size=window_size, + softmax_scale=softmax_scale, + attn_sink=attn_sink, + topk_indices=( + topk_indices[:num_prefill_tokens] + if topk_indices is not None + else None + ), + ) + with deepseek_v4_profile_scope(f"attn_{kind}_mixed_prefill_copy"): + out[:num_prefill_tokens].copy_(prefill_out) + if num_decode_tokens > 0: + decode_end = num_prefill_tokens + num_decode_tokens + self.forward_metadata = self._metadata_slice( + metadata, + req_start=num_prefill_reqs, + req_end=num_prefill_reqs + num_decode_reqs, + token_start=num_prefill_tokens, + token_end=decode_end, + forward_mode=ForwardMode.DECODE, + ) + decode_out = self.forward_deepseek_v4_decode( + q=q[num_prefill_tokens:decode_end], + positions=positions[num_prefill_tokens:decode_end], + token_to_kv_pool=token_to_kv_pool, + layer_id=layer_id, + kind=kind, + compress_ratio=compress_ratio, + num_local_heads=num_local_heads, + padded_heads=padded_heads, + head_dim=head_dim, + window_size=window_size, + softmax_scale=softmax_scale, + attn_sink=attn_sink, + topk_indices=( + topk_indices[num_prefill_tokens:decode_end] + if topk_indices is not None + else None + ), + ) + with deepseek_v4_profile_scope(f"attn_{kind}_mixed_decode_copy"): + out[num_prefill_tokens:decode_end].copy_(decode_out) + finally: + self.forward_metadata = saved_metadata + return out + def _prefill_gather_lens( self, *, @@ -700,6 +1091,10 @@ def _metadata_slice( key: offsets[req_start:req_end] for key, offsets in metadata.compressor_state_base_logical_pages.items() } + req_count = max(0, req_end - req_start) + token_count = max(0, token_end - token_start) + num_prefill_reqs = req_count if forward_mode.is_extend() else 0 + num_prefill_tokens = token_count if forward_mode.is_extend() else 0 return DeepseekV4ForwardMetadata( page_size=metadata.page_size, req_pool_indices=metadata.req_pool_indices[req_start:req_end], @@ -708,6 +1103,23 @@ def _metadata_slice( query_lens=metadata.query_lens[req_start:req_end], query_start_loc=_cu_seqlens(metadata.query_lens[req_start:req_end]), token_to_req_indices=token_to_req, + is_valid_token=( + metadata.is_valid_token[token_start:token_end] + if metadata.is_valid_token is not None + else None + ), + seq_lens_cpu=( + metadata.seq_lens_cpu[req_start:req_end] + if metadata.seq_lens_cpu is not None + else None + ), + query_lens_cpu=( + metadata.query_lens_cpu[req_start:req_end] + if metadata.query_lens_cpu is not None + else None + ), + num_prefill_reqs=num_prefill_reqs, + num_prefill_tokens=num_prefill_tokens, forward_mode=forward_mode, paged_cache_block_tables=paged_cache_block_tables, paged_cache_block_table_base_offsets=paged_cache_block_table_base_offsets, @@ -955,6 +1367,11 @@ def init_cuda_graph_state( dtype=torch.int32, device=self.device, ) + self._cuda_graph_is_valid_token = torch.ones( + max_bs, + dtype=torch.bool, + device=self.device, + ) def _refresh_cuda_graph_paged_cache_block_tables( self, @@ -1077,6 +1494,9 @@ def init_forward_metadata_capture_cuda_graph( query_lens=self._cuda_graph_query_lens[:bs], query_start_loc=self._cuda_graph_query_start_loc[: bs + 1], token_to_req_indices=self._cuda_graph_token_to_req[:bs], + is_valid_token=self._cuda_graph_is_valid_token[:bs], + seq_lens_cpu=None, + query_lens_cpu=None, forward_mode=forward_mode, paged_cache_block_tables=metadata_paged, paged_cache_block_table_base_offsets=metadata_base_offsets, @@ -1103,6 +1523,7 @@ def init_forward_metadata_replay_cuda_graph( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) + actual_bs = max(0, min(int(kwargs.pop("actual_bs", bs)), bs)) del kwargs if forward_mode is not None and not forward_mode.is_decode_or_idle(): raise NotImplementedError( @@ -1118,6 +1539,9 @@ def init_forward_metadata_replay_cuda_graph( self._cuda_graph_token_to_req[:bs].copy_( torch.arange(bs, dtype=torch.int32, device=self.device) ) + self._cuda_graph_is_valid_token[:actual_bs].fill_(True) + if actual_bs < bs: + self._cuda_graph_is_valid_token[actual_bs:bs].fill_(False) if req_to_page is not None: self._cuda_graph_block_table[:bs, : self.max_num_pages].copy_( req_to_page[req_pool_indices[:bs], : self.max_num_pages] @@ -1159,6 +1583,8 @@ def init_forward_metadata_replay_cuda_graph( metadata.compressor_state_base_logical_pages = compressor_state_base metadata.indexer_state_block_table = indexer_state_block_table metadata.indexer_state_base_logical_page = indexer_state_base + metadata.num_prefill_reqs = 0 + metadata.num_prefill_tokens = 0 if ( forward_mode is not None and forward_mode.is_decode() @@ -1171,6 +1597,11 @@ def init_forward_metadata_replay_cuda_graph( block_size=self._decode_swa_block_size, ) metadata.refresh_decode_compressed_slot_mappings() + _refresh_decode_indexer_plan_cache( + metadata, + max_context_len=self.context_len, + ) + _refresh_decode_indexer_schedule_metadata(metadata) self.forward_metadata = metadata def advance_draft_forward_metadata(self): diff --git a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py index e14ce6d04..6b057bb9b 100644 --- a/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py +++ b/python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py @@ -2203,11 +2203,18 @@ def _deepseek_v4_compute_global_topk_indices_and_lens_kernel( token_to_req_indices_ptr, block_table_ptr, block_table_stride, + is_valid_token_ptr, + has_valid_token: tl.constexpr, block_size: tl.constexpr, topk: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) + if has_valid_token: + is_valid_token = tl.load(is_valid_token_ptr + token_idx) + if not is_valid_token: + tl.store(topk_lens_ptr + token_idx, 0) + return req_idx = tl.load(token_to_req_indices_ptr + token_idx) count = tl.zeros((), dtype=tl.int32) @@ -2245,6 +2252,7 @@ def deepseek_v4_compute_global_topk_indices_and_lens( token_to_req_indices: torch.Tensor, block_table: torch.Tensor, block_size: int, + is_valid_token: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local CSA top-k indices to global KV slots in one Triton kernel.""" @@ -2257,13 +2265,31 @@ def deepseek_v4_compute_global_topk_indices_and_lens( topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) if num_tokens == 0: return global_topk_indices, topk_lens + if is_valid_token is not None: + is_valid_token = is_valid_token[:num_tokens].to( + device=topk_indices.device, + dtype=torch.bool, + ) if not topk_indices.is_cuda: valid = topk_indices >= 0 + if is_valid_token is not None: + valid = valid & is_valid_token[:, None] req_idx = token_to_req_indices[:num_tokens].to(torch.int64) + rows = int(block_table.shape[0]) if block_table.dim() >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.dim() >= 2 else 0 + if rows <= 0 or cols <= 0: + global_topk_indices.fill_(-1) + topk_lens.zero_() + return global_topk_indices, topk_lens safe_local = torch.where(valid, topk_indices, torch.zeros_like(topk_indices)) block_indices = torch.div(safe_local, block_size, rounding_mode="floor") block_offsets = safe_local % block_size - block_numbers = block_table[req_idx[:, None], block_indices.long()] + req_valid = (req_idx >= 0) & (req_idx < rows) + block_valid = (block_indices >= 0) & (block_indices < cols) + valid = valid & req_valid[:, None] & block_valid + safe_req = req_idx.clamp(0, rows - 1) + safe_block = block_indices.long().clamp(0, cols - 1) + block_numbers = block_table[safe_req[:, None], safe_block] global_topk_indices.copy_( torch.where( valid, @@ -2273,6 +2299,8 @@ def deepseek_v4_compute_global_topk_indices_and_lens( ) topk_lens.copy_(valid.sum(dim=1, dtype=torch.int32)) return global_topk_indices, topk_lens + if is_valid_token is None: + is_valid_token = torch.empty(0, dtype=torch.bool, device=topk_indices.device) _deepseek_v4_compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, @@ -2283,6 +2311,8 @@ def deepseek_v4_compute_global_topk_indices_and_lens( token_to_req_indices.to(torch.int32), block_table.to(torch.int32), block_table.stride(0), + is_valid_token, + is_valid_token.numel() != 0, block_size=block_size, topk=topk_indices.shape[-1], TRITON_BLOCK_SIZE=1024, @@ -2591,15 +2621,22 @@ def _deepseek_v4_decode_swa_indices_and_lens_kernel( query_start_loc_ptr, seq_lens_ptr, token_to_req_indices_ptr, + is_valid_token_ptr, block_table_ptr, block_table_base_offsets_ptr, block_table_stride, max_blocks_per_seq: tl.constexpr, + has_valid_token: tl.constexpr, window_size: tl.constexpr, block_size: tl.constexpr, candidate_block: tl.constexpr, ): token_idx = tl.program_id(0) + if has_valid_token: + is_valid = tl.load(is_valid_token_ptr + token_idx) + if not is_valid: + tl.store(swa_lens_ptr + token_idx, 0) + return req_idx = tl.load(token_to_req_indices_ptr + token_idx).to(tl.int32) query_start = tl.load(query_start_loc_ptr + req_idx).to(tl.int32) @@ -2647,6 +2684,7 @@ def deepseek_v4_decode_swa_indices_and_lens( window_size: int, block_size: int, block_table_base_offsets: torch.Tensor | None = None, + is_valid_token: torch.Tensor | None = None, out_indices: torch.Tensor | None = None, out_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -2663,6 +2701,13 @@ def deepseek_v4_decode_swa_indices_and_lens( out_lens = torch.empty(num_tokens, dtype=torch.int32, device=seq_lens.device) if num_tokens == 0: return out_indices, out_lens + if is_valid_token is None: + is_valid_token = torch.empty(0, dtype=torch.bool, device=seq_lens.device) + else: + is_valid_token = is_valid_token[:num_tokens].to( + device=seq_lens.device, + dtype=torch.bool, + ) candidate_block = min(1024, triton.next_power_of_2(window_size)) _deepseek_v4_decode_swa_indices_and_lens_kernel[(num_tokens,)]( @@ -2672,6 +2717,7 @@ def deepseek_v4_decode_swa_indices_and_lens( query_start_loc.to(torch.int32), seq_lens.to(torch.int32), token_to_req_indices.to(torch.int32), + is_valid_token, block_table.to(torch.int32), ( block_table_base_offsets.to(torch.int32) @@ -2680,6 +2726,7 @@ def deepseek_v4_decode_swa_indices_and_lens( ), block_table.stride(0), block_table.shape[-1], + is_valid_token.numel() != 0, window_size=window_size, block_size=block_size, candidate_block=candidate_block, diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index bcec744af..bef078ef1 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -234,6 +234,15 @@ class DeepseekV4ForwardMetadata: query_lens: torch.Tensor query_start_loc: torch.Tensor token_to_req_indices: torch.Tensor + # Padding mask for CUDA graph replay rows; this is not mixed-batch state. + is_valid_token: Optional[torch.Tensor] = None + # CPU lens are retained for sparse prefill/indexer planning without + # forcing another device-to-host sync in the model path. + seq_lens_cpu: Optional[torch.Tensor] = None + query_lens_cpu: Optional[torch.Tensor] = None + # Cached split boundary derived from scheduler num_extends/query_lens. + num_prefill_reqs: int = 0 + num_prefill_tokens: int = 0 forward_mode: object = None decode_swa_indices: torch.Tensor | None = None decode_swa_lens: torch.Tensor | None = None @@ -257,9 +266,35 @@ class DeepseekV4ForwardMetadata: decode_compressed_slot_mappings: dict[tuple[int, int], torch.Tensor] = field( default_factory=dict ) + # Cache for dense compressed decode attention indices/lens. CSA decode uses + # dynamic top-k indices and does not populate this cache. + decode_dense_compressed_indices_cache: dict[ + tuple[int, int, int, int], tuple[torch.Tensor, torch.Tensor] + ] = field(default_factory=dict) + decode_dense_compressed_indices_capture_safe_keys: set[ + tuple[int, int, int, int] + ] = field(default_factory=set) decode_indexer_schedule_metadata: dict[tuple[int, int, int], torch.Tensor] = field( default_factory=dict ) + decode_indexer_plan_cache: dict[tuple[int, int, int], Any] = field( + default_factory=dict + ) + decode_indexer_plan_refreshed_keys: set[tuple[int, int, int]] = field( + default_factory=set + ) + prefill_indexer_plan_cache: dict[tuple[int, int, int], Any] = field( + default_factory=dict + ) + + def decode_req_count(self) -> int: + return max(0, int(self.req_pool_indices.shape[0]) - int(self.num_prefill_reqs)) + + def decode_token_count(self) -> int: + return max( + 0, + int(self.token_to_req_indices.shape[0]) - int(self.num_prefill_tokens), + ) def _use_decode_compressed_slot_cache(self, positions: torch.Tensor) -> bool: return ( diff --git a/python/tokenspeed/runtime/models/deepseek_v4.py b/python/tokenspeed/runtime/models/deepseek_v4.py index 9fe9db96d..4da61c90e 100644 --- a/python/tokenspeed/runtime/models/deepseek_v4.py +++ b/python/tokenspeed/runtime/models/deepseek_v4.py @@ -38,8 +38,6 @@ import torch import torch.nn.functional as F -import triton -import triton.language as tl try: # Optional dependency; the module-level wrapper imports the external @@ -49,7 +47,13 @@ except ImportError: deep_gemm = None # type: ignore[assignment] +from tokenspeed_kernel.ops.attention.triton.deepseek_v4 import ( + deepseek_v4_indexer_decode_metadata_compute, +) from tokenspeed_kernel.ops.gemm.fp8_utils import per_token_group_quant_fp8 +from tokenspeed_kernel.ops.moe.triton import ( + stage_deepseek_v4_mega_moe_inputs as _stage_deepseek_v4_mega_moe_inputs, +) from tokenspeed_kernel.ops.routing.cuda import dsv3_router_gemm from tokenspeed_kernel.platform import current_platform from tokenspeed_kernel.thirdparty.cuda import ( @@ -122,6 +126,7 @@ get_colorful_logger, set_weight_attrs, ) +from tokenspeed.runtime.utils.custom_ops import direct_register_custom_op from tokenspeed.runtime.utils.env import global_server_args_dict, pdl_enabled _platform = current_platform() @@ -717,27 +722,186 @@ def _deepseek_v4_indexer_topk_from_cache_batched( return topk +@dataclass(frozen=True) +class _DeepseekV4IndexerPrefillChunk: + token_start: int + token_end: int + req_start: int + req_end: int + query_start: int + query_end: int + skip_kv_gather: bool = False + + +@dataclass(frozen=True) +class _DeepseekV4IndexerPrefillMetadata: + chunk_bounds: torch.Tensor + chunk_plan: torch.Tensor + slots: torch.Tensor + cu_seq_lens: torch.Tensor + cu_start: torch.Tensor + cu_end: torch.Tensor + row_lens: torch.Tensor + + +@dataclass +class _DeepseekV4IndexerDecodeMetadata: + context_lens: torch.Tensor + block_table: torch.Tensor + max_context_len: int + + +def _deepseek_v4_indexer_prefill_max_logits_bytes( + max_logits_bytes: Optional[int] = None, +) -> int: + if max_logits_bytes is not None: + return max(1, int(max_logits_bytes)) + max_logits_mb = global_server_args_dict.get( + "deepseek_v4_indexer_prefill_max_logits_mb", + _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB, + ) + return max(1, int(max_logits_mb) * 1024 * 1024) + + +def _deepseek_v4_indexer_prefill_workspace_size( + seq_lens_cpu: torch.Tensor, + workspace_size: Optional[int] = None, +) -> int: + if workspace_size is not None: + return max(1, int(workspace_size)) + context_len = global_server_args_dict.get("context_length") + if isinstance(context_len, int) and context_len > 0: + return context_len * 40 + max_seq_len = int(seq_lens_cpu.max().item()) if seq_lens_cpu.numel() else 1 + return max(1, max_seq_len) * 40 + + +def _deepseek_v4_indexer_prefill_request_chunks( + *, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + compress_ratio: int, + num_tokens: int, + max_logits_bytes: Optional[int] = None, + workspace_size: Optional[int] = None, + request_offset: int = 0, +) -> list[_DeepseekV4IndexerPrefillChunk]: + """Build request/query-slice sparse-indexer prefill chunks.""" + + if num_tokens == 0: + return [] + + seq_lens = seq_lens_cpu.detach().cpu().to(torch.int64) + query_lens = query_lens_cpu.detach().cpu().to(torch.int64) + if seq_lens.numel() != query_lens.numel(): + return [] + + query_lens_list = [max(0, int(x)) for x in query_lens.tolist()] + if sum(query_lens_list) != num_tokens: + return [] + + compressed_seq_lens = torch.div( + seq_lens, + max(1, int(compress_ratio)), + rounding_mode="floor", + ) + compressed_seq_lens_list = [max(0, int(x)) for x in compressed_seq_lens.tolist()] + workspace_rows = _deepseek_v4_indexer_prefill_workspace_size( + seq_lens, + workspace_size, + ) + max_logits_elems = ( + _deepseek_v4_indexer_prefill_max_logits_bytes(max_logits_bytes) // 4 + ) + max_logits_elems = max(1, max_logits_elems) + + query_offsets = [0] + for query_len in query_lens_list: + query_offsets.append(query_offsets[-1] + query_len) + + chunks: list[_DeepseekV4IndexerPrefillChunk] = [] + n_reqs = len(query_lens_list) + end = 0 + while end < n_reqs: + start = end + chunk_m = 0 + chunk_n = 0 + while end < n_reqs: + q_len = query_lens_list[end] + seq_len = compressed_seq_lens_list[end] + new_m = chunk_m + q_len + new_n = chunk_n + seq_len + if new_n <= workspace_rows and new_m * new_n <= max_logits_elems: + chunk_m = new_m + chunk_n = new_n + end += 1 + else: + break + + if end == start: + chunk_m = query_lens_list[end] + chunk_n = compressed_seq_lens_list[end] + end += 1 + + if chunk_m <= 0: + continue + + req_start = start + request_offset + req_end = end + request_offset + max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m + chunk_token_start = query_offsets[start] + for query_start in range(0, chunk_m, max_q): + query_end = min(query_start + max_q, chunk_m) + chunks.append( + _DeepseekV4IndexerPrefillChunk( + token_start=chunk_token_start + query_start, + token_end=chunk_token_start + query_end, + req_start=req_start, + req_end=req_end, + query_start=query_start, + query_end=query_end, + skip_kv_gather=query_start > 0, + ) + ) + return chunks + + def _deepseek_v4_indexer_prefill_topk_chunks( positions: torch.Tensor, compress_ratio: int, max_logits_bytes: int | None = None, + *, + seq_lens_cpu: Optional[torch.Tensor] = None, + query_lens_cpu: Optional[torch.Tensor] = None, ) -> list[tuple[int, int]]: num_tokens = positions.numel() if num_tokens == 0: return [] - if max_logits_bytes is None: - max_logits_mb = global_server_args_dict.get( - "deepseek_v4_indexer_prefill_max_logits_mb", - _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB, - ) - max_logits_bytes = max_logits_mb * 1024 * 1024 - max_logits_elems = max(1, int(max_logits_bytes) // 4) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - lengths = compressed_lens.detach().cpu().tolist() + max_logits_elems = max( + 1, + _deepseek_v4_indexer_prefill_max_logits_bytes(max_logits_bytes) // 4, + ) + lengths: Optional[list[int]] = None + if seq_lens_cpu is not None and query_lens_cpu is not None: + seq_lens_list = seq_lens_cpu.detach().cpu().tolist() + query_lens_list = query_lens_cpu.detach().cpu().tolist() + cpu_lengths: list[int] = [] + for seq_len, query_len in zip(seq_lens_list, query_lens_list): + total_len = int(seq_len) + query_len = max(0, int(query_len)) + prefix_len = max(0, total_len - query_len) + for query_offset in range(query_len): + cpu_lengths.append((prefix_len + query_offset + 1) // compress_ratio) + if len(cpu_lengths) == num_tokens: + lengths = cpu_lengths + + if lengths is None: + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + lengths = compressed_lens.detach().cpu().tolist() chunks: list[tuple[int, int]] = [] end = 0 @@ -884,15 +1048,456 @@ def _deepseek_v4_gather_indexer_mxfp4_cache( return values, scales +def _deepseek_v4_gather_paged_indexer_mxfp4_cache_available() -> bool: + global _DEEPSEEK_V4_PAGED_GATHER_CHECKED + global _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + if _DEEPSEEK_V4_PAGED_GATHER_CHECKED: + return _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + try: + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + has_indexer_mxfp4_paged_gather, + ) + except Exception: + _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = False + else: + _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = bool(has_indexer_mxfp4_paged_gather()) + _DEEPSEEK_V4_PAGED_GATHER_CHECKED = True + return _DEEPSEEK_V4_PAGED_GATHER_AVAILABLE + + +def _deepseek_v4_gather_paged_indexer_mxfp4_cache( + cache_2d: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + block_size: int, + out: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + if out is None: + total_rows = int(cu_seq_lens[-1].item()) if cu_seq_lens.numel() else 0 + values = torch.empty( + (total_rows, value_bytes), + dtype=torch.uint8, + device=cache_2d.device, + ) + scales = torch.empty( + (total_rows, scale_bytes), + dtype=torch.uint8, + device=cache_2d.device, + ) + else: + if out[0].shape[0] != out[1].shape[0]: + raise ValueError( + "DeepSeek V4 paged gather workspace value/scale rows must match, " + f"got values={out[0].shape[0]}, scales={out[1].shape[0]}" + ) + total_rows = int(out[0].shape[0]) + values = out[0][:total_rows] + scales = out[1][:total_rows] + if total_rows == 0: + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + if ( + cache_2d.is_cuda + and block_table.is_cuda + and cu_seq_lens.is_cuda + and _deepseek_v4_gather_paged_indexer_mxfp4_cache_available() + ): + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + indexer_mxfp4_paged_gather, + ) + + indexer_mxfp4_paged_gather( + kv_cache=cache_2d, + values_out=values, + scales_out=scales, + block_table=block_table, + cu_seq_lens=cu_seq_lens, + cache_block_size=block_size, + ) + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + exact_rows = int(cu_seq_lens[-1].item()) if cu_seq_lens.numel() else 0 + if exact_rows <= 0: + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + req_lens = torch.diff(cu_seq_lens.to(torch.int64)) + req_ids = torch.repeat_interleave( + torch.arange(req_lens.numel(), device=cache_2d.device, dtype=torch.int64), + req_lens.to(device=cache_2d.device), + output_size=exact_rows, + ) + cu_seq_lens_device = cu_seq_lens.to(device=cache_2d.device, dtype=torch.int64) + local = torch.arange(exact_rows, device=cache_2d.device, dtype=torch.int64) + local = local - cu_seq_lens_device[:-1][req_ids] + pages = torch.div(local, block_size, rounding_mode="floor") + page_offsets = local % block_size + block_table_device = block_table.to(device=cache_2d.device, dtype=torch.int64) + slots = block_table_device[req_ids, pages] * block_size + page_offsets + _deepseek_v4_gather_indexer_mxfp4_cache( + cache_2d, + slots, + block_size, + out=(values[:exact_rows], scales[:exact_rows]), + ) + return values.view(torch.int8), scales.view(torch.int32).squeeze(-1) + + +def _deepseek_v4_indexer_prefill_gather_plan( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + num_tokens = positions.numel() + device = positions.device + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + if num_tokens == 0: + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + req_idx = token_to_req_indices[:num_tokens].to(torch.int64) + new_group = torch.ones(num_tokens, dtype=torch.bool, device=device) + if num_tokens > 1: + new_group[1:] = req_idx[1:] != req_idx[:-1] + group_starts = torch.nonzero(new_group, as_tuple=False).flatten() + group_ends = torch.empty_like(group_starts) + group_ends[:-1] = group_starts[1:] + group_ends[-1] = num_tokens + group_lengths = group_ends - group_starts + group_max_lens = compressed_lens[group_ends - 1].to(torch.int32) + + cu_seq_lens = torch.empty( + group_starts.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(group_max_lens, dim=0, out=cu_seq_lens[1:]) + total_k = int(cu_seq_lens[-1].item()) + row_lens = compressed_lens.to(torch.int32) + + group_for_token = torch.repeat_interleave( + torch.arange(group_starts.numel(), device=device, dtype=torch.int64), + group_lengths.to(torch.int64), + output_size=num_tokens, + ) + cu_start = cu_seq_lens[:-1][group_for_token] + cu_end = cu_start + row_lens + max_len = int(group_max_lens.max().item()) if group_max_lens.numel() else 0 + if total_k <= 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, cu_start, cu_end, row_lens, max_len + + group_ids = torch.repeat_interleave( + torch.arange(group_starts.numel(), device=device, dtype=torch.int64), + group_max_lens.to(torch.int64), + output_size=total_k, + ) + group_bases = cu_seq_lens[:-1][group_ids].to(torch.int64) + local = torch.arange(total_k, device=device, dtype=torch.int64) - group_bases + req_for_k = req_idx[group_starts][group_ids] + pages = torch.div(local, cache_block_size, rounding_mode="floor") + page_offsets = local % cache_block_size + page_ids = block_table[req_for_k, pages.long()].to(torch.int64) + slots = page_ids * cache_block_size + page_offsets + return slots, cu_start, cu_end, row_lens, max_len + + +def _deepseek_v4_indexer_prefill_request_gather_plan( + *, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + req_start: int, + req_end: int, + query_start: int, + query_end: int, + build_slots: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + device = block_table.device + num_rows = max(0, int(query_end) - int(query_start)) + if num_rows == 0 or req_end <= req_start: + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + seq_lens_list = ( + seq_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + ) + query_lens_list = ( + query_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + ) + if len(seq_lens_list) != len(query_lens_list): + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, empty_i32, empty_i32, empty_i32, 0 + + ratio = max(1, int(compress_ratio)) + seq_lens_list = [max(0, int(x)) for x in seq_lens_list] + query_lens_list = [max(0, int(x)) for x in query_lens_list] + compressed_lens_list = [seq_len // ratio for seq_len in seq_lens_list] + total_k = sum(compressed_lens_list) + + query_offsets: list[int] = [0] + for query_len in query_lens_list: + query_offsets.append(query_offsets[-1] + query_len) + + req_local_list: list[int] = [] + row_lens_list: list[int] = [] + req_local = 0 + last_req = max(0, len(query_lens_list) - 1) + for row_offset in range(int(query_start), int(query_end)): + while req_local < last_req and row_offset >= query_offsets[req_local + 1]: + req_local += 1 + local_query_offset = row_offset - query_offsets[req_local] + prefix_len = max(0, seq_lens_list[req_local] - query_lens_list[req_local]) + row_lens_list.append((prefix_len + local_query_offset + 1) // ratio) + req_local_list.append(req_local) + max_len = max(row_lens_list) if row_lens_list else 0 + + compressed_lens = torch.tensor( + compressed_lens_list, + dtype=torch.int64, + device=device, + ) + + cu_seq_lens = torch.empty( + compressed_lens.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(compressed_lens.to(torch.int32), dim=0, out=cu_seq_lens[1:]) + + req_local_tensor = torch.tensor(req_local_list, dtype=torch.int64, device=device) + row_lens = torch.tensor(row_lens_list, dtype=torch.int32, device=device) + cu_start = cu_seq_lens[:-1][req_local_tensor] + cu_end = cu_start + row_lens + + if total_k <= 0 or not build_slots: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + return empty_i64, cu_start, cu_end, row_lens, max_len + + req_ids = torch.repeat_interleave( + torch.arange(req_start, req_end, device=device, dtype=torch.int64), + compressed_lens, + output_size=total_k, + ) + req_local_for_k = req_ids - int(req_start) + group_bases = cu_seq_lens[:-1][req_local_for_k].to(torch.int64) + local = torch.arange(total_k, device=device, dtype=torch.int64) - group_bases + pages = torch.div(local, cache_block_size, rounding_mode="floor") + page_offsets = local % cache_block_size + page_ids = block_table[req_ids, pages.long()].to(torch.int64) + slots = page_ids * cache_block_size + page_offsets + return slots, cu_start, cu_end, row_lens, max_len + + +def _deepseek_v4_indexer_prefill_chunk_total_rows( + *, + seq_lens_cpu: torch.Tensor, + compress_ratio: int, + req_start: int, + req_end: int, +) -> int: + ratio = max(1, int(compress_ratio)) + seq_lens = seq_lens_cpu.detach().cpu().to(torch.int64)[req_start:req_end].tolist() + return sum(max(0, int(seq_len)) // ratio for seq_len in seq_lens) + + +def _deepseek_v4_empty_indexer_prefill_metadata( + device: torch.device, +) -> _DeepseekV4IndexerPrefillMetadata: + return _DeepseekV4IndexerPrefillMetadata( + chunk_bounds=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + chunk_plan=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + slots=torch.empty(0, dtype=torch.int64, device=device), + cu_seq_lens=torch.empty(0, dtype=torch.int32, device=device), + cu_start=torch.empty(0, dtype=torch.int32, device=device), + cu_end=torch.empty(0, dtype=torch.int32, device=device), + row_lens=torch.empty(0, dtype=torch.int32, device=device), + ) + + +def _deepseek_v4_indexer_prefill_metadata( + *, + metadata: Any, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + num_prefill_tokens: int, +) -> _DeepseekV4IndexerPrefillMetadata: + device = block_table.device + if num_prefill_tokens <= 0: + return _deepseek_v4_empty_indexer_prefill_metadata(device) + + seq_lens_cpu = getattr(metadata, "seq_lens_cpu", None) + query_lens_cpu = getattr(metadata, "query_lens_cpu", None) + num_prefill_reqs = int(getattr(metadata, "num_prefill_reqs", 0) or 0) + if seq_lens_cpu is None or query_lens_cpu is None or num_prefill_reqs <= 0: + return _deepseek_v4_empty_indexer_prefill_metadata(device) + + seq_lens_cpu = seq_lens_cpu[:num_prefill_reqs] + query_lens_cpu = query_lens_cpu[:num_prefill_reqs] + cache_key = (compress_ratio, cache_block_size, num_prefill_tokens) + cache = getattr(metadata, "prefill_indexer_plan_cache", None) + cached = cache.get(cache_key) if cache is not None else None + if cached is not None and cached.slots.device == device: + return cached + + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + compress_ratio=compress_ratio, + num_tokens=num_prefill_tokens, + ) + if not chunks: + out = _deepseek_v4_empty_indexer_prefill_metadata(device) + if cache is not None: + cache[cache_key] = out + return out + + chunk_bounds_rows: list[list[int]] = [] + chunk_plan_rows: list[list[int]] = [] + slot_parts: list[torch.Tensor] = [] + cu_seq_lens_parts: list[torch.Tensor] = [] + cu_start_parts: list[torch.Tensor] = [] + cu_end_parts: list[torch.Tensor] = [] + row_lens_parts: list[torch.Tensor] = [] + slot_offset = 0 + cu_seq_offset = 0 + row_offset = 0 + for chunk in chunks: + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + build_slots=False, + ) + ) + slot_count = _deepseek_v4_indexer_prefill_chunk_total_rows( + seq_lens_cpu=seq_lens_cpu, + compress_ratio=compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + ) + compressed_lens = torch.div( + seq_lens_cpu[chunk.req_start : chunk.req_end].to( + dtype=torch.int32, + device=device, + ), + max(1, int(compress_ratio)), + rounding_mode="floor", + ) + cu_seq_lens = torch.empty( + compressed_lens.numel() + 1, + dtype=torch.int32, + device=device, + ) + cu_seq_lens[:1] = 0 + torch.cumsum(compressed_lens, dim=0, out=cu_seq_lens[1:]) + slot_end = slot_offset + slot_count + cu_seq_end = cu_seq_offset + cu_seq_lens.numel() + row_end = row_offset + row_lens.numel() + chunk_bounds_rows.append( + [ + chunk.token_start, + chunk.token_end, + chunk.req_start, + chunk.req_end, + chunk.query_start, + chunk.query_end, + 1 if chunk.skip_kv_gather else 0, + ] + ) + chunk_plan_rows.append( + [ + slot_offset, + slot_end, + row_offset, + row_end, + max_len, + cu_seq_offset, + cu_seq_end, + ] + ) + if slots.numel() > 0: + slot_parts.append(slots) + cu_seq_lens_parts.append(cu_seq_lens) + cu_start_parts.append(cu_start) + cu_end_parts.append(cu_end) + row_lens_parts.append(row_lens) + slot_offset = slot_end + cu_seq_offset = cu_seq_end + row_offset = row_end + + out = _DeepseekV4IndexerPrefillMetadata( + chunk_bounds=torch.tensor(chunk_bounds_rows, dtype=torch.int64, device="cpu"), + chunk_plan=torch.tensor(chunk_plan_rows, dtype=torch.int64, device="cpu"), + slots=( + torch.cat(slot_parts, dim=0) + if slot_parts + else torch.empty(0, dtype=torch.int64, device=device) + ), + cu_seq_lens=( + torch.cat(cu_seq_lens_parts, dim=0) + if cu_seq_lens_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + cu_start=( + torch.cat(cu_start_parts, dim=0) + if cu_start_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + cu_end=( + torch.cat(cu_end_parts, dim=0) + if cu_end_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + row_lens=( + torch.cat(row_lens_parts, dim=0) + if row_lens_parts + else torch.empty(0, dtype=torch.int32, device=device) + ), + ) + if cache is not None: + cache[cache_key] = out + return out + + def _deepseek_v4_indexer_topk_from_logits( logits: torch.Tensor, lengths: torch.Tensor, topk_tokens: int, *, + next_n: int = 1, preserve_topk_order: bool = False, - out: torch.Tensor | None = None, + sort_preserved_topk: Optional[bool] = None, + row_starts: Optional[torch.Tensor] = None, + row_ends: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - num_tokens = lengths.numel() + lengths_for_kernel = lengths.to(torch.int32).contiguous() + length_rows = lengths_for_kernel.reshape(-1) + num_tokens = length_rows.numel() if out is None: topk = torch.empty( (num_tokens, topk_tokens), @@ -908,11 +1513,33 @@ def _deepseek_v4_indexer_topk_from_logits( if max_len <= 0: return topk + row_starts_for_kernel: Optional[torch.Tensor] = None + row_ends_for_kernel: Optional[torch.Tensor] = None + if row_starts is not None or row_ends is not None: + if row_starts is None: + row_starts_for_kernel = torch.zeros_like(length_rows) + else: + row_starts_for_kernel = row_starts.to( + device=logits.device, dtype=torch.int32 + ).reshape(-1) + if row_ends is None: + row_ends_for_kernel = row_starts_for_kernel + length_rows + else: + row_ends_for_kernel = row_ends.to( + device=logits.device, dtype=torch.int32 + ).reshape(-1) + length_rows = (row_ends_for_kernel - row_starts_for_kernel).clamp_min(0) + + if sort_preserved_topk is None: + sort_preserved_topk = False + if preserve_topk_order: prefill_topk = _deepseek_v4_indexer_topk_from_logits_prefill_op( logits, - lengths.to(torch.int32).reshape(-1), + length_rows, topk_tokens, + row_starts=row_starts_for_kernel, + row_ends=row_ends_for_kernel, out=topk, ) if prefill_topk is not None: @@ -923,29 +1550,55 @@ def _deepseek_v4_indexer_topk_from_logits( fast_topk_v2( logits.contiguous(), - lengths.to(torch.int32).contiguous(), + lengths_for_kernel, topk, topk_tokens, + next_n, ) return topk offsets = torch.arange(max_len, device=logits.device, dtype=torch.int64) + if row_starts_for_kernel is not None and row_ends_for_kernel is not None: + row_starts_i64 = row_starts_for_kernel.to(torch.int64) + row_ends_i64 = row_ends_for_kernel.to(torch.int64) + valid = (offsets[None, :] >= row_starts_i64[:, None]) & ( + offsets[None, :] < row_ends_i64[:, None] + ) + masked_logits = logits.masked_fill(~valid, -float("inf")) + selected = min(int(length_rows.max().item()), topk_tokens) + if selected <= 0: + return topk + values, indices = torch.topk( + masked_logits, + k=selected, + dim=-1, + sorted=bool(sort_preserved_topk), + ) + indices = indices - row_starts_i64[:, None] + indices = torch.where( + torch.isfinite(values), + indices, + torch.full_like(indices, -1), + ).to(torch.int32) + topk[:, :selected] = indices + return topk + masked_logits = logits.masked_fill( - offsets[None, :] >= lengths[:, None], -float("inf") + offsets[None, :] >= length_rows[:, None], -float("inf") ) if preserve_topk_order: - for raw_len in torch.unique(lengths).tolist(): + for raw_len in torch.unique(length_rows).tolist(): num_compressed = int(raw_len) selected = min(num_compressed, topk_tokens) if selected <= 0: continue - row_mask = lengths == num_compressed + row_mask = length_rows == num_compressed token_topk = torch.topk( masked_logits[row_mask, :num_compressed], k=selected, dim=-1, - sorted=False, + sorted=sort_preserved_topk, ).indices topk[row_mask, :selected] = token_topk.to(torch.int32) return topk @@ -961,24 +1614,22 @@ def _deepseek_v4_indexer_topk_from_logits( return topk -_DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False -_DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = False - - def _deepseek_v4_prefill_topk_op_available() -> bool: - global _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE global _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED + global _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE if _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED: return _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE try: - from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( - has_indexer_topk_prefill, - ) + import tokenspeed_kernel.thirdparty.trtllm # noqa: F401 except Exception: _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False else: - _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = bool(has_indexer_topk_prefill()) + trtllm_ops = getattr(torch.ops, "trtllm", None) + _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = trtllm_ops is not None and hasattr( + trtllm_ops, + "indexer_topk_prefill", + ) _DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = True return _DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE @@ -988,9 +1639,11 @@ def _deepseek_v4_indexer_topk_from_logits_prefill_op( length_rows: torch.Tensor, topk_tokens: int, *, + row_starts: Optional[torch.Tensor] = None, + row_ends: Optional[torch.Tensor] = None, out: torch.Tensor, ) -> Optional[torch.Tensor]: - """Use the local CUDA prefill selector when the extension is available.""" + """Use the local TRT-LLM CUDA prefill selector.""" if not logits.is_cuda or logits.dtype != torch.float32: return None @@ -1001,49 +1654,48 @@ def _deepseek_v4_indexer_topk_from_logits_prefill_op( if num_rows == 0: return out[:0] logits = logits.contiguous() - row_starts = torch.zeros(num_rows, device=logits.device, dtype=torch.int32) - row_ends = length_rows.to(device=logits.device, dtype=torch.int32).reshape(-1) + if row_starts is None: + row_starts_for_kernel = torch.zeros( + num_rows, + device=logits.device, + dtype=torch.int32, + ) + else: + row_starts_for_kernel = ( + row_starts.to( + device=logits.device, + dtype=torch.int32, + ) + .reshape(-1) + .contiguous() + ) + if row_ends is None: + row_ends_for_kernel = ( + row_starts_for_kernel + + length_rows.to(device=logits.device, dtype=torch.int32).reshape(-1) + ).contiguous() + else: + row_ends_for_kernel = ( + row_ends.to( + device=logits.device, + dtype=torch.int32, + ) + .reshape(-1) + .contiguous() + ) topk = out[:num_rows] topk.fill_(-1) - from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( - indexer_topk_prefill, - ) - - indexer_topk_prefill( + torch.ops.trtllm.indexer_topk_prefill( logits, - row_starts, - row_ends.contiguous(), + row_starts_for_kernel, + row_ends_for_kernel, topk, topk_tokens, ) return topk -def _deepseek_v4_indexer_ascending_prefill_topk( - positions: torch.Tensor, - compress_ratio: int, - topk_tokens: int, -) -> torch.Tensor: - num_tokens = positions.numel() - offsets = torch.arange(topk_tokens, device=positions.device, dtype=torch.int32) - lengths = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp(min=0, max=topk_tokens) - return torch.where( - offsets[None, :] < lengths[:, None], - offsets[None, :], - torch.full( - (num_tokens, topk_tokens), - -1, - device=positions.device, - dtype=torch.int32, - ), - ) - - def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( *, cache_2d: torch.Tensor, @@ -1068,12 +1720,15 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( device=positions.device, dtype=torch.int32, ) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - max_len = int(compressed_lens.max().item()) + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_gather_plan( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + ) + ) if max_len <= 0: return torch.full( (num_tokens, topk_tokens), @@ -1081,27 +1736,12 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( device=positions.device, dtype=torch.int32, ) - - offsets = torch.arange(max_len, device=positions.device, dtype=torch.int64) - local = offsets[None, :].expand(num_tokens, -1) - valid = local < compressed_lens[:, None] - req_idx = token_to_req_indices[:num_tokens].to(torch.int64) - pages = torch.div(local, cache_block_size, rounding_mode="floor") - page_offsets = local % cache_block_size - page_ids = block_table[req_idx[:, None], pages.long()].to(torch.int64) - slots = page_ids * cache_block_size + page_offsets - with deepseek_v4_profile_scope("indexer_topk_prefill_gather_mxfp4"): k_values, k_scales = _deepseek_v4_gather_indexer_mxfp4_cache( cache_2d, - slots[valid], + slots, cache_block_size, ) - row_lens = valid.sum(dim=1, dtype=torch.int32) - cu_end = torch.cumsum(row_lens, dim=0, dtype=torch.int32) - cu_start = torch.empty_like(cu_end) - cu_start[0] = 0 - cu_start[1:] = cu_end[:-1] try: with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): @@ -1127,97 +1767,994 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( ) -def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( +def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + *, + cache_2d: torch.Tensor, + gather_plan: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int], + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cache_block_size: int, + topk_tokens: int, + preserve_topk_order: bool, + gathered_k: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + gather_workspace: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None, gathered_k + + num_tokens = q_values.shape[0] + slots, cu_start, cu_end, row_lens, max_len = gather_plan + if num_tokens == 0: + return ( + torch.empty( + (0, topk_tokens), + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + if max_len <= 0: + return ( + torch.full( + (num_tokens, topk_tokens), + -1, + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + + if gathered_k is None: + with deepseek_v4_profile_scope("indexer_topk_prefill_gather_mxfp4"): + gathered_k = _deepseek_v4_gather_indexer_mxfp4_cache( + cache_2d, + slots, + cache_block_size, + out=gather_workspace, + ) + k_values, k_scales = gathered_k + + try: + with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_mqa_logits( + q=(q_values.contiguous().view(torch.int8), q_scales.contiguous()), + kv=(k_values.contiguous(), k_scales.contiguous()), + weights=weights.contiguous(), + cu_seq_len_k_start=cu_start, + cu_seq_len_k_end=cu_end, + clean_logits=False, + max_seqlen_k=max_len, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None, gathered_k + + with deepseek_v4_profile_scope("indexer_topk_prefill_select"): + return ( + _deepseek_v4_indexer_topk_from_logits( + logits, + row_lens, + topk_tokens, + preserve_topk_order=preserve_topk_order, + ), + gathered_k, + ) + + +def _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_contract( + *, + cache_2d: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + cu_start: torch.Tensor, + cu_end: torch.Tensor, + row_lens: torch.Tensor, + max_len: int, + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cache_block_size: int, + topk_tokens: int, + preserve_topk_order: bool, + gathered_k: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + gather_workspace: Optional[tuple[torch.Tensor, torch.Tensor]] = None, +) -> tuple[Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None, gathered_k + + num_tokens = q_values.shape[0] + if num_tokens == 0: + return ( + torch.empty( + (0, topk_tokens), + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + if max_len <= 0: + return ( + torch.full( + (num_tokens, topk_tokens), + -1, + device=q_values.device, + dtype=torch.int32, + ), + gathered_k, + ) + + if gathered_k is None: + with deepseek_v4_profile_scope("indexer_topk_prefill_gather_paged_mxfp4"): + gathered_k = _deepseek_v4_gather_paged_indexer_mxfp4_cache( + cache_2d, + block_table, + cu_seq_lens, + cache_block_size, + out=gather_workspace, + ) + k_values, k_scales = gathered_k + + try: + with deepseek_v4_profile_scope("indexer_topk_prefill_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_mqa_logits( + q=(q_values.contiguous().view(torch.int8), q_scales.contiguous()), + kv=(k_values.contiguous(), k_scales.contiguous()), + weights=weights.contiguous(), + cu_seq_len_k_start=cu_start, + cu_seq_len_k_end=cu_end, + clean_logits=False, + max_seqlen_k=max_len, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None, gathered_k + + with deepseek_v4_profile_scope("indexer_topk_prefill_select"): + return ( + _deepseek_v4_indexer_topk_from_logits( + logits, + row_lens, + topk_tokens, + preserve_topk_order=preserve_topk_order, + ), + gathered_k, + ) + + +def _deepseek_v4_indexer_decode_metadata( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + metadata: Optional[Any] = None, + is_valid_token: Optional[torch.Tensor] = None, +) -> _DeepseekV4IndexerDecodeMetadata: + num_tokens = positions.numel() + key = (int(compress_ratio), int(cache_block_size), int(num_tokens)) + cache = getattr(metadata, "decode_indexer_plan_cache", None) + refreshed_keys = getattr(metadata, "decode_indexer_plan_refreshed_keys", None) + cached = cache.get(key) if cache is not None else None + # Hot path: the attention metadata builder hook + # (_refresh_decode_indexer_plan_cache in backends/deepseek_v4.py) pre-builds + # the plan tensors at metadata setup time and adds the key to + # refreshed_keys. The metadata builder also clears refreshed_keys at the + # start of each refresh so a stale entry from a previous step cannot + # cause an early-return with capture-time data. By returning the cached plan + # here, the per-layer + # `run_indexer` call dispatched into `_deepseek_v4_maybe_execute_in_parallel` + # becomes a pure read, eliminating the cross-stream allocator race + # against `insert_and_compress` on aux_stream. + if cached is not None and refreshed_keys is not None and key in refreshed_keys: + return cached + + if num_tokens == 0: + context_lens = torch.empty((0, 1), dtype=torch.int32, device=positions.device) + block_tables = torch.empty( + (0, 1), + dtype=torch.int32, + device=block_table.device, + ) + plan = _DeepseekV4IndexerDecodeMetadata(context_lens, block_tables, 0) + if cache is not None: + cache[key] = plan + if refreshed_keys is not None: + refreshed_keys.add(key) + return plan + + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + max_len = _deepseek_v4_indexer_decode_max_len( + block_table, + cache_block_size, + compress_ratio, + ) + max_blocks = max(1, (max_len + cache_block_size - 1) // cache_block_size) + + expected_context_shape = (num_tokens, 1) + expected_block_shape = (num_tokens, max_blocks) + if ( + cached is None + or cached.context_lens.shape != expected_context_shape + or cached.context_lens.device != positions.device + or cached.context_lens.dtype != torch.int32 + or cached.block_table.shape != expected_block_shape + or cached.block_table.device != block_table.device + or cached.block_table.dtype != torch.int32 + ): + context_lens = torch.empty( + expected_context_shape, + dtype=torch.int32, + device=positions.device, + ) + block_tables = torch.empty( + expected_block_shape, + dtype=torch.int32, + device=block_table.device, + ) + plan = _DeepseekV4IndexerDecodeMetadata( + context_lens=context_lens, + block_table=block_tables, + max_context_len=max_len, + ) + if cache is not None: + cache[key] = plan + else: + plan = cached + plan.max_context_len = max_len + + if rows <= 0 or cols <= 0: + plan.context_lens.zero_() + plan.block_table.zero_() + plan.max_context_len = 0 + else: + deepseek_v4_indexer_decode_metadata_compute( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + max_blocks=max_blocks, + out_context_lens=plan.context_lens, + out_block_tables=plan.block_table, + ) + if is_valid_token is None: + is_valid_token = getattr(metadata, "is_valid_token", None) + if is_valid_token is not None: + valid = is_valid_token[:num_tokens].to( + device=plan.context_lens.device, + dtype=torch.bool, + ) + with torch.inference_mode(): + plan.context_lens.masked_fill_(~valid.view(num_tokens, 1), 0) + plan.block_table.masked_fill_( + ~valid.to(device=plan.block_table.device).view(num_tokens, 1), + 0, + ) + if refreshed_keys is not None: + refreshed_keys.add(key) + return plan + + +def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + *, + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + index_q: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + compress_ratio: int, + topk_tokens: int, + metadata: Optional[Any] = None, + schedule_metadata: Optional[torch.Tensor] = None, + decode_context_lens: Optional[torch.Tensor] = None, + decode_block_table: Optional[torch.Tensor] = None, + decode_max_context_len: Optional[int] = None, + is_valid_token: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + q_values, q_scales = index_q + if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): + return None + + num_tokens = positions.numel() + if num_tokens == 0: + if out is not None: + return out[:0] + return torch.empty((0, topk_tokens), device=positions.device, dtype=torch.int32) + if decode_context_lens is not None and decode_block_table is not None: + context_lens = decode_context_lens + block_tables = decode_block_table + max_len = ( + int(decode_max_context_len) + if decode_max_context_len is not None + else int(context_lens.max().item()) + ) + else: + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + metadata=metadata, + is_valid_token=is_valid_token, + ) + context_lens = decode_plan.context_lens + block_tables = decode_plan.block_table + max_len = decode_plan.max_context_len + topk = ( + torch.empty( + (num_tokens, topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + if out is None + else out[:num_tokens] + ) + if max_len <= 0: + topk.fill_(-1) + return topk + kv_cache = _deepseek_v4_indexer_mxfp4_cache_view(cache_2d, cache_block_size) + schedule_key = (compress_ratio, cache_block_size, num_tokens) + schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) + if schedule_metadata is None: + schedule_metadata = ( + schedule_cache.get(schedule_key) if schedule_cache is not None else None + ) + if schedule_metadata is None: + with deepseek_v4_profile_scope("indexer_decode_schedule_metadata"): + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if schedule_cache is not None: + schedule_cache[schedule_key] = schedule_metadata + + try: + with deepseek_v4_profile_scope("indexer_decode_deepgemm_logits"): + logits = deep_gemm.fp8_fp4_paged_mqa_logits( + q=( + q_values.contiguous().view(torch.int8).unsqueeze(1), + q_scales.contiguous().unsqueeze(1), + ), + kv_cache=kv_cache, + weights=weights.contiguous(), + context_lens=context_lens, + block_table=block_tables, + schedule_meta=schedule_metadata, + max_context_len=max_len, + clean_logits=False, + logits_dtype=torch.float32, + ) + except RuntimeError: + return None + + with deepseek_v4_profile_scope("indexer_decode_topk"): + return _deepseek_v4_indexer_topk_from_logits( + logits, + context_lens, + topk_tokens, + next_n=1, + out=out, + ) + + +def _deepseek_v4_indexer_decode_schedule_metadata( + *, + positions: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + metadata: Optional[Any], + context_lens: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + if positions.numel() == 0: + return None + if getattr(deep_gemm, "get_paged_mqa_logits_metadata", None) is None: + return None + + num_tokens = positions.numel() + if context_lens is None: + compressed_lens = torch.div( + positions.to(torch.int64) + 1, + compress_ratio, + rounding_mode="floor", + ).clamp_min(0) + context_lens = compressed_lens.to(torch.int32).view(num_tokens, 1).contiguous() + schedule_key = (compress_ratio, cache_block_size, num_tokens) + schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) + schedule_metadata = ( + schedule_cache.get(schedule_key) if schedule_cache is not None else None + ) + + with deepseek_v4_profile_scope("indexer_decode_schedule_metadata"): + refreshed = deep_gemm.get_paged_mqa_logits_metadata( + context_lens, + cache_block_size, + deep_gemm.get_num_sms(), + ) + if schedule_metadata is not None: + if ( + schedule_metadata.shape == refreshed.shape + and schedule_metadata.device == refreshed.device + and schedule_metadata.dtype == refreshed.dtype + ): + with torch.inference_mode(): + schedule_metadata.copy_(refreshed) + return schedule_metadata + if schedule_cache is not None: + schedule_cache[schedule_key] = refreshed + return refreshed + schedule_metadata = refreshed + if schedule_cache is not None: + schedule_cache[schedule_key] = schedule_metadata + return schedule_metadata + + +def _deepseek_v4_sparse_attn_indexer_native( + *, + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: Optional[torch.Tensor], + decode_context_lens: Optional[torch.Tensor], + decode_block_table: Optional[torch.Tensor], + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + total_tokens = positions.numel() + topk_out = topk_indices_buffer[:total_tokens] + topk_out.fill_(-1) + if total_tokens == 0: + return topk_out + + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + + def fill_prefill() -> None: + if num_prefill_tokens <= 0: + return + + prefill_positions = positions[:num_prefill_tokens] + if prefill_chunk_bounds.numel() > 0: + gather_cache_key = None + gathered_k = None + num_chunks = prefill_chunk_bounds.shape[0] + for chunk_idx in range(num_chunks): + bounds = prefill_chunk_bounds[chunk_idx] + plan = prefill_chunk_plan[chunk_idx] + token_start = int(bounds[0].item()) + token_end = int(bounds[1].item()) + req_start = int(bounds[2].item()) + req_end = int(bounds[3].item()) + skip_kv_gather = bool(int(bounds[6].item())) + slot_start = int(plan[0].item()) + slot_end = int(plan[1].item()) + row_start = int(plan[2].item()) + row_end = int(plan[3].item()) + max_len = int(plan[4].item()) + cu_seq_start = int(plan[5].item()) if plan.numel() > 5 else 0 + cu_seq_end = int(plan[6].item()) if plan.numel() > 6 else 0 + gather_rows = max(0, slot_end - slot_start) + gather_plan = ( + prefill_slots[slot_start:slot_end], + prefill_cu_start[row_start:row_end], + prefill_cu_end[row_start:row_end], + prefill_row_lens[row_start:row_end], + max_len, + ) + gather_workspace = None + if ( + prefill_gather_values_workspace.numel() > 0 + and prefill_gather_scales_workspace.numel() > 0 + and gather_rows <= prefill_gather_values_workspace.shape[0] + and gather_rows <= prefill_gather_scales_workspace.shape[0] + ): + gather_workspace = ( + prefill_gather_values_workspace[:gather_rows], + prefill_gather_scales_workspace[:gather_rows], + ) + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + key = (req_start, req_end) + reuse_k = ( + gathered_k + if skip_kv_gather and gather_cache_key == key + else None + ) + if ( + prefill_cu_seq_lens.numel() > 0 + and cu_seq_end > cu_seq_start + ): + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_contract( + cache_2d=cache_2d, + block_table=block_table[req_start:req_end], + cu_seq_lens=prefill_cu_seq_lens[ + cu_seq_start:cu_seq_end + ], + cu_start=prefill_cu_start[row_start:row_end], + cu_end=prefill_cu_end[row_start:row_end], + row_lens=prefill_row_lens[row_start:row_end], + max_len=max_len, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[token_start:token_end], + packed_q_scales[token_start:token_end], + ), + weights=packed_weights[token_start:token_end], + topk_tokens=topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + gather_workspace=gather_workspace, + ) + ) + else: + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=cache_2d, + gather_plan=gather_plan, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[token_start:token_end], + packed_q_scales[token_start:token_end], + ), + weights=packed_weights[token_start:token_end], + topk_tokens=topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + gather_workspace=gather_workspace, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None and fallback_index_q.numel() > 0: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=prefill_positions[token_start:token_end], + token_to_req_indices=token_to_req_indices[ + token_start:token_end + ], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[token_start:token_end], + weights=fallback_weights[token_start:token_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer prefill DeepGEMM path failed " + "without a prepared fallback." + ) + if topk is not None: + topk_out[token_start:token_end].copy_(topk) + return + + topk_chunks = [] + for start, end in _deepseek_v4_indexer_prefill_topk_chunks( + prefill_positions, + compress_ratio, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + ): + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( + cache_2d=cache_2d, + positions=prefill_positions[start:end], + token_to_req_indices=token_to_req_indices[start:end], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[start:end], + packed_q_scales[start:end], + ), + weights=packed_weights[start:end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None and fallback_index_q.numel() > 0: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=prefill_positions[start:end], + token_to_req_indices=token_to_req_indices[start:end], + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[start:end], + weights=fallback_weights[start:end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + preserve_topk_order=True, + ) + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer prefill DeepGEMM path failed " + "without a prepared fallback." + ) + if topk is not None: + topk_chunks.append(topk) + if topk_chunks: + with deepseek_v4_profile_scope("indexer_topk_cat_prefill"): + topk_out[:num_prefill_tokens].copy_(torch.cat(topk_chunks, dim=0)) + + def fill_decode() -> None: + if num_decode_tokens <= 0: + return + + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_token_to_req = token_to_req_indices[decode_start:decode_end] + decode_out = topk_out[decode_start:decode_end] + topk = None + if has_packed_q: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_decode"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + cache_2d=cache_2d, + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=block_table, + cache_block_size=cache_block_size, + index_q=( + packed_q_values[decode_start:decode_end], + packed_q_scales[decode_start:decode_end], + ), + weights=packed_weights[decode_start:decode_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + out=decode_out, + ) + if topk is None and fallback_index_q.shape[0] >= decode_end: + with deepseek_v4_profile_scope("indexer_topk_fallback_decode"): + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=cache_2d, + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=block_table, + cache_block_size=cache_block_size, + index_q=fallback_index_q[decode_start:decode_end], + weights=fallback_weights[decode_start:decode_end], + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + out=decode_out, + ) + topk = decode_out + if topk is None: + raise RuntimeError( + "DeepSeek V4 sparse indexer decode DeepGEMM path failed " + "without a prepared fallback." + ) + + fill_prefill() + fill_decode() + return topk_out + + +def _deepseek_v4_sparse_attn_indexer_op( + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: torch.Tensor, + decode_context_lens: torch.Tensor, + decode_block_table: torch.Tensor, + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + schedule_metadata = ( + decode_schedule_metadata if decode_schedule_metadata.numel() > 0 else None + ) + context_lens = decode_context_lens if decode_context_lens.numel() > 0 else None + decode_blocks = decode_block_table if decode_block_table.numel() > 0 else None + return _deepseek_v4_sparse_attn_indexer_native( + cache_2d=cache_2d, + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_chunk_bounds, + prefill_chunk_plan=prefill_chunk_plan, + prefill_slots=prefill_slots, + prefill_cu_seq_lens=prefill_cu_seq_lens, + prefill_cu_start=prefill_cu_start, + prefill_cu_end=prefill_cu_end, + prefill_row_lens=prefill_row_lens, + packed_q_values=packed_q_values, + packed_q_scales=packed_q_scales, + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=schedule_metadata, + decode_context_lens=context_lens, + decode_block_table=decode_blocks, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_indices_buffer, + prefill_gather_values_workspace=prefill_gather_values_workspace, + prefill_gather_scales_workspace=prefill_gather_scales_workspace, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=use_fp4_cache, + has_packed_q=has_packed_q, + ) + + +def _deepseek_v4_sparse_attn_indexer_fake( + cache_2d: torch.Tensor, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: torch.Tensor, + decode_context_lens: torch.Tensor, + decode_block_table: torch.Tensor, + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + topk_tokens: int, + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + del ( + cache_2d, + positions, + token_to_req_indices, + block_table, + seq_lens_cpu, + query_lens_cpu, + prefill_chunk_bounds, + prefill_chunk_plan, + prefill_slots, + prefill_cu_seq_lens, + prefill_cu_start, + prefill_cu_end, + prefill_row_lens, + packed_q_values, + packed_q_scales, + packed_weights, + fallback_index_q, + fallback_weights, + decode_schedule_metadata, + decode_context_lens, + decode_block_table, + decode_max_context_len, + cache_block_size, + prefill_gather_values_workspace, + prefill_gather_scales_workspace, + compress_ratio, + topk_tokens, + num_prefill_tokens, + num_decode_tokens, + use_fp4_cache, + has_packed_q, + ) + return topk_indices_buffer + + +direct_register_custom_op( + op_name="deepseek_v4_sparse_attn_indexer", + op_func=_deepseek_v4_sparse_attn_indexer_op, + mutates_args=[ + "topk_indices_buffer", + "prefill_gather_values_workspace", + "prefill_gather_scales_workspace", + ], + fake_impl=_deepseek_v4_sparse_attn_indexer_fake, +) + + +def _deepseek_v4_sparse_attn_indexer( *, cache_2d: torch.Tensor, positions: torch.Tensor, token_to_req_indices: torch.Tensor, block_table: torch.Tensor, + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + prefill_chunk_bounds: torch.Tensor, + prefill_chunk_plan: torch.Tensor, + prefill_slots: torch.Tensor, + prefill_cu_seq_lens: torch.Tensor, + prefill_cu_start: torch.Tensor, + prefill_cu_end: torch.Tensor, + prefill_row_lens: torch.Tensor, + packed_q_values: torch.Tensor, + packed_q_scales: torch.Tensor, + packed_weights: torch.Tensor, + fallback_index_q: torch.Tensor, + fallback_weights: torch.Tensor, + decode_schedule_metadata: Optional[torch.Tensor], + decode_context_lens: Optional[torch.Tensor], + decode_block_table: Optional[torch.Tensor], + decode_max_context_len: int, + topk_indices_buffer: torch.Tensor, + prefill_gather_values_workspace: torch.Tensor, + prefill_gather_scales_workspace: torch.Tensor, cache_block_size: int, - index_q: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, compress_ratio: int, topk_tokens: int, - metadata: Any | None = None, - out: torch.Tensor | None = None, -) -> torch.Tensor | None: - q_values, q_scales = index_q - if not _deepseek_v4_deepgemm_fp4_indexer_available(q_values): - return None - - num_tokens = positions.numel() - if num_tokens == 0: - if out is not None: - return out[:0] - return torch.empty((0, topk_tokens), device=positions.device, dtype=torch.int32) - compressed_lens = torch.div( - positions.to(torch.int64) + 1, - compress_ratio, - rounding_mode="floor", - ).clamp_min(0) - if positions.is_cuda and torch.cuda.is_current_stream_capturing(): - max_len = _deepseek_v4_indexer_decode_max_len( + num_prefill_tokens: int, + num_decode_tokens: int, + use_fp4_cache: bool, + has_packed_q: bool, +) -> torch.Tensor: + if decode_schedule_metadata is None: + decode_schedule_metadata = torch.empty( + 0, + dtype=torch.int32, + device=positions.device, + ) + if decode_context_lens is None: + decode_context_lens = torch.empty( + (0, 1), + dtype=torch.int32, + device=positions.device, + ) + if decode_block_table is None: + decode_block_table = torch.empty( + (0, 1), + dtype=block_table.dtype, + device=block_table.device, + ) + if positions.is_cuda: + return torch.ops.tokenspeed.deepseek_v4_sparse_attn_indexer( + cache_2d, + positions, + token_to_req_indices, block_table, + seq_lens_cpu, + query_lens_cpu, + prefill_chunk_bounds, + prefill_chunk_plan, + prefill_slots, + prefill_cu_seq_lens, + prefill_cu_start, + prefill_cu_end, + prefill_row_lens, + packed_q_values, + packed_q_scales, + packed_weights, + fallback_index_q, + fallback_weights, + decode_schedule_metadata, + decode_context_lens, + decode_block_table, + decode_max_context_len, + topk_indices_buffer, + prefill_gather_values_workspace, + prefill_gather_scales_workspace, cache_block_size, compress_ratio, + topk_tokens, + num_prefill_tokens, + num_decode_tokens, + use_fp4_cache, + has_packed_q, ) - else: - max_len = int(compressed_lens.max().item()) - if max_len <= 0: - topk = ( - torch.empty( - (num_tokens, topk_tokens), - device=positions.device, - dtype=torch.int32, - ) - if out is None - else out[:num_tokens] - ) - topk.fill_(-1) - return topk - - max_blocks = max(1, (max_len + cache_block_size - 1) // cache_block_size) - req_idx = token_to_req_indices[:num_tokens].to(torch.int64) - block_tables = block_table[req_idx, :max_blocks].contiguous() - context_lens = compressed_lens.to(torch.int32).view(num_tokens, 1).contiguous() - kv_cache = _deepseek_v4_indexer_mxfp4_cache_view(cache_2d, cache_block_size) - schedule_key = (compress_ratio, cache_block_size, num_tokens) - schedule_cache = getattr(metadata, "decode_indexer_schedule_metadata", None) - schedule_metadata = ( - schedule_cache.get(schedule_key) if schedule_cache is not None else None - ) - if schedule_metadata is None: - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( - context_lens, - cache_block_size, - deep_gemm.get_num_sms(), - ) - if schedule_cache is not None: - schedule_cache[schedule_key] = schedule_metadata - - try: - logits = deep_gemm.fp8_fp4_paged_mqa_logits( - q=( - q_values.contiguous().view(torch.int8).unsqueeze(1), - q_scales.contiguous().unsqueeze(1), - ), - kv_cache=kv_cache, - weights=weights.contiguous(), - context_lens=context_lens, - block_table=block_tables, - schedule_meta=schedule_metadata, - max_context_len=max_len, - clean_logits=False, - logits_dtype=torch.float32, - ) - except RuntimeError: - return None - - return _deepseek_v4_indexer_topk_from_logits( - logits, - compressed_lens.to(torch.int32), - topk_tokens, - out=out, + return _deepseek_v4_sparse_attn_indexer_native( + cache_2d=cache_2d, + positions=positions, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_chunk_bounds, + prefill_chunk_plan=prefill_chunk_plan, + prefill_slots=prefill_slots, + prefill_cu_seq_lens=prefill_cu_seq_lens, + prefill_cu_start=prefill_cu_start, + prefill_cu_end=prefill_cu_end, + prefill_row_lens=prefill_row_lens, + packed_q_values=packed_q_values, + packed_q_scales=packed_q_scales, + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_indices_buffer, + prefill_gather_values_workspace=prefill_gather_values_workspace, + prefill_gather_scales_workspace=prefill_gather_scales_workspace, + cache_block_size=cache_block_size, + compress_ratio=compress_ratio, + topk_tokens=topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=use_fp4_cache, + has_packed_q=has_packed_q, ) @@ -1225,6 +2762,10 @@ def _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( DEEPSEEK_V4_FP8_BLOCK_SIZE = 128 _DEEPSEEK_V4_INDEXER_PREFILL_MAX_LOGITS_MB = 512 _DEEPSEEK_V4_FUSED_ROUTER_AVAILABLE = True +_DEEPSEEK_V4_PREFILL_TOPK_OP_CHECKED = False +_DEEPSEEK_V4_PREFILL_TOPK_OP_AVAILABLE = False +_DEEPSEEK_V4_PAGED_GATHER_CHECKED = False +_DEEPSEEK_V4_PAGED_GATHER_AVAILABLE = False def _deepseek_v4_maybe_execute_in_parallel( @@ -1445,166 +2986,6 @@ def get(self, num_tokens: int, device: torch.device) -> torch.Tensor: return self.buffer[:num_tokens] -@triton.jit -def _deepseek_v4_stage_mega_moe_inputs_kernel( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_stride_m: tl.constexpr, - hidden_stride_k: tl.constexpr, - x_stride_m: tl.constexpr, - x_stride_k: tl.constexpr, - x_sf_stride_m: tl.constexpr, - x_sf_stride_k: tl.constexpr, - topk_ids_stride_m: tl.constexpr, - topk_ids_stride_k: tl.constexpr, - topk_weights_stride_m: tl.constexpr, - topk_weights_stride_k: tl.constexpr, - topk_idx_stride_m: tl.constexpr, - topk_idx_stride_k: tl.constexpr, - topk_weights_out_stride_m: tl.constexpr, - topk_weights_out_stride_k: tl.constexpr, - hidden_size: tl.constexpr, - top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, - BLOCK_TOPK: tl.constexpr, -) -> None: - token_id = tl.program_id(0) - k_block_id = tl.program_id(1) - - k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - k_mask = k_offsets < hidden_size - hidden = tl.load( - hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, - mask=k_mask, - other=0.0, - ).to(tl.float32) - - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) - amax = tl.maximum(amax, 1.0e-4) - - scale = amax / 448.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) - - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) - tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, - ) - - scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) - tl.store( - x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, - packed_scale, - ) - - if k_block_id == 0: - topk_offsets = tl.arange(0, BLOCK_TOPK) - topk_mask = topk_offsets < top_k - - ids = tl.load( - topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, - mask=topk_mask, - other=0, - ).to(tl.int64) - tl.store( - topk_idx_out - + token_id * topk_idx_stride_m - + topk_offsets * topk_idx_stride_k, - ids, - mask=topk_mask, - ) - - weights = tl.load( - topk_weights - + token_id * topk_weights_stride_m - + topk_offsets * topk_weights_stride_k, - mask=topk_mask, - other=0.0, - ) - tl.store( - topk_weights_out - + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, - weights, - mask=topk_mask, - ) - - -def _stage_deepseek_v4_mega_moe_inputs( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, - topk_idx_out: torch.Tensor, - topk_weights_out: torch.Tensor, -) -> None: - num_tokens, hidden_size = hidden_states.shape - if num_tokens == 0: - return - if hidden_size % DEEPSEEK_V4_FP8_BLOCK_SIZE != 0: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires hidden_size to be " - f"a multiple of {DEEPSEEK_V4_FP8_BLOCK_SIZE}." - ) - if topk_weights.shape != topk_ids.shape: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires topk_weights and " - "topk_ids to have the same shape." - ) - - block_k = DEEPSEEK_V4_FP8_BLOCK_SIZE - grid = (num_tokens, triton.cdiv(hidden_size, block_k)) - block_topk = triton.next_power_of_2(topk_ids.shape[1]) - _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_states.stride(0), - hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), - x_sf.stride(0), - x_sf.stride(1), - topk_ids.stride(0), - topk_ids.stride(1), - topk_weights.stride(0), - topk_weights.stride(1), - topk_idx_out.stride(0), - topk_idx_out.stride(1), - topk_weights_out.stride(0), - topk_weights_out.stride(1), - hidden_size, - topk_ids.shape[1], - BLOCK_K=block_k, - GROUP_K=32, - BLOCK_TOPK=block_topk, - num_warps=4, - ) - - DEEPSEEK_V4_MXFP4_BLOCK_SIZE = 32 @@ -2530,6 +3911,277 @@ def __init__( self.topk_tokens = int(config.index_topk) self.topk_buffer = topk_buffer self.softmax_scale = self.head_dim**-0.5 + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + self.register_buffer( + "_prefill_gather_values_workspace", + torch.empty((0, value_bytes), dtype=torch.uint8), + persistent=False, + ) + self.register_buffer( + "_prefill_gather_scales_workspace", + torch.empty((0, scale_bytes), dtype=torch.uint8), + persistent=False, + ) + + def _prefill_gather_workspace( + self, + rows: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + rows = max(0, int(rows)) + value_bytes = DEEPSEEK_V4_INDEXER_DIM // 2 + scale_bytes = DEEPSEEK_V4_INDEXER_DIM // DEEPSEEK_V4_MXFP4_BLOCK_SIZE + if ( + self._prefill_gather_values_workspace.device != device + or self._prefill_gather_values_workspace.shape[0] < rows + ): + self._prefill_gather_values_workspace = torch.empty( + (rows, value_bytes), + dtype=torch.uint8, + device=device, + ) + if ( + self._prefill_gather_scales_workspace.device != device + or self._prefill_gather_scales_workspace.shape[0] < rows + ): + self._prefill_gather_scales_workspace = torch.empty( + (rows, scale_bytes), + dtype=torch.uint8, + device=device, + ) + return ( + self._prefill_gather_values_workspace[:rows], + self._prefill_gather_scales_workspace[:rows], + ) + + def prepare_decode_metadata( + self, + *, + positions: torch.Tensor, + metadata: Any, + indexer_block_size: int, + ) -> None: + if not self.use_fp4_cache or not positions.is_cuda: + return + forward_mode = metadata.forward_mode + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_tokens = int(metadata.num_prefill_tokens) + num_decode_tokens = metadata.decode_token_count() + elif forward_mode is not None and forward_mode.is_decode(): + num_prefill_tokens = 0 + num_decode_tokens = positions.numel() + else: + return + if num_decode_tokens <= 0: + return + + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + indexer_block_table = metadata.compressed_block_table( + self.compress_ratio, + indexer_block_size, + ) + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=decode_positions, + token_to_req_indices=metadata.token_to_req_indices[decode_start:decode_end], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + is_valid_token=decode_valid_token, + ) + _deepseek_v4_indexer_decode_schedule_metadata( + positions=decode_positions, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + context_lens=decode_plan.context_lens, + ) + + def _forward_sparse_indexer_custom_op( + self, + *, + hidden_states: torch.Tensor, + qr: torch.Tensor, + positions: torch.Tensor, + metadata: Any, + indexer_cache: torch.Tensor, + indexer_block_size: int, + cos_sin_cache: torch.Tensor, + ) -> Optional[torch.Tensor]: + if not self.use_fp4_cache or not positions.is_cuda: + return None + + forward_mode = metadata.forward_mode + total_tokens = positions.numel() + if total_tokens == 0: + return torch.empty( + (0, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + if forward_mode is not None and forward_mode.is_mixed(): + num_prefill_tokens = int(metadata.num_prefill_tokens) + num_decode_tokens = metadata.decode_token_count() + elif forward_mode is not None and forward_mode.is_decode(): + num_prefill_tokens = 0 + num_decode_tokens = total_tokens + else: + num_prefill_tokens = total_tokens + num_decode_tokens = 0 + + with deepseek_v4_profile_scope("indexer_wq_b"): + index_q, _ = self.wq_b(qr) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj"): + weights, _ = self.weights_proj(hidden_states) + with deepseek_v4_profile_scope("indexer_prepare_mxfp4"): + packed_index_q, packed_weights = deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + + packed_indexer_available = _deepseek_v4_deepgemm_fp4_indexer_available( + packed_index_q[0] + ) + fallback_index_q = index_q.new_empty((0, self.n_head, self.head_dim)) + fallback_weights = weights.new_empty((0, self.n_head)) + if not packed_indexer_available: + with deepseek_v4_profile_scope("indexer_prepare_reference_fallback"): + fallback_index_q, fallback_weights = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + + empty_cpu = torch.empty(0, dtype=torch.int32, device="cpu") + seq_lens_cpu = ( + metadata.seq_lens_cpu[: metadata.num_prefill_reqs] + if metadata.seq_lens_cpu is not None and num_prefill_tokens > 0 + else empty_cpu + ) + query_lens_cpu = ( + metadata.query_lens_cpu[: metadata.num_prefill_reqs] + if metadata.query_lens_cpu is not None and num_prefill_tokens > 0 + else empty_cpu + ) + indexer_block_table = metadata.compressed_block_table( + self.compress_ratio, + indexer_block_size, + ) + prefill_metadata = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + num_prefill_tokens=num_prefill_tokens, + ) + max_prefill_gather_rows = 0 + if prefill_metadata.chunk_plan.numel() > 0: + slot_counts = ( + prefill_metadata.chunk_plan[:, 1] - prefill_metadata.chunk_plan[:, 0] + ) + max_prefill_gather_rows = int(slot_counts.max().item()) + prefill_gather_values, prefill_gather_scales = self._prefill_gather_workspace( + max_prefill_gather_rows, + positions.device, + ) + + decode_schedule_metadata = None + decode_context_lens = None + decode_block_table = None + decode_max_context_len = 0 + if num_decode_tokens > 0: + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + decode_plan = _deepseek_v4_indexer_decode_metadata( + positions=decode_positions, + token_to_req_indices=metadata.token_to_req_indices[ + decode_start:decode_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + is_valid_token=decode_valid_token, + ) + decode_context_lens = decode_plan.context_lens + decode_block_table = decode_plan.block_table + decode_max_context_len = decode_plan.max_context_len + decode_schedule_metadata = _deepseek_v4_indexer_decode_schedule_metadata( + positions=decode_positions, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + metadata=metadata, + context_lens=decode_context_lens, + ) + + topk_out = ( + self.topk_buffer.get(total_tokens, positions.device) + if self.topk_buffer is not None + else torch.empty( + (total_tokens, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[:total_tokens] + return _deepseek_v4_sparse_attn_indexer( + cache_2d=indexer_cache, + positions=positions, + token_to_req_indices=metadata.token_to_req_indices[:total_tokens], + block_table=indexer_block_table, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + prefill_chunk_bounds=prefill_metadata.chunk_bounds, + prefill_chunk_plan=prefill_metadata.chunk_plan, + prefill_slots=prefill_metadata.slots, + prefill_cu_seq_lens=prefill_metadata.cu_seq_lens, + prefill_cu_start=prefill_metadata.cu_start, + prefill_cu_end=prefill_metadata.cu_end, + prefill_row_lens=prefill_metadata.row_lens, + packed_q_values=packed_index_q[0], + packed_q_scales=packed_index_q[1], + packed_weights=packed_weights, + fallback_index_q=fallback_index_q, + fallback_weights=fallback_weights, + decode_schedule_metadata=decode_schedule_metadata, + decode_context_lens=decode_context_lens, + decode_block_table=decode_block_table, + decode_max_context_len=decode_max_context_len, + topk_indices_buffer=topk_out, + prefill_gather_values_workspace=prefill_gather_values, + prefill_gather_scales_workspace=prefill_gather_scales, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_fp4_cache=self.use_fp4_cache, + has_packed_q=packed_indexer_available, + ) def forward( self, @@ -2547,6 +4199,9 @@ def forward( raise RuntimeError("DeepSeek V4 indexer requires forward metadata") indexer_state = pool.get_indexer_state_buffer(layer_index) indexer_state_block_table = metadata.indexer_state_block_table + indexer_state_base_logical_page = getattr( + metadata, "indexer_state_base_logical_page", None + ) if indexer_state_block_table is not None: indexer_state_block_size = pool.get_indexer_state_block_size(layer_index) indexer_state_slot_mapping = _group_slot_mapping_from_raw( @@ -2554,12 +4209,13 @@ def forward( metadata.token_to_req_indices[: positions.numel()], indexer_state_block_table, indexer_state_block_size, - base_offsets=metadata.indexer_state_base_logical_page, + base_offsets=indexer_state_base_logical_page, ) else: indexer_state_block_table = metadata.block_table indexer_state_block_size = pool.state_block_size indexer_state_slot_mapping = out_cache_loc + indexer_state_base_logical_page = None with deepseek_v4_profile_scope("indexer_compressor_total"): self.compressor( hidden_states=hidden_states, @@ -2571,7 +4227,7 @@ def forward( state_cache=indexer_state, state_block_table=indexer_state_block_table, state_block_size=indexer_state_block_size, - state_base_logical_page=metadata.indexer_state_base_logical_page, + state_base_logical_page=indexer_state_base_logical_page, write_compressed_cache=False, ) with deepseek_v4_profile_scope("indexer_compressed_slot_mapping"): @@ -2592,7 +4248,7 @@ def forward( positions=positions, compressor_slot_mapping=indexer_state_slot_mapping, block_table=indexer_state_block_table, - block_table_base_offsets=metadata.indexer_state_base_logical_page, + block_table_base_offsets=indexer_state_base_logical_page, compressor_block_size=indexer_state_block_size, rms_norm_weight=self.compressor.norm.weight, rms_norm_eps=self.compressor.norm.variance_epsilon, @@ -2603,6 +4259,321 @@ def forward( use_fp4_cache=self.use_fp4_cache, compress_ratio=self.compress_ratio, ) + custom_topk = self._forward_sparse_indexer_custom_op( + hidden_states=hidden_states, + qr=qr, + positions=positions, + metadata=metadata, + indexer_cache=pool.get_indexer_kv_buffer_2d(layer_index), + indexer_block_size=indexer_block_size, + cos_sin_cache=cos_sin_cache, + ) + if custom_topk is not None: + return custom_topk + + if ctx.forward_mode is not None and ctx.forward_mode.is_mixed(): + num_prefill_tokens = metadata.num_prefill_tokens + num_decode_tokens = metadata.decode_token_count() + total_tokens = positions.numel() + topk_out = ( + self.topk_buffer.get(total_tokens, positions.device) + if self.topk_buffer is not None + else torch.empty( + (total_tokens, self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[:total_tokens] + topk_out.fill_(-1) + + def fill_prefill_topk() -> None: + if num_prefill_tokens <= 0: + return + prefill_positions = positions[:num_prefill_tokens] + + with deepseek_v4_profile_scope("indexer_wq_b_prefill"): + index_q, _ = self.wq_b(qr[:num_prefill_tokens]) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj_prefill"): + weights, _ = self.weights_proj(hidden_states[:num_prefill_tokens]) + + packed_index_q = None + packed_weights = None + if self.use_fp4_cache: + with deepseek_v4_profile_scope("indexer_prepare_mxfp4_prefill"): + packed_index_q, packed_weights = ( + deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=prefill_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + ) + + with deepseek_v4_profile_scope("indexer_prepare_reference_prefill"): + index_q_fallback, weights_fallback = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=prefill_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if self.use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + indexer_cache = pool.get_indexer_kv_buffer_2d(layer_index) + seq_lens_cpu = ( + metadata.seq_lens_cpu[: metadata.num_prefill_reqs] + if metadata.seq_lens_cpu is not None + else None + ) + query_lens_cpu = ( + metadata.query_lens_cpu[: metadata.num_prefill_reqs] + if metadata.query_lens_cpu is not None + else None + ) + request_chunks = ( + _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + compress_ratio=self.compress_ratio, + num_tokens=num_prefill_tokens, + ) + if seq_lens_cpu is not None and query_lens_cpu is not None + else [] + ) + if request_chunks: + gather_cache_key = None + gathered_k = None + for chunk in request_chunks: + topk = None + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope( + "indexer_topk_deepgemm_prefill" + ): + gather_plan = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + ) + ) + key = (chunk.req_start, chunk.req_end) + reuse_k = ( + gathered_k + if chunk.skip_kv_gather and gather_cache_key == key + else None + ) + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=indexer_cache, + gather_plan=gather_plan, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][ + chunk.token_start : chunk.token_end + ], + packed_index_q[1][ + chunk.token_start : chunk.token_end + ], + ), + weights=packed_weights[ + chunk.token_start : chunk.token_end + ], + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None: + with deepseek_v4_profile_scope( + "indexer_topk_fallback_prefill" + ): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=prefill_positions[ + chunk.token_start : chunk.token_end + ], + token_to_req_indices=metadata.token_to_req_indices[ + chunk.token_start : chunk.token_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[ + chunk.token_start : chunk.token_end + ], + weights=weights_fallback[ + chunk.token_start : chunk.token_end + ], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + topk_out[chunk.token_start : chunk.token_end].copy_(topk) + return + + topk_chunks = [] + for start, end in _deepseek_v4_indexer_prefill_topk_chunks( + prefill_positions, + self.compress_ratio, + seq_lens_cpu=seq_lens_cpu, + query_lens_cpu=query_lens_cpu, + ): + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + topk = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill( + cache_2d=indexer_cache, + positions=prefill_positions[start:end], + token_to_req_indices=metadata.token_to_req_indices[ + start:end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][start:end], + packed_index_q[1][start:end], + ), + weights=packed_weights[start:end], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + ) + if topk is not None: + topk_chunks.append(topk) + continue + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk_chunks.append( + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=prefill_positions[start:end], + token_to_req_indices=metadata.token_to_req_indices[ + start:end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[start:end], + weights=weights_fallback[start:end], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + ) + if topk_chunks: + with deepseek_v4_profile_scope("indexer_topk_cat_prefill"): + topk_out[:num_prefill_tokens].copy_( + torch.cat(topk_chunks, dim=0) + ) + + def fill_decode_topk() -> None: + if num_decode_tokens <= 0: + return + decode_start = num_prefill_tokens + decode_end = decode_start + num_decode_tokens + decode_positions = positions[decode_start:decode_end] + decode_token_to_req = metadata.token_to_req_indices[ + decode_start:decode_end + ] + decode_valid_token = ( + metadata.is_valid_token[decode_start:decode_end] + if getattr(metadata, "is_valid_token", None) is not None + else None + ) + decode_out = topk_out[decode_start:decode_end] + with deepseek_v4_profile_scope("indexer_wq_b_decode"): + index_q, _ = self.wq_b(qr[decode_start:decode_end]) + index_q = index_q.view(-1, self.n_head, self.head_dim) + with deepseek_v4_profile_scope("indexer_weights_proj_decode"): + weights, _ = self.weights_proj( + hidden_states[decode_start:decode_end] + ) + + packed_index_q = None + packed_weights = None + if self.use_fp4_cache: + with deepseek_v4_profile_scope("indexer_prepare_mxfp4_decode"): + packed_index_q, packed_weights = ( + deepseek_v4_prepare_indexer_q_mxfp4( + index_q=index_q, + positions=decode_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + ) + ) + with deepseek_v4_profile_scope("indexer_topk_deepgemm_decode"): + topk = _deepseek_v4_indexer_topk_from_cache_deepgemm_decode( + cache_2d=pool.get_indexer_kv_buffer_2d(layer_index), + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=packed_index_q, + weights=packed_weights, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + metadata=metadata, + is_valid_token=decode_valid_token, + out=decode_out, + ) + if topk is not None: + return + + with deepseek_v4_profile_scope("indexer_prepare_reference_decode"): + index_q_fallback, weights_fallback = ( + deepseek_v4_prepare_indexer_q_reference( + index_q=index_q, + positions=decode_positions, + cos_sin_cache=cos_sin_cache, + weights=weights, + softmax_scale=self.softmax_scale, + head_scale=self.n_head**-0.5, + use_fp4=self.use_fp4_cache, + ) + ) + cache_reader = ( + read_deepseek_v4_indexer_mxfp4_cache + if self.use_fp4_cache + else read_deepseek_v4_indexer_fp8_cache + ) + _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=pool.get_indexer_kv_buffer_2d(layer_index), + positions=decode_positions, + token_to_req_indices=decode_token_to_req, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback, + weights=weights_fallback, + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + out=decode_out, + ) + + fill_prefill_topk() + fill_decode_topk() + return topk_out with deepseek_v4_profile_scope("indexer_wq_b"): index_q, _ = self.wq_b(qr) index_q = index_q.view(-1, self.n_head, self.head_dim) @@ -2638,6 +4609,11 @@ def forward( compress_ratio=self.compress_ratio, topk_tokens=self.topk_tokens, metadata=metadata, + is_valid_token=( + metadata.is_valid_token[: positions.numel()] + if getattr(metadata, "is_valid_token", None) is not None + else None + ), out=topk_out, ) if topk is not None: @@ -2681,10 +4657,104 @@ def forward( ) indexer_cache = pool.get_indexer_kv_buffer_2d(layer_index) + request_chunks = ( + _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, + compress_ratio=self.compress_ratio, + num_tokens=positions.numel(), + ) + if metadata.seq_lens_cpu is not None and metadata.query_lens_cpu is not None + else [] + ) + if request_chunks: + topk_out = ( + self.topk_buffer.get(positions.numel(), positions.device) + if self.topk_buffer is not None + else torch.empty( + (positions.numel(), self.topk_tokens), + device=positions.device, + dtype=torch.int32, + ) + )[: positions.numel()] + topk_out.fill_(-1) + gather_cache_key = None + gathered_k = None + for chunk in request_chunks: + topk = None + if packed_index_q is not None and packed_weights is not None: + with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): + gather_plan = _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + compress_ratio=self.compress_ratio, + req_start=chunk.req_start, + req_end=chunk.req_end, + query_start=chunk.query_start, + query_end=chunk.query_end, + ) + key = (chunk.req_start, chunk.req_end) + reuse_k = ( + gathered_k + if chunk.skip_kv_gather and gather_cache_key == key + else None + ) + topk, next_gathered_k = ( + _deepseek_v4_indexer_topk_from_cache_deepgemm_prefill_plan( + cache_2d=indexer_cache, + gather_plan=gather_plan, + cache_block_size=indexer_block_size, + index_q=( + packed_index_q[0][ + chunk.token_start : chunk.token_end + ], + packed_index_q[1][ + chunk.token_start : chunk.token_end + ], + ), + weights=packed_weights[ + chunk.token_start : chunk.token_end + ], + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + gathered_k=reuse_k, + ) + ) + if topk is not None and next_gathered_k is not None: + gather_cache_key = key + gathered_k = next_gathered_k + if topk is None: + with deepseek_v4_profile_scope("indexer_topk_fallback_prefill"): + topk = _deepseek_v4_indexer_topk_from_cache_batched( + cache_reader=cache_reader, + cache_2d=indexer_cache, + positions=positions[chunk.token_start : chunk.token_end], + token_to_req_indices=metadata.token_to_req_indices[ + chunk.token_start : chunk.token_end + ], + block_table=indexer_block_table, + cache_block_size=indexer_block_size, + index_q=index_q_fallback[ + chunk.token_start : chunk.token_end + ], + weights=weights_fallback[ + chunk.token_start : chunk.token_end + ], + compress_ratio=self.compress_ratio, + topk_tokens=self.topk_tokens, + preserve_topk_order=True, + ) + topk_out[chunk.token_start : chunk.token_end].copy_(topk) + return topk_out + topk_chunks = [] for start, end in _deepseek_v4_indexer_prefill_topk_chunks( positions, self.compress_ratio, + seq_lens_cpu=metadata.seq_lens_cpu, + query_lens_cpu=metadata.query_lens_cpu, ): if packed_index_q is not None and packed_weights is not None: with deepseek_v4_profile_scope("indexer_topk_deepgemm_prefill"): @@ -3342,6 +5412,14 @@ def run_compressor() -> None: topk_indices = None if self.indexer is not None: assert self.compressor is not None + with deepseek_v4_profile_scope( + f"{profile_prefix}_indexer_prepare_decode_metadata" + ): + self.indexer.prepare_decode_metadata( + positions=positions, + metadata=metadata, + indexer_block_size=pool.get_indexer_block_size(self.layer_index), + ) def run_indexer() -> torch.Tensor: with deepseek_v4_profile_scope(f"{profile_prefix}_indexer"): @@ -3381,12 +5459,38 @@ def insert_and_compress() -> None: "forward_deepseek_v4_decode", None, ) + backend_mixed = getattr( + ctx.attn_backend, + "forward_deepseek_v4_mixed", + None, + ) backend_prefill = getattr( ctx.attn_backend, "forward_deepseek_v4_prefill", None, ) if ( + backend_mixed is not None + and ctx.forward_mode is not None + and ctx.forward_mode.is_mixed() + ): + with deepseek_v4_profile_scope(f"{profile_prefix}_mixed_backend"): + attn_output = backend_mixed( + q=q, + positions=positions, + token_to_kv_pool=pool, + layer_id=self.layer_index, + kind=self.attention_kind, + compress_ratio=self.compress_ratio, + num_local_heads=self.num_local_heads, + padded_heads=self.padded_heads, + head_dim=self.head_dim, + window_size=self.layout.swa_window, + softmax_scale=self.scale, + attn_sink=self.attn_sink, + topk_indices=topk_indices, + ) + elif ( backend_decode is not None and ctx.forward_mode is not None and ctx.forward_mode.is_decode() @@ -3534,6 +5638,9 @@ def _pre_mlp_input_ids_comm( [tokens[:count] for tokens, count in zip(gathered, token_counts)], dim=0 ) + def _mega_moe_token_counts(self, ctx: ForwardContext) -> list[int]: + return self.comm_manager.moe_tp_ep_group_scattered_num_tokens(ctx) + def forward( self, positions: torch.Tensor, @@ -3564,11 +5671,7 @@ def forward( ffn_input_ids = input_ids use_mega_moe = getattr(self.ffn, "use_mega_moe", False) if use_mega_moe: - token_counts = ( - [int(count) for count in ctx.global_num_tokens] - if ctx.global_num_tokens is not None - else [int(hidden_states.shape[0])] - ) + token_counts = self._mega_moe_token_counts(ctx) num_global_tokens = sum(token_counts) max_num_tokens_per_gpu = max(token_counts) if token_counts else 0 else: diff --git a/python/tokenspeed/runtime/utils/custom_ops.py b/python/tokenspeed/runtime/utils/custom_ops.py new file mode 100644 index 000000000..1a9fb2d65 --- /dev/null +++ b/python/tokenspeed/runtime/utils/custom_ops.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import importlib +from collections.abc import Callable + +import torch +import torch.library +from torch.library import Library + +tokenspeed_lib = Library("tokenspeed", "FRAGMENT") + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str], + fake_impl: Callable | None = None, + target_lib: Library | None = None, +) -> None: + """Register a low-overhead torch custom op in the TokenSpeed namespace.""" + + target = target_lib or tokenspeed_lib + lib_name = getattr(getattr(target, "m", None), "name", "tokenspeed") + try: + if hasattr(torch.ops, lib_name) and hasattr( + getattr(torch.ops, lib_name), op_name + ): + return + except (AttributeError, RuntimeError): + pass + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + custom_op_impl = importlib.import_module("torch._custom_op.impl") + schema_str = custom_op_impl.infer_schema(op_func, mutates_args) + + target.define(op_name + schema_str) + target.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + target._register_fake(op_name, fake_impl) diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index 3b8aba51d..80bcb3517 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -88,6 +88,7 @@ class ServerArgs: max_total_tokens: int | None = None chunked_prefill_size: int | None = None max_prefill_tokens: int = 8192 + enable_mixed_chunk: bool = True block_size: int = 64 # special kv cache mamba_ssm_dtype: str = "float32" @@ -814,6 +815,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.chunked_prefill_size, help="Maximum number of tokens the scheduler may issue in a single iteration. Setting this to -1 disables chunked prefill.", ) + parser.add_argument( + "--disable-mixed-chunk", + action="store_false", + dest="enable_mixed_chunk", + default=ServerArgs.enable_mixed_chunk, + help="Disallow the scheduler from issuing prefill and decode requests in the same iteration.", + ) parser.add_argument( "--block-size", metavar="BLOCK_SIZE", diff --git a/test/runtime/kernels/test_trtllm_wrapper.py b/test/runtime/kernels/test_trtllm_wrapper.py new file mode 100644 index 000000000..c458a9047 --- /dev/null +++ b/test/runtime/kernels/test_trtllm_wrapper.py @@ -0,0 +1,103 @@ +import unittest +from unittest.mock import patch + +import torch + + +class TRTLLMWrapperTest(unittest.TestCase): + def test_fast_topk_v2_decode_accepts_2d_lens(self): + from tokenspeed_kernel.registry import error_fn + from tokenspeed_kernel.thirdparty import trtllm + + if trtllm.fast_topk_v2 is None or trtllm.fast_topk_v2 is error_fn: + self.skipTest("TRTLLM fast_topk_v2 is unavailable on this platform") + + captured = {} + + def fake_indexer_topk_decode(values, seq_lens, indices, next_n, topk): + del values, indices + captured["seq_lens"] = seq_lens + captured["next_n"] = next_n + captured["topk"] = topk + + with patch.object( + torch.ops.trtllm, + "indexer_topk_decode", + fake_indexer_topk_decode, + create=True, + ): + values = torch.empty((2, 4), dtype=torch.float32) + seq_lens = torch.tensor([[3], [4]], dtype=torch.int64) + indices = torch.empty((2, 2), dtype=torch.int32) + + trtllm.fast_topk_v2( + values, + seq_lens, + indices, + topk=2, + next_n=1, + ) + + self.assertEqual(captured["next_n"], 1) + self.assertEqual(captured["topk"], 2) + self.assertEqual(captured["seq_lens"].dtype, torch.int32) + self.assertEqual(captured["seq_lens"].dim(), 1) + torch.testing.assert_close( + captured["seq_lens"], + torch.tensor([3, 4], dtype=torch.int32), + atol=0, + rtol=0, + ) + + def test_fast_topk_v2_prefill_uses_int32_row_offsets(self): + from tokenspeed_kernel.registry import error_fn + from tokenspeed_kernel.thirdparty import trtllm + + if trtllm.fast_topk_v2 is None or trtllm.fast_topk_v2 is error_fn: + self.skipTest("TRTLLM fast_topk_v2 is unavailable on this platform") + + captured = {} + + def fake_indexer_topk_prefill(values, row_starts, row_ends, indices, topk): + del values, indices + captured["row_starts"] = row_starts + captured["row_ends"] = row_ends + captured["topk"] = topk + + with patch.object( + torch.ops.trtllm, + "indexer_topk_prefill", + fake_indexer_topk_prefill, + create=True, + ): + values = torch.empty((3, 4), dtype=torch.float32) + seq_lens = torch.tensor([[1], [2]], dtype=torch.int64) + indices = torch.empty((2, 2), dtype=torch.int32) + + trtllm.fast_topk_v2( + values, + seq_lens, + indices, + topk=2, + next_n=2, + ) + + self.assertEqual(captured["topk"], 2) + self.assertEqual(captured["row_starts"].dtype, torch.int32) + self.assertEqual(captured["row_ends"].dtype, torch.int32) + torch.testing.assert_close( + captured["row_starts"], + torch.tensor([0, 1], dtype=torch.int32), + atol=0, + rtol=0, + ) + torch.testing.assert_close( + captured["row_ends"], + torch.tensor([1, 3], dtype=torch.int32), + atol=0, + rtol=0, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/runtime/test_cli_config_compat.py b/test/runtime/test_cli_config_compat.py index 3cc53e480..1b47ab6d2 100644 --- a/test/runtime/test_cli_config_compat.py +++ b/test/runtime/test_cli_config_compat.py @@ -193,6 +193,7 @@ def test_prefill_token_defaults(self): args = self._parse_args(["--model", "test/model"]) self.assertEqual(args.max_prefill_tokens, 8192) self.assertIsNone(args.chunked_prefill_size) + self.assertTrue(args.enable_mixed_chunk) sa = self._from_cli_args_no_init(args) sa.mapping = SimpleNamespace(world_size=1) @@ -205,6 +206,11 @@ def test_prefill_token_defaults(self): self.assertEqual(sa.max_prefill_tokens, 8192) self.assertEqual(sa.chunked_prefill_size, 8192) + self.assertTrue(sa.enable_mixed_chunk) + + def test_mixed_chunk_can_be_disabled(self): + args = self._parse_args(["--model", "test/model", "--disable-mixed-chunk"]) + self.assertFalse(args.enable_mixed_chunk) def test_distributed_timeout_seconds_arg(self): args = self._parse_args( diff --git a/test/runtime/test_deepseek_v4_attention_ops.py b/test/runtime/test_deepseek_v4_attention_ops.py index 5c92d257c..34549e094 100644 --- a/test/runtime/test_deepseek_v4_attention_ops.py +++ b/test/runtime/test_deepseek_v4_attention_ops.py @@ -628,6 +628,91 @@ def test_indexer_mxfp4_cache_matches_reference(self): ) self.assertEqual(int(flat_cache[64:128].sum()), 0) + def test_indexer_mxfp4_paged_gather_matches_paged_layout(self): + from tokenspeed_kernel.thirdparty.cuda.deepseek_v4_attention import ( + has_indexer_mxfp4_paged_gather, + indexer_mxfp4_paged_gather, + ) + + if not has_indexer_mxfp4_paged_gather(): + self.skipTest("DeepSeek V4 paged MXFP4 gather op is not available") + + device = torch.device("cuda") + block_size = 4 + value_bytes = 64 + scale_bytes = 4 + num_blocks = 3 + kv_cache = torch.zeros( + num_blocks, + block_size * (value_bytes + scale_bytes), + device=device, + dtype=torch.uint8, + ) + + value_rows = {} + scale_rows = {} + for block_idx in range(num_blocks): + for row_idx in range(block_size): + values = ( + ( + torch.arange(value_bytes, device=device, dtype=torch.int16) + + block_idx * 37 + + row_idx * 11 + ) + .remainder(251) + .to(torch.uint8) + ) + scales = torch.tensor( + [block_idx, row_idx, block_idx * 17 + row_idx, 200 + block_idx], + device=device, + dtype=torch.uint8, + ) + value_base = row_idx * value_bytes + scale_base = block_size * value_bytes + row_idx * scale_bytes + kv_cache[block_idx, value_base : value_base + value_bytes].copy_(values) + kv_cache[block_idx, scale_base : scale_base + scale_bytes].copy_(scales) + value_rows[(block_idx, row_idx)] = values + scale_rows[(block_idx, row_idx)] = scales + + block_table = torch.tensor([[2, 0], [1, 0]], device=device, dtype=torch.int32) + cu_seq_lens = torch.tensor([0, 5, 7], device=device, dtype=torch.int32) + values_out = torch.full( + (8, value_bytes), 0xCC, device=device, dtype=torch.uint8 + ) + scales_out = torch.full( + (8, scale_bytes), 0xDD, device=device, dtype=torch.uint8 + ) + + indexer_mxfp4_paged_gather( + kv_cache, + values_out, + scales_out, + block_table, + cu_seq_lens, + block_size, + ) + torch.cuda.synchronize() + + expected_plan = [ + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (0, 0), + (1, 0), + (1, 1), + ] + expected_values = torch.stack([value_rows[item] for item in expected_plan]) + expected_scales = torch.stack([scale_rows[item] for item in expected_plan]) + self.assertTrue(torch.equal(values_out[:7].cpu(), expected_values.cpu())) + self.assertTrue(torch.equal(scales_out[:7].cpu(), expected_scales.cpu())) + self.assertTrue( + torch.equal(values_out[7].cpu(), torch.full((64,), 0xCC, dtype=torch.uint8)) + ) + self.assertTrue( + torch.equal(scales_out[7].cpu(), torch.full((4,), 0xDD, dtype=torch.uint8)) + ) + def test_csa_indexer_cache_insert_matches_reference(self): torch.manual_seed(8901) device = torch.device("cuda") @@ -1082,6 +1167,46 @@ def test_decode_swa_indices_and_lens_matches_reference(self): compact_lens.cpu(), actual_lens.cpu(), atol=0, rtol=0 ) + def test_decode_swa_indices_and_lens_masks_invalid_tokens(self): + device = torch.device("cuda") + query_start_loc = torch.tensor([0, 1, 2], device=device, dtype=torch.int32) + seq_lens = torch.tensor([70, 3], device=device, dtype=torch.int32) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + is_valid_token = torch.tensor([True, False], device=device) + block_table = torch.tensor( + [[10, 11], [20, 21]], + device=device, + dtype=torch.int32, + ) + out_indices = torch.full((2, 4), -123, device=device, dtype=torch.int32) + out_lens = torch.empty((2,), device=device, dtype=torch.int32) + + actual, actual_lens = deepseek_v4_decode_swa_indices_and_lens( + query_start_loc=query_start_loc, + seq_lens=seq_lens, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + window_size=4, + block_size=64, + is_valid_token=is_valid_token, + out_indices=out_indices, + out_lens=out_lens, + ) + torch.cuda.synchronize() + + self.assertTrue( + torch.equal(actual_lens.cpu(), torch.tensor([4, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + actual[0].cpu(), + torch.tensor([706, 707, 708, 709], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal(actual[1].cpu(), torch.full((4,), -123, dtype=torch.int32)) + ) + def test_compute_global_topk_indices_and_lens_matches_reference(self): device = torch.device("cuda") topk_indices = torch.tensor( @@ -1134,6 +1259,46 @@ def test_compute_global_topk_indices_and_lens_matches_reference(self): actual_lens.cpu(), expected_lens.cpu(), atol=0, rtol=0 ) + def test_compute_global_topk_indices_and_lens_masks_invalid_tokens(self): + device = torch.device("cuda") + topk_indices = torch.tensor( + [ + [0, 1, -1, 5], + [3, -1, -1, -1], + ], + device=device, + dtype=torch.int32, + ) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + is_valid_token = torch.tensor([True, False], device=device) + block_table = torch.tensor( + [ + [10, 11], + [20, 21], + ], + device=device, + dtype=torch.int32, + ) + + actual, actual_lens = deepseek_v4_compute_global_topk_indices_and_lens( + topk_indices=topk_indices, + token_to_req_indices=token_to_req_indices, + block_table=block_table, + block_size=4, + is_valid_token=is_valid_token, + ) + torch.cuda.synchronize() + + self.assertTrue( + torch.equal(actual_lens.cpu(), torch.tensor([3, 0], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + actual[0].cpu(), + torch.tensor([40, 41, -1, 45], dtype=torch.int32), + ) + ) + def test_compressed_slot_mapping_matches_page_reference(self): device = torch.device("cuda") query_start_loc = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) diff --git a/test/runtime/test_deepseek_v4_config.py b/test/runtime/test_deepseek_v4_config.py index 3bf2c1715..c539a623e 100644 --- a/test/runtime/test_deepseek_v4_config.py +++ b/test/runtime/test_deepseek_v4_config.py @@ -12,12 +12,17 @@ configure_deepseek_v4_attention, is_deepseek_v4, ) +from tokenspeed.runtime.execution.cuda_graph_wrapper import CudaGraphWrapper from tokenspeed.runtime.execution.forward_batch_info import ForwardMode +from tokenspeed.runtime.layers.attention.backends import ( + deepseek_v4 as deepseek_v4_backend, +) from tokenspeed.runtime.layers.attention.backends.deepseek_v4 import ( DeepseekV4AttentionBackend, ) from tokenspeed.runtime.layers.attention.deepseek_v4_ops import ( DeepseekV4AttentionOpUnavailable, + deepseek_v4_compute_global_topk_indices_and_lens, deepseek_v4_indexer_topk_reference, fused_qnorm_rope_kv_insert, has_fused_qnorm_rope_kv_insert, @@ -33,21 +38,26 @@ _get_flashinfer_mxfp4_device_permute_indices, _reorder_w1w3_to_w3w1, ) -from tokenspeed.runtime.layers.moe.backends.mxfp4.triton_kernel import ( - _mxfp4_scale_for_layout, -) -from tokenspeed.runtime.layers.moe.backends.mxfp4.weights import MXFP4_SCALE_DTYPE from tokenspeed.runtime.layers.quantization import QUANTIZATION_METHODS +from tokenspeed.runtime.models import deepseek_v4 as deepseek_v4_model from tokenspeed.runtime.models.deepseek_v4 import ( DeepseekV4Attention, + DeepseekV4Indexer, + DeepseekV4MLP, DeepseekV4MoEGate, _deepseek_v4_fused_select_experts, _deepseek_v4_gather_indexer_mxfp4_cache, _deepseek_v4_get_fp8_linear_deep_gemm, _deepseek_v4_indexer_decode_max_len, + _deepseek_v4_indexer_prefill_gather_plan, + _deepseek_v4_indexer_prefill_max_logits_bytes, + _deepseek_v4_indexer_prefill_metadata, + _deepseek_v4_indexer_prefill_request_chunks, + _deepseek_v4_indexer_prefill_request_gather_plan, _deepseek_v4_indexer_prefill_topk_chunks, _deepseek_v4_indexer_topk_from_cache_batched, _deepseek_v4_indexer_topk_from_logits, + _deepseek_v4_prefill_topk_op_available, _deepseek_v4_reorder_c4_ape_2604, _DeepseekV4TopKBuffer, _fp8_act_quant_dequant, @@ -80,6 +90,62 @@ def test_config_registry(self): self.assertEqual(DeepseekV4Config.model_type, "deepseek_v4") self.assertIs(_CONFIG_REGISTRY["deepseek_v4"], DeepseekV4Config) + def test_forward_mode_mixed_predicate(self): + self.assertTrue(ForwardMode.MIXED.is_mixed()) + self.assertFalse(ForwardMode.EXTEND.is_mixed()) + self.assertFalse(ForwardMode.DECODE.is_mixed()) + self.assertEqual(ForwardMode.from_num_extends(0, 0), ForwardMode.IDLE) + self.assertEqual(ForwardMode.from_num_extends(0, 2), ForwardMode.DECODE) + self.assertEqual( + ForwardMode.from_num_extends(0, 2, has_drafter=True), + ForwardMode.TARGET_VERIFY, + ) + self.assertEqual(ForwardMode.from_num_extends(2, 2), ForwardMode.EXTEND) + self.assertEqual(ForwardMode.from_num_extends(1, 2), ForwardMode.MIXED) + + def test_cuda_graph_group_table_padding_uses_dummy_page_rows(self): + table = torch.tensor([[5, -1]], dtype=torch.int32) + padded = CudaGraphWrapper._pad_block_tables_to_padded_bs( + {"v4.swa": table}, + actual_bs=1, + padded_bs=3, + ) + + self.assertEqual(padded["v4.swa"].tolist(), [[5, -1], [0, 0], [0, 0]]) + + def test_cuda_graph_replay_keeps_idle_actual_bs_with_padded_group_tables(self): + captured = {} + + class FakeBackend: + uses_paged_cache_groups = True + uses_padded_decode_token_mask = True + + def init_forward_metadata_replay_cuda_graph(self, *args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + wrapper = object.__new__(CudaGraphWrapper) + wrapper.attn_backend = FakeBackend() + wrapper.draft_attn_backend = None + + wrapper._init_replay_metadata( + padded_bs=4, + actual_bs=0, + req_pool_indices=torch.zeros(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + req_to_page=torch.zeros((1, 1), dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + paged_cache_block_tables={ + "v4.swa": torch.zeros((4, 1), dtype=torch.int32), + }, + ) + + self.assertEqual(captured["kwargs"]["actual_bs"], 0) + self.assertEqual( + captured["kwargs"]["paged_cache_block_tables"]["v4.swa"].shape, + (4, 1), + ) + def test_deepseek_v4_tokenizer_wrapper_uses_model_encoder(self): calls = [] @@ -219,6 +285,19 @@ def test_deepseek_v4_server_args_cli_flags_round_trip(self): global_server_args_dict.clear() global_server_args_dict.update(snapshot) + def test_deepseek_v4_indexer_prefill_max_logits_uses_server_arg(self): + snapshot = dict(global_server_args_dict) + try: + global_server_args_dict["deepseek_v4_indexer_prefill_max_logits_mb"] = 7 + + self.assertEqual( + _deepseek_v4_indexer_prefill_max_logits_bytes(), + 7 * 1024 * 1024, + ) + finally: + global_server_args_dict.clear() + global_server_args_dict.update(snapshot) + def test_fp8_quantization_config(self): quantization = QUANTIZATION_METHODS["fp8"] @@ -426,6 +505,27 @@ def test_deepseek_v4_attention_op_boundary_fails_loudly_when_missing(self): q, kv, cache, slots, positions, cos_sin, 1e-6, 256 ) + def test_deepseek_v4_flashmla_wrapper_exposes_required_api(self): + try: + from tokenspeed_kernel.ops.attention.flash_mla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, + get_mla_metadata, + ) + from tokenspeed_kernel.registry import error_fn + except Exception as exc: + self.skipTest(f"FlashMLA wrapper unavailable: {exc}") + if ( + flash_mla_with_kvcache is error_fn + or flash_mla_sparse_fwd is error_fn + or get_mla_metadata is error_fn + ): + self.skipTest("FlashMLA wrapper unavailable on this platform") + + self.assertTrue(callable(flash_mla_with_kvcache)) + self.assertTrue(callable(flash_mla_sparse_fwd)) + self.assertTrue(callable(get_mla_metadata)) + def test_deepseek_v4_model_config_uses_mla_runtime_metadata(self): model_config = object.__new__(ModelConfig) model_config.hf_config = SimpleNamespace( @@ -539,7 +639,7 @@ def test_deepseek_v4_kv_pool_allocates_v4_cache_families(self): use_fp4_indexer_cache=True, ) - self.assertEqual(layout.cache_cell_size(3), 17329) + self.assertEqual(layout.cache_cell_size(3), 16771) pool = DeepseekV4TokenToKVPool( size=128, @@ -675,6 +775,45 @@ def test_deepseek_v4_backend_preserves_compact_paged_cache_contract(self): self.assertTrue(torch.equal(metadata.swa_block_table, compact)) self.assertTrue(torch.equal(metadata.swa_base_logical_page, base)) + def test_deepseek_v4_mixed_metadata_keeps_decode_rows_single_token(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=4096, + ) + ) + + backend.init_forward_metadata( + bs=3, + num_tokens=10, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int64), + seq_lens=torch.tensor([7, 10, 4], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.zeros((3, 1), dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([7], dtype=torch.int32), + num_extends=1, + ) + + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + assert metadata is not None + self.assertEqual(metadata.query_lens.tolist(), [7, 1, 1]) + self.assertEqual(metadata.query_lens_cpu.tolist(), [7, 1, 1]) + self.assertEqual(metadata.num_prefill_reqs, 1) + self.assertEqual(metadata.num_prefill_tokens, 7) + self.assertEqual(metadata.decode_req_count(), 2) + self.assertEqual(metadata.decode_token_count(), 2) + self.assertEqual( + metadata.token_to_req_indices.tolist(), + [0, 0, 0, 0, 0, 0, 0, 1, 2], + ) + def test_deepseek_v4_cuda_graph_refresh_keeps_compact_table_columns(self): backend = DeepseekV4AttentionBackend( SimpleNamespace( @@ -918,7 +1057,7 @@ def test_deepseek_v4_metadata_maps_compressed_slots(self): torch.tensor([3, 7, 127], dtype=torch.int64), compress_ratio=4, ) - self.assertTrue(torch.equal(slots, torch.tensor([0, 1, 31]))) + self.assertTrue(torch.equal(slots, torch.tensor([640, 641, 671]))) page256_metadata = DeepseekV4ForwardMetadata( page_size=256, @@ -936,6 +1075,271 @@ def test_deepseek_v4_metadata_maps_compressed_slots(self): ) self.assertTrue(torch.equal(slots, torch.tensor([383, 384, 447]))) + grouped_metadata = DeepseekV4ForwardMetadata( + page_size=256, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[5, 6], [7, 8]], dtype=torch.int32), + seq_lens=torch.tensor([300, 10], dtype=torch.int32), + query_lens=torch.tensor([3, 2], dtype=torch.int32), + query_start_loc=torch.tensor([0, 3, 5], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + paged_cache_block_tables={ + "v4.c4a.compressed_kv": torch.tensor( + [[20, 21], [30, -1]], dtype=torch.int32 + ) + }, + ) + slots = grouped_metadata.compressed_slot_mapping( + torch.tensor([255, 256, 511, 2560, 4], dtype=torch.int64), + compress_ratio=4, + kv_cache_block_size=64, + ) + self.assertTrue(torch.equal(slots, torch.tensor([1343, 1344, 1407, -1, 1921]))) + + def test_deepseek_v4_group_slot_mapping_from_raw(self): + block_table = torch.tensor([[10, 11], [20, -1]], dtype=torch.int32) + slots = _group_slot_mapping_from_raw( + positions=torch.tensor([0, 63, 64, 9, 10], dtype=torch.int64), + req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + block_table=block_table, + rows_per_page=64, + entry_stride_tokens=1, + ) + self.assertTrue(torch.equal(slots, torch.tensor([640, 703, 704, 1289, 1290]))) + + compressed_slots = _group_slot_mapping_from_raw( + positions=torch.tensor([0, 255, 256, 511], dtype=torch.int64), + req_indices=torch.tensor([0, 0, 0, 1], dtype=torch.int32), + block_table=block_table, + rows_per_page=64, + entry_stride_tokens=4, + ) + self.assertTrue( + torch.equal(compressed_slots, torch.tensor([640, 703, 704, -1])) + ) + + def test_deepseek_v4_mixed_metadata_splits_prefill_and_decode(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=3, + num_tokens=5, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 1, 1], dtype=torch.int32), + extend_prefix_lens_cpu=torch.tensor([2, 8, 11], dtype=torch.int32), + num_extends=1, + ) + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + self.assertEqual(metadata.num_prefill_reqs, 1) + self.assertEqual(metadata.num_prefill_tokens, 3) + self.assertEqual(metadata.decode_req_count(), 2) + self.assertEqual(metadata.decode_token_count(), 2) + self.assertTrue( + torch.equal( + metadata.token_to_req_indices, + torch.tensor([0, 0, 0, 1, 2], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.seq_lens_cpu, + torch.tensor([5, 9, 12], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.query_lens_cpu, + torch.tensor([3, 1, 1], dtype=torch.int32), + ) + ) + + prefill = backend._metadata_slice( + metadata, + req_start=0, + req_end=1, + token_start=0, + token_end=3, + forward_mode=ForwardMode.EXTEND, + ) + decode = backend._metadata_slice( + metadata, + req_start=1, + req_end=3, + token_start=3, + token_end=5, + forward_mode=ForwardMode.DECODE, + ) + + self.assertTrue(prefill.forward_mode.is_extend()) + self.assertTrue(decode.forward_mode.is_decode()) + self.assertTrue( + torch.equal(prefill.token_to_req_indices, torch.tensor([0, 0, 0])) + ) + self.assertTrue(torch.equal(decode.token_to_req_indices, torch.tensor([0, 1]))) + self.assertTrue( + torch.equal( + decode.query_start_loc, torch.tensor([0, 1, 2], dtype=torch.int32) + ) + ) + self.assertTrue(torch.equal(decode.block_table[:, 0], torch.tensor([20, 30]))) + self.assertTrue( + torch.equal(prefill.seq_lens_cpu, torch.tensor([5], dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(decode.query_lens_cpu, torch.tensor([1, 1], dtype=torch.int32)) + ) + + def test_deepseek_v4_mixed_metadata_accepts_prefill_prefix_lens_only(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=4, + num_tokens=8, + req_pool_indices=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12, 6], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30], [40]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 4, 1, 1], dtype=torch.int32), + extend_prefix_lens_cpu=torch.tensor([2, 5, 11], dtype=torch.int32), + num_extends=3, + ) + + metadata = backend.forward_metadata + self.assertIsNotNone(metadata) + self.assertEqual(metadata.num_prefill_reqs, 3) + self.assertEqual(metadata.num_prefill_tokens, 8) + self.assertEqual(metadata.decode_req_count(), 1) + self.assertEqual(metadata.decode_token_count(), 1) + self.assertTrue( + torch.equal( + metadata.seq_lens_cpu, + torch.tensor([5, 9, 12, 6], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.query_lens_cpu, + torch.tensor([3, 4, 1, 1], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_mixed_backend_slices_prefill_and_decode(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + backend.init_forward_metadata( + bs=3, + num_tokens=5, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 1, 1], dtype=torch.int32), + num_extends=1, + ) + calls = [] + + def fake_prefill(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "prefill", + kwargs["q"].shape[0], + kwargs["positions"].tolist(), + kwargs["topk_indices"].tolist(), + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + return kwargs["q"].new_full((3, 2, 4), 1.0) + + def fake_decode(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "decode", + kwargs["q"].shape[0], + kwargs["positions"].tolist(), + kwargs["topk_indices"].tolist(), + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + return kwargs["q"].new_full((2, 2, 4), 2.0) + + backend.forward_deepseek_v4_prefill = fake_prefill + backend.forward_deepseek_v4_decode = fake_decode + q = torch.zeros((5, 2, 4), dtype=torch.float32) + topk = torch.arange(10, dtype=torch.int32).view(5, 2) + out = backend.forward_deepseek_v4_mixed( + q=q, + positions=torch.arange(5, dtype=torch.int32), + token_to_kv_pool=SimpleNamespace(), + layer_id=0, + kind="mla", + compress_ratio=4, + num_local_heads=2, + padded_heads=2, + head_dim=4, + window_size=4, + softmax_scale=1.0, + attn_sink=torch.zeros(2), + topk_indices=topk, + ) + + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0][0], "prefill") + self.assertEqual(calls[0][1], 3) + self.assertEqual(calls[0][2], [0, 1, 2]) + self.assertEqual(calls[0][3], [[0, 1], [2, 3], [4, 5]]) + self.assertEqual(calls[0][4], [0]) + self.assertEqual(calls[0][5], [0, 0, 0]) + self.assertTrue(calls[0][6].is_extend()) + self.assertEqual(calls[1][0], "decode") + self.assertEqual(calls[1][1], 2) + self.assertEqual(calls[1][2], [3, 4]) + self.assertEqual(calls[1][3], [[6, 7], [8, 9]]) + self.assertEqual(calls[1][4], [1, 2]) + self.assertEqual(calls[1][5], [0, 1]) + self.assertTrue(calls[1][6].is_decode()) + self.assertTrue(torch.equal(out[:3], torch.ones((3, 2, 4)))) + self.assertTrue(torch.equal(out[3:], torch.full((2, 2, 4), 2.0))) + def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): backend = DeepseekV4AttentionBackend( SimpleNamespace( @@ -946,7 +1350,7 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): attn_tp_size=1, dtype=torch.bfloat16, head_dim=512, - context_len=4096, + context_len=128, ) ) seq_lens = torch.tensor([70, 3], dtype=torch.int32) @@ -964,7 +1368,7 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): [[1, 65, 3, -1], [0, -1, -1, -1]], dtype=torch.int32, ) - indices, lens = backend._decode_compressed_indices_and_lens( + indices, lens = backend._decode_compressed_attention_indices_and_lens( positions, compress_ratio=4, block_size=64, @@ -993,8 +1397,9 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): dtype=torch.int32, ), ) - indices, lens = backend._decode_compressed_indices_and_lens( - seq_lens.to(torch.int64) - 1, + hca_positions = seq_lens.to(torch.int64) - 1 + indices, lens = backend._decode_compressed_attention_indices_and_lens( + hca_positions, compress_ratio=128, block_size=64, topk_indices=None, @@ -1006,6 +1411,79 @@ def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): torch.tensor([[640, 641], [1280, -1]], dtype=torch.int32), ) ) + cached_indices, cached_lens = ( + backend._decode_compressed_attention_indices_and_lens( + hca_positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + ) + self.assertEqual(cached_indices.data_ptr(), indices.data_ptr()) + self.assertEqual(cached_lens.data_ptr(), lens.data_ptr()) + + def test_deepseek_v4_decode_backend_capture_ignores_warmup_cache(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for capture cache semantics") + device = torch.device("cuda") + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cuda", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + seq_lens = torch.tensor([128, 64], device=device, dtype=torch.int32) + backend.init_forward_metadata( + bs=2, + num_tokens=2, + req_pool_indices=torch.tensor([0, 1], device=device, dtype=torch.int64), + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor( + [[10, 11], [20, 21]], + device=device, + dtype=torch.int32, + ), + ) + positions = seq_lens.to(torch.int64) - 1 + + warmup_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + metadata = backend.forward_metadata + key = next(iter(metadata.decode_dense_compressed_indices_cache.keys())) + metadata.decode_dense_compressed_indices_capture_safe_keys.clear() + + original_capturing = torch.cuda.is_current_stream_capturing + torch.cuda.is_current_stream_capturing = lambda: True + try: + capture_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + reused_indices, _ = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + finally: + torch.cuda.is_current_stream_capturing = original_capturing + + self.assertNotEqual(capture_indices.data_ptr(), warmup_indices.data_ptr()) + self.assertEqual(reused_indices.data_ptr(), capture_indices.data_ptr()) + self.assertIn(key, metadata.decode_dense_compressed_indices_capture_safe_keys) def test_deepseek_v4_c128a_prefill_local_compressed_indices_contract(self): backend = DeepseekV4AttentionBackend( @@ -1130,6 +1608,264 @@ def test_deepseek_v4_indexer_mxfp4_gather_reuses_workspace(self): self.assertTrue(torch.equal(values, expected_values)) self.assertTrue(torch.equal(scales, expected_scales)) + def test_deepseek_v4_decode_backend_masks_padding_tokens(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + seq_lens = torch.tensor([70, 3], dtype=torch.int32) + backend.init_forward_metadata( + bs=2, + num_tokens=2, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + ) + metadata = backend.forward_metadata + metadata.is_valid_token = torch.tensor([True, False]) + positions = seq_lens.to(torch.int64) - 1 + + topk_indices = torch.tensor( + [[1, 65, 3, -1], [0, -1, -1, -1]], + dtype=torch.int32, + ) + _, csa_lens = backend._decode_compressed_attention_indices_and_lens( + positions, + compress_ratio=4, + block_size=64, + topk_indices=topk_indices, + ) + _, hca_lens = backend._decode_compressed_attention_indices_and_lens( + torch.tensor([255, 128], dtype=torch.int64), + compress_ratio=128, + block_size=64, + topk_indices=None, + ) + + self.assertTrue(torch.equal(csa_lens, torch.tensor([3, 0], dtype=torch.int32))) + self.assertTrue(torch.equal(hca_lens, torch.tensor([2, 0], dtype=torch.int32))) + + def test_deepseek_v4_global_topk_cpu_masks_invalid_req_before_indexing(self): + indices, lens = deepseek_v4_compute_global_topk_indices_and_lens( + topk_indices=torch.tensor([[0, 4], [0, 1]], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 99], dtype=torch.int32), + block_table=torch.tensor([[10]], dtype=torch.int32), + block_size=4, + is_valid_token=torch.tensor([True, False]), + ) + + self.assertTrue( + torch.equal( + indices, + torch.tensor([[40, -1], [-1, -1]], dtype=torch.int32), + ) + ) + self.assertTrue(torch.equal(lens, torch.tensor([1, 0], dtype=torch.int32))) + + def test_deepseek_v4_cuda_graph_replay_marks_padding_tokens_invalid(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + ) + ) + backend.init_cuda_graph_state(max_bs=4) + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=4, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + ) + + backend.init_forward_metadata_replay_cuda_graph( + bs=4, + actual_bs=2, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.tensor([70, 3, 1, 1], dtype=torch.int32), + forward_mode=ForwardMode.DECODE, + req_to_page=torch.tensor( + [ + [10, 11], + [20, 21], + [30, 31], + [40, 41], + ], + dtype=torch.int32, + ), + ) + + metadata = backend.forward_metadata + self.assertTrue( + torch.equal( + metadata.is_valid_token, + torch.tensor([True, True, False, False]), + ) + ) + self.assertEqual(metadata.decode_token_count(), 4) + + def test_deepseek_v4_indexer_metadata_refresh_masks_padding_tokens(self): + key = (4, 4, 3) + metadata = DeepseekV4ForwardMetadata( + page_size=64, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + block_table=torch.tensor([[10, 11], [20, 21], [30, 31]], dtype=torch.int32), + seq_lens=torch.tensor([9, 5, 3], dtype=torch.int32), + query_lens=torch.ones(3, dtype=torch.int32), + query_start_loc=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + is_valid_token=torch.tensor([True, False, True]), + forward_mode=ForwardMode.DECODE, + ) + plan = SimpleNamespace( + context_lens=torch.empty((3, 1), dtype=torch.int32), + block_table=torch.empty((3, 2), dtype=torch.int32), + max_context_len=0, + ) + metadata.decode_indexer_plan_cache[key] = plan + + def fake_compute(**kwargs): + kwargs["out_context_lens"].copy_( + torch.tensor([[2], [2], [1]], dtype=torch.int32) + ) + kwargs["out_block_tables"].copy_( + torch.tensor([[10, 11], [20, 21], [30, 31]], dtype=torch.int32) + ) + + with patch.object( + deepseek_v4_backend, + "deepseek_v4_indexer_decode_metadata_compute", + side_effect=fake_compute, + ): + deepseek_v4_backend._refresh_decode_indexer_plan_cache( + metadata, + max_context_len=256, + ) + + self.assertTrue( + torch.equal( + plan.context_lens, + torch.tensor([[2], [0], [1]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + plan.block_table, + torch.tensor([[10, 11], [0, 0], [30, 31]], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_indexer_decode_metadata_accepts_sliced_valid_mask(self): + metadata = SimpleNamespace( + decode_indexer_plan_cache={}, + decode_indexer_plan_refreshed_keys=set(), + ) + + def fake_compute(**kwargs): + kwargs["out_context_lens"].copy_( + torch.tensor([[2], [2]], dtype=torch.int32) + ) + kwargs["out_block_tables"].copy_( + torch.tensor([[10], [20]], dtype=torch.int32) + ) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_indexer_decode_metadata_compute", + side_effect=fake_compute, + ): + plan = deepseek_v4_model._deepseek_v4_indexer_decode_metadata( + positions=torch.tensor([8, 4], dtype=torch.int64), + token_to_req_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=4, + metadata=metadata, + is_valid_token=torch.tensor([False, True]), + ) + + self.assertTrue( + torch.equal( + plan.context_lens, + torch.tensor([[0], [2]], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + plan.block_table, + torch.tensor([[0], [20]], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_indexer_schedule_refresh_uses_decode_plan_lens(self): + captured = {} + + def fake_get_metadata(context_lens, cache_block_size, num_sms): + captured["context_lens"] = context_lens.clone() + captured["cache_block_size"] = cache_block_size + captured["num_sms"] = num_sms + return torch.full((2, 1), 9, dtype=torch.int32) + + fake_deep_gemm = SimpleNamespace( + get_paged_mqa_logits_metadata=fake_get_metadata, + get_num_sms=lambda: 123, + ) + key = (4, 4, 2) + metadata = DeepseekV4ForwardMetadata( + page_size=64, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [0]], dtype=torch.int32), + seq_lens=torch.tensor([5, 1], dtype=torch.int32), + query_lens=torch.tensor([1, 1], dtype=torch.int32), + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + token_to_req_indices=torch.tensor([0, 1], dtype=torch.int32), + is_valid_token=torch.tensor([True, False]), + forward_mode=ForwardMode.DECODE, + ) + metadata.decode_indexer_plan_cache[key] = SimpleNamespace( + context_lens=torch.zeros((2, 1), dtype=torch.int32), + ) + metadata.decode_indexer_schedule_metadata[key] = torch.zeros( + (2, 1), + dtype=torch.int32, + ) + + with patch( + "tokenspeed_kernel.thirdparty.deep_gemm", + fake_deep_gemm, + create=True, + ): + deepseek_v4_backend._refresh_decode_indexer_schedule_metadata(metadata) + + self.assertTrue( + torch.equal( + captured["context_lens"], torch.zeros((2, 1), dtype=torch.int32) + ) + ) + self.assertEqual(captured["cache_block_size"], 4) + self.assertEqual(captured["num_sms"], 123) + self.assertTrue( + torch.equal( + metadata.decode_indexer_schedule_metadata[key], + torch.full((2, 1), 9, dtype=torch.int32), + ) + ) + def test_deepseek_v4_indexer_decode_batches_cache_reads(self): torch.manual_seed(0) positions = torch.tensor([15, 7, 3], dtype=torch.int64) @@ -1220,6 +1956,104 @@ def test_deepseek_v4_indexer_topk_reuses_output_buffer(self): self.assertTrue(torch.equal(actual[0].sort().values, torch.tensor([1, 2]))) self.assertTrue(torch.equal(actual[1].sort().values, torch.tensor([0, 3]))) + def test_deepseek_v4_indexer_topk_accepts_decode_lens_shape(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf")], + [4.0, 1.0, 2.0, 3.0], + ], + dtype=torch.float32, + ) + lengths = torch.tensor([[3], [4]], dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=2, + next_n=1, + ) + + self.assertEqual(actual.shape, (2, 2)) + self.assertTrue(torch.equal(actual[0].sort().values, torch.tensor([1, 2]))) + self.assertTrue(torch.equal(actual[1].sort().values, torch.tensor([0, 3]))) + + def test_deepseek_v4_indexer_topk_can_sort_preserved_order(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf")], + [4.0, 1.0, 2.0, 3.0], + ], + dtype=torch.float32, + ) + lengths = torch.tensor([3, 4], dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=4, + preserve_topk_order=True, + sort_preserved_topk=True, + ) + + self.assertTrue(torch.equal(actual[0], torch.tensor([1, 2, 0, -1]))) + self.assertTrue(torch.equal(actual[1], torch.tensor([0, 3, 2, 1]))) + + def test_deepseek_v4_indexer_topk_handles_shifted_prefill_rows(self): + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf"), -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), -float("inf"), 2.0, 8.0, 5.0], + ], + dtype=torch.float32, + ) + row_starts = torch.tensor([0, 3], dtype=torch.int32) + row_ends = torch.tensor([3, 6], dtype=torch.int32) + lengths = row_ends - row_starts + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + lengths, + topk_tokens=3, + preserve_topk_order=True, + sort_preserved_topk=True, + row_starts=row_starts, + row_ends=row_ends, + ) + + self.assertTrue(torch.equal(actual[0], torch.tensor([1, 2, 0]))) + self.assertTrue(torch.equal(actual[1], torch.tensor([1, 2, 0]))) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA is required") + def test_deepseek_v4_indexer_topk_uses_local_prefill_op(self): + if not _deepseek_v4_prefill_topk_op_available(): + self.skipTest("TRT-LLM indexer_topk_prefill is unavailable") + + logits = torch.tensor( + [ + [0.0, 3.0, 1.0, -float("inf"), -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), -float("inf"), 2.0, 8.0, 5.0], + ], + device="cuda", + dtype=torch.float32, + ) + row_starts = torch.tensor([0, 3], device="cuda", dtype=torch.int32) + row_ends = torch.tensor([3, 6], device="cuda", dtype=torch.int32) + + actual = _deepseek_v4_indexer_topk_from_logits( + logits, + row_ends - row_starts, + topk_tokens=4, + preserve_topk_order=True, + row_starts=row_starts, + row_ends=row_ends, + ) + + expected = torch.tensor( + [[0, 1, 2, -1], [0, 1, 2, -1]], + dtype=torch.int32, + ) + self.assertTrue(torch.equal(actual.cpu(), expected)) + def test_deepseek_v4_topk_buffer_grows_and_reuses(self): buffer = _DeepseekV4TopKBuffer(topk_tokens=3) @@ -1233,6 +2067,162 @@ def test_deepseek_v4_topk_buffer_grows_and_reuses(self): self.assertEqual(third.shape, (4, 3)) self.assertGreaterEqual(buffer.buffer.shape[0], 4) + def test_deepseek_v4_sparse_indexer_custom_op_registered(self): + self.assertTrue( + hasattr(torch.ops.tokenspeed, "deepseek_v4_sparse_attn_indexer") + ) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA is required") + def test_deepseek_v4_sparse_indexer_custom_op_fallback_covers_decode_tokens(self): + device = torch.device("cuda") + n_head = 2 + head_dim = 4 + total_tokens = 3 + + class FakeLinear: + def __init__(self, out_features): + self.out_features = out_features + + def __call__(self, x): + return ( + torch.zeros( + (x.shape[0], self.out_features), + device=x.device, + dtype=x.dtype, + ), + None, + ) + + self_obj = SimpleNamespace( + use_fp4_cache=True, + wq_b=FakeLinear(n_head * head_dim), + weights_proj=FakeLinear(n_head), + n_head=n_head, + head_dim=head_dim, + softmax_scale=1.0, + compress_ratio=4, + topk_tokens=2, + topk_buffer=None, + _prefill_gather_workspace=lambda rows, device: ( + torch.empty((0, 0), dtype=torch.uint8, device=device), + torch.empty((0, 0), dtype=torch.uint8, device=device), + ), + ) + metadata = SimpleNamespace( + forward_mode=ForwardMode.MIXED, + num_prefill_tokens=1, + num_prefill_reqs=1, + seq_lens_cpu=torch.tensor([4], dtype=torch.int32), + query_lens_cpu=torch.tensor([1], dtype=torch.int32), + token_to_req_indices=torch.tensor( + [0, 0, 0], dtype=torch.int32, device=device + ), + compressed_block_table=lambda compress_ratio, block_size: torch.zeros( + (1, 1), + dtype=torch.int32, + device=device, + ), + decode_token_count=lambda: 2, + ) + captured = {} + + def fake_prepare_mxfp4(**kwargs): + index_q = kwargs["index_q"] + rows = index_q.shape[0] + return ( + ( + torch.empty( + (rows, n_head, head_dim // 2), dtype=torch.uint8, device=device + ), + torch.empty((rows, n_head, 1), dtype=torch.uint8, device=device), + ), + torch.empty((rows, n_head), dtype=torch.float32, device=device), + ) + + def fake_prepare_reference(**kwargs): + captured["reference_rows"] = kwargs["positions"].numel() + rows = kwargs["positions"].numel() + return ( + torch.empty( + (rows, n_head, head_dim), dtype=torch.float32, device=device + ), + torch.empty((rows, n_head), dtype=torch.float32, device=device), + ) + + def fake_sparse_indexer(**kwargs): + captured["fallback_rows"] = kwargs["fallback_index_q"].shape[0] + captured["has_packed_q"] = kwargs["has_packed_q"] + captured["num_prefill_tokens"] = kwargs["num_prefill_tokens"] + captured["num_decode_tokens"] = kwargs["num_decode_tokens"] + return torch.full( + (total_tokens, self_obj.topk_tokens), + 7, + dtype=torch.int32, + device=device, + ) + + empty_prefill_metadata = SimpleNamespace( + chunk_bounds=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + chunk_plan=torch.empty((0, 7), dtype=torch.int64, device="cpu"), + slots=torch.empty(0, dtype=torch.int64, device=device), + cu_seq_lens=torch.empty(0, dtype=torch.int32, device=device), + cu_start=torch.empty(0, dtype=torch.int32, device=device), + cu_end=torch.empty(0, dtype=torch.int32, device=device), + row_lens=torch.empty(0, dtype=torch.int32, device=device), + ) + decode_metadata = SimpleNamespace( + context_lens=torch.ones((2, 1), dtype=torch.int32, device=device), + block_table=torch.zeros((2, 1), dtype=torch.int32, device=device), + max_context_len=1, + ) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_mxfp4", + side_effect=fake_prepare_mxfp4, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_deepgemm_fp4_indexer_available", + return_value=False, + ), patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_reference", + side_effect=fake_prepare_reference, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_prefill_metadata", + return_value=empty_prefill_metadata, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_decode_metadata", + return_value=decode_metadata, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_decode_schedule_metadata", + return_value=None, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_sparse_attn_indexer", + side_effect=fake_sparse_indexer, + ): + actual = DeepseekV4Indexer._forward_sparse_indexer_custom_op( + self_obj, + hidden_states=torch.zeros((total_tokens, 8), device=device), + qr=torch.zeros((total_tokens, 8), device=device), + positions=torch.arange(total_tokens, dtype=torch.int64, device=device), + metadata=metadata, + indexer_cache=torch.empty((1, 1), dtype=torch.uint8, device=device), + indexer_block_size=1, + cos_sin_cache=torch.empty((1, 1), device=device), + ) + + self.assertEqual(tuple(actual.shape), (total_tokens, self_obj.topk_tokens)) + self.assertEqual(captured["reference_rows"], total_tokens) + self.assertEqual(captured["fallback_rows"], total_tokens) + self.assertFalse(captured["has_packed_q"]) + self.assertEqual(captured["num_prefill_tokens"], 1) + self.assertEqual(captured["num_decode_tokens"], 2) + def test_deepseek_v4_indexer_prefill_topk_chunks_cap_logits_bytes(self): positions = torch.tensor([3, 7, 11, 15], dtype=torch.int64) @@ -1261,6 +2251,257 @@ def test_deepseek_v4_indexer_prefill_topk_chunks_cap_logits_bytes(self): [(0, 1)], ) + def test_deepseek_v4_indexer_prefill_topk_chunks_use_cpu_lengths(self): + positions = torch.zeros(6, dtype=torch.int64) + + self.assertEqual( + _deepseek_v4_indexer_prefill_topk_chunks( + positions, + compress_ratio=4, + max_logits_bytes=16, + seq_lens_cpu=torch.tensor([12, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + ), + [(0, 2), (2, 3), (3, 4), (4, 6)], + ) + + def test_deepseek_v4_mixed_indexer_fallback_uses_compressed_block_table(self): + base_block_table = torch.tensor([[1]], dtype=torch.int32) + indexer_block_table = torch.tensor([[7]], dtype=torch.int32) + captured = {} + + class FakeLinear: + def __init__(self, out_features): + self.out_features = out_features + + def __call__(self, x): + return ( + torch.zeros( + (x.shape[0], self.out_features), + dtype=torch.float32, + device=x.device, + ), + None, + ) + + class FakeCompressor: + def __init__(self): + self.norm = SimpleNamespace( + weight=torch.ones(1), + variance_epsilon=1e-6, + ) + + def __call__(self, **kwargs): + return None + + pool = SimpleNamespace( + state_block_size=4, + get_indexer_state_buffer=lambda layer_id: torch.empty((1, 1)), + get_indexer_block_size=lambda layer_id: 4, + get_indexer_kv_buffer_2d=lambda layer_id: torch.empty((8, 128)), + ) + metadata = SimpleNamespace( + forward_mode=ForwardMode.MIXED, + indexer_state_block_table=None, + block_table=base_block_table, + token_to_req_indices=torch.tensor([0, 0], dtype=torch.int32), + compressed_block_table=( + lambda compress_ratio, block_size: indexer_block_table + ), + compressed_slot_mapping=lambda *args, **kwargs: torch.zeros( + 2, dtype=torch.int64 + ), + decode_token_count=lambda: 0, + num_prefill_tokens=2, + num_prefill_reqs=1, + seq_lens_cpu=torch.tensor([8], dtype=torch.int32), + query_lens_cpu=torch.tensor([2], dtype=torch.int32), + ) + ctx = SimpleNamespace( + token_to_kv_pool=pool, + attn_backend=SimpleNamespace(forward_metadata=metadata), + forward_mode=ForwardMode.MIXED, + ) + self_obj = SimpleNamespace( + use_fp4_cache=False, + compressor=FakeCompressor(), + compress_ratio=4, + n_head=1, + head_dim=4, + softmax_scale=1.0, + topk_tokens=2, + topk_buffer=None, + wq_b=FakeLinear(4), + weights_proj=FakeLinear(1), + _forward_sparse_indexer_custom_op=lambda **kwargs: None, + ) + + def fake_prepare_reference(**kwargs): + rows = kwargs["positions"].numel() + return ( + torch.zeros((rows, 1, 4), dtype=torch.float32), + torch.zeros((rows, 1), dtype=torch.float32), + ) + + def fake_topk_from_cache(**kwargs): + captured["block_table"] = kwargs["block_table"] + rows = kwargs["positions"].numel() + return torch.full((rows, 2), 3, dtype=torch.int32) + + with patch.object( + deepseek_v4_model, + "deepseek_v4_csa_indexer_cache_insert", + return_value=None, + ), patch.object( + deepseek_v4_model, + "deepseek_v4_prepare_indexer_q_reference", + side_effect=fake_prepare_reference, + ), patch.object( + deepseek_v4_model, + "_deepseek_v4_indexer_topk_from_cache_batched", + side_effect=fake_topk_from_cache, + ): + topk = DeepseekV4Indexer.forward( + self_obj, + hidden_states=torch.zeros((2, 8)), + qr=torch.zeros((2, 8)), + positions=torch.tensor([6, 7], dtype=torch.int64), + ctx=ctx, + out_cache_loc=torch.zeros(2, dtype=torch.int64), + layer_index=0, + cos_sin_cache=torch.empty((1, 1)), + ) + + self.assertTrue(torch.equal(captured["block_table"], indexer_block_table)) + self.assertTrue(torch.equal(topk, torch.full((2, 2), 3, dtype=torch.int32))) + + def test_deepseek_v4_indexer_prefill_gather_plan_reuses_request_k(self): + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_gather_plan( + positions=torch.tensor([0, 1, 5, 0, 3], dtype=torch.int64), + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32), + block_table=torch.tensor([[10], [20]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=2, + ) + ) + + self.assertTrue(torch.equal(slots, torch.tensor([40, 41, 42, 80, 81]))) + self.assertTrue(torch.equal(cu_start, torch.tensor([0, 0, 0, 3, 3]))) + self.assertTrue(torch.equal(cu_end, torch.tensor([0, 1, 3, 3, 5]))) + self.assertTrue(torch.equal(row_lens, torch.tensor([0, 1, 3, 0, 2]))) + self.assertEqual(max_len, 3) + + def test_deepseek_v4_indexer_prefill_request_chunks_match_reference(self): + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=torch.tensor([16], dtype=torch.int32), + query_lens_cpu=torch.tensor([6], dtype=torch.int32), + compress_ratio=4, + num_tokens=6, + max_logits_bytes=32, + workspace_size=100, + ) + + self.assertEqual( + [ + ( + c.req_start, + c.req_end, + c.query_start, + c.query_end, + c.token_start, + c.token_end, + c.skip_kv_gather, + ) + for c in chunks + ], + [ + (0, 1, 0, 2, 0, 2, False), + (0, 1, 2, 4, 2, 4, True), + (0, 1, 4, 6, 4, 6, True), + ], + ) + + chunks = _deepseek_v4_indexer_prefill_request_chunks( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([2, 2], dtype=torch.int32), + compress_ratio=4, + num_tokens=4, + max_logits_bytes=128, + workspace_size=100, + ) + + self.assertEqual(len(chunks), 1) + self.assertEqual((chunks[0].req_start, chunks[0].req_end), (0, 2)) + self.assertEqual((chunks[0].token_start, chunks[0].token_end), (0, 4)) + self.assertFalse(chunks[0].skip_kv_gather) + + def test_deepseek_v4_indexer_prefill_request_gather_plan_matches_reference(self): + slots, cu_start, cu_end, row_lens, max_len = ( + _deepseek_v4_indexer_prefill_request_gather_plan( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + block_table=torch.tensor([[10], [20]], dtype=torch.int32), + cache_block_size=4, + compress_ratio=4, + req_start=0, + req_end=2, + query_start=1, + query_end=5, + ) + ) + + self.assertTrue(torch.equal(slots, torch.tensor([40, 41, 42, 43, 80, 81]))) + self.assertTrue(torch.equal(cu_start, torch.tensor([0, 0, 0, 4]))) + self.assertTrue(torch.equal(cu_end, torch.tensor([3, 3, 4, 5]))) + self.assertTrue(torch.equal(row_lens, torch.tensor([3, 3, 4, 1]))) + self.assertEqual(max_len, 4) + + def test_deepseek_v4_indexer_prefill_metadata_packs_and_caches_plan(self): + metadata = SimpleNamespace( + seq_lens_cpu=torch.tensor([16, 8], dtype=torch.int32), + query_lens_cpu=torch.tensor([4, 2], dtype=torch.int32), + num_prefill_reqs=2, + prefill_indexer_plan_cache={}, + ) + block_table = torch.tensor([[10], [20]], dtype=torch.int32) + + actual = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=block_table, + cache_block_size=4, + compress_ratio=4, + num_prefill_tokens=6, + ) + cached = _deepseek_v4_indexer_prefill_metadata( + metadata=metadata, + block_table=block_table, + cache_block_size=4, + compress_ratio=4, + num_prefill_tokens=6, + ) + + self.assertIs(actual, cached) + self.assertTrue( + torch.equal( + actual.chunk_bounds, + torch.tensor([[0, 6, 0, 2, 0, 6, 0]], dtype=torch.int64), + ) + ) + self.assertTrue( + torch.equal( + actual.chunk_plan, + torch.tensor([[0, 6, 0, 6, 4, 0, 3]], dtype=torch.int64), + ) + ) + self.assertEqual(actual.slots.numel(), 0) + self.assertTrue( + torch.equal(actual.cu_seq_lens, torch.tensor([0, 4, 6], dtype=torch.int32)) + ) + self.assertTrue(torch.equal(actual.cu_start, torch.tensor([0, 0, 0, 0, 4, 4]))) + self.assertTrue(torch.equal(actual.cu_end, torch.tensor([3, 3, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(actual.row_lens, torch.tensor([3, 3, 3, 4, 1, 2]))) + def test_hidden_compression_helpers_preserve_expected_shapes(self): import torch @@ -1456,6 +2697,8 @@ def test_deepseek_v4_gate_fallback_returns_fp32_logits(self): topk_method=None, ) gate = DeepseekV4MoEGate(config, layer_index=1) + with torch.no_grad(): + gate.weight.copy_(torch.randn_like(gate.weight)) hidden_states = torch.randn(3, config.hidden_size) logits = gate(hidden_states) @@ -1634,38 +2877,6 @@ def test_packed_topk_router_logits_recover_weights_after_softmax(self): self.assertTrue(torch.allclose(recovered, topk_weights)) - def test_mxfp4_scale_dtype_preserves_e8m0_checkpoint_bits(self): - import torch - - if not hasattr(torch, "float8_e8m0fnu"): - self.skipTest("float8_e8m0fnu is unavailable") - - loaded = torch.tensor( - [[0.0078125, 0.015625], [0.03125, 0.0625]], dtype=torch.float32 - ).to(torch.float8_e8m0fnu) - param = torch.empty_like(loaded, dtype=MXFP4_SCALE_DTYPE) - param.copy_(loaded) - - self.assertEqual(MXFP4_SCALE_DTYPE, torch.float8_e8m0fnu) - self.assertTrue(torch.equal(param.view(torch.uint8), loaded.view(torch.uint8))) - - def test_mxfp4_triton_scale_layout_uses_uint8_view_for_e8m0(self): - import torch - - if not hasattr(torch, "float8_e8m0fnu"): - self.skipTest("float8_e8m0fnu is unavailable") - - scale = torch.tensor( - [[0.0078125, 0.015625], [0.03125, 0.0625]], dtype=torch.float32 - ).to(torch.float8_e8m0fnu) - - layout_scale = _mxfp4_scale_for_layout(scale) - self.assertEqual(layout_scale.dtype, torch.uint8) - self.assertTrue(torch.equal(layout_scale, scale.view(torch.uint8))) - - uint8_scale = scale.view(torch.uint8) - self.assertIs(_mxfp4_scale_for_layout(uint8_scale), uint8_scale) - def test_mxfp4_flashinfer_reorders_w1w3_halves_for_trtllm(self): import torch diff --git a/test/runtime/test_generation_output_processor.py b/test/runtime/test_generation_output_processor.py new file mode 100644 index 000000000..91c1724cc --- /dev/null +++ b/test/runtime/test_generation_output_processor.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import torch + +from tokenspeed.runtime.engine.generation_output_processor import ( + OutputProcesser, + RequestState, +) +from tokenspeed.runtime.sampling.sampling_params import SamplingParams + + +class _Sender: + def __init__(self): + self.items = [] + + def send_pyobj(self, obj): + self.items.append(obj) + + +class _Tokenizer: + eos_token_id = None + additional_stop_token_ids = None + + def decode(self, ids): + return "".join(str(i) for i in ids) + + +class _Metrics: + enabled = False + + +class _ForwardOp: + request_ids = ["prefill", "decode"] + request_pool_indices = [0, 1] + input_lengths = [4, 1] + extend_prefix_lens = [0] + + def num_extends(self): + return 1 + + +class _ExecutionResult: + output_tokens = torch.tensor([11, 22], dtype=torch.int32) + output_lengths = torch.tensor([1, 1], dtype=torch.int32) + output_logprobs = None + grammar_completion = None + + def sync(self): + return None + + +def _state(input_ids: list[int], *, computed_length: int = 0) -> RequestState: + state = RequestState( + prompt_input_ids=input_ids, + sampling_params=SamplingParams(max_new_tokens=8, stop=[], ignore_eos=True), + stream=False, + tokenizer=_Tokenizer(), + ) + state.computed_length = computed_length + return state + + +def test_mixed_forward_updates_reserve_for_decode_slots_only(): + sender = _Sender() + processor = OutputProcesser( + sender, + global_rank=0, + metrics=_Metrics(), + ) + processor.rid_to_state["prefill"] = _state([1, 2, 3, 4]) + processor.rid_to_state["decode"] = _state([5, 6, 7], computed_length=3) + + events = processor.post_process_forward_op(_ForwardOp(), _ExecutionResult()) + + reserve_events = [ + event for event in events if type(event).__name__ == "UpdateReserveNumTokens" + ] + assert len(reserve_events) == 1 + assert reserve_events[0].request_id == "decode" + assert reserve_events[0].reserve_num_tokens_in_next_schedule_event == 1 diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py index ee4962c37..8f44aa6a4 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_mla/__init__.py @@ -24,11 +24,13 @@ platform = current_platform() flash_mla_with_kvcache = error_fn +flash_mla_sparse_fwd = error_fn get_mla_metadata = error_fn if platform.is_nvidia and platform.is_hopper: try: from flash_mla import ( + flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) @@ -39,4 +41,4 @@ # Direct export # ------------------------------------------------------------------------------ -__all__ = ["flash_mla_with_kvcache", "get_mla_metadata"] +__all__ = ["flash_mla_sparse_fwd", "flash_mla_with_kvcache", "get_mla_metadata"] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py new file mode 100644 index 000000000..8b0773ab2 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/deepseek_v4.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton + +__all__ = ["deepseek_v4_indexer_decode_metadata_compute"] + + +@triton.jit +def _deepseek_v4_indexer_decode_metadata_kernel( + out_block_tables_ptr, + out_block_tables_stride, + out_context_lens_ptr, + positions_ptr, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + rows: tl.constexpr, + cols: tl.constexpr, + compress_ratio: tl.constexpr, + cache_block_size: tl.constexpr, + max_blocks: tl.constexpr, + candidate_block: tl.constexpr, +): + token_idx = tl.program_id(0) + pos = tl.load(positions_ptr + token_idx).to(tl.int64) + compressed_lens = tl.maximum((pos + 1) // compress_ratio, 0) + req = tl.load(token_to_req_indices_ptr + token_idx).to(tl.int32) + req_valid = (req >= 0) & (req < rows) + safe_req = tl.maximum(0, tl.minimum(req, rows - 1)) + num_valid_pages = tl.zeros((), dtype=tl.int64) + for col_start in range(0, max_blocks, candidate_block): + col_offsets = col_start + tl.arange(0, candidate_block) + col_mask = col_offsets < max_blocks + in_cols = col_offsets < cols + safe_col = tl.where(in_cols, col_offsets, 0) + bt_load_mask = col_mask & in_cols & req_valid + bt_vals = tl.load( + block_table_ptr + safe_req * block_table_stride + safe_col, + mask=bt_load_mask, + other=0, + ) + page_valid = (bt_vals >= 0) & in_cols + final_mask = page_valid & req_valid & col_mask + masked_bt = tl.where(final_mask, bt_vals, 0) + tl.store( + out_block_tables_ptr + token_idx * out_block_tables_stride + col_offsets, + masked_bt, + mask=col_mask, + ) + num_valid_pages += tl.sum(final_mask.to(tl.int64), axis=0) + available_lens = num_valid_pages * cache_block_size + context_len_val = tl.minimum(compressed_lens, available_lens) + context_len_val = tl.where(req_valid, context_len_val, 0) + tl.store(out_context_lens_ptr + token_idx, context_len_val.to(tl.int32)) + + +def deepseek_v4_indexer_decode_metadata_compute( + *, + positions: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + cache_block_size: int, + compress_ratio: int, + max_blocks: int, + out_context_lens: torch.Tensor, + out_block_tables: torch.Tensor, +) -> None: + """Build decode-indexer context lengths and block tables in one Triton pass.""" + num_tokens = int(positions.shape[0]) if positions.ndim >= 1 else 0 + if num_tokens == 0: + return + if out_context_lens.dtype != torch.int32 or out_block_tables.dtype != torch.int32: + raise TypeError("output buffers must be int32") + positions_i64 = positions.to(torch.int64) + token_to_req_indices_i32 = token_to_req_indices.to(torch.int32) + block_table_i32 = block_table.to(torch.int32) + rows = int(block_table.shape[0]) if block_table.ndim >= 1 else 0 + cols = int(block_table.shape[1]) if block_table.ndim >= 2 else 0 + candidate_block = min(1024, max(16, triton.next_power_of_2(max_blocks))) + _deepseek_v4_indexer_decode_metadata_kernel[(num_tokens,)]( + out_block_tables, + out_block_tables.stride(0), + out_context_lens, + positions_i64, + token_to_req_indices_i32, + block_table_i32, + block_table_i32.stride(0), + rows=rows, + cols=cols, + compress_ratio=int(compress_ratio), + cache_block_size=int(cache_block_size), + max_blocks=int(max_blocks), + candidate_block=candidate_block, + ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py index a010c16a7..92e09a04f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py @@ -44,11 +44,180 @@ "moe_align_block_size", "moe_sum_reduce_torch_compile", "moe_sum_reduce_triton", + "stage_deepseek_v4_mega_moe_inputs", ] padding_size = 128 if bool(int(os.getenv("TOKENSPEED_MOE_PADDING", "0"))) else 0 +# --------------------------------------------------------------------------- +# DeepSeek V4 MegaMoE staging +# --------------------------------------------------------------------------- + + +_DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE = 128 + + +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_K: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K + hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(hidden_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + scale = amax / 448.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( + tl.uint32 + ) + scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) + rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + scaled = tl.reshape(scaled, [BLOCK_K]) + fp8 = scaled.to(tl.float8e4nv) + tl.store( + x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, + fp8, + mask=k_mask, + ) + + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp8: torch.Tensor, + x_sf: torch.Tensor, + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % _DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + f"a multiple of {_DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE}." + ) + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = _DEEPSEEK_V4_MEGAMOE_FP8_BLOCK_SIZE + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(topk_ids.shape[1]) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + topk_ids.shape[1], + BLOCK_K=block_k, + GROUP_K=32, + BLOCK_TOPK=block_topk, + num_warps=4, + ) + + # --------------------------------------------------------------------------- # Routing (top-k) # --------------------------------------------------------------------------- diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu index d51c978d0..3fe93c4e6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention.cu @@ -18,7 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. // -// DeepSeek V4 fused SWA cache insert. +// DeepSeek V4 fused SWA cache insert and sparse attention/indexer helpers. // // Cache layout per paged block: // [0, block_size * 576): token data, each token [448 fp8 bytes | 64 bf16/fp16] @@ -49,6 +49,76 @@ constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; constexpr int kThreads = 256; constexpr float kFp8Max = 448.0f; +template +__global__ void gather_paged_indexer_mxfp4_cache_kernel( + const uint8_t* __restrict__ kv_cache, + uint8_t* __restrict__ values_out, + uint8_t* __restrict__ scales_out, + const int32_t* __restrict__ block_table, + const int32_t* __restrict__ cu_seq_lens, + int batch_size, + int num_tokens, + int value_bytes, + int scale_bytes, + int cache_block_size, + int64_t cache_block_stride, + int64_t value_stride, + int64_t scale_stride, + int64_t block_table_stride) { + constexpr int kVecBytes = sizeof(uint4); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * kVecBytes; + + __shared__ int batch_idx[BlockYSize]; + if (threadIdx.x == 0) { + batch_idx[threadIdx.y] = -1; + } + __syncthreads(); + + for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x; + ++iter) { + const int req = iter * blockDim.x + threadIdx.x; + if (req < batch_size) { + const int seq_start = cu_seq_lens[req]; + const int seq_end = cu_seq_lens[req + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = req; + } + } + } + __syncthreads(); + + const int req = batch_idx[threadIdx.y]; + if (token_idx >= num_tokens || req < 0) { + return; + } + + const int in_req_token_idx = token_idx - cu_seq_lens[req]; + const int block_idx = + block_table[static_cast(req) * block_table_stride + + in_req_token_idx / cache_block_size]; + const int block_offset = in_req_token_idx % cache_block_size; + const int64_t block_base = static_cast(block_idx) * cache_block_stride; + + if (head_idx < value_bytes) { + const int64_t value_src = + block_base + static_cast(block_offset) * value_bytes + head_idx; + const int64_t value_dst = + static_cast(token_idx) * value_stride + head_idx; + *reinterpret_cast(values_out + value_dst) = + *reinterpret_cast(kv_cache + value_src); + } + + if (blockIdx.y == 0 && threadIdx.x == 0) { + const int64_t scale_src = + block_base + static_cast(cache_block_size) * value_bytes + + static_cast(block_offset) * scale_bytes; + const int64_t scale_dst = static_cast(token_idx) * scale_stride; + *reinterpret_cast(scales_out + scale_dst) = + *reinterpret_cast(kv_cache + scale_src); + } +} + template __device__ __forceinline__ float scalar_to_float(scalar_t value); @@ -179,8 +249,8 @@ __global__ void fused_qnorm_rope_kv_insert_kernel( } __syncthreads(); - // Match vLLM's numeric contract: materialize K at activation dtype before - // the UE8M0 absmax and final cache write. + // Match the reference cache writer by materializing K at activation dtype + // before the UE8M0 absmax and final cache write. for (int dim = tid; dim < kHeadDim; dim += blockDim.x) { values[dim] = scalar_to_float(float_to_scalar(values[dim])); } @@ -249,6 +319,102 @@ void launch_fused_qnorm_rope_kv_insert( } // namespace +void deepseek_v4_gather_paged_indexer_mxfp4_cache(TensorView kv_cache, + TensorView values_out, + TensorView scales_out, + TensorView block_table, + TensorView cu_seq_lens, + int64_t cache_block_size) { + CHECK_CUDA(kv_cache); + CHECK_CUDA(values_out); + CHECK_CUDA(scales_out); + CHECK_CUDA(block_table); + CHECK_CUDA(cu_seq_lens); + CHECK_DIM(2, kv_cache); + CHECK_DIM(2, values_out); + CHECK_DIM(2, scales_out); + CHECK_DIM(2, block_table); + CHECK_DIM(1, cu_seq_lens); + + TVM_FFI_ICHECK(kv_cache.dtype() == dl_uint8) << "kv_cache must be uint8"; + TVM_FFI_ICHECK(values_out.dtype() == dl_uint8) << "values_out must be uint8"; + TVM_FFI_ICHECK(scales_out.dtype() == dl_uint8) << "scales_out must be uint8"; + TVM_FFI_ICHECK(block_table.dtype() == dl_int32) + << "block_table must be int32"; + TVM_FFI_ICHECK(cu_seq_lens.dtype() == dl_int32) + << "cu_seq_lens must be int32"; + TVM_FFI_ICHECK(kv_cache.stride(1) == 1) << "kv_cache last dim must be contiguous"; + TVM_FFI_ICHECK(values_out.stride(1) == 1) + << "values_out last dim must be contiguous"; + TVM_FFI_ICHECK(scales_out.stride(1) == 1) + << "scales_out last dim must be contiguous"; + TVM_FFI_ICHECK(cache_block_size > 0) << "cache_block_size must be positive"; + TVM_FFI_ICHECK(cu_seq_lens.size(0) == block_table.size(0) + 1) + << "cu_seq_lens must have batch_size + 1 entries"; + + const int batch_size = static_cast(block_table.size(0)); + const int num_tokens = static_cast(values_out.size(0)); + TVM_FFI_ICHECK(scales_out.size(0) >= num_tokens) + << "scales_out must cover values_out rows"; + // Output rows may be an exact length or a conservative upper bound, so do + // not read cu_seq_lens[-1] on host here. The kernel only writes rows covered + // by device-side cu_seq_lens. + if (batch_size == 0 || num_tokens == 0) { + return; + } + const int value_bytes = static_cast(values_out.size(1)); + const int scale_bytes = static_cast(scales_out.size(1)); + TVM_FFI_ICHECK(value_bytes > 0 && value_bytes % static_cast(sizeof(uint4)) == 0) + << "values_out width must be a positive multiple of 16 bytes"; + TVM_FFI_ICHECK(scale_bytes > 0) << "scales_out width must be positive"; + TVM_FFI_ICHECK(scale_bytes == static_cast(sizeof(uint32_t))) + << "paged indexer MXFP4 gather expects 4 scale bytes per row"; + TVM_FFI_ICHECK(kv_cache.size(1) >= cache_block_size * (value_bytes + scale_bytes)) + << "kv_cache block stride is too small for indexer MXFP4 rows"; + + cudaSetDevice(kv_cache.device().device_id); + const cudaStream_t stream = get_stream(kv_cache.device()); + constexpr int kBlockX = 8; + constexpr int kVecBytes = sizeof(uint4); + const int grid_y = (value_bytes + kBlockX * kVecBytes - 1) / (kBlockX * kVecBytes); + +#define LAUNCH_PAGED_GATHER(BLOCK_Y) \ + do { \ + const dim3 grid((num_tokens + (BLOCK_Y)-1) / (BLOCK_Y), grid_y); \ + const dim3 block(kBlockX, (BLOCK_Y)); \ + gather_paged_indexer_mxfp4_cache_kernel<(BLOCK_Y)> \ + <<>>( \ + static_cast(kv_cache.data_ptr()), \ + static_cast(values_out.data_ptr()), \ + static_cast(scales_out.data_ptr()), \ + static_cast(block_table.data_ptr()), \ + static_cast(cu_seq_lens.data_ptr()), batch_size, \ + num_tokens, value_bytes, scale_bytes, \ + static_cast(cache_block_size), kv_cache.stride(0), \ + values_out.stride(0), scales_out.stride(0), block_table.stride(0)); \ + } while (0) + + if (num_tokens < 32) { + LAUNCH_PAGED_GATHER(1); + } else if (num_tokens < 64) { + LAUNCH_PAGED_GATHER(2); + } else if (num_tokens < 128) { + LAUNCH_PAGED_GATHER(4); + } else if (num_tokens < 256) { + LAUNCH_PAGED_GATHER(8); + } else if (num_tokens < 512) { + LAUNCH_PAGED_GATHER(16); + } else { + LAUNCH_PAGED_GATHER(32); + } +#undef LAUNCH_PAGED_GATHER + + cudaError_t status = cudaGetLastError(); + TVM_FFI_ICHECK(status == cudaSuccess) + << "deepseek_v4_gather_paged_indexer_mxfp4_cache failed: " + << cudaGetErrorString(status); +} + void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( TensorView q, TensorView kv, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu index c17011b00..058d30bab 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/csrc/deepseek_v4_attention_binding.cu @@ -38,7 +38,16 @@ void deepseek_v4_indexer_topk_prefill(TensorView logits, TensorView output, int64_t k); +void deepseek_v4_gather_paged_indexer_mxfp4_cache(TensorView kv_cache, + TensorView values_out, + TensorView scales_out, + TensorView block_table, + TensorView cu_seq_lens, + int64_t cache_block_size); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert, fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); TVM_FFI_DLL_EXPORT_TYPED_FUNC(deepseek_v4_indexer_topk_prefill, deepseek_v4_indexer_topk_prefill); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(deepseek_v4_gather_paged_indexer_mxfp4_cache, + deepseek_v4_gather_paged_indexer_mxfp4_cache); diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py index 2962ca068..40f9c6bf9 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/cuda/deepseek_v4_attention.py @@ -40,6 +40,14 @@ def has_indexer_topk_prefill() -> bool: return hasattr(module, "deepseek_v4_indexer_topk_prefill") +def has_indexer_mxfp4_paged_gather() -> bool: + try: + module = _load_deepseek_v4_attention_module() + except Exception: + return False + return hasattr(module, "deepseek_v4_gather_paged_indexer_mxfp4_cache") + + def fused_qnorm_rope_kv_insert( q: torch.Tensor, kv: torch.Tensor, @@ -97,3 +105,36 @@ def indexer_topk_prefill( output, int(k), ) + + +def indexer_mxfp4_paged_gather( + kv_cache: torch.Tensor, + values_out: torch.Tensor, + scales_out: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + cache_block_size: int, +) -> None: + if kv_cache.dtype != torch.uint8: + raise TypeError(f"kv_cache must be uint8, got {kv_cache.dtype}") + if values_out.dtype != torch.uint8: + raise TypeError(f"values_out must be uint8, got {values_out.dtype}") + if scales_out.dtype != torch.uint8: + raise TypeError(f"scales_out must be uint8, got {scales_out.dtype}") + if block_table.dtype != torch.int32: + block_table = block_table.to(torch.int32) + if cu_seq_lens.dtype != torch.int32: + cu_seq_lens = cu_seq_lens.to(torch.int32) + if values_out.shape[0] != scales_out.shape[0]: + raise ValueError( + "DeepSeek V4 paged gather output value/scale rows must match, " + f"got values={values_out.shape[0]}, scales={scales_out.shape[0]}" + ) + _load_deepseek_v4_attention_module().deepseek_v4_gather_paged_indexer_mxfp4_cache( + kv_cache, + values_out, + scales_out, + block_table.contiguous(), + cu_seq_lens.contiguous(), + int(cache_block_size), + ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py index a334cec51..cbf8103c0 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/thirdparty/trtllm/__init__.py @@ -141,13 +141,13 @@ def fast_topk_v2( topk: int, next_n: int = 1, ): - seq_lens = seq_lens.to(torch.int32).contiguous() + seq_lens = seq_lens.to(torch.int32).reshape(-1).contiguous() if next_n == 1: torch.ops.trtllm.indexer_topk_decode( values, seq_lens, indices, next_n, topk ) else: - row_ends = seq_lens.cumsum(0) + row_ends = torch.cumsum(seq_lens, dim=0, dtype=torch.int32) row_starts = row_ends - seq_lens torch.ops.trtllm.indexer_topk_prefill( values, row_starts, row_ends, indices, topk diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index 428d36faa..12f5e7dc6 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -218,6 +218,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("enable_l3_storage", &tokenspeed::SchedulerConfig::enable_l3_storage) .def_rw("prefetch_threshold", &tokenspeed::SchedulerConfig::prefetch_threshold) .def_rw("enable_kv_cache_events", &tokenspeed::SchedulerConfig::enable_kv_cache_events) + .def_rw("enable_mixed_prefill_decode", &tokenspeed::SchedulerConfig::enable_mixed_prefill_decode) .def_rw("disable_prefix_cache", &tokenspeed::SchedulerConfig::disable_prefix_cache) .def_rw("enable_mamba", &tokenspeed::SchedulerConfig::enable_mamba) .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index de978d82e..a51790c66 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -509,6 +509,11 @@ Scheduler::newForwardOperation(std::vector candidates) { } ops.push_back(std::move(op)); }; + auto has_prefill_op = [&]() { + return std::any_of(ops.begin(), ops.end(), [](const ForwardOperation& op) { + return std::holds_alternative(op); + }); + }; std::vector loadback_ops; auto simulated_free = initialPagedCacheGroupSimulatedFree(); for (Request* request : candidates) { @@ -535,13 +540,13 @@ Scheduler::newForwardOperation(std::vector candidates) { } } else if (request->Is() || (request->Is() && config_.role != Role::kP)) { // Prefill-first: skip ALL decode if any prefill was scheduled this round. - if (!ops.empty() && std::holds_alternative(ops.back())) break; + if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; if (auto ev = scheduleDecode(request, simulated_free)) { push_op(applyEventAndGenerateOp(request, *ev)); } } else if (request->Is() && config_.role != Role::kP) { - if (!ops.empty() && std::holds_alternative(ops.back())) break; + if (!config_.enable_mixed_prefill_decode && has_prefill_op()) break; if (auto ev = scheduleDecodeFromRetracted(request, simulated_free)) { std::vector loadback_diff = ev->GetLoadbackDiff(); diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index d1b2173ba..e48d042bf 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -83,6 +83,7 @@ struct SchedulerConfig { bool enable_l3_storage{false}; std::int32_t prefetch_threshold{4}; // num pages bool enable_kv_cache_events{false}; + bool enable_mixed_prefill_decode{false}; std::int32_t num_pages_reserved_for_retracted_or_running{}; Role role{Role::kFused}; diff --git a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py index 17175ad76..ab831137f 100644 --- a/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py +++ b/tokenspeed-scheduler/python/tests/test_fsm_and_scheduling.py @@ -280,6 +280,26 @@ def test_decode_batch_only_when_no_prefill_work(self): assert plan.forward[0].num_extends() > 0 assert plan.forward[0].request_ids == ["r1"] + def test_mixed_prefill_decode_can_schedule_decode_with_new_prefill(self): + cfg = make_config(max_scheduled_tokens=512, max_batch_size=8) + cfg.enable_mixed_prefill_decode = True + s = Scheduler(cfg) + + submit(s, "r0", list(range(8))) + s.next_execution_plan() # r0 → PrefillDone + s.next_execution_plan() # r0 → Decoding + advance_forward(s, "r0", tokens=[99]) + + submit(s, "r1", list(range(8))) + plan = s.next_execution_plan() + op = plan.forward[0] + + assert op.request_ids == ["r1", "r0"] + assert op.num_extends() == 1 + assert len(op.input_ids) == sum(op.input_lengths[: op.num_extends()]) + assert len(op.input_ids) + len(op.decode_input_ids) == sum(op.input_lengths) + assert op.sizes == [1, 0] + def test_max_batch_size_limits_scheduled_requests(self): """max_batch_size caps the number of requests per plan.""" s = Scheduler(make_config(max_scheduled_tokens=512, max_batch_size=2)) @@ -728,9 +748,9 @@ def test_retract_recovered_carries_last_prefill_token(self): def test_mixed_batch_decode_input_ids_length(self): """decode_input_ids has one entry per decode request; all -1 for normal decodes.""" - s = Scheduler( - make_config(page_size=16, num_device_pages=1024, max_batch_size=8) - ) + cfg = make_config(page_size=16, num_device_pages=1024, max_batch_size=8) + cfg.enable_mixed_prefill_decode = True + s = Scheduler(cfg) # Bring r0 to Decoding. submit(s, "r0", list(range(8))) s.next_execution_plan() # r0 → PrefillDone From f48e2645ed9a7da89c1c1d1b9a2dca61113748a1 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 13 May 2026 13:07:29 +0000 Subject: [PATCH 2/2] feat(deepseek-v4): support mtp speculative decoding Co-authored-by: jiyingd <87510204+dongjiyingdjy@users.noreply.github.com> Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> Signed-off-by: jiyingd <87510204+dongjiyingdjy@users.noreply.github.com> --- docs/serving/deepseek-v4.md | 17 +- .../runtime/configs/model_config.py | 35 +- .../runtime/execution/cuda_graph_wrapper.py | 12 +- .../runtime/execution/drafter/eagle.py | 1 + .../layers/attention/backends/deepseek_v4.py | 458 +++++++++++++-- .../layers/attention/kv_cache/deepseek_v4.py | 21 +- .../runtime/layers/attention/registry.py | 57 +- .../runtime/layers/attention/utils.py | 8 +- .../tokenspeed/runtime/models/deepseek_v4.py | 85 ++- .../runtime/models/deepseek_v4_mtp.py | 524 +++++++++++++++++ test/runtime/test_deepseek_v4_config.py | 526 ++++++++++++++++++ 11 files changed, 1654 insertions(+), 90 deletions(-) create mode 100644 python/tokenspeed/runtime/models/deepseek_v4_mtp.py diff --git a/docs/serving/deepseek-v4.md b/docs/serving/deepseek-v4.md index 0eeada139..2b313ba92 100644 --- a/docs/serving/deepseek-v4.md +++ b/docs/serving/deepseek-v4.md @@ -5,7 +5,8 @@ DP=4 + expert parallel + mega_moe + FP8 KV cache (B200, 4× SM100): ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \ +CUDA_VISIBLE_DEVICES=0,1,2,3 exec ts serve \ + --model deepseek-ai/DeepSeek-V4-Flash \ --host localhost --port 30100 \ --dist-init-addr 127.0.0.1:4013 \ --trust-remote-code \ @@ -50,6 +51,20 @@ also be bumped to 256.) - `--deepseek-v4-indexer-prefill-max-logits-mb N`: caps the FP4 indexer prefill logits buffer in MB (default 512). +## MTP speculative decoding + +DeepSeek V4 can use the checkpoint's NextN/MTP draft layers through the standard +speculative flags. For `num_steps > 1`, keep the main V4 launch flags and add: + +```bash +--speculative-algorithm MTP \ +--speculative-num-steps 3 +``` + +When `--speculative-draft-model-path` is omitted for MTP, TokenSpeed uses the +same V4 checkpoint as the draft source and loads the `DeepseekV4ForCausalLMNextN` +architecture. + ## Hardware / dependency requirements - 4× NVIDIA Blackwell SM100 (B200) GPUs. diff --git a/python/tokenspeed/runtime/configs/model_config.py b/python/tokenspeed/runtime/configs/model_config.py index 4e4099b4d..aec509c04 100644 --- a/python/tokenspeed/runtime/configs/model_config.py +++ b/python/tokenspeed/runtime/configs/model_config.py @@ -45,6 +45,7 @@ _DEEPSEEK_V4_ARCHITECTURES = frozenset( { "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", } ) _MLA_ARCHITECTURES = frozenset( @@ -85,10 +86,13 @@ def override_model_config(model_config, ext_yaml): def is_deepseek_v4(config: PretrainedConfig) -> bool: - return ( - config.architectures is not None - and config.architectures[0] in _DEEPSEEK_V4_ARCHITECTURES - ) + architectures = getattr(config, "architectures", None) or [] + return len(architectures) > 0 and architectures[0] in _DEEPSEEK_V4_ARCHITECTURES + + +def is_deepseek_v4_nextn(config: PretrainedConfig) -> bool: + architectures = getattr(config, "architectures", None) or [] + return len(architectures) > 0 and architectures[0] == "DeepseekV4ForCausalLMNextN" def configure_deepseek_v4_attention(model_config) -> None: @@ -111,6 +115,19 @@ def configure_deepseek_v4_attention(model_config) -> None: model_config.scaling = model_config.scaling * mscale * mscale +def _derive_num_attention_layers( + hf_config: PretrainedConfig, + num_hidden_layers: int, +) -> int: + architectures = getattr(hf_config, "architectures", None) or [] + num_attention_layers = num_hidden_layers + if is_deepseek_v4_nextn(hf_config): + num_attention_layers = int(getattr(hf_config, "num_nextn_predict_layers", 1)) + if any(arch in _DOUBLE_ATTENTION_LAYER_ARCHITECTURES for arch in architectures): + num_attention_layers = num_hidden_layers * 2 + return num_attention_layers + + class ModelConfig: def __init__( self, @@ -249,12 +266,10 @@ def __init__( self.num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", None) if self.num_hidden_layers is None: self.num_hidden_layers = self.hf_text_config.num_layers - self.num_attention_layers = self.num_hidden_layers - if any( - arch in _DOUBLE_ATTENTION_LAYER_ARCHITECTURES - for arch in self.hf_config.architectures - ): - self.num_attention_layers = self.num_hidden_layers * 2 + self.num_attention_layers = _derive_num_attention_layers( + self.hf_config, + self.num_hidden_layers, + ) self.vocab_size = self.hf_text_config.vocab_size # Verify quantization diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index 4402c79eb..c7c5825c8 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -357,8 +357,6 @@ def _capture_one(self, bs: int): grammar_backend=self.grammar_backend, ) - self._init_capture_metadata(bs) - def run_once(): # Dummy add_batch keeps the grammar queue 1:1 with replays — # fetch_batch pops once per forward, so warmup + capture @@ -377,6 +375,7 @@ def run_once(): self.sampling_backend.prepare_capture( bs=bs, num_tokens_per_req=self.max_tokens_per_req ) + self._init_capture_metadata(bs) run_once() # Clear any per-pool state that warm-up dirtied at pool row 0, @@ -392,6 +391,7 @@ def run_once(): self.sampling_backend.prepare_capture( bs=bs, num_tokens_per_req=self.max_tokens_per_req ) + self._init_capture_metadata(bs) self.deepep_adapter.capture() @@ -597,14 +597,16 @@ def _init_replay_metadata( **kwargs, ) if self.draft_attn_backend is not None: - # DRAFT_EXTEND covers step 0 + N-1 decode steps (drafter syncs per step). + draft_attn_kwargs = {} + if getattr(self.draft_attn_backend, "uses_padded_decode_token_mask", False): + draft_attn_kwargs["actual_bs"] = actual_bs self.draft_attn_backend.init_forward_metadata_replay_cuda_graph( padded_bs, req_pool_indices, seq_lens, req_to_page=self.drafter.req_to_page, forward_mode=ForwardMode.DRAFT_EXTEND, - **kwargs, + **draft_attn_kwargs, ) @nvtx_range("attn_meta_prep", color="orange") @@ -628,7 +630,7 @@ def _init_forward_metadata( **kwargs, ) if self.draft_attn_backend is not None: - if forward_mode.is_extend(): + if forward_mode == ForwardMode.EXTEND or forward_mode.is_mixed(): # Initial prefill: draft step 0 uses EXTEND (regular prefill) # kernel with the caller's prefix kwargs. Step 0 and the # subsequent decode steps have structurally different diff --git a/python/tokenspeed/runtime/execution/drafter/eagle.py b/python/tokenspeed/runtime/execution/drafter/eagle.py index a599a454a..ab6fc7f53 100644 --- a/python/tokenspeed/runtime/execution/drafter/eagle.py +++ b/python/tokenspeed/runtime/execution/drafter/eagle.py @@ -285,6 +285,7 @@ def _run_multi_step_decode( ) out_cache_loc = cache_locs[:, i - 1].contiguous() + ctx.attn_backend.advance_draft_forward_metadata() with nvtx_range("draft_forward", color="red"): logits_output = self.draft_model_runner.forward( diff --git a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py index f53edb274..c3962f1b2 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py @@ -237,6 +237,8 @@ def __init__(self, config) -> None: (self.context_len + self.page_size - 1) // self.page_size, ) self.forward_metadata: DeepseekV4ForwardMetadata | None = None + self.forward_prefill_metadata: DeepseekV4ForwardMetadata | None = None + self.forward_decode_metadata: DeepseekV4ForwardMetadata | None = None self._decode_tile_metadata = {} self._cuda_graph_metadata = {} self._cuda_graph_paged_cache_block_tables: dict[str, torch.Tensor] = {} @@ -250,6 +252,14 @@ def __init__(self, config) -> None: self._prefill_dense_compressed_indices_buffer: Optional[torch.Tensor] = None self._decode_swa_window_size = 0 self._decode_swa_block_size = 0 + self.speculative_num_steps = getattr(config, "speculative_num_steps", 0) or 0 + self.speculative_num_draft_tokens = ( + getattr(config, "speculative_num_draft_tokens", 0) or 0 + ) + self._draft_decode_step = 0 + self._draft_decode_base_seq_lens: Optional[torch.Tensor] = None + self._draft_decode_metadata: Optional[DeepseekV4ForwardMetadata] = None + self._cuda_graph_draft_decode_metadata = {} def _get_prefill_workspace( self, @@ -285,6 +295,7 @@ def _get_prefill_workspace( def _query_lens( self, bs: int, + num_tokens: int, seq_lens: torch.Tensor, forward_mode: ForwardMode | None, num_extends: int, @@ -318,6 +329,23 @@ def _query_lens( else: lens[:num_prefill_reqs] = seq_lens[:num_prefill_reqs].to(torch.int32) return lens + if forward_mode is not None and ( + forward_mode.is_target_verify() or forward_mode.is_draft_extend() + ): + if bs == 0: + return torch.zeros(0, dtype=torch.int32, device=seq_lens.device) + if num_tokens % bs != 0: + raise RuntimeError( + "DeepSeek V4 speculative metadata expects uniformly packed " + f"tokens per request, got num_tokens={num_tokens}, bs={bs}" + ) + tokens_per_req = num_tokens // bs + return torch.full( + (bs,), + tokens_per_req, + dtype=torch.int32, + device=seq_lens.device, + ) if extend_seq_lens_cpu is not None: return extend_seq_lens_cpu[:bs].to(seq_lens.device, dtype=torch.int32) if extend_prefix_lens_cpu is not None: @@ -355,6 +383,119 @@ def _query_lens_cpu( return None return None + def _draft_decode_is_valid_token( + self, + prefill_metadata: DeepseekV4ForwardMetadata, + ) -> Optional[torch.Tensor]: + if prefill_metadata.is_valid_token is None: + return None + bs = prefill_metadata.req_pool_indices.numel() + return prefill_metadata.is_valid_token[ + prefill_metadata.query_start_loc[:bs].to(torch.int64) + ] + + def _is_cuda_graph_prefill_metadata( + self, + metadata: DeepseekV4ForwardMetadata, + ) -> bool: + bs = metadata.req_pool_indices.numel() + return self._cuda_graph_metadata.get(bs) is metadata + + def _prepare_draft_decode_metadata( + self, + prefill_metadata: DeepseekV4ForwardMetadata, + base_seq_lens: torch.Tensor, + ) -> None: + self.forward_prefill_metadata = prefill_metadata + self._draft_decode_step = 0 + self._draft_decode_base_seq_lens = base_seq_lens + + bs = prefill_metadata.req_pool_indices.numel() + device = prefill_metadata.req_pool_indices.device + is_cuda_graph_metadata = self._is_cuda_graph_prefill_metadata(prefill_metadata) + metadata = ( + self._cuda_graph_draft_decode_metadata.get(bs) + if is_cuda_graph_metadata + else self._draft_decode_metadata + ) + is_valid_token = self._draft_decode_is_valid_token(prefill_metadata) + if ( + metadata is None + or metadata.req_pool_indices.numel() != bs + or metadata.seq_lens.numel() != bs + or metadata.query_lens.numel() != bs + or metadata.token_to_req_indices.numel() != bs + or metadata.req_pool_indices.device != device + ): + query_lens = torch.ones(bs, dtype=torch.int32, device=device) + token_to_req = torch.arange(bs, dtype=torch.int32, device=device) + decode_seq_lens = torch.empty_like(base_seq_lens) + decode_seq_lens.copy_(base_seq_lens) + decode_is_valid_token = None + if is_valid_token is not None: + decode_is_valid_token = torch.empty_like(is_valid_token) + decode_is_valid_token.copy_(is_valid_token) + metadata = DeepseekV4ForwardMetadata( + page_size=self.page_size, + req_pool_indices=prefill_metadata.req_pool_indices, + block_table=prefill_metadata.block_table, + seq_lens=decode_seq_lens, + query_lens=query_lens, + query_start_loc=_cu_seqlens(query_lens), + token_to_req_indices=token_to_req, + is_valid_token=decode_is_valid_token, + forward_mode=ForwardMode.DECODE, + ) + if is_cuda_graph_metadata: + self._cuda_graph_draft_decode_metadata[bs] = metadata + self._draft_decode_metadata = metadata + return + + metadata.req_pool_indices = prefill_metadata.req_pool_indices + metadata.block_table = prefill_metadata.block_table + metadata.seq_lens.copy_(base_seq_lens) + if is_valid_token is None: + metadata.is_valid_token = None + else: + if ( + metadata.is_valid_token is None + or metadata.is_valid_token.shape != is_valid_token.shape + or metadata.is_valid_token.device != is_valid_token.device + ): + metadata.is_valid_token = torch.empty_like(is_valid_token) + metadata.is_valid_token.copy_(is_valid_token) + metadata.num_prefill_reqs = 0 + metadata.num_prefill_tokens = 0 + metadata.forward_mode = ForwardMode.DECODE + # Reuse path: cached decode-indexer plans still describe the previous + # prefill. Refresh after updating seq_lens so draft step 0 does not + # reuse stale context_lens / block_table tensors. + metadata.refresh_decode_compressed_slot_mappings() + _refresh_decode_indexer_plan_cache( + metadata, + max_context_len=self.context_len, + ) + _refresh_decode_indexer_schedule_metadata(metadata) + self._draft_decode_metadata = metadata + + def _select_decode_metadata( + self, + num_tokens: int, + ) -> Optional[DeepseekV4ForwardMetadata]: + for metadata in ( + self.forward_metadata, + self.forward_decode_metadata, + self.forward_prefill_metadata, + ): + if ( + metadata is not None + and metadata.forward_mode is not None + and metadata.forward_mode.is_decode() + and metadata.token_to_req_indices.numel() == num_tokens + ): + return metadata + return self.forward_metadata or self.forward_decode_metadata + def init_forward_metadata( self, bs: int, @@ -374,12 +515,13 @@ def init_forward_metadata( ) num_extends_arg = kwargs.pop("num_extends", None) num_extends = bs if num_extends_arg is None else int(num_extends_arg) - del num_tokens, kwargs + del kwargs device = seq_lens.device req_pool_indices = req_pool_indices[:bs] seq_lens = seq_lens[:bs].to(torch.int32) query_lens = self._query_lens( bs, + num_tokens, seq_lens, forward_mode, num_extends, @@ -387,7 +529,13 @@ def init_forward_metadata( extend_prefix_lens_cpu, extend_prefix_lens, ) - if forward_mode is not None and forward_mode.is_mixed(): + is_spec = forward_mode is not None and ( + forward_mode.is_target_verify() or forward_mode.is_draft_extend() + ) + metadata_forward_mode = ForwardMode.DECODE if is_spec else forward_mode + if is_spec: + num_prefill_reqs = 0 + elif forward_mode is not None and forward_mode.is_mixed(): num_prefill_reqs = max(0, min(num_extends, bs)) elif forward_mode is not None and forward_mode.is_extend(): num_prefill_reqs = bs @@ -428,6 +576,8 @@ def init_forward_metadata( elif forward_mode.is_mixed(): seq_lens_cpu = seq_lens[:bs].to(dtype=torch.int32, device="cpu") max_seq_len = int(seq_lens.max().item()) if bs else 0 + if forward_mode is not None and forward_mode.is_extend(): + max_seq_len += max(self.speculative_num_steps - 1, 0) max_pages = (max_seq_len + self.page_size - 1) // self.page_size if req_to_page is None: block_table = torch.zeros( @@ -480,7 +630,7 @@ def init_forward_metadata( query_lens_cpu=query_lens_cpu, num_prefill_reqs=num_prefill_reqs, num_prefill_tokens=num_prefill_tokens, - forward_mode=forward_mode, + forward_mode=metadata_forward_mode, paged_cache_block_tables=paged_cache_block_tables, paged_cache_block_table_base_offsets=base_offsets_on_device, swa_block_table=swa_block_table, @@ -490,6 +640,29 @@ def init_forward_metadata( indexer_state_block_table=indexer_state_block_table, indexer_state_base_logical_page=indexer_state_base, ) + if is_spec: + self.forward_decode_metadata = self.forward_metadata + if forward_mode is not None and forward_mode.is_draft_extend(): + self._prepare_draft_decode_metadata( + self.forward_metadata, + seq_lens.clone(), + ) + elif ( + metadata_forward_mode is not None + and metadata_forward_mode.is_decode_or_idle() + ): + self.forward_decode_metadata = self.forward_metadata + if ( + self.forward_prefill_metadata is not None + and self.forward_prefill_metadata.req_pool_indices.numel() + == seq_lens.numel() + ): + self._prepare_draft_decode_metadata( + self.forward_prefill_metadata, + seq_lens.clone(), + ) + elif forward_mode == ForwardMode.EXTEND: + self.forward_prefill_metadata = self.forward_metadata self._decode_tile_metadata = {} def _update_decode_swa_metadata( @@ -748,13 +921,20 @@ def forward_deepseek_v4_decode( attn_sink: torch.Tensor, topk_indices: torch.Tensor | None, ) -> torch.Tensor: - metadata = self.forward_metadata + metadata = self._select_decode_metadata(q.shape[0]) if metadata is None: raise RuntimeError("DeepSeek V4 decode requires forward metadata") + self.forward_metadata = metadata if metadata.forward_mode is None or not metadata.forward_mode.is_decode(): raise RuntimeError( "forward_deepseek_v4_decode only supports ForwardMode.DECODE" ) + if metadata.token_to_req_indices.numel() != q.shape[0]: + raise RuntimeError( + "DeepSeek V4 decode metadata token count mismatch: " + f"metadata_tokens={metadata.token_to_req_indices.numel()}, " + f"q_tokens={q.shape[0]}" + ) try: from flash_mla import flash_mla_with_kvcache except Exception as exc: @@ -1207,6 +1387,14 @@ def _prefill_chunk_token_offsets( metadata: DeepseekV4ForwardMetadata, num_reqs: int, ) -> list[int]: + query_lens_cpu = metadata.query_lens_cpu + if query_lens_cpu is not None and query_lens_cpu.numel() >= num_reqs: + offsets = [0] + total = 0 + for q_len in query_lens_cpu[:num_reqs].tolist(): + total += max(0, int(q_len)) + offsets.append(total) + return offsets return [ int(x) for x in metadata.query_start_loc[: num_reqs + 1].detach().cpu().tolist() @@ -1230,14 +1418,27 @@ def forward_deepseek_v4_prefill( topk_indices: Optional[torch.Tensor], ) -> torch.Tensor: metadata = self.forward_metadata + if ( + metadata is None + or metadata.forward_mode is None + or not metadata.forward_mode.is_extend() + ): + metadata = self.forward_prefill_metadata or metadata if metadata is None: raise RuntimeError("DeepSeek V4 prefill requires forward metadata") + self.forward_metadata = metadata if metadata.forward_mode is None or not metadata.forward_mode.is_extend(): raise RuntimeError( "forward_deepseek_v4_prefill only supports extend/prefill modes" ) + if metadata.token_to_req_indices.numel() != q.shape[0]: + raise RuntimeError( + "DeepSeek V4 prefill metadata token count mismatch: " + f"metadata_tokens={metadata.token_to_req_indices.numel()}, " + f"q_tokens={q.shape[0]}" + ) - num_reqs = int(metadata.seq_lens.numel()) + num_reqs = int(metadata.num_prefill_reqs or metadata.seq_lens.numel()) if num_reqs <= self.prefill_chunk_size: return self._forward_deepseek_v4_prefill_chunk( q=q, @@ -1305,6 +1506,13 @@ def init_cuda_graph_state( max_tokens_per_req: int = 1, ): del seq_lens_buf + self._decode_tile_metadata = {} + self._cuda_graph_max_tokens_per_req = max( + 1, + int(max_tokens_per_req), + int(self.speculative_num_draft_tokens or 0), + ) + max_tokens = max_bs * self._cuda_graph_max_tokens_per_req self._cuda_graph_block_table = torch.zeros( (max_bs, self.max_num_pages), dtype=torch.int32, @@ -1331,7 +1539,7 @@ def init_cuda_graph_state( device=self.device, ) self._cuda_graph_token_to_req = torch.arange( - max_bs, + max_tokens, dtype=torch.int32, device=self.device, ) @@ -1368,7 +1576,7 @@ def init_cuda_graph_state( device=self.device, ) self._cuda_graph_is_valid_token = torch.ones( - max_bs, + max_tokens, dtype=torch.bool, device=self.device, ) @@ -1431,6 +1639,61 @@ def _refresh_cuda_graph_base_offsets( out[gid] = buf[:bs] return out + def _cuda_graph_tokens_per_req( + self, + bs: int, + num_tokens: int, + forward_mode: Optional[ForwardMode], + ) -> int: + if forward_mode is not None and ( + forward_mode.is_target_verify() or forward_mode.is_draft_extend() + ): + if bs == 0: + return self._cuda_graph_max_tokens_per_req + if num_tokens % bs != 0: + raise RuntimeError( + "DeepSeek V4 speculative CUDA graph metadata expects " + f"uniformly packed tokens per request, got " + f"num_tokens={num_tokens}, bs={bs}" + ) + tokens_per_req = num_tokens // bs + if tokens_per_req > self._cuda_graph_max_tokens_per_req: + raise RuntimeError( + "DeepSeek V4 speculative CUDA graph metadata was initialized " + f"for at most {self._cuda_graph_max_tokens_per_req} tokens " + f"per request, got {tokens_per_req}" + ) + return max(1, tokens_per_req) + return 1 + + def _refresh_cuda_graph_packed_metadata( + self, + *, + bs: int, + actual_bs: int, + tokens_per_req: int, + ) -> int: + total_tokens = bs * tokens_per_req + actual_tokens = actual_bs * tokens_per_req + self._cuda_graph_query_lens[:bs].fill_(tokens_per_req) + self._cuda_graph_query_start_loc[: bs + 1].copy_( + torch.arange( + bs + 1, + dtype=torch.int32, + device=self.device, + ) + * tokens_per_req + ) + self._cuda_graph_token_to_req[:total_tokens].copy_( + torch.arange(bs, dtype=torch.int32, device=self.device).repeat_interleave( + tokens_per_req + ) + ) + self._cuda_graph_is_valid_token[:actual_tokens].fill_(True) + if actual_tokens < total_tokens: + self._cuda_graph_is_valid_token[actual_tokens:total_tokens].fill_(False) + return total_tokens + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -1444,19 +1707,35 @@ def init_forward_metadata_capture_cuda_graph( paged_cache_block_table_base_offsets = ( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) - del num_tokens, kwargs - if forward_mode is not None and not forward_mode.is_decode_or_idle(): + del kwargs + if forward_mode is not None and not ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend() + ): raise NotImplementedError( f"DeepSeek V4 CUDA graph capture not supported for {forward_mode}" ) - self._cuda_graph_req_pool_indices[:bs].copy_(req_pool_indices[:bs]) - self._cuda_graph_seq_lens[:bs].copy_(seq_lens[:bs].to(torch.int32)) - self._cuda_graph_query_lens[:bs].fill_(1) - self._cuda_graph_query_start_loc[: bs + 1].copy_( - torch.arange(bs + 1, dtype=torch.int32, device=self.device) + tokens_per_req = self._cuda_graph_tokens_per_req(bs, num_tokens, forward_mode) + total_tokens = self._refresh_cuda_graph_packed_metadata( + bs=bs, + actual_bs=bs, + tokens_per_req=tokens_per_req, ) - self._cuda_graph_token_to_req[:bs].copy_( - torch.arange(bs, dtype=torch.int32, device=self.device) + is_spec = forward_mode is not None and ( + forward_mode.is_target_verify() or forward_mode.is_draft_extend() + ) + capture_seq_lens = seq_lens[:bs].to(torch.int32) + if is_spec: + capture_seq_lens = torch.maximum( + capture_seq_lens, + torch.full_like(capture_seq_lens, tokens_per_req), + ) + self._cuda_graph_req_pool_indices[:bs].copy_(req_pool_indices[:bs]) + self._cuda_graph_seq_lens[:bs].copy_(capture_seq_lens) + metadata_forward_mode = ForwardMode.DECODE if is_spec else forward_mode + is_decode = ( + metadata_forward_mode is not None and metadata_forward_mode.is_decode() ) offsets_on_device = { str(gid): off.to(device=self.device, dtype=torch.int32) @@ -1486,28 +1765,58 @@ def init_forward_metadata_capture_cuda_graph( metadata_paged, metadata_base_offsets, ) - metadata = DeepseekV4ForwardMetadata( - page_size=self.page_size, - req_pool_indices=self._cuda_graph_req_pool_indices[:bs], - block_table=self._cuda_graph_block_table[:bs, : self.max_num_pages], - seq_lens=self._cuda_graph_seq_lens[:bs], - query_lens=self._cuda_graph_query_lens[:bs], - query_start_loc=self._cuda_graph_query_start_loc[: bs + 1], - token_to_req_indices=self._cuda_graph_token_to_req[:bs], - is_valid_token=self._cuda_graph_is_valid_token[:bs], - seq_lens_cpu=None, - query_lens_cpu=None, - forward_mode=forward_mode, - paged_cache_block_tables=metadata_paged, - paged_cache_block_table_base_offsets=metadata_base_offsets, - swa_block_table=swa_block_table, - swa_base_logical_page=swa_base, - compressor_state_block_tables=compressor_state_block_tables, - compressor_state_base_logical_pages=compressor_state_base, - indexer_state_block_table=indexer_state_block_table, - indexer_state_base_logical_page=indexer_state_base, - ) + metadata = self._cuda_graph_metadata.get(bs) + if metadata is None: + metadata = DeepseekV4ForwardMetadata( + page_size=self.page_size, + req_pool_indices=self._cuda_graph_req_pool_indices[:bs], + block_table=self._cuda_graph_block_table[:bs, : self.max_num_pages], + seq_lens=self._cuda_graph_seq_lens[:bs], + query_lens=self._cuda_graph_query_lens[:bs], + query_start_loc=self._cuda_graph_query_start_loc[: bs + 1], + token_to_req_indices=self._cuda_graph_token_to_req[:total_tokens], + is_valid_token=self._cuda_graph_is_valid_token[:total_tokens], + seq_lens_cpu=None, + query_lens_cpu=None, + forward_mode=metadata_forward_mode, + paged_cache_block_tables=metadata_paged, + paged_cache_block_table_base_offsets=metadata_base_offsets, + swa_block_table=swa_block_table, + swa_base_logical_page=swa_base, + compressor_state_block_tables=compressor_state_block_tables, + compressor_state_base_logical_pages=compressor_state_base, + indexer_state_block_table=indexer_state_block_table, + indexer_state_base_logical_page=indexer_state_base, + ) + else: + metadata.req_pool_indices = self._cuda_graph_req_pool_indices[:bs] + metadata.block_table = self._cuda_graph_block_table[ + :bs, : self.max_num_pages + ] + metadata.seq_lens = self._cuda_graph_seq_lens[:bs] + metadata.query_lens = self._cuda_graph_query_lens[:bs] + metadata.query_start_loc = self._cuda_graph_query_start_loc[: bs + 1] + metadata.token_to_req_indices = self._cuda_graph_token_to_req[:total_tokens] + metadata.is_valid_token = self._cuda_graph_is_valid_token[:total_tokens] + metadata.seq_lens_cpu = None + metadata.query_lens_cpu = None + metadata.forward_mode = metadata_forward_mode + metadata.paged_cache_block_tables = metadata_paged + metadata.paged_cache_block_table_base_offsets = metadata_base_offsets + metadata.swa_block_table = swa_block_table + metadata.swa_base_logical_page = swa_base + metadata.compressor_state_block_tables = compressor_state_block_tables + metadata.compressor_state_base_logical_pages = compressor_state_base + metadata.indexer_state_block_table = indexer_state_block_table + metadata.indexer_state_base_logical_page = indexer_state_base self._cuda_graph_metadata[bs] = metadata + if forward_mode is not None and forward_mode.is_draft_extend(): + self._prepare_draft_decode_metadata( + metadata, + self._cuda_graph_seq_lens[:bs], + ) + if is_decode: + self.forward_decode_metadata = metadata self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -1524,24 +1833,33 @@ def init_forward_metadata_replay_cuda_graph( kwargs.pop("paged_cache_block_table_base_offsets", None) or {} ) actual_bs = max(0, min(int(kwargs.pop("actual_bs", bs)), bs)) + num_tokens_arg = kwargs.pop("num_tokens", None) del kwargs - if forward_mode is not None and not forward_mode.is_decode_or_idle(): + if forward_mode is not None and not ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend() + ): raise NotImplementedError( f"DeepSeek V4 CUDA graph replay not supported for {forward_mode}" ) + if ( + num_tokens_arg is None + and forward_mode is not None + and (forward_mode.is_target_verify() or forward_mode.is_draft_extend()) + ): + num_tokens = bs * self._cuda_graph_max_tokens_per_req + else: + num_tokens = int(num_tokens_arg if num_tokens_arg is not None else bs) + tokens_per_req = self._cuda_graph_tokens_per_req(bs, num_tokens, forward_mode) + total_tokens = self._refresh_cuda_graph_packed_metadata( + bs=bs, + actual_bs=actual_bs, + tokens_per_req=tokens_per_req, + ) metadata = self._cuda_graph_metadata[bs] self._cuda_graph_req_pool_indices[:bs].copy_(req_pool_indices[:bs]) self._cuda_graph_seq_lens[:bs].copy_(seq_lens[:bs].to(torch.int32)) - self._cuda_graph_query_lens[:bs].fill_(1) - self._cuda_graph_query_start_loc[: bs + 1].copy_( - torch.arange(bs + 1, dtype=torch.int32, device=self.device) - ) - self._cuda_graph_token_to_req[:bs].copy_( - torch.arange(bs, dtype=torch.int32, device=self.device) - ) - self._cuda_graph_is_valid_token[:actual_bs].fill_(True) - if actual_bs < bs: - self._cuda_graph_is_valid_token[actual_bs:bs].fill_(False) if req_to_page is not None: self._cuda_graph_block_table[:bs, : self.max_num_pages].copy_( req_to_page[req_pool_indices[:bs], : self.max_num_pages] @@ -1574,7 +1892,16 @@ def init_forward_metadata_replay_cuda_graph( metadata_paged, metadata_base_offsets, ) - metadata.forward_mode = forward_mode + is_spec = forward_mode is not None and ( + forward_mode.is_target_verify() or forward_mode.is_draft_extend() + ) + metadata_forward_mode = ForwardMode.DECODE if is_spec else forward_mode + is_decode = ( + metadata_forward_mode is not None and metadata_forward_mode.is_decode() + ) + metadata.forward_mode = metadata_forward_mode + metadata.token_to_req_indices = self._cuda_graph_token_to_req[:total_tokens] + metadata.is_valid_token = self._cuda_graph_is_valid_token[:total_tokens] metadata.paged_cache_block_tables = metadata_paged metadata.paged_cache_block_table_base_offsets = metadata_base_offsets metadata.swa_block_table = swa_block_table @@ -1585,9 +1912,14 @@ def init_forward_metadata_replay_cuda_graph( metadata.indexer_state_base_logical_page = indexer_state_base metadata.num_prefill_reqs = 0 metadata.num_prefill_tokens = 0 + if forward_mode is not None and forward_mode.is_draft_extend(): + self._prepare_draft_decode_metadata( + metadata, + self._cuda_graph_seq_lens[:bs], + ) if ( - forward_mode is not None - and forward_mode.is_decode() + metadata_forward_mode is not None + and metadata_forward_mode.is_decode() and self._decode_swa_window_size > 0 and self._decode_swa_block_size > 0 ): @@ -1602,12 +1934,32 @@ def init_forward_metadata_replay_cuda_graph( max_context_len=self.context_len, ) _refresh_decode_indexer_schedule_metadata(metadata) + if is_decode: + self.forward_decode_metadata = metadata self.forward_metadata = metadata def advance_draft_forward_metadata(self): - raise NotImplementedError( - "DeepSeek V4 attention does not support draft graphs yet" + if ( + self._draft_decode_base_seq_lens is None + or self.forward_prefill_metadata is None + or self._draft_decode_metadata is None + ): + raise RuntimeError("DeepSeek V4 draft metadata was not initialized") + self._draft_decode_step += 1 + metadata = self._draft_decode_metadata + metadata.seq_lens.add_(1) + metadata.forward_mode = ForwardMode.DECODE + # seq_lens just changed, so any previously-refreshed plan tensors are + # stale. Re-run the same metadata-setup hooks the main path uses. + metadata.refresh_decode_compressed_slot_mappings() + _refresh_decode_indexer_plan_cache( + metadata, + max_context_len=self.context_len, ) + _refresh_decode_indexer_schedule_metadata(metadata) + self.forward_decode_metadata = metadata + self.forward_metadata = metadata + self._decode_tile_metadata = {} def forward_decode(self, *args, **kwargs): raise NotImplementedError("DeepSeek V4 uses the model-local attention forward") diff --git a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py index bef078ef1..d63ad15c3 100644 --- a/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py +++ b/python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py @@ -14,7 +14,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import Any, Iterable, Optional import numpy as np import torch @@ -444,9 +444,21 @@ def deepseek_v4_cache_layout_from_config( hf_config, page_size: int, use_fp4_indexer_cache: bool, + layer_indices: Optional[Iterable[int]] = None, ) -> DeepseekV4CacheLayout: + compress_ratios = tuple(hf_config.compress_ratios) + if layer_indices is None: + layer_ratios = compress_ratios + else: + layer_indices = tuple(layer_indices) + if any(idx < 0 or idx >= len(compress_ratios) for idx in layer_indices): + raise ValueError( + "DeepSeek V4 cache layout layer index out of range: " + f"indices={layer_indices}, ratios={len(compress_ratios)}" + ) + layer_ratios = [compress_ratios[idx] for idx in layer_indices] return DeepseekV4CacheLayout( - layer_ratio=tuple(max(1, int(x)) for x in hf_config.compress_ratios), + layer_ratio=tuple(max(1, int(x)) for x in layer_ratios), head_dim=int(hf_config.head_dim), page_size=page_size, use_fp4_indexer_cache=use_fp4_indexer_cache, @@ -481,6 +493,11 @@ def __init__( ) -> None: if size <= 0: raise ValueError(f"DeepSeek V4 KV pool size must be positive, got {size}") + if layer_num != len(layout.layer_ratio): + raise ValueError( + "DeepSeek V4 KV pool layer_num must match cache layout ratios: " + f"layer_num={layer_num}, ratios={len(layout.layer_ratio)}" + ) super().__init__( size=size, dtype=torch.uint8, diff --git a/python/tokenspeed/runtime/layers/attention/registry.py b/python/tokenspeed/runtime/layers/attention/registry.py index dcae488ff..dfdfdf8de 100644 --- a/python/tokenspeed/runtime/layers/attention/registry.py +++ b/python/tokenspeed/runtime/layers/attention/registry.py @@ -322,10 +322,15 @@ def create_attn_components( architectures = getattr(model_config.hf_config, "architectures", None) or [] is_hybrid_gdn = any(a in _HYBRID_GDN_ARCHITECTURES for a in architectures) is_deepseek_v4_model = is_deepseek_v4(model_config.hf_config) + is_deepseek_v4_draft_model = draft_model_config is not None and is_deepseek_v4( + draft_model_config.hf_config + ) original_attn_backend = server_args.attention_backend if is_deepseek_v4_model: server_args.attention_backend = "deepseek_v4" - elif is_hybrid_gdn: + if is_deepseek_v4_draft_model: + server_args.drafter_attention_backend = "deepseek_v4" + if is_hybrid_gdn: # Qwen3.5 GDN hybrid models always need hybrid_linear_attn. # Save the user's original choice for the full-attention sub-backend. server_args.attention_backend = "hybrid_linear_attn" @@ -346,7 +351,9 @@ def create_attn_components( ) num_layers = model_config.num_attention_layers deepseek_v4_layout = None + draft_deepseek_v4_layout = None profile_cache_cell_size = None + draft_profile_cache_cell_size = None if is_deepseek_v4_model: from tokenspeed.runtime.layers.attention.kv_cache.deepseek_v4 import ( deepseek_v4_cache_layout_from_config, @@ -358,8 +365,30 @@ def create_attn_components( use_fp4_indexer_cache=_attention_use_fp4_indexer_cache( server_args, model_config.hf_config ), + layer_indices=range(num_layers), ) profile_cache_cell_size = deepseek_v4_layout.cache_cell_size(num_layers) + if is_deepseek_v4_draft_model: + from tokenspeed.runtime.layers.attention.kv_cache.deepseek_v4 import ( + deepseek_v4_cache_layout_from_config, + ) + + draft_layer_start = draft_model_config.num_hidden_layers + draft_num_layers = draft_model_config.num_attention_layers + draft_deepseek_v4_layout = deepseek_v4_cache_layout_from_config( + draft_model_config.hf_config, + page_size=server_args.block_size, + use_fp4_indexer_cache=_attention_use_fp4_indexer_cache( + server_args, draft_model_config.hf_config + ), + layer_indices=range( + draft_layer_start, + draft_layer_start + draft_num_layers, + ), + ) + draft_profile_cache_cell_size = draft_deepseek_v4_layout.cache_cell_size( + draft_model_config.num_attention_layers + ) hf_config = getattr(model_config, "hf_config", None) text_config = getattr(hf_config, "text_config", hf_config) if hf_config else None @@ -393,6 +422,7 @@ def create_attn_components( **_profile_kwargs, gpu_memory_utilization=server_args.gpu_memory_utilization, cache_cell_size=profile_cache_cell_size, + draft_cache_cell_size=draft_profile_cache_cell_size, ) max_num_tokens = _resolve_max_num_tokens( max_total_num_pages, @@ -435,6 +465,7 @@ def create_attn_components( **_profile_kwargs, gpu_memory_utilization=server_args.gpu_memory_utilization, cache_cell_size=profile_cache_cell_size, + draft_cache_cell_size=draft_profile_cache_cell_size, ) max_num_tokens = _resolve_max_num_tokens( max_total_num_pages, @@ -498,7 +529,29 @@ def create_attn_components( if draft_attn_config: # Check if draft model is also a hybrid GDN model. draft_archs = getattr(draft_model_config.hf_config, "architectures", None) or [] - if any(a in _HYBRID_GDN_ARCHITECTURES for a in draft_archs): + if is_deepseek_v4_draft_model: + from tokenspeed.runtime.layers.attention.kv_cache.deepseek_v4 import ( + DeepseekV4TokenToKVPool, + ) + + draft_attn_backend = _create_attn_backend( + draft_model_config.attention_arch, draft_attn_config + ) + draft_pool = DeepseekV4TokenToKVPool( + size=max_num_tokens, + model_dtype=draft_model_config.dtype, + layout=draft_deepseek_v4_layout, + layer_num=draft_model_config.num_attention_layers, + device=draft_attn_config.device, + enable_memory_saver=enable_memory_saver, + max_batch_size=draft_attn_config.max_bs, + max_context_len=draft_attn_config.context_len, + page_size=server_args.block_size, + rank=rank, + hf_config=draft_model_config.hf_config, + max_scheduled_tokens=server_args.chunked_prefill_size, + ) + elif any(a in _HYBRID_GDN_ARCHITECTURES for a in draft_archs): resolved_draft_backend = _BACKEND_ALIASES.get( original_attn_backend, original_attn_backend ) diff --git a/python/tokenspeed/runtime/layers/attention/utils.py b/python/tokenspeed/runtime/layers/attention/utils.py index b619f213b..0fb7189e7 100755 --- a/python/tokenspeed/runtime/layers/attention/utils.py +++ b/python/tokenspeed/runtime/layers/attention/utils.py @@ -137,6 +137,7 @@ def profile_max_num_pages( draft_attn_config: BaseAttnConfig | None = None, draft_num_attention_layers: int | None = None, cache_cell_size: int | None = None, + draft_cache_cell_size: int | None = None, ): cpu_group = ( pg_manager.get_process_group("gloo", world_group) @@ -156,7 +157,12 @@ def profile_max_num_pages( else: cell_size = cache_cell_size if draft_attn_config is not None: - cell_size += draft_attn_config.cache_cell_size() * draft_num_attention_layers + if draft_cache_cell_size is None: + cell_size += ( + draft_attn_config.cache_cell_size() * draft_num_attention_layers + ) + else: + cell_size += draft_cache_cell_size if cell_size <= 0: raise ValueError(f"KV cache cell size must be positive, got {cell_size}") max_num_token = int(rest_memory * (1 << 30) // cell_size) diff --git a/python/tokenspeed/runtime/models/deepseek_v4.py b/python/tokenspeed/runtime/models/deepseek_v4.py index 4da61c90e..2895ceeec 100644 --- a/python/tokenspeed/runtime/models/deepseek_v4.py +++ b/python/tokenspeed/runtime/models/deepseek_v4.py @@ -77,6 +77,7 @@ process_group_manager as pg_manager, ) from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode from tokenspeed.runtime.layers.attention.deepseek_v4_ops import ( DEEPSEEK_V4_INDEXER_DIM, DeepseekV4AttentionOpUnavailable, @@ -144,6 +145,37 @@ def is_sm90_supported(device: object | None = None) -> bool: logger = get_colorful_logger(__name__) +def _deepseek_v4_metadata_matches_tokens(metadata, num_tokens: int) -> bool: + return ( + metadata is not None + and getattr(metadata, "token_to_req_indices", None) is not None + and metadata.token_to_req_indices.numel() == num_tokens + ) + + +def _deepseek_v4_forward_metadata(ctx: ForwardContext): + metadata = getattr(ctx.attn_backend, "forward_metadata", None) + if ctx.forward_mode == ForwardMode.EXTEND: + return getattr(ctx.attn_backend, "forward_prefill_metadata", None) or metadata + if ctx.forward_mode is not None and ctx.forward_mode.is_draft_extend(): + prefill_metadata = getattr(ctx.attn_backend, "forward_prefill_metadata", None) + if _deepseek_v4_metadata_matches_tokens( + prefill_metadata, + ctx.input_num_tokens, + ): + return prefill_metadata + return prefill_metadata or metadata + if ctx.forward_mode is not None and ctx.forward_mode.is_decode_or_idle(): + decode_metadata = getattr(ctx.attn_backend, "forward_decode_metadata", None) + if _deepseek_v4_metadata_matches_tokens( + decode_metadata, + ctx.input_num_tokens, + ): + return decode_metadata + return decode_metadata or metadata + return metadata + + def _dequant_fp8_weight(layer: nn.Module, shape: tuple[int, ...]) -> torch.Tensor: weight = layer.weight.view(*shape) scale = getattr(layer, "weight_scale_inv", None) @@ -3779,7 +3811,7 @@ def forward( write_compressed_cache: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: pool = ctx.token_to_kv_pool - metadata = ctx.attn_backend.forward_metadata + metadata = _deepseek_v4_forward_metadata(ctx) if metadata is None: raise RuntimeError("DeepSeek V4 compressor requires forward metadata") profile_prefix = ( @@ -4194,7 +4226,7 @@ def forward( cos_sin_cache: torch.Tensor, ) -> torch.Tensor: pool = ctx.token_to_kv_pool - metadata = ctx.attn_backend.forward_metadata + metadata = _deepseek_v4_forward_metadata(ctx) if metadata is None: raise RuntimeError("DeepSeek V4 indexer requires forward metadata") indexer_state = pool.get_indexer_state_buffer(layer_index) @@ -4812,9 +4844,15 @@ def __init__( prefix: str, aux_stream: torch.cuda.Stream | None = None, topk_buffer: _DeepseekV4TopKBuffer | None = None, + cache_layer_index: int | None = None, ) -> None: super().__init__() + # `layer_index` addresses checkpoint/config metadata; `cache_layer_index` + # addresses this model's compact KV cache slot. self.layer_index = layer_index + self.cache_layer_index = ( + layer_index if cache_layer_index is None else cache_layer_index + ) self.aux_stream = aux_stream if self.aux_stream is not None: self.ln_events: list[torch.cuda.Event | None] = [ @@ -5120,7 +5158,7 @@ def _forward_flashmla_sparse( "`tokenspeed-kernel/python` with FlashMLA before serving V4." ) from exc - metadata = ctx.attn_backend.forward_metadata + metadata = _deepseek_v4_forward_metadata(ctx) if metadata is None: raise RuntimeError("DeepSeek V4 attention requires forward metadata") pool = ctx.token_to_kv_pool @@ -5134,7 +5172,7 @@ def _forward_flashmla_sparse( per_token_slots: list[tuple[torch.Tensor, torch.Tensor]] = [] max_candidates = 0 - compressed_block_size = pool.get_compressed_block_size(self.layer_index) + compressed_block_size = pool.get_compressed_block_size(self.cache_layer_index) for token_idx in range(positions.numel()): position = int(positions[token_idx].item()) compressed = self._compressed_slots_for_token( @@ -5160,11 +5198,11 @@ def _forward_flashmla_sparse( rows = [] cursor = 0 compressed_cache = ( - pool.get_compressed_kv_buffer_2d(self.layer_index) + pool.get_compressed_kv_buffer_2d(self.cache_layer_index) if self.compress_ratio > 1 else None ) - swa_cache = pool.get_swa_kv_buffer(self.layer_index) + swa_cache = pool.get_swa_kv_buffer(self.cache_layer_index) for token_idx, (compressed, swa) in enumerate(per_token_slots): token_rows = [] if compressed.numel() > 0: @@ -5390,7 +5428,7 @@ def insert_swa_cache() -> None: self._insert_swa_cache( q=q, kv=kv, - swa_kv_cache=pool.get_swa_kv_buffer(self.layer_index), + swa_kv_cache=pool.get_swa_kv_buffer(self.cache_layer_index), slot_mapping=swa_slot_mapping, positions=positions, block_size=pool.swa_block_size, @@ -5405,7 +5443,7 @@ def run_compressor() -> None: positions=positions, ctx=ctx, out_cache_loc=out_cache_loc, - layer_index=self.layer_index, + layer_index=self.cache_layer_index, cos_sin_cache=self._cos_sin_cache(), ) @@ -5418,7 +5456,9 @@ def run_compressor() -> None: self.indexer.prepare_decode_metadata( positions=positions, metadata=metadata, - indexer_block_size=pool.get_indexer_block_size(self.layer_index), + indexer_block_size=pool.get_indexer_block_size( + self.cache_layer_index + ), ) def run_indexer() -> torch.Tensor: @@ -5429,7 +5469,7 @@ def run_indexer() -> torch.Tensor: positions=positions, ctx=ctx, out_cache_loc=out_cache_loc, - layer_index=self.layer_index, + layer_index=self.cache_layer_index, cos_sin_cache=self._cos_sin_cache(), ) @@ -5479,7 +5519,7 @@ def insert_and_compress() -> None: q=q, positions=positions, token_to_kv_pool=pool, - layer_id=self.layer_index, + layer_id=self.cache_layer_index, kind=self.attention_kind, compress_ratio=self.compress_ratio, num_local_heads=self.num_local_heads, @@ -5493,14 +5533,18 @@ def insert_and_compress() -> None: elif ( backend_decode is not None and ctx.forward_mode is not None - and ctx.forward_mode.is_decode() + and ( + ctx.forward_mode.is_decode() + or ctx.forward_mode.is_target_verify() + or ctx.forward_mode.is_draft_extend() + ) ): with deepseek_v4_profile_scope(f"{profile_prefix}_decode_backend"): attn_output = backend_decode( q=q, positions=positions, token_to_kv_pool=pool, - layer_id=self.layer_index, + layer_id=self.cache_layer_index, kind=self.attention_kind, compress_ratio=self.compress_ratio, num_local_heads=self.num_local_heads, @@ -5521,7 +5565,7 @@ def insert_and_compress() -> None: q=q, positions=positions, token_to_kv_pool=pool, - layer_id=self.layer_index, + layer_id=self.cache_layer_index, kind=self.attention_kind, compress_ratio=self.compress_ratio, num_local_heads=self.num_local_heads, @@ -5551,6 +5595,7 @@ def __init__( prefix: str, aux_stream: torch.cuda.Stream | None = None, topk_buffer: _DeepseekV4TopKBuffer | None = None, + cache_layer_index: int | None = None, ) -> None: super().__init__() self.mapping = mapping @@ -5570,6 +5615,7 @@ def __init__( add_prefix("attn", prefix), aux_stream=aux_stream, topk_buffer=topk_buffer, + cache_layer_index=cache_layer_index, ) self.ffn = DeepseekV4MoE( config, mapping, quant_config, layer_id, add_prefix("ffn", prefix) @@ -5766,7 +5812,7 @@ def forward( ctx: ForwardContext, out_cache_loc: torch.Tensor, input_embeds: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, None]: + ) -> tuple[torch.Tensor, list[torch.Tensor] | None]: hidden_states = input_embeds if hidden_states is None: with deepseek_v4_profile_scope("embed_tokens"): @@ -5777,6 +5823,13 @@ def forward( hidden_states = layer( positions, hidden_states, ctx, out_cache_loc, input_ids ) + aux_hidden_states = None + if ( + ctx.capture_hidden_mode is not None + and ctx.capture_hidden_mode.need_capture() + ): + # V4 MTP consumes the pre-hc_head hypercompressed residual. + aux_hidden_states = [hidden_states.flatten(1)] with deepseek_v4_profile_scope("hc_head"): hidden_states = hc_head( hidden_states, @@ -5788,7 +5841,7 @@ def forward( ) with deepseek_v4_profile_scope("final_norm"): hidden_states = self.norm(hidden_states) - return hidden_states, None + return hidden_states, aux_hidden_states class DeepseekV4ForCausalLM(BaseCausalLM): diff --git a/python/tokenspeed/runtime/models/deepseek_v4_mtp.py b/python/tokenspeed/runtime/models/deepseek_v4_mtp.py new file mode 100644 index 000000000..4ad6c953e --- /dev/null +++ b/python/tokenspeed/runtime/models/deepseek_v4_mtp.py @@ -0,0 +1,524 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Inference-only DeepSeek V4 MTP / NextN draft model.""" + +from __future__ import annotations + +import logging +import re +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from tokenspeed.runtime.distributed.mapping import Mapping +from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.layers.layernorm import RMSNorm +from tokenspeed.runtime.layers.linear import ReplicatedLinear +from tokenspeed.runtime.layers.logits_processor import LogitsMetadata, LogitsProcessor +from tokenspeed.runtime.layers.moe.checkpoint import ( + ExpertCheckpointSchema, + build_moe_checkpoint_loader, +) +from tokenspeed.runtime.layers.moe.layer import MoELayer +from tokenspeed.runtime.layers.quantization.base_config import QuantizationConfig +from tokenspeed.runtime.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader +from tokenspeed.runtime.models.deepseek_v4 import ( + DeepseekV4Compressor, + DeepseekV4DecoderLayer, + DeepseekV4MegaMoEExperts, + _fp8_linear, + hc_head, +) +from tokenspeed.runtime.utils import add_prefix + +logger = logging.getLogger(__name__) + + +_EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$") + + +def _spec_layer_idx(config: PretrainedConfig, weight_name: str) -> Optional[int]: + if getattr(config, "num_nextn_predict_layers", 0) <= 0: + return None + start = config.num_hidden_layers + for idx in range(start, start + config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{idx}."): + return idx + return None + + +def _find_mtp_layer_idx(name: str) -> int: + parts = name.split(".") + if len(parts) > 1 and parts[0] == "mtp": + try: + return int(parts[1]) + except ValueError: + pass + for part in parts: + try: + return int(part) + except ValueError: + continue + return 0 + + +class DeepseekV4MTPSharedHead(nn.Module): + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class DeepseekV4MultiTokenPredictorLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + mapping: Mapping, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + cache_layer_index: Optional[int] = None, + ) -> None: + super().__init__() + self.config = config + self.layer_id = layer_id + self.cache_layer_index = ( + layer_id if cache_layer_index is None else cache_layer_index + ) + self.rms_norm_eps = config.rms_norm_eps + self.hc_eps = config.hc_eps + self.hc_mult = config.hc_mult + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.e_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("e_proj", prefix), + ) + self.h_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("h_proj", prefix), + ) + self.hc_head_fn = nn.Parameter( + torch.empty( + self.hc_mult, + self.hc_mult * config.hidden_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_head_base = nn.Parameter( + torch.empty(self.hc_mult, dtype=torch.float32), + requires_grad=False, + ) + self.hc_head_scale = nn.Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + self.shared_head = DeepseekV4MTPSharedHead(config) + self.mtp_block = DeepseekV4DecoderLayer( + config, + layer_id, + mapping, + quant_config, + add_prefix("mtp_block", prefix), + cache_layer_index=self.cache_layer_index, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is None: + raise ValueError("DeepSeek V4 MTP requires input_embeds.") + input_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, input_embeds) + input_embeds = self.enorm(input_embeds) + previous_hidden_states = previous_hidden_states.view( + -1, self.hc_mult, self.config.hidden_size + ) + previous_hidden_states = self.hnorm(previous_hidden_states) + hidden_states = _fp8_linear( + self.h_proj, + previous_hidden_states, + (self.h_proj.output_size, self.h_proj.input_size), + ) + _fp8_linear( + self.e_proj, + input_embeds, + (self.e_proj.output_size, self.e_proj.input_size), + ).unsqueeze( + -2 + ) + + return self.mtp_block( + positions, + hidden_states, + ctx, + out_cache_loc, + input_ids, + ) + + def compute_logits_hidden(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view(-1, self.hc_mult, self.config.hidden_size) + hidden_states = hc_head( + hidden_states, + self.hc_head_fn, + self.hc_head_scale, + self.hc_head_base, + self.rms_norm_eps, + self.hc_eps, + ) + return self.shared_head.norm(hidden_states) + + +class DeepseekV4MultiTokenPredictor(nn.Module): + def __init__( + self, + config: PretrainedConfig, + mapping: Mapping, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.mapping = mapping + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + tp_rank=mapping.attn.tp_rank, + tp_size=mapping.attn.tp_size, + tp_group=mapping.attn.tp_group, + prefix=add_prefix("embed_tokens", prefix), + ) + layers = {} + for local_idx in range(self.num_mtp_layers): + # Checkpoint layer ids remain global, while draft KV slots are compact. + layer_idx = self.mtp_start_layer_idx + local_idx + layers[str(layer_idx)] = DeepseekV4MultiTokenPredictorLayer( + config, + mapping, + layer_idx, + quant_config, + add_prefix(f"layers.{layer_idx}", prefix), + cache_layer_index=local_idx, + ) + self.layers = nn.ModuleDict(layers) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + current_step_idx = spec_step_idx % self.num_mtp_layers + layer_idx = self.mtp_start_layer_idx + current_step_idx + return self.layers[str(layer_idx)]( + input_ids, + positions, + previous_hidden_states, + ctx, + out_cache_loc, + input_embeds, + ) + + def compute_logits_hidden( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = spec_step_idx % self.num_mtp_layers + layer_idx = self.mtp_start_layer_idx + current_step_idx + return self.layers[str(layer_idx)].compute_logits_hidden(hidden_states) + + +class DeepseekV4ForCausalLMNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + mapping: Mapping, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.mapping = mapping + self.quant_config = quant_config + self.model = DeepseekV4MultiTokenPredictor( + config, + mapping=mapping, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + if self.mapping.attn.has_dp: + self.lm_head = ReplicatedLinear( + config.hidden_size, + config.vocab_size, + bias=False, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + tp_rank=self.mapping.attn.tp_rank, + tp_size=self.mapping.attn.tp_size, + tp_group=self.mapping.attn.tp_group, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor( + config, + tp_rank=self.mapping.attn.tp_rank, + tp_size=self.mapping.attn.tp_size, + tp_group=self.mapping.attn.tp_group, + ) + + def get_hot_token_id(self): + return None + + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed: torch.Tensor, head: torch.Tensor) -> None: + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @torch.no_grad() + def forward( + self, + ctx: ForwardContext, + input_ids: torch.Tensor, + positions: torch.Tensor, + out_cache_loc: torch.Tensor, + input_lengths: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + captured_hidden_states: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + **kwargs, + ): + del kwargs + if captured_hidden_states is None: + if not ctx.forward_mode.is_idle(): + raise ValueError("DeepSeek V4 MTP requires captured_hidden_states.") + captured_hidden_states = torch.zeros( + 0, + self.config.hc_mult * self.config.hidden_size, + device=input_ids.device, + dtype=self.model.embed_tokens.weight.dtype, + ) + + mtp_hidden_states = self.model( + input_ids, + positions, + captured_hidden_states, + ctx, + out_cache_loc, + input_embeds=input_embeds, + spec_step_idx=spec_step_idx, + ).flatten(1) + logits_hidden_states = self.model.compute_logits_hidden( + mtp_hidden_states, + spec_step_idx, + ) + logits_metadata = LogitsMetadata.from_forward_context(ctx, input_lengths) + return self.logits_processor( + input_ids, + logits_hidden_states, + self.lm_head, + logits_metadata, + aux_hidden_states=[mtp_hidden_states], + ) + + @staticmethod + def _remap_weight_name(name: str) -> str: + for old, new in { + ".emb.tok_emb.weight": ".embed_tokens.weight", + ".head.weight": ".shared_head.head.weight", + ".norm.weight": ".shared_head.norm.weight", + }.items(): + if old in name: + name = name.replace(old, new) + return name + + @staticmethod + def _rewrite_spec_layer_name(spec_layer: int, name: str) -> str: + spec_layer_weight_names = ( + "embed_tokens", + "enorm", + "hnorm", + "h_proj", + "e_proj", + "shared_head", + "hc_head_fn", + "hc_head_base", + "hc_head_scale", + ) + shared_weight_names = ("embed_tokens",) + is_spec_weight = any( + weight_name in name for weight_name in spec_layer_weight_names + ) + is_shared_weight = any( + weight_name in name for weight_name in shared_weight_names + ) + if not is_spec_weight: + name = name.replace( + f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.", + ) + elif is_shared_weight: + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name + + def _map_checkpoint_name(self, raw_name: str) -> Optional[str]: + if raw_name.startswith("mtp."): + mtp_layer_idx = _find_mtp_layer_idx(raw_name) + raw_name = raw_name.replace( + f"mtp.{mtp_layer_idx}.", + f"model.layers.{self.config.num_hidden_layers + mtp_layer_idx}.", + 1, + ) + spec_layer = _spec_layer_idx(self.config, raw_name) + if spec_layer is None: + return None + name = self._remap_weight_name(raw_name) + name = self._rewrite_spec_layer_name(spec_layer, name) + if name.endswith(".shared_head.head.weight"): + return None + if name.endswith(".scale"): + suffix = ( + ".weight_scale" + if _EXPERT_SCALE_RE.search(name) + else ".weight_scale_inv" + ) + name = name.removesuffix(".scale") + suffix + if ".shared_experts.w2" in name: + name = name.replace(".shared_experts.w2", ".shared_experts.down_proj") + if ".ffn.gate.bias" in name: + name = name.replace(".ffn.gate.bias", ".ffn.gate.e_score_correction_bias") + return name + + def get_stacked_params_mapping(self): + return [ + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ("attn.fused_wqa_wkv", "attn.wq_a", 0), + ("attn.fused_wqa_wkv", "attn.wkv", 1), + ("compressor.fused_wkv_wgate", "compressor.wkv", 0), + ("compressor.fused_wkv_wgate", "compressor.wgate", 1), + ] + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = self.get_stacked_params_mapping() + params_dict = dict(self.named_parameters()) + moe_loader = build_moe_checkpoint_loader( + params_dict=params_dict, + expert_schema=ExpertCheckpointSchema( + gate_proj_name="w1", + down_proj_name="w2", + up_proj_name="w3", + ), + num_experts=self.config.n_routed_experts, + ep_rank=self.mapping.moe.ep_rank, + ep_size=self.mapping.moe.ep_size, + ) + loaded_params: set[str] = set() + for raw_name, loaded_weight in weights: + name = self._map_checkpoint_name(raw_name) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or ".experts." in name: + continue + mapped_name = name.replace(weight_name, param_name) + param = params_dict.get(mapped_name) + if param is None: + break + param.weight_loader(param, loaded_weight, shard_id) + loaded_params.add(mapped_name) + break + else: + if moe_loader.matches(name): + mapped_name = moe_loader.load(name, loaded_weight) + loaded_params.add(mapped_name) + continue + param = params_dict.get(name) + if param is None: + logger.debug("Skipping unmatched DeepSeek V4 MTP weight: %s", name) + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + missing_layers = [] + for layer_idx in range( + self.model.mtp_start_layer_idx, + self.model.mtp_start_layer_idx + self.model.num_mtp_layers, + ): + if not any(f"model.layers.{layer_idx}." in name for name in loaded_params): + missing_layers.append(layer_idx) + if missing_layers: + raise ValueError( + "DeepSeek V4 MTP weights missing for speculative layer(s) " + f"{missing_layers}. Use a checkpoint that includes `mtp.*` " + "weights or disable NEXTN speculative decoding." + ) + self.post_load_weights() + return loaded_params + + def post_load_weights(self): + for module in self.modules(): + if isinstance(module, DeepseekV4Compressor): + module.process_weights_after_loading() + elif isinstance(module, DeepseekV4MegaMoEExperts): + module.finalize_weights() + elif isinstance(module, MoELayer): + module.process_weights_after_loading(module) + + +EntryClass = [DeepseekV4ForCausalLMNextN] diff --git a/test/runtime/test_deepseek_v4_config.py b/test/runtime/test_deepseek_v4_config.py index c539a623e..3d2bf62e6 100644 --- a/test/runtime/test_deepseek_v4_config.py +++ b/test/runtime/test_deepseek_v4_config.py @@ -9,8 +9,10 @@ from tokenspeed.runtime.configs.model_config import ( AttentionArch, ModelConfig, + _derive_num_attention_layers, configure_deepseek_v4_attention, is_deepseek_v4, + is_deepseek_v4_nextn, ) from tokenspeed.runtime.execution.cuda_graph_wrapper import CudaGraphWrapper from tokenspeed.runtime.execution.forward_batch_info import ForwardMode @@ -70,6 +72,7 @@ mhc_pre, pack_topk_as_router_logits, ) +from tokenspeed.runtime.models.deepseek_v4_mtp import DeepseekV4ForCausalLMNextN from tokenspeed.runtime.utils.env import global_server_args_dict from tokenspeed.runtime.utils.hf_transformers_utils import ( _CONFIG_REGISTRY, @@ -549,6 +552,71 @@ def test_deepseek_v4_model_config_uses_mla_runtime_metadata(self): self.assertEqual(model_config.index_head_dim, 128) self.assertAlmostEqual(model_config.scaling, 512**-0.5) + def test_deepseek_v4_nextn_architecture_uses_v4_runtime_metadata(self): + model_config = object.__new__(ModelConfig) + model_config.hf_config = SimpleNamespace( + architectures=["DeepseekV4ForCausalLMNextN"], + head_dim=512, + qk_rope_head_dim=64, + index_head_dim=128, + rope_scaling=None, + ) + + self.assertTrue(is_deepseek_v4(model_config.hf_config)) + self.assertTrue(is_deepseek_v4_nextn(model_config.hf_config)) + + configure_deepseek_v4_attention(model_config) + + self.assertEqual(model_config.attention_arch, AttentionArch.MLA) + self.assertEqual(model_config.head_dim, 512) + self.assertEqual(model_config.qk_nope_head_dim, 448) + self.assertEqual( + _derive_num_attention_layers( + SimpleNamespace( + architectures=["DeepseekV4ForCausalLMNextN"], + num_nextn_predict_layers=1, + ), + num_hidden_layers=43, + ), + 1, + ) + self.assertFalse(is_deepseek_v4(SimpleNamespace(architectures=None))) + self.assertFalse(is_deepseek_v4_nextn(SimpleNamespace())) + self.assertEqual( + _derive_num_attention_layers( + SimpleNamespace(architectures=None), + num_hidden_layers=43, + ), + 43, + ) + + def test_deepseek_v4_mtp_checkpoint_name_remap(self): + model = object.__new__(DeepseekV4ForCausalLMNextN) + model.config = SimpleNamespace( + num_hidden_layers=43, + num_nextn_predict_layers=1, + ) + + self.assertEqual( + model._map_checkpoint_name("mtp.0.emb.tok_emb.weight"), + "model.embed_tokens.weight", + ) + self.assertEqual( + model._map_checkpoint_name("mtp.0.norm.weight"), + "model.layers.43.shared_head.norm.weight", + ) + self.assertEqual( + model._map_checkpoint_name("mtp.0.attn.wq_a.weight"), + "model.layers.43.mtp_block.attn.wq_a.weight", + ) + self.assertEqual( + model._map_checkpoint_name("mtp.0.ffn.experts.7.w1.scale"), + "model.layers.43.mtp_block.ffn.experts.7.w1.weight_scale", + ) + self.assertIsNone(model._map_checkpoint_name("mtp.0.head.weight")) + self.assertIsNone(model._map_checkpoint_name("model.layers.43.head.weight")) + self.assertIsNone(model._map_checkpoint_name("model.layers.1.attn.wq_a.weight")) + def test_deepseek_v4_attention_layout_matches_compressed_cache_contract(self): config = SimpleNamespace( compress_ratios=[0, 4, 128], @@ -588,6 +656,30 @@ def test_deepseek_v4_attention_layout_matches_compressed_cache_contract(self): self.assertTrue(hca.needs_compressed_cache) self.assertFalse(hca.needs_indexer) + def test_deepseek_v4_cache_layout_can_slice_mtp_layer_range(self): + config = SimpleNamespace( + compress_ratios=[0, 4, 128, 0], + head_dim=512, + index_head_dim=128, + ) + + layout = deepseek_v4_cache_layout_from_config( + config, + page_size=64, + use_fp4_indexer_cache=True, + layer_indices=range(3, 4), + ) + + self.assertEqual(layout.layer_ratio, (1,)) + self.assertEqual(layout.cache_cell_size(1), layout.swa_cell_bytes()) + with self.assertRaisesRegex(ValueError, "out of range"): + deepseek_v4_cache_layout_from_config( + config, + page_size=64, + use_fp4_indexer_cache=True, + layer_indices=range(4, 5), + ) + def test_deepseek_v4_attention_layout_rejects_unknown_ratio(self): config = SimpleNamespace( compress_ratios=[8], @@ -1025,6 +1117,34 @@ def test_deepseek_v4_metadata_slice_preserves_compact_base_offsets(self): ) ) + def test_deepseek_v4_kv_pool_requires_matching_layout_layers(self): + config = SimpleNamespace( + compress_ratios=[1], + head_dim=512, + index_head_dim=128, + ) + layout = deepseek_v4_cache_layout_from_config( + config, + page_size=64, + use_fp4_indexer_cache=True, + ) + + with self.assertRaisesRegex(ValueError, "layer_num"): + DeepseekV4TokenToKVPool( + size=128, + model_dtype=torch.bfloat16, + layout=layout, + layer_num=2, + device="cpu", + enable_memory_saver=False, + max_batch_size=2, + max_context_len=128, + page_size=64, + rank=0, + hf_config=config, + max_scheduled_tokens=1, + ) + def test_deepseek_v4_metadata_maps_compressed_slots(self): compressed_table = torch.tensor([[10, 11], [20, 21]], dtype=torch.int32) metadata = DeepseekV4ForwardMetadata( @@ -1340,6 +1460,154 @@ def fake_decode(**kwargs): self.assertTrue(torch.equal(out[:3], torch.ones((3, 2, 4)))) self.assertTrue(torch.equal(out[3:], torch.full((2, 2, 4), 2.0))) + def test_deepseek_v4_mixed_prefill_uses_current_slice(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=8, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=576, + context_len=256, + ) + ) + stale_prefill_metadata = SimpleNamespace( + forward_mode=ForwardMode.EXTEND, + num_prefill_reqs=1, + req_pool_indices=torch.tensor([99], dtype=torch.int32), + token_to_req_indices=torch.tensor([9, 9, 9], dtype=torch.int32), + seq_lens=torch.tensor([3], dtype=torch.int32), + ) + backend.forward_prefill_metadata = stale_prefill_metadata + backend.init_forward_metadata( + bs=3, + num_tokens=5, + req_pool_indices=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([5, 9, 12], dtype=torch.int32), + forward_mode=ForwardMode.MIXED, + req_to_page=torch.tensor([[10], [20], [30]], dtype=torch.int32), + extend_seq_lens_cpu=torch.tensor([3, 1, 1], dtype=torch.int32), + num_extends=1, + ) + mixed_metadata = backend.forward_metadata + self.assertIs(backend.forward_prefill_metadata, stale_prefill_metadata) + + calls = [] + + def fake_prefill_chunk(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "prefill", + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + q = kwargs["q"] + return q.new_full((q.shape[0], 1, 2), 1.0) + + def fake_decode(**kwargs): + metadata = backend.forward_metadata + calls.append( + ( + "decode", + metadata.req_pool_indices.tolist(), + metadata.token_to_req_indices.tolist(), + metadata.forward_mode, + ) + ) + q = kwargs["q"] + return q.new_full((q.shape[0], 1, 2), 2.0) + + backend._forward_deepseek_v4_prefill_chunk = fake_prefill_chunk + backend.forward_deepseek_v4_decode = fake_decode + out = backend.forward_deepseek_v4_mixed( + q=torch.zeros((5, 1, 2), dtype=torch.float32), + positions=torch.arange(5, dtype=torch.int32), + token_to_kv_pool=SimpleNamespace(), + layer_id=0, + kind="mla", + compress_ratio=4, + num_local_heads=1, + padded_heads=1, + head_dim=2, + window_size=4, + softmax_scale=1.0, + attn_sink=torch.zeros(1), + topk_indices=None, + ) + + self.assertEqual(calls[0][0], "prefill") + self.assertEqual(calls[0][1], [0]) + self.assertEqual(calls[0][2], [0, 0, 0]) + self.assertTrue(calls[0][3].is_extend()) + self.assertEqual(calls[1][0], "decode") + self.assertEqual(calls[1][1], [1, 2]) + self.assertEqual(calls[1][2], [0, 1]) + self.assertTrue(calls[1][3].is_decode()) + self.assertIs(backend.forward_metadata, mixed_metadata) + self.assertTrue(torch.equal(out[:3], torch.ones((3, 1, 2)))) + self.assertTrue(torch.equal(out[3:], torch.full((2, 1, 2), 2.0))) + + def test_deepseek_v4_spec_metadata_requires_uniform_pack(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=4096, + speculative_num_draft_tokens=4, + ) + ) + + backend.init_forward_metadata( + bs=2, + num_tokens=8, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), + seq_lens=torch.tensor([70, 3], dtype=torch.int32), + forward_mode=ForwardMode.TARGET_VERIFY, + req_to_page=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + ) + self.assertTrue( + torch.equal( + backend.forward_metadata.query_lens, + torch.tensor([4, 4], dtype=torch.int32), + ) + ) + self.assertEqual(backend.forward_metadata.forward_mode, ForwardMode.DECODE) + self.assertEqual(backend.forward_metadata.num_prefill_reqs, 0) + self.assertEqual(backend.forward_metadata.decode_req_count(), 2) + self.assertEqual(backend.forward_metadata.decode_token_count(), 8) + + backend.init_forward_metadata( + bs=2, + num_tokens=8, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), + seq_lens=torch.tensor([70, 3], dtype=torch.int32), + forward_mode=ForwardMode.DRAFT_EXTEND, + req_to_page=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + ) + self.assertEqual(backend.forward_metadata.forward_mode, ForwardMode.DECODE) + self.assertIs(backend.forward_prefill_metadata, backend.forward_metadata) + self.assertIs(backend.forward_decode_metadata, backend.forward_metadata) + + with self.assertRaisesRegex(RuntimeError, "uniformly packed"): + backend.init_forward_metadata( + bs=2, + num_tokens=7, + req_pool_indices=torch.tensor([0, 1], dtype=torch.int64), + seq_lens=torch.tensor([70, 3], dtype=torch.int32), + forward_mode=ForwardMode.TARGET_VERIFY, + req_to_page=torch.tensor([[10, 11], [20, 21]], dtype=torch.int32), + ) + def test_deepseek_v4_decode_backend_maps_compressed_slots_batched(self): backend = DeepseekV4AttentionBackend( SimpleNamespace( @@ -1866,6 +2134,264 @@ def fake_get_metadata(context_lens, cache_block_size, num_sms): ) ) + def test_deepseek_v4_cuda_graph_target_verify_uses_packed_decode_metadata(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + speculative_num_draft_tokens=4, + ) + ) + backend.init_cuda_graph_state(max_bs=4) + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.TARGET_VERIFY, + ) + + metadata = backend.forward_metadata + self.assertEqual(metadata.forward_mode, ForwardMode.DECODE) + self.assertTrue( + torch.equal(metadata.seq_lens, torch.full((4,), 4, dtype=torch.int32)) + ) + self.assertTrue( + torch.equal(metadata.query_lens, torch.full((4,), 4, dtype=torch.int32)) + ) + self.assertTrue( + torch.equal( + metadata.query_start_loc, + torch.tensor([0, 4, 8, 12, 16], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.token_to_req_indices, + torch.tensor( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + dtype=torch.int32, + ), + ) + ) + self.assertEqual(metadata.decode_token_count(), 16) + + backend.init_forward_metadata_replay_cuda_graph( + bs=4, + actual_bs=2, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.tensor([70, 3, 1, 1], dtype=torch.int32), + forward_mode=ForwardMode.TARGET_VERIFY, + req_to_page=torch.tensor( + [ + [10, 11], + [20, 21], + [30, 31], + [40, 41], + ], + dtype=torch.int32, + ), + ) + + metadata = backend.forward_metadata + self.assertEqual(metadata.forward_mode, ForwardMode.DECODE) + self.assertTrue( + torch.equal( + metadata.is_valid_token, + torch.tensor( + [True] * 8 + [False] * 8, + dtype=torch.bool, + ), + ) + ) + self.assertEqual(metadata.decode_req_count(), 4) + self.assertEqual(metadata.decode_token_count(), 16) + + def test_deepseek_v4_cuda_graph_draft_extend_advances_decode_metadata(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + speculative_num_draft_tokens=4, + ) + ) + backend.init_cuda_graph_state(max_bs=4) + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.DRAFT_EXTEND, + ) + backend.init_forward_metadata_replay_cuda_graph( + bs=4, + actual_bs=2, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.tensor([70, 3, 1, 1], dtype=torch.int32), + forward_mode=ForwardMode.DRAFT_EXTEND, + req_to_page=torch.tensor( + [ + [10, 11], + [20, 21], + [30, 31], + [40, 41], + ], + dtype=torch.int32, + ), + ) + + self.assertIs(backend.forward_prefill_metadata, backend.forward_metadata) + backend.advance_draft_forward_metadata() + + metadata = backend.forward_metadata + self.assertEqual(metadata.forward_mode, ForwardMode.DECODE) + self.assertTrue( + torch.equal( + metadata.seq_lens, + torch.tensor([71, 4, 2, 2], dtype=torch.int32), + ) + ) + self.assertTrue( + torch.equal( + metadata.is_valid_token, + torch.tensor([True, True, False, False], dtype=torch.bool), + ) + ) + self.assertEqual(metadata.decode_token_count(), 4) + + first_decode_metadata = metadata + cached_swa = torch.empty((4, 8), dtype=torch.int32) + first_decode_metadata.decode_swa_indices = cached_swa + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.DRAFT_EXTEND, + ) + backend.advance_draft_forward_metadata() + self.assertIs(backend.forward_metadata, first_decode_metadata) + self.assertIs(backend.forward_metadata.decode_swa_indices, cached_swa) + + def test_deepseek_v4_eager_draft_decode_refreshes_stale_graph_metadata(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + speculative_num_draft_tokens=4, + ) + ) + backend.init_cuda_graph_state(max_bs=4) + backend.init_forward_metadata_capture_cuda_graph( + bs=4, + num_tokens=16, + req_pool_indices=torch.arange(4, dtype=torch.int32), + seq_lens=torch.ones(4, dtype=torch.int32), + forward_mode=ForwardMode.DRAFT_EXTEND, + ) + self.assertEqual(backend._draft_decode_metadata.token_to_req_indices.numel(), 4) + + req_pool_indices = torch.tensor([0], dtype=torch.int32) + seq_lens = torch.tensor([6], dtype=torch.int32) + req_to_page = torch.tensor([[10]], dtype=torch.int32) + backend.init_forward_metadata( + bs=1, + num_tokens=6, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=ForwardMode.EXTEND, + req_to_page=req_to_page, + ) + backend.init_forward_metadata( + bs=1, + num_tokens=1, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + req_to_page=req_to_page, + ) + backend.advance_draft_forward_metadata() + + metadata = backend.forward_metadata + self.assertEqual(metadata.forward_mode, ForwardMode.DECODE) + self.assertEqual(metadata.token_to_req_indices.numel(), 1) + self.assertEqual(metadata.decode_token_count(), 1) + self.assertTrue( + torch.equal( + metadata.token_to_req_indices, + torch.tensor([0], dtype=torch.int32), + ) + ) + + def test_deepseek_v4_prefill_uses_prefill_metadata_slot(self): + backend = DeepseekV4AttentionBackend( + SimpleNamespace( + page_size=64, + device="cpu", + num_attention_heads=64, + num_kv_heads=1, + attn_tp_size=1, + dtype=torch.bfloat16, + head_dim=512, + context_len=128, + speculative_num_draft_tokens=4, + ) + ) + prefill_metadata = SimpleNamespace( + forward_mode=ForwardMode.EXTEND, + num_prefill_reqs=1, + seq_lens=torch.tensor([6], dtype=torch.int32), + token_to_req_indices=torch.zeros(6, dtype=torch.int32), + ) + decode_metadata = SimpleNamespace(forward_mode=ForwardMode.DECODE) + backend.forward_prefill_metadata = prefill_metadata + backend.forward_metadata = decode_metadata + + def fake_prefill_chunk(**kwargs): + self.assertIs(backend.forward_metadata, prefill_metadata) + q = kwargs["q"] + return q.new_zeros((q.shape[0], 1, 2)) + + backend._forward_deepseek_v4_prefill_chunk = fake_prefill_chunk + out = backend.forward_deepseek_v4_prefill( + q=torch.empty((6, 1, 2), dtype=torch.bfloat16), + positions=torch.arange(6, dtype=torch.int64), + token_to_kv_pool=SimpleNamespace(), + layer_id=0, + kind="test", + compress_ratio=1, + num_local_heads=1, + padded_heads=1, + head_dim=2, + window_size=64, + softmax_scale=1.0, + attn_sink=torch.empty((1,), dtype=torch.float32), + topk_indices=None, + ) + + self.assertEqual(out.shape, (6, 1, 2)) + self.assertIs(backend.forward_metadata, prefill_metadata) + def test_deepseek_v4_indexer_decode_batches_cache_reads(self): torch.manual_seed(0) positions = torch.tensor([15, 7, 3], dtype=torch.int64)